@@ -27,8 +27,8 @@ def test_shapes():
2727
2828
2929def test_transport_cost_improves ():
30- x = keras .random .normal ((32 , 8 ), seed = 0 )
31- y = keras .random .normal ((32 , 8 ), seed = 1 )
30+ x = keras .random .normal ((1024 , 2 ), seed = 0 )
31+ y = keras .random .normal ((1024 , 2 ), seed = 1 )
3232
3333 before_cost = keras .ops .sum (keras .ops .norm (x - y , axis = - 1 ))
3434
@@ -40,14 +40,26 @@ def test_transport_cost_improves():
4040
4141
4242def test_assignment_is_optimal ():
43- x = keras .ops .stack ([keras .ops .linspace (- 1 , 1 , 10 ), keras .ops .linspace (- 1 , 1 , 10 )])
44- y = keras .ops .copy (x )
45-
46- # we could shuffle x and y, but flipping is a more reliable permutation
47- y = keras .ops .flip (y , axis = 0 )
48-
49- x , y = optimal_transport (x , y , regularization = 1e-3 , seed = 0 , max_steps = 1000 )
50-
51- cost = keras .ops .sum (keras .ops .norm (x - y , axis = - 1 ))
52-
53- assert_allclose (cost , 0.0 )
43+ x = keras .ops .convert_to_tensor (
44+ [
45+ [- 1 , 2 ],
46+ [- 1 , 1 ],
47+ [- 1 , 0 ],
48+ [- 1 , - 1 ],
49+ [- 1 , - 2 ],
50+ ]
51+ )
52+ optimal_y = keras .ops .convert_to_tensor (
53+ [
54+ [1 , 2 ],
55+ [1 , 1 ],
56+ [1 , 0 ],
57+ [1 , - 1 ],
58+ [1 , - 2 ],
59+ ]
60+ )
61+ y = keras .random .shuffle (optimal_y , axis = 0 , seed = 0 )
62+
63+ x , y = optimal_transport (x , y , regularization = 0.1 , seed = 0 , max_steps = None , scale_regularization = False )
64+
65+ assert_allclose (x , y )
0 commit comments