11import numpy as np
22import torch
33import argparse
4+ import json
45import random
56import ipdb
67
78from ltsm .data_provider .data_factory import get_datasets
8- from ltsm .data_provider .data_loader import HF_Dataset
9+ from ltsm .data_provider .data_loader import HF_Dataset , HF_Timestamp_Dataset
910from ltsm .data_pipeline .model_manager import ModelManager
1011
1112import logging
@@ -72,11 +73,10 @@ def run(self):
7273 )
7374
7475 train_dataset , eval_dataset , test_datasets , _ = get_datasets (self .args )
75- train_dataset , eval_dataset = HF_Dataset (train_dataset ), HF_Dataset (eval_dataset )
76-
77- if self .args .model == 'PatchTST' or self .args .model == 'DLinear' :
78- # Set the patch number to the size of the input sequence including the prompt sequence
79- self .model_manager .args .seq_len = train_dataset [0 ]["input_data" ].size ()[0 ]
76+ if self .args .model == "Informer" :
77+ train_dataset , eval_dataset = HF_Timestamp_Dataset (train_dataset ), HF_Timestamp_Dataset (eval_dataset )
78+ else :
79+ train_dataset , eval_dataset = HF_Dataset (train_dataset ), HF_Dataset (eval_dataset )
8080
8181 model = self .model_manager .create_model ()
8282
@@ -103,16 +103,24 @@ def run(self):
103103
104104 # Testing settings
105105 for test_dataset in test_datasets :
106+ if self .args .model == "Informer" :
107+ test_ds = HF_Timestamp_Dataset (test_dataset )
108+ else :
109+ test_ds = HF_Dataset (test_dataset )
110+
106111 trainer .compute_loss = self .model_manager .compute_loss
107112 trainer .prediction_step = self .model_manager .prediction_step
108113 test_dataset = HF_Dataset (test_dataset )
109114
110- metrics = trainer .evaluate (test_dataset )
115+ metrics = trainer .evaluate (test_ds )
111116 trainer .log_metrics ("Test" , metrics )
112117 trainer .save_metrics ("Test" , metrics )
113118
114119def get_args ():
115120 parser = argparse .ArgumentParser (description = 'LTSM' )
121+
122+ # Load JSON config file
123+ parser .add_argument ('--config' , type = str , help = 'Path to JSON configuration file' )
116124
117125 # Basic Config
118126 parser .add_argument ('--model_id' , type = str , default = 'test_run' , help = 'model id' )
@@ -122,8 +130,9 @@ def get_args():
122130 parser .add_argument ('--checkpoints' , type = str , default = './checkpoints/' )
123131
124132 # Data Settings
133+ parser .add_argument ('--data' , help = 'dataset type' )
125134 parser .add_argument ('--data_path' , nargs = '+' , default = 'dataset/weather.csv' , help = 'data files' )
126- parser .add_argument ('--test_data_path_list' , nargs = '+' , required = True , help = 'test data file' )
135+ parser .add_argument ('--test_data_path_list' , nargs = '+' , help = 'test data file' )
127136 parser .add_argument ('--prompt_data_path' , type = str , default = './weather.csv' , help = 'prompt data file' )
128137 parser .add_argument ('--data_processing' , type = str , default = "standard_scaler" , help = 'data processing method' )
129138 parser .add_argument ('--train_ratio' , type = float , default = 0.7 , help = 'train data ratio' )
@@ -153,7 +162,6 @@ def get_args():
153162 parser .add_argument ('--model' , type = str , default = 'model' , help = 'model name, , options:[LTSM, LTSM_WordPrompt, LTSM_Tokenizer, DLinear, PatchTST, Informer]' )
154163 parser .add_argument ('--stride' , type = int , default = 8 , help = 'stride' )
155164 parser .add_argument ('--tmax' , type = int , default = 10 , help = 'tmax' )
156- parser .add_argument ('--dropout' , type = float , default = 0.05 , help = 'dropout' )
157165 parser .add_argument ('--embed' , type = str , default = 'timeF' ,
158166 help = 'time features encoding, options:[timeF, fixed, learned]' )
159167 parser .add_argument ('--activation' , type = str , default = 'gelu' , help = 'activation' )
@@ -200,6 +208,14 @@ def get_args():
200208
201209 args , unknown = parser .parse_known_args ()
202210
211+ if args .config :
212+ with open (args .config , 'r' ) as f :
213+ config = json .load (f )
214+ json_args = argparse .Namespace (** config )
215+
216+ for key , value in vars (json_args ).items ():
217+ setattr (args , key , value )
218+
203219 return args
204220
205221
0 commit comments