Skip to content

Commit 8239910

Browse files
authored
Merge pull request #26 from joshhan619/ltsm-stack
Informer model and baseline benchmark scripts
2 parents 1034501 + 236b9dd commit 8239910

File tree

13 files changed

+353
-94
lines changed

13 files changed

+353
-94
lines changed

ltsm/data_pipeline/data_pipeline.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import numpy as np
22
import torch
33
import argparse
4+
import json
45
import random
56
import ipdb
67

78
from ltsm.data_provider.data_factory import get_datasets
8-
from ltsm.data_provider.data_loader import HF_Dataset
9+
from ltsm.data_provider.data_loader import HF_Dataset, HF_Timestamp_Dataset
910
from ltsm.data_pipeline.model_manager import ModelManager
1011

1112
import logging
@@ -72,11 +73,10 @@ def run(self):
7273
)
7374

7475
train_dataset, eval_dataset, test_datasets, _ = get_datasets(self.args)
75-
train_dataset, eval_dataset= HF_Dataset(train_dataset), HF_Dataset(eval_dataset)
76-
77-
if self.args.model == 'PatchTST' or self.args.model == 'DLinear':
78-
# Set the patch number to the size of the input sequence including the prompt sequence
79-
self.model_manager.args.seq_len = train_dataset[0]["input_data"].size()[0]
76+
if self.args.model == "Informer":
77+
train_dataset, eval_dataset = HF_Timestamp_Dataset(train_dataset), HF_Timestamp_Dataset(eval_dataset)
78+
else:
79+
train_dataset, eval_dataset= HF_Dataset(train_dataset), HF_Dataset(eval_dataset)
8080

8181
model = self.model_manager.create_model()
8282

@@ -103,16 +103,24 @@ def run(self):
103103

104104
# Testing settings
105105
for test_dataset in test_datasets:
106+
if self.args.model == "Informer":
107+
test_ds = HF_Timestamp_Dataset(test_dataset)
108+
else:
109+
test_ds = HF_Dataset(test_dataset)
110+
106111
trainer.compute_loss = self.model_manager.compute_loss
107112
trainer.prediction_step = self.model_manager.prediction_step
108113
test_dataset = HF_Dataset(test_dataset)
109114

110-
metrics = trainer.evaluate(test_dataset)
115+
metrics = trainer.evaluate(test_ds)
111116
trainer.log_metrics("Test", metrics)
112117
trainer.save_metrics("Test", metrics)
113118

114119
def get_args():
115120
parser = argparse.ArgumentParser(description='LTSM')
121+
122+
# Load JSON config file
123+
parser.add_argument('--config', type=str, help='Path to JSON configuration file')
116124

117125
# Basic Config
118126
parser.add_argument('--model_id', type=str, default='test_run', help='model id')
@@ -122,8 +130,9 @@ def get_args():
122130
parser.add_argument('--checkpoints', type=str, default='./checkpoints/')
123131

124132
# Data Settings
133+
parser.add_argument('--data', help='dataset type')
125134
parser.add_argument('--data_path', nargs='+', default='dataset/weather.csv', help='data files')
126-
parser.add_argument('--test_data_path_list', nargs='+', required=True, help='test data file')
135+
parser.add_argument('--test_data_path_list', nargs='+', help='test data file')
127136
parser.add_argument('--prompt_data_path', type=str, default='./weather.csv', help='prompt data file')
128137
parser.add_argument('--data_processing', type=str, default="standard_scaler", help='data processing method')
129138
parser.add_argument('--train_ratio', type=float, default=0.7, help='train data ratio')
@@ -153,7 +162,6 @@ def get_args():
153162
parser.add_argument('--model', type=str, default='model', help='model name, , options:[LTSM, LTSM_WordPrompt, LTSM_Tokenizer, DLinear, PatchTST, Informer]')
154163
parser.add_argument('--stride', type=int, default=8, help='stride')
155164
parser.add_argument('--tmax', type=int, default=10, help='tmax')
156-
parser.add_argument('--dropout', type=float, default=0.05, help='dropout')
157165
parser.add_argument('--embed', type=str, default='timeF',
158166
help='time features encoding, options:[timeF, fixed, learned]')
159167
parser.add_argument('--activation', type=str, default='gelu', help='activation')
@@ -200,6 +208,14 @@ def get_args():
200208

201209
args, unknown = parser.parse_known_args()
202210

