Skip to content

Commit 8c20edc

Browse files
authored
Merge pull request #33 from joshhan619/ltsm-stack
Pretrained weight loading and prompt data generation added to API
2 parents 4cbea68 + 7ab243a commit 8c20edc

File tree

12 files changed

+499
-45
lines changed

12 files changed

+499
-45
lines changed

ltsm/common/base_training_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class TrainingConfig:
5252
"data": "ETTh1",
5353
"features": "MS",
5454
"prompt_data_path": "./weather.csv",
55+
"hf_hub_model": None
5556
}
5657

5758
def __init__(self, model_config: PretrainedConfig, **kwargs):
@@ -126,7 +127,7 @@ def __init__(self,
126127
self.config = config
127128

128129
if not model:
129-
self.model = get_model(config.model_config, config.train_params["model"], config.train_params["local_pretrain"])
130+
self.model = get_model(config.model_config, config.train_params["model"], config.train_params["local_pretrain"], config.train_params["hf_hub_model"])
130131

131132
if self.config.train_params["lora"]:
132133
peft_config = LoraConfig(

ltsm/data_provider/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from .data_splitter import SplitterByTimestamp
1515
from .dataset import TSDataset, TSPromptDataset, TSTokenDataset
16+
from .prompt_generator import prompt_generate_split, prompt_normalization_split
1617

1718
__all__ = {
1819
DatasetFactory,
@@ -29,5 +30,7 @@
2930
SplitterByTimestamp,
3031
TSDataset,
3132
TSPromptDataset,
32-
TSTokenDataset
33+
TSTokenDataset,
34+
prompt_generate_split,
35+
prompt_normalization_split
3336
}

ltsm/data_provider/prompt_generator.py

Lines changed: 404 additions & 0 deletions
Large diffs are not rendered by default.

ltsm/models/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,15 @@ def register_model(module, module_name: str):
2929
register_model(DLinear, 'DLinear')
3030
register_model(Informer, 'Informer')
3131

32-
def get_model(config: PretrainedConfig, model_name: str, local_pretrain: str = None) -> PreTrainedModel:
32+
def get_model(config: PretrainedConfig, model_name: str, local_pretrain: str = None, hf_hub_model: str = None) -> PreTrainedModel:
3333
"""
3434
Factory method to create a model by name.
3535
3636
Args:
3737
config (PreTrainedConfig): The configuration for the model.
3838
model_name (str): The name of the model to instantiate.
3939
local_pretrain (bool): If True, load the model from a local pretraining path.
40+
hf_hub_model (str): The Hugging Face Hub model name.
4041
4142
Returns:
4243
torch.nn.Module: Instantiated model.
@@ -47,6 +48,10 @@ def get_model(config: PretrainedConfig, model_name: str, local_pretrain: str = N
4748
if model_name not in model_dict:
4849
raise ValueError(f"Model {model_name} is not registered. Available models: {list(model_dict.keys())}")
4950

51+
# Load pretrained weights if hf_hub_model is provided
52+
if hf_hub_model is not None:
53+
return model_dict[model_name].from_pretrained(hf_hub_model, config)
54+
5055
# Check for local pretraining
5156
if local_pretrain is None or local_pretrain == "None":
5257
return model_dict[model_name](config)

tests/common/base_training_pipeline_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def mock_pipeline(mocker):
2222
"train_ratio": 0.7,
2323
"val_ratio": 0.1,
2424
"downsample_rate": 1,
25-
"do_anomaly": False
25+
"do_anomaly": False,
26+
"hf_hub_model": None
2627
}
2728
config.model_config = mocker.MagicMock()
2829
config.train_params["lora"] = False
@@ -50,7 +51,8 @@ def test_create_model_lora_enabled(mocker):
5051
"tmax": 10,
5152
"learning_rate": 1e-3,
5253
"model": "LTSM",
53-
"local_pretrain": "None"
54+
"local_pretrain": "None",
55+
"hf_hub_model": None
5456
}
5557
config.model_config = mocker.MagicMock()
5658
config.train_params["lora"] = True

tests/prompt_reader/prompt_generate_split_test.py renamed to tests/data_provider/prompt_generator_test.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
import pandas as pd
44
import numpy as np
55
import torch
6+
from ltsm.data_provider.prompt_generator import save_data, prompt_save
67

78
@pytest.fixture
8-
def setup(mocker, tmp_path):
9+
def setup_prompt(mocker, tmp_path):
910
"""set up the test environment"""
1011
mocker.patch.dict('sys.modules', {'tsfel': mocker.MagicMock()})
11-
from ltsm.prompt_reader.stat_prompt.prompt_generate_split import prompt_save
1212

1313
sample_prompt_buf = {
1414
'train': pd.DataFrame({
@@ -36,10 +36,10 @@ def setup(mocker, tmp_path):
3636
return prompt_save, sample_prompt_buf, output_path, data_name, ifTest
3737

3838
@pytest.mark.parametrize("save_format", ["pth.tar", "csv", "npz"])
39-
def test_prompt_save(setup, save_format):
39+
def test_prompt_save(setup_prompt, save_format):
4040
"""test if the prompt data is saved correctly in different formats and loaded back correctly
4141
"""
42-
prompt_save, sample_prompt_buf, output_path, data_name, ifTest = setup
42+
prompt_save, sample_prompt_buf, output_path, data_name, ifTest = setup_prompt
4343
prompt_save(sample_prompt_buf, output_path, data_name, save_format, ifTest)
4444

4545
for split in ["train", "val", "test"]:
@@ -75,3 +75,34 @@ def test_prompt_save(setup, save_format):
7575
if save_format != "csv":
7676
assert load_data.equals(prompt_data), f"Data mismatch: {load_data} vs {prompt_data}"
7777
print(f"All tests passed for {file_path}")
78+
79+
80+
@pytest.fixture
81+
def setup_save():
82+
"""input data for testing"""
83+
data = pd.DataFrame([range(133)])
84+
print(data.shape)
85+
return data
86+
87+
@pytest.mark.parametrize("save_format", ["pth.tar", "csv", "npz"])
88+
def test_save_data(tmpdir, setup_save, save_format):
89+
"""test save_data function: save data in different formats and load it back to check if the data is saved correctly"""
90+
data_path = os.path.join(tmpdir, f"test_data.{save_format}")
91+
92+
save_data(setup_save, data_path, save_format)
93+
94+
if save_format == "pth.tar":
95+
loaded_data = torch.load(data_path)
96+
elif save_format == "csv":
97+
loaded_data = pd.read_csv(data_path)
98+
loaded_data.columns = loaded_data.columns.astype(int)
99+
elif save_format == "npz":
100+
loaded = np.load(data_path)
101+
loaded_data = pd.DataFrame(data=loaded["data"])
102+
103+
assert isinstance(loaded_data, pd.DataFrame), "Loaded data should be a DataFrame"
104+
assert loaded_data.shape == setup_save.shape, f"Shape mismatch: {loaded_data.shape} vs {setup_save.shape}"
105+
assert loaded_data.columns.equals(setup_save.columns), "Columns mismatch"
106+
assert np.allclose(loaded_data.values, setup_save.values, rtol=1e-8, atol=1e-8), "Data values mismatch"
107+
108+

0 commit comments

Comments
 (0)