Skip to content

Commit 01f026d

Browse files
committed
make optimal transport tests omre reliable
1 parent e95f181 commit 01f026d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/test_utils/test_optimal_transport.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ def test_jit_compile():
1717

1818

1919
def test_transport_cost_improves():
20-
x = keras.random.normal((100, 100), seed=0)
21-
y = keras.random.normal((100, 100), seed=1)
20+
x = keras.random.normal((100, 2), seed=0)
21+
y = keras.random.normal((100, 2), seed=1)
2222

2323
before_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1))
2424

25-
x, y = optimal_transport(x, y, regularization=0.1, seed=0)
25+
x, y = optimal_transport(x, y, regularization=1e-3, seed=0)
2626

2727
after_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1))
2828

0 commit comments

Comments
 (0)