Skip to content

Commit 0acf709

Browse files
committed
Add loading from HF pre-trained model
1 parent bdce2cb commit 0acf709

File tree

2 files changed

+34
-21
lines changed

2 files changed

+34
-21
lines changed

squadro/state/evaluators/rl.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,40 +7,45 @@
77

88
import numpy as np
99
import torch
10+
from huggingface_hub import hf_hub_download, list_repo_files
1011
from numpy.typing import NDArray
1112

1213
from squadro.evaluators.evaluator import default_device, Model
1314
from squadro.ml.channels import get_channels
1415
from squadro.state.evaluators.base import Evaluator
1516
from squadro.state.evaluators.ml import ModelConfig
1617
from squadro.state.state import State
17-
from squadro.tools.constants import DATA_PATH, inf
18+
from squadro.tools.constants import inf
1819
from squadro.tools.dates import get_file_modified_time, get_now, READABLE_DATE_FMT
1920
from squadro.tools.logs import logger
2021
from squadro.tools.state import get_grid_shape, state_to_index, get_reward
2122

23+
HF_REPO_ID = "martin-shark/squadro"
24+
2225

2326
class _RLEvaluator(Evaluator, ABC):
2427
_default_dir = 'default'
2528
_weight_update_timestamp = defaultdict(lambda: 'unknown')
29+
_online_models = ()
2630

2731
def __init__(self, model_path: str | Path = None, dtype='json'):
2832
"""
29-
:param model_path: Path to the directory where the model is stored.
33+
:param model_path: Path to the directory where the model(s) is stored.
3034
"""
3135
self._model_path = model_path
36+
if self._model_path:
37+
self.model_path = Path(self._model_path)
38+
else:
39+
logger.info(f"No model path specified, using online pre-trained models.")
40+
dir_path = self._default_dir
41+
hf_files = [f for f in list_repo_files(HF_REPO_ID) if f.startswith(dir_path)]
42+
path = None
43+
for f in hf_files:
44+
path = hf_hub_download(repo_id=HF_REPO_ID, filename=f)
45+
assert path is not None, f"No files found in repo {HF_REPO_ID} for dir {dir_path}"
46+
self.model_path = Path(path).parent
3247
self.dtype = dtype
3348

34-
@property
35-
def model_path(self) -> Path:
36-
return Path(self._model_path or DATA_PATH / self._default_dir)
37-
38-
def get_model_path(self, n_pawns: int) -> Path:
39-
path = self.model_path
40-
if not self._model_path:
41-
path /= str(n_pawns)[0]
42-
return path
43-
4449
def get_weight_update_timestamp(self, n_pawns: int):
4550
return self._weight_update_timestamp[self.get_filepath(n_pawns)]
4651

@@ -54,7 +59,7 @@ def reload(cls):
5459
cls._models = {}
5560

5661
def get_filepath(self, n_pawns: int, model_path=None) -> str:
57-
model_path = Path(model_path or self.get_model_path(n_pawns))
62+
model_path = Path(model_path or self.model_path)
5863
return str(model_path / f"model_{n_pawns}.{self.dtype}")
5964

6065
def clear(self):
@@ -92,6 +97,14 @@ def dump(self, model_path: str | Path = None):
9297
def _dump(self, model, filepath: str):
9398
...
9499

100+
def check_online_model(self, n_pawns: int):
101+
if not self._model_path:
102+
raise ValueError(
103+
"When `model_path` is not specified, a pre-trained model must be loaded."
104+
f" But there is no pre-trained model available for {self._default_dir}"
105+
f" with {n_pawns} pawns. Available number of pawns: {self._online_models}."
106+
)
107+
95108

96109
class QLearningEvaluator(_RLEvaluator):
97110
"""
@@ -102,6 +115,7 @@ class QLearningEvaluator(_RLEvaluator):
102115
"""
103116
_models = {}
104117
_default_dir = 'q_learning'
118+
_online_models = (2, 3)
105119

106120
@property
107121
def is_json(self):
@@ -118,6 +132,7 @@ def get_model(self, n_pawns: int) -> dict: # noqa
118132
logger.info(f"Using Q table at {filepath}")
119133
self._weight_update_timestamp[filepath] = get_file_modified_time(filepath)
120134
else:
135+
self.check_online_model(n_pawns)
121136
if self.is_json:
122137
self.models[n_pawns] = {}
123138
else:
@@ -175,6 +190,7 @@ class DeepQLearningEvaluatorMultipleGrids(_RLEvaluator):
175190

176191
_models = {}
177192
_default_dir = 'deep_q_learning'
193+
_online_models = (3, 4, 5)
178194

179195
def __init__(
180196
self,
@@ -286,6 +302,7 @@ def get_model(self, n_pawns: int, player: int = None) -> Model:
286302
self.models[key] = model
287303
self._weight_update_timestamp[filepath] = get_file_modified_time(filepath)
288304
else:
305+
self.check_online_model(n_pawns)
289306
logger.warn(f"No file at {filepath}, creating new model")
290307
self.models[key] = Model(
291308
n_pawns=n_pawns,
@@ -344,15 +361,11 @@ def is_separate_networks(self, n_pawns: int) -> bool:
344361
return self._separate_networks
345362
n_pawns = int(str(n_pawns)[0])
346363
if self._separate_networks.get(n_pawns) is None:
347-
model_path = self.get_model_path(n_pawns)
364+
model_path = self.model_path
348365
files = os.listdir(model_path) if os.path.exists(model_path) else []
349-
files = [
350-
f.replace('.pt', '').replace('model_', '')
351-
for f in files if f.endswith('.pt')
352-
]
353-
if set(files) == {'0', '1'}:
366+
if {f'model_{n_pawns}_0.pt', f'model_{n_pawns}_1.pt'}.issubset(files):
354367
self._separate_networks[n_pawns] = True
355-
elif len(files) == 1:
368+
elif {f'model_{n_pawns}.pt'}.issubset(files):
356369
self._separate_networks[n_pawns] = False
357370
else:
358371
self._separate_networks[n_pawns] = super().is_separate_networks(n_pawns)

squadro/tests/evaluators/test_q_learning_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_eval(self):
1717
state = State(advancement=[[1, 2, 3], [1, 2, 4]], cur_player=0)
1818
with (
1919
TemporaryDirectory() as model_path,
20-
patch.object(self.evaluator, '_model_path', model_path),
20+
patch.object(self.evaluator, 'model_path', model_path),
2121
):
2222
json.dump(
2323
{'[[1, 2, 3], [1, 2, 4]], 0': .14},

0 commit comments

Comments
 (0)