Skip to content

Commit 27368ed

Browse files
committed
chg: Implemented pretrained weight loading in get_model + tests
1 parent f1df7de commit 27368ed

File tree

7 files changed

+52
-2
lines changed

7 files changed

+52
-2
lines changed

ltsm/common/base_training_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class TrainingConfig:
5252
"data": "ETTh1",
5353
"features": "MS",
5454
"prompt_data_path": "./weather.csv",
55+
"hf_hub_model": None
5556
}
5657

5758
def __init__(self, model_config: PretrainedConfig, **kwargs):
@@ -126,7 +127,7 @@ def __init__(self,
126127
self.config = config
127128

128129
if not model:
129-
self.model = get_model(config.model_config, config.train_params["model"], config.train_params["local_pretrain"])
130+
self.model = get_model(config.model_config, config.train_params["model"], config.train_params["local_pretrain"], config.train_params["hf_hub_model"])
130131

131132
if self.config.train_params["lora"]:
132133
peft_config = LoraConfig(

ltsm/models/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,15 @@ def register_model(module, module_name: str):
2929
register_model(DLinear, 'DLinear')
3030
register_model(Informer, 'Informer')
3131

32-
def get_model(config: PretrainedConfig, model_name: str, local_pretrain: str = None) -> PreTrainedModel:
32+
def get_model(config: PretrainedConfig, model_name: str, local_pretrain: str = None, hf_hub_model: str = None) -> PreTrainedModel:
3333
"""
3434
Factory method to create a model by name.
3535
3636
Args:
3737
config (PreTrainedConfig): The configuration for the model.
3838
model_name (str): The name of the model to instantiate.
3939
local_pretrain (bool): If True, load the model from a local pretraining path.
40+
hf_hub_model (str): The Hugging Face Hub model name.
4041
4142
Returns:
4243
torch.nn.Module: Instantiated model.
@@ -47,6 +48,10 @@ def get_model(config: PretrainedConfig, model_name: str, local_pretrain: str = N
4748
if model_name not in model_dict:
4849
raise ValueError(f"Model {model_name} is not registered. Available models: {list(model_dict.keys())}")
4950

51+
# Load pretrained weights if hf_hub_model is provided
52+
if hf_hub_model is not None:
53+
return model_dict[model_name].from_pretrained(hf_hub_model, config)
54+
5055
# Check for local pretraining
5156
if local_pretrain is None or local_pretrain == "None":
5257
return model_dict[model_name](config)

tests/models/init_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import pytest
2+
from transformers import PretrainedConfig, PreTrainedModel
3+
from ltsm.models import register_model, get_model, model_dict
4+
5+
def test_register_model(mocker):
6+
mock_model = mocker.MagicMock(spec=PreTrainedModel)
7+
register_model(mock_model, "MockModel1")
8+
assert "MockModel1" in model_dict
9+
assert model_dict["MockModel1"] == mock_model
10+
11+
with pytest.raises(AssertionError, match="Reader MockModel1 already registered"):
12+
register_model(mock_model, "MockModel1")
13+
14+
def test_get_model(mocker):
15+
mock_model = mocker.MagicMock(spec=PreTrainedModel)
16+
mock_config = mocker.MagicMock(spec=PretrainedConfig)
17+
register_model(mock_model, "MockModel2")
18+
19+
instance = get_model(mock_config, "MockModel2")
20+
mock_model.assert_called_once_with(mock_config)
21+
assert isinstance(instance, mocker.MagicMock)
22+
23+
def test_get_model_invalid_name():
24+
with pytest.raises(ValueError, match="Model NonExistentModel is not registered"):
25+
get_model(PretrainedConfig(), "NonExistentModel")
26+
27+
def test_get_model_local_pretrain(mocker):
28+
mock_from_pretrained = mocker.patch("transformers.PretrainedConfig.from_pretrained")
29+
mock_model = mocker.MagicMock(spec=PreTrainedModel)
30+
register_model(mock_model, "MockModel3")
31+
32+
mock_from_pretrained.return_value = mocker.MagicMock()
33+
instance = get_model(PretrainedConfig(), "MockModel3", local_pretrain="path/to/pretrained")
34+
mock_model.from_pretrained.assert_called_once_with("path/to/pretrained", mock_from_pretrained.return_value)
35+
assert isinstance(instance, mocker.MagicMock)
36+
37+
def test_get_model_hf_hub(mocker):
38+
mock_from_pretrained = mocker.patch("transformers.PreTrainedModel.from_pretrained")
39+
mock_model = mocker.MagicMock(spec=PreTrainedModel)
40+
register_model(mock_model, "MockModel4")
41+
42+
instance = get_model(PretrainedConfig(), "MockModel4", hf_hub_model="mock-hub-model")
43+
mock_model.from_pretrained.assert_called_once_with("mock-hub-model", PretrainedConfig())
44+
assert isinstance(instance, mocker.MagicMock)

0 commit comments

Comments
 (0)