@@ -136,8 +136,15 @@ def hyper_parameter(task=None, model_name=None, dataset_name=None, config_file=N
136136 # load config
137137 experiment_config = ConfigParser (task , model_name , dataset_name , config_file = config_file ,
138138 other_args = other_args )
139+ # exp_id
140+ exp_id = experiment_config .get ('exp_id' , None )
141+ if exp_id is None :
142+ exp_id = int (random .SystemRandom ().random () * 100000 )
143+ experiment_config ['exp_id' ] = exp_id
139144 # logger
140145 logger = get_logger (experiment_config )
146+ logger .info ('Begin ray-tune, task={}, model_name={}, dataset_name={}, exp_id={}' .
147+ format (str (task ), str (model_name ), str (dataset_name ), str (exp_id )))
141148 logger .info (experiment_config .config )
142149 # check space_file
143150 if space_file is None :
@@ -167,8 +174,11 @@ def train(config, checkpoint_dir=None, experiment_config=None,
167174 experiment_config [key ] = config [key ]
168175 experiment_config ['hyper_tune' ] = True
169176 logger = get_logger (experiment_config )
170- logger .info ('Begin pipeline, task={}, model_name={}, dataset_name={}'
171- .format (str (task ), str (model_name ), str (dataset_name )))
177+ # exp_id
178+ exp_id = int (random .SystemRandom ().random () * 100000 )
179+ experiment_config ['exp_id' ] = exp_id
180+ logger .info ('Begin pipeline, task={}, model_name={}, dataset_name={}, exp_id={}' .
181+ format (str (task ), str (model_name ), str (dataset_name ), str (exp_id )))
172182 logger .info ('running parameters: ' + str (config ))
173183 # load model
174184 model = get_model (experiment_config , data_feature )
@@ -215,9 +225,9 @@ def train(config, checkpoint_dir=None, experiment_config=None,
215225 # save best
216226 best_path = os .path .join (best_trial .checkpoint .value , "checkpoint" )
217227 model_state , optimizer_state = torch .load (best_path )
218- model_cache_file = './libcity/cache/model_cache/{}_{}.m' .format (
219- model_name , dataset_name )
220- ensure_dir ('./libcity/cache/model_cache' )
228+ model_cache_file = './libcity/cache/{}/ model_cache/{}_{}.m' .format (
229+ exp_id , model_name , dataset_name )
230+ ensure_dir ('./libcity/cache/{}/ model_cache' . format ( exp_id ) )
221231 torch .save ((model_state , optimizer_state ), model_cache_file )
222232
223233
0 commit comments