77
88import numpy as np
99import torch
10+ from huggingface_hub import hf_hub_download , list_repo_files
1011from numpy .typing import NDArray
1112
1213from squadro .evaluators .evaluator import default_device , Model
1314from squadro .ml .channels import get_channels
1415from squadro .state .evaluators .base import Evaluator
1516from squadro .state .evaluators .ml import ModelConfig
1617from squadro .state .state import State
17- from squadro .tools .constants import DATA_PATH , inf
18+ from squadro .tools .constants import inf
1819from squadro .tools .dates import get_file_modified_time , get_now , READABLE_DATE_FMT
1920from squadro .tools .logs import logger
2021from squadro .tools .state import get_grid_shape , state_to_index , get_reward
2122
23+ HF_REPO_ID = "martin-shark/squadro"
24+
2225
2326class _RLEvaluator (Evaluator , ABC ):
2427 _default_dir = 'default'
2528 _weight_update_timestamp = defaultdict (lambda : 'unknown' )
29+ _online_models = ()
2630
2731 def __init__ (self , model_path : str | Path = None , dtype = 'json' ):
2832 """
29- :param model_path: Path to the directory where the model is stored.
33+ :param model_path: Path to the directory where the model(s) is stored.
3034 """
3135 self ._model_path = model_path
36+ if self ._model_path :
37+ self .model_path = Path (self ._model_path )
38+ else :
39+ logger .info (f"No model path specified, using online pre-trained models." )
40+ dir_path = self ._default_dir
41+ hf_files = [f for f in list_repo_files (HF_REPO_ID ) if f .startswith (dir_path )]
42+ path = None
43+ for f in hf_files :
44+ path = hf_hub_download (repo_id = HF_REPO_ID , filename = f )
45+ assert path is not None , f"No files found in repo { HF_REPO_ID } for dir { dir_path } "
46+ self .model_path = Path (path ).parent
3247 self .dtype = dtype
3348
34- @property
35- def model_path (self ) -> Path :
36- return Path (self ._model_path or DATA_PATH / self ._default_dir )
37-
38- def get_model_path (self , n_pawns : int ) -> Path :
39- path = self .model_path
40- if not self ._model_path :
41- path /= str (n_pawns )[0 ]
42- return path
43-
4449 def get_weight_update_timestamp (self , n_pawns : int ):
4550 return self ._weight_update_timestamp [self .get_filepath (n_pawns )]
4651
@@ -54,7 +59,7 @@ def reload(cls):
5459 cls ._models = {}
5560
5661 def get_filepath (self , n_pawns : int , model_path = None ) -> str :
57- model_path = Path (model_path or self .get_model_path ( n_pawns ) )
62+ model_path = Path (model_path or self .model_path )
5863 return str (model_path / f"model_{ n_pawns } .{ self .dtype } " )
5964
6065 def clear (self ):
@@ -92,6 +97,14 @@ def dump(self, model_path: str | Path = None):
9297 def _dump (self , model , filepath : str ):
9398 ...
9499
100+ def check_online_model (self , n_pawns : int ):
101+ if not self ._model_path :
102+ raise ValueError (
103+ "When `model_path` is not specified, a pre-trained model must be loaded."
104+ f" But there is no pre-trained model available for { self ._default_dir } "
105+ f" with { n_pawns } pawns. Available number of pawns: { self ._online_models } ."
106+ )
107+
95108
96109class QLearningEvaluator (_RLEvaluator ):
97110 """
@@ -102,6 +115,7 @@ class QLearningEvaluator(_RLEvaluator):
102115 """
103116 _models = {}
104117 _default_dir = 'q_learning'
118+ _online_models = (2 , 3 )
105119
106120 @property
107121 def is_json (self ):
@@ -118,6 +132,7 @@ def get_model(self, n_pawns: int) -> dict: # noqa
118132 logger .info (f"Using Q table at { filepath } " )
119133 self ._weight_update_timestamp [filepath ] = get_file_modified_time (filepath )
120134 else :
135+ self .check_online_model (n_pawns )
121136 if self .is_json :
122137 self .models [n_pawns ] = {}
123138 else :
@@ -175,6 +190,7 @@ class DeepQLearningEvaluatorMultipleGrids(_RLEvaluator):
175190
176191 _models = {}
177192 _default_dir = 'deep_q_learning'
193+ _online_models = (3 , 4 , 5 )
178194
179195 def __init__ (
180196 self ,
@@ -286,6 +302,7 @@ def get_model(self, n_pawns: int, player: int = None) -> Model:
286302 self .models [key ] = model
287303 self ._weight_update_timestamp [filepath ] = get_file_modified_time (filepath )
288304 else :
305+ self .check_online_model (n_pawns )
289306 logger .warn (f"No file at { filepath } , creating new model" )
290307 self .models [key ] = Model (
291308 n_pawns = n_pawns ,
@@ -344,15 +361,11 @@ def is_separate_networks(self, n_pawns: int) -> bool:
344361 return self ._separate_networks
345362 n_pawns = int (str (n_pawns )[0 ])
346363 if self ._separate_networks .get (n_pawns ) is None :
347- model_path = self .get_model_path ( n_pawns )
364+ model_path = self .model_path
348365 files = os .listdir (model_path ) if os .path .exists (model_path ) else []
349- files = [
350- f .replace ('.pt' , '' ).replace ('model_' , '' )
351- for f in files if f .endswith ('.pt' )
352- ]
353- if set (files ) == {'0' , '1' }:
366+ if {f'model_{ n_pawns } _0.pt' , f'model_{ n_pawns } _1.pt' }.issubset (files ):
354367 self ._separate_networks [n_pawns ] = True
355- elif len (files ) == 1 :
368+ elif { f'model_ { n_pawns } .pt' }. issubset (files ):
356369 self ._separate_networks [n_pawns ] = False
357370 else :
358371 self ._separate_networks [n_pawns ] = super ().is_separate_networks (n_pawns )
0 commit comments