File tree Expand file tree Collapse file tree 2 files changed +2
-8
lines changed
src/midst_toolkit/models/clavaddpm
tests/integration/models/clavaddpm Expand file tree Collapse file tree 2 files changed +2
-8
lines changed Original file line number Diff line number Diff line change @@ -443,15 +443,9 @@ def train_model(
443443 )
444444 diffusion .to (device )
445445
446- print ("++++++++++++++++++++++++ BEFORE ++++++++++++++++++++++++++" )
447- print (diffusion .state_dict ())
448-
449446 if initial_state_file_path is not None :
450447 diffusion .load_state_dict (torch .load (initial_state_file_path , weights_only = True ))
451448
452- print ("++++++++++++++++++++++++ AFTER ++++++++++++++++++++++++++" )
453- print (diffusion .state_dict ())
454-
455449 diffusion .train ()
456450
457451 trainer = Trainer (
Original file line number Diff line number Diff line change @@ -288,7 +288,7 @@ def test_train_single_table(tmp_path: Path):
288288 # if np.allclose(model_data[model_layers[0]].detach(), expected_model_data[expected_model_layers[0]].detach()):
289289 # if the first layer is equal with minimal tolerance, all others should be equal as well
290290 assert all (
291- np .allclose (model_data [layer ].detach (), expected_model_data [layer ].detach (), atol = 0.05 )
291+ np .allclose (model_data [layer ].detach (), expected_model_data [layer ].detach (), atol = 0.08 )
292292 for layer in model_layers
293293 )
294294
@@ -356,7 +356,7 @@ def test_train_multi_table(tmp_path: Path):
356356 # if np.allclose(model_data[model_layers[0]].detach(), expected_model_data[expected_model_layers[0]].detach()):
357357 # if the first layer is equal with minimal tolerance, all others should be equal as well
358358 assert all (
359- np .allclose (model_data [layer ].detach (), expected_model_data [layer ].detach (), atol = 0.05 )
359+ np .allclose (model_data [layer ].detach (), expected_model_data [layer ].detach (), atol = 0.08 )
360360 for layer in model_layers
361361 )
362362
You can’t perform that action at this time.
0 commit comments