99def test_jit_compile ():
1010 import jax
1111
12- x = keras .random .normal ((100 , 100 ), seed = 0 )
13- y = keras .random .normal ((100 , 100 ), seed = 1 )
12+ x = keras .random .normal ((128 , 8 ), seed = 0 )
13+ y = keras .random .normal ((128 , 8 ), seed = 1 )
1414
1515 ot = jax .jit (optimal_transport , static_argnames = ["regularization" , "seed" ])
16- ot (x , y , regularization = 0.1 , seed = 0 )
16+ ot (x , y , regularization = 1.0 , seed = 0 , max_steps = 10 )
17+
18+
19+ def test_shapes ():
20+ x = keras .random .normal ((128 , 8 ), seed = 0 )
21+ y = keras .random .normal ((128 , 8 ), seed = 1 )
22+
23+ ox , oy = optimal_transport (x , y , regularization = 1.0 , seed = 0 , max_steps = 10 )
24+
25+ assert keras .ops .shape (ox ) == keras .ops .shape (x )
26+ assert keras .ops .shape (oy ) == keras .ops .shape (y )
1727
1828
1929def test_transport_cost_improves ():
20- x = keras .random .normal ((100 , 2 ), seed = 0 )
21- y = keras .random .normal ((100 , 2 ), seed = 1 )
30+ x = keras .random .normal ((32 , 8 ), seed = 0 )
31+ y = keras .random .normal ((32 , 8 ), seed = 1 )
2232
2333 before_cost = keras .ops .sum (keras .ops .norm (x - y , axis = - 1 ))
2434
25- x , y = optimal_transport (x , y , regularization = 1e-3 , seed = 0 )
35+ x , y = optimal_transport (x , y , regularization = 0.1 , seed = 0 , max_steps = None )
2636
2737 after_cost = keras .ops .sum (keras .ops .norm (x - y , axis = - 1 ))
2838
@@ -36,7 +46,7 @@ def test_assignment_is_optimal():
3646 # we could shuffle x and y, but flipping is a more reliable permutation
3747 y = keras .ops .flip (y , axis = 0 )
3848
39- x , y = optimal_transport (x , y , regularization = 1e-6 , seed = 0 )
49+ x , y = optimal_transport (x , y , regularization = 1e-3 , seed = 0 , max_steps = 1000 )
4050
4151 cost = keras .ops .sum (keras .ops .norm (x - y , axis = - 1 ))
4252
0 commit comments