211+
if args.config:
212+
with open(args.config, 'r') as f:
213+
config = json.load(f)
214+
json_args = argparse.Namespace(**config)
215+
216+
for key, value in vars(json_args).items():
217+
setattr(args, key, value)
218+
203219
return args
204220

205221

ltsm/data_pipeline/model_manager.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,12 @@ def compute_loss(self, model, inputs, return_outputs=False):
126126
Returns:
127127
torch.Tensor or tuple: The computed loss, and optionally the outputs.
128128
"""
129-
outputs = model(inputs["input_data"])
129+
if self.args.model == 'Informer':
130+
input_data_mark = inputs["timestamp_input"].to(model.module.device)
131+
label_mark = inputs["timestamp_labels"].to(model.module.device)
132+
outputs = model(inputs["input_data"], input_data_mark, inputs["labels"], label_mark)
133+
else:
134+
outputs = model(inputs["input_data"])
130135
loss = nn.functional.mse_loss(outputs, inputs["labels"])
131136
return (loss, outputs) if return_outputs else loss
132137

@@ -146,7 +151,12 @@ def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys
146151
"""
147152
input_data = inputs["input_data"].to(model.module.device)
148153
labels = inputs["labels"].to(model.module.device)
149-
outputs = model(input_data)
154+
if self.args.model == 'Informer':
155+
input_data_mark = inputs["timestamp_input"].to(model.module.device)
156+
label_mark = inputs["timestamp_labels"].to(model.module.device)
157+
outputs = model(input_data, input_data_mark, labels, label_mark)
158+
else:
159+
outputs = model(input_data)
150160
loss = nn.functional.mse_loss(outputs, labels)
151161
return (loss, outputs, labels)
152162

@@ -160,6 +170,14 @@ def collate_fn(self, batch):
160170
Returns:
161171
dict: Collated batch with 'input_data' and 'labels' tensors.
162172
"""
173+
if self.args.model == 'Informer':
174+
return {
175+
'input_data': torch.from_numpy(np.stack([x['input_data'] for x in batch])).type(torch.float32),
176+
'labels': torch.from_numpy(np.stack([x['labels'] for x in batch])).type(torch.float32),
177+
'timestamp_input': torch.from_numpy(np.stack([x['timestamp_input'] for x in batch])).type(torch.float32),
178+
'timestamp_labels': torch.from_numpy(np.stack([x['timestamp_labels'] for x in batch])).type(torch.float32)
179+
}
180+
163181
return {
164182
'input_data': torch.from_numpy(np.stack([x['input_data'] for x in batch])).type(torch.float32),
165183
'labels': torch.from_numpy(np.stack([x['labels'] for x in batch])).type(torch.float32),

ltsm/data_provider/data_factory.py

Lines changed: 55 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ltsm.data_provider.data_splitter import SplitterByTimestamp
88
from ltsm.data_provider.tokenizer import processor_dict
99
from ltsm.data_provider.dataset import TSDataset, TSPromptDataset, TSTokenDataset
10+
from ltsm.data_provider.data_loader import Dataset_Custom, Dataset_ETT_hour, Dataset_ETT_minute
1011

1112
from typing import Tuple, List, Union, Dict
1213
import logging
@@ -341,37 +342,61 @@ def getDatasets(self)->Tuple[TSDataset, TSDataset, List[TSDataset]]:
341342
return train_ds, val_ds, test_ds_list
342343

343344
def get_datasets(args):
344-
ds_factory = DatasetFactory(
345-
data_paths=args.data_path,
346-
prompt_data_path=args.prompt_data_path,
347-
data_processing=args.data_processing,
348-
seq_len=args.seq_len,
349-
pred_len=args.pred_len,
350-
train_ratio=args.train_ratio,
351-
val_ratio=args.val_ratio,
352-
model=args.model,
353-
downsample_rate=args.downsample_rate,
354-
do_anomaly=args.do_anomaly
355-
)
356-
train_ds, val_ds, test_ds_list= ds_factory.getDatasets()
357-
358-
return train_ds, val_ds, test_ds_list, ds_factory.processor
345+
if "LTSM" in args.model:
346+
# Create datasets
347+
dataset_factory = DatasetFactory(
348+
data_paths=args.data_path,
349+
prompt_data_path=args.prompt_data_path,
350+
data_processing=args.data_processing,
351+
seq_len=args.seq_len,
352+
pred_len=args.pred_len,
353+
train_ratio=args.train_ratio,
354+
val_ratio=args.val_ratio,
355+
model=args.model,
356+
split_test_sets=False,
357+
downsample_rate=args.downsample_rate,
358+
do_anomaly=args.do_anomaly
359+
)
360+
train_dataset, val_dataset, test_datasets = dataset_factory.getDatasets()
361+
processor = dataset_factory.processor
362+
else:
363+
timeenc = 0 if args.embed != 'timeF' else 1
364+
Data = Dataset_Custom
365+
if args.data == "ETTh1" or args.data == "ETTh2":
366+
Data = Dataset_ETT_hour
367+
elif args.data == "ETTm1" or args.data == "ETTm2":
368+
Data = Dataset_ETT_minute
369+
370+
train_dataset = Data(
371+
data_path=args.data_path[0],
372+
split='train',
373+
size=[args.seq_len, args.pred_len],
374+
freq=args.freq,
375+
timeenc=timeenc,
376+
features=args.features
377+
)
378+
val_dataset = Data(
379+
data_path=args.data_path[0],
380+
split='val',
381+
size=[args.seq_len, args.pred_len],
382+
freq=args.freq,
383+
timeenc=timeenc,
384+
features=args.features
385+
)
386+
test_datasets = [Data(
387+
data_path=args.data_path[0],
388+
split='test',
389+
size=[args.seq_len, args.pred_len],
390+
freq=args.freq,
391+
timeenc=timeenc,
392+
features=args.features
393+
)]
394+
processor = train_dataset.scaler
395+
396+
return train_dataset, val_dataset, test_datasets, processor
359397

360398
def get_data_loaders(args):
361-
# Create datasets
362-
dataset_factory = DatasetFactory(
363-
data_paths=args.data_path,
364-
prompt_data_path=args.prompt_data_path,
365-
data_processing=args.data_processing,
366-
seq_len=args.seq_len,
367-
pred_len=args.pred_len,
368-
train_ratio=args.train_ratio,
369-
val_ratio=args.val_ratio,
370-
model=args.model,
371-
split_test_sets=False,
372-
do_anomaly=args.do_anomaly
373-
)
374-
train_dataset, val_dataset, test_datasets = dataset_factory.getDatasets()
399+
train_dataset, val_dataset, test_datasets, processor = get_datasets()
375400
print(f"Data loaded, train size {len(train_dataset)}, val size {len(val_dataset)}")
376401

377402
train_loader = DataLoader(
@@ -396,4 +421,4 @@ def get_data_loaders(args):
396421
num_workers=0,
397422
)
398423

399-
return train_loader, val_loader, test_loader, dataset_factory.processor
424+
return train_loader, val_loader, test_loader, processor

ltsm/data_provider/data_loader.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,42 @@ def inverse_transform(self, data):
3131
def add_data(self, df):
3232
return self.dataset.add_data(df)
3333

34-
def __getitem__(self, index):
35-
36-
seq_x, seq_y = self.dataset.__getitem__(index)
34+
def __getitem__(self, index):
35+
outputs = self.dataset.__getitem__(index)
36+
seq_x = outputs[0]
37+
seq_y = outputs[1]
3738

3839
return {
3940
"input_data": seq_x,
4041
"labels": seq_y
4142
}
43+
44+
class HF_Timestamp_Dataset(Dataset):
45+
def __init__(self, dataset):
46+
super().__init__()
47+
self.dataset = dataset
48+
49+
def __read_data__(self):
50+
return self.dataset.__read_data__()
51+
52+
def __len__(self):
53+
return self.dataset.__len__()
54+
55+
def inverse_transform(self, data):
56+
return self.dataset.inverse_transform(data)
57+
58+
def add_data(self, df):
59+
return self.dataset.add_data(df)
60+
61+
def __getitem__(self, index):
62+
seq_x, seq_y, seq_x_mark, seq_y_mark = self.dataset.__getitem__(index)
63+
64+
return {
65+
"input_data": seq_x,
66+
"labels": seq_y,
67+
"timestamp_input": seq_x_mark,
68+
"timestamp_labels": seq_y_mark
69+
}
4270

4371
class Dataset_ETT_hour(Dataset):
4472
def __init__(
@@ -131,8 +159,13 @@ def __getitem__(self, index):
131159
s_end = s_begin + self.seq_len
132160
r_begin = s_end
133161
r_end = r_begin + self.pred_len
134-
seq_x = self.data_x[s_begin:s_end, feat_id:feat_id+1]
135-
seq_y = self.data_y[r_begin:r_end, feat_id:feat_id+1]
162+
if self.enc_in > 1:
163+
seq_x = self.data_x[s_begin:s_end]
164+
seq_y = self.data_y[r_begin:r_end]
165+
else:
166+
seq_x = self.data_x[s_begin:s_end, feat_id:feat_id+1]
167+
seq_y = self.data_y[r_begin:r_end, feat_id:feat_id+1]
168+
136169
seq_x_mark = self.data_stamp[s_begin:s_end]
137170
seq_y_mark = self.data_stamp[r_begin:r_end]
138171

@@ -233,8 +266,13 @@ def __getitem__(self, index):
233266
s_end = s_begin + self.seq_len
234267
r_begin = s_end
235268
r_end = r_begin + self.pred_len
236-
seq_x = self.data_x[s_begin:s_end, feat_id:feat_id+1]
237-
seq_y = self.data_y[r_begin:r_end, feat_id:feat_id+1]
269+
if self.enc_in > 1:
270+
seq_x = self.data_x[s_begin:s_end]
271+
seq_y = self.data_y[r_begin:r_end]
272+
else:
273+
seq_x = self.data_x[s_begin:s_end, feat_id:feat_id+1]
274+
seq_y = self.data_y[r_begin:r_end, feat_id:feat_id+1]
275+
238276
seq_x_mark = self.data_stamp[s_begin:s_end]
239277
seq_y_mark = self.data_stamp[r_begin:r_end]
240278

@@ -345,8 +383,14 @@ def __getitem__(self, index):
345383
s_end = s_begin + self.seq_len
346384
r_begin = s_end
347385
r_end = r_begin + self.pred_len
348-
seq_x = self.data_x[s_begin:s_end, feat_id:feat_id+1]
349-
seq_y = self.data_y[r_begin:r_end, feat_id:feat_id+1]
386+
387+
if self.enc_in > 1:
388+
seq_x = self.data_x[s_begin:s_end]
389+
seq_y = self.data_y[r_begin:r_end]
390+
else:
391+
seq_x = self.data_x[s_begin:s_end, feat_id:feat_id+1]
392+
seq_y = self.data_y[r_begin:r_end, feat_id:feat_id+1]
393+
350394
seq_x_mark = self.data_stamp[s_begin:s_end]
351395
seq_y_mark = self.data_stamp[r_begin:r_end]
352396

tests/model/DLinear_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def config(tmp_path):
2020
"test_data_path_list": [str(data_path)],
2121
"prompt_data_path": str(prompt_data_path),
2222
"enc_in": 1,
23-
"seq_len": 336+133, # Equal to the sequence length + the length of prompt
24-
"train_epochs": 1000,
23+
"seq_len": 336, # Equal to the sequence length + the length of prompt
24+
"train_epochs": 100,
2525
"patience": 10,
2626
"lradj": 'TST',
2727
"pct_start": 0.2,

tests/model/Informer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def config(tmp_path):
2828
"dropout": 0.2,
2929
"fc_dropout": 0.2,
3030
"head_dropout": 0,
31-
"seq_len": 336+133, # Equal to the sequence length + the length of prompt
31+
"seq_len": 336,
3232
"patch_len": 16,
3333
"stride": 8,
3434
"des": 'Exp',

tests/model/PatchTST_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def config(tmp_path):
2727
"dropout": 0.2,
2828
"fc_dropout": 0.2,
2929
"head_dropout": 0,
30-
"seq_len": 336+133, # Equal to the sequence length + the length of prompt
30+
"seq_len": 336,
3131
"patch_len": 16,
3232
"stride": 8,
3333
"des": 'Exp',

tests/test_scripts/dlinear.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"model": "DLinear",
3+
"model_name_or_path": "gpt2-medium",
4+
"pred_len": 96,
5+
"gradient_accumulation_steps": 64,
6+
"seq_len": 336,
7+
"des": "Exp",
8+
"train_epochs": 100,
9+
"freeze": 0,
10+
"itr": 1,
11+
"learning_rate": 1e-3,
12+
"downsample_rate": 20,
13+
"output_dir": "output/dlinear/",
14+
"eval": 0,
15+
"features": "M"
16+
}

0 commit comments

Comments
 (0)