File tree Expand file tree Collapse file tree 1 file changed +8
-2
lines changed
tests/integration/models/clavaddpm Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change @@ -285,7 +285,10 @@ def test_train_single_table(tmp_path: Path):
285285 # expected_model_layers = list(expected_model_data.keys())
286286 # if np.allclose(model_data[model_layers[0]].detach(), expected_model_data[expected_model_layers[0]].detach()):
287287 # if the first layer is equal with minimal tolerance, all others should be equal as well
288- assert all (np .allclose (model_data [layer ].detach (), expected_model_data [layer ].detach ()) for layer in model_layers )
288+ assert all (
289+ np .allclose (model_data [layer ].detach (), expected_model_data [layer ].detach (), atol = 0.05 )
290+ for layer in model_layers
291+ )
289292
290293 # TODO: Figure out if there is a good way of testing the synthetic data results
291294 # on multiple platforms. https://app.clickup.com/t/868f43wp0
@@ -348,7 +351,10 @@ def test_train_multi_table(tmp_path: Path):
348351
349352 # if np.allclose(model_data[model_layers[0]].detach(), expected_model_data[expected_model_layers[0]].detach()):
350353 # if the first layer is equal with minimal tolerance, all others should be equal as well
351- assert all (np .allclose (model_data [layer ].detach (), expected_model_data [layer ].detach ()) for layer in model_layers )
354+ assert all (
355+ np .allclose (model_data [layer ].detach (), expected_model_data [layer ].detach (), atol = 0.05 )
356+ for layer in model_layers
357+ )
352358
353359 # # TODO: Figure out if there is a good way of testing the synthetic data results
354360 # # on multiple platforms. https://app.clickup.com/t/868f43wp0
You can’t perform that action at this time.
0 commit comments