Skip to content

Commit b0f71b9

Browse files
committed
re-add numItermax for ot pot test
1 parent 8f7ddb2 commit b0f71b9

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tests/test_utils/test_optimal_transport.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def test_assignment_aligns_with_pot():
6161
from ot.bregman import sinkhorn_log
6262
except (ImportError, ModuleNotFoundError):
6363
pytest.skip("Need to install POT to run this test.")
64+
return
6465

6566
x = keras.random.normal((16, 2), seed=0)
6667
p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(x)[0]), seed=0)
@@ -71,7 +72,7 @@ def test_assignment_aligns_with_pot():
7172
M = x[:, None] - y[None, :]
7273
M = keras.ops.norm(M, axis=-1)
7374

74-
pot_plan = sinkhorn_log(a, b, M, reg=1e-3, stopThr=1e-7)
75+
pot_plan = sinkhorn_log(a, b, M, numItermax=10_000, reg=1e-3, stopThr=1e-7)
7576
pot_assignments = keras.random.categorical(keras.ops.log(pot_plan), num_samples=1, seed=0)
7677
pot_assignments = keras.ops.squeeze(pot_assignments, axis=-1)
7778

0 commit comments

Comments
 (0)