@@ -283,7 +283,7 @@ def clava_clustering(tables, relation_order, save_dir, configs):
283283 return tables , all_group_lengths_prob_dicts
284284
285285
286- def clava_training (tables , relation_order , save_dir , configs , device = "cuda" , initial_state_dict = None ):
286+ def clava_training (tables , relation_order , save_dir , configs , device = "cuda" , initial_state_file_path = None ):
287287 models = {}
288288 for parent , child in relation_order :
289289 print (f"Training { parent } -> { child } model from scratch" )
@@ -298,7 +298,7 @@ def clava_training(tables, relation_order, save_dir, configs, device="cuda", ini
298298 child ,
299299 configs ,
300300 device ,
301- initial_state_dict ,
301+ initial_state_file_path ,
302302 )
303303
304304 models [(parent , child )] = result
@@ -324,7 +324,7 @@ def child_training(
324324 child_name : str ,
325325 configs : dict [str , Any ],
326326 device : str = "cuda" ,
327- initial_state_dict : dict [ str , Tensor ] | None = None ,
327+ initial_state_file_path : Path | None = None ,
328328) -> dict [str , Any ]:
329329 if parent_name is None :
330330 y_col = "placeholder"
@@ -354,7 +354,7 @@ def child_training(
354354 configs ["diffusion" ]["lr" ],
355355 configs ["diffusion" ]["weight_decay" ],
356356 device = device ,
357- initial_state_dict = initial_state_dict ,
357+ initial_state_file_path = initial_state_file_path ,
358358 )
359359
360360 if parent_name is None :
@@ -398,7 +398,7 @@ def train_model(
398398 lr : float ,
399399 weight_decay : float ,
400400 device : str = "cuda" ,
401- initial_state_dict : dict [ str , Tensor ] | None = None ,
401+ initial_state_file_path : Path | None = None ,
402402) -> dict [str , Any ]:
403403 T = Transformations (** T_dict )
404404 dataset , label_encoders , column_orders = make_dataset_from_df (
@@ -443,8 +443,14 @@ def train_model(
443443 )
444444 diffusion .to (device )
445445
446- if initial_state_dict is not None :
447- diffusion .load_state_dict (initial_state_dict )
446+ print ("++++++++++++++++++++++++ BEFORE ++++++++++++++++++++++++++" )
447+ print (diffusion .state_dict ())
448+
449+ if initial_state_file_path is not None :
450+ diffusion .load_state_dict (torch .load (initial_state_file_path , weights_only = True ))
451+
452+ print ("++++++++++++++++++++++++ AFTER ++++++++++++++++++++++++++" )
453+ print (diffusion .state_dict ())
448454
449455 diffusion .train ()
450456
0 commit comments