Skip to content

Commit 8f7ddb2

Browse files
fix faulty indexing with tensor for tensorflow backend
1 parent f25df63 commit 8f7ddb2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/test_utils/test_optimal_transport.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_assignment_is_optimal(method):
4646
y = keras.random.normal((16, 2), seed=0)
4747
p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(y)[0]), seed=0)
4848

49-
x = y[p]
49+
x = keras.ops.take(y, p, axis=0)
5050

5151
_, _, assignments = optimal_transport(
5252
x, y, regularization=0.1, seed=0, max_steps=10_000, method=method, return_assignments=True

0 commit comments

Comments
 (0)