Skip to content

Commit bae3ffc

Browse files
committed
Adding 0.05 tolerance
1 parent 9c7641b commit bae3ffc

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tests/integration/models/clavaddpm/test_model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,10 @@ def test_train_single_table(tmp_path: Path):
285285
# expected_model_layers = list(expected_model_data.keys())
286286
# if np.allclose(model_data[model_layers[0]].detach(), expected_model_data[expected_model_layers[0]].detach()):
287287
# if the first layer is equal with minimal tolerance, all others should be equal as well
288-
assert all(np.allclose(model_data[layer].detach(), expected_model_data[layer].detach()) for layer in model_layers)
288+
assert all(
289+
np.allclose(model_data[layer].detach(), expected_model_data[layer].detach(), atol=0.05)
290+
for layer in model_layers
291+
)
289292

290293
# TODO: Figure out if there is a good way of testing the synthetic data results
291294
# on multiple platforms. https://app.clickup.com/t/868f43wp0
@@ -348,7 +351,10 @@ def test_train_multi_table(tmp_path: Path):
348351

349352
# if np.allclose(model_data[model_layers[0]].detach(), expected_model_data[expected_model_layers[0]].detach()):
350353
# if the first layer is equal with minimal tolerance, all others should be equal as well
351-
assert all(np.allclose(model_data[layer].detach(), expected_model_data[layer].detach()) for layer in model_layers)
354+
assert all(
355+
np.allclose(model_data[layer].detach(), expected_model_data[layer].detach(), atol=0.05)
356+
for layer in model_layers
357+
)
352358

353359
# # TODO: Figure out if there is a good way of testing the synthetic data results
354360
# # on multiple platforms. https://app.clickup.com/t/868f43wp0

0 commit comments

Comments
 (0)