@@ -4729,9 +4729,14 @@ def spec_regex(s):
47294729 return str (s ).replace (r"(" , r"\(" ).replace (r")" , r"\)" )
47304730
47314731
4732- @jtu .with_config (jax_use_shardy_partitioner = False )
47334732class ShardingInTypesTest (jtu .JaxTestCase ):
47344733
4734+ def check_wsc_in_lowered (self , text ):
4735+ if config .use_shardy_partitioner .value :
4736+ self .assertIn ('sdy.sharding_constraint' , text )
4737+ else :
4738+ self .assertIn ('@Sharding' , text )
4739+
47354740 @jtu .with_user_mesh ((2 , 2 ), ('x' , 'y' ))
47364741 def test_basic_mul (self , mesh ):
47374742 np_inp = np .arange (16. ).reshape (8 , 2 )
@@ -4753,7 +4758,7 @@ def f(x):
47534758
47544759 lowered_text = f .lower (arr ).as_text ()
47554760 if config .use_shardy_partitioner .value :
4756- self .assertIn ( 'sdy.sharding_constraint' , lowered_text )
4761+ self .assertEqual ( lowered_text . count ( 'sdy.sharding_constraint' ), 3 )
47574762 else :
47584763 self .assertEqual (lowered_text .count ('@Sharding' ), 3 )
47594764
@@ -4834,7 +4839,7 @@ def f(x, y):
48344839 self .assertEqual (out .sharding , NamedSharding (mesh , out_spec ))
48354840
48364841 lowered = f .lower (arr1 , arr2 )
4837- self .assertIn ( '@Sharding' , lowered .as_text ())
4842+ self .check_wsc_in_lowered ( lowered .as_text ())
48384843
48394844 compiled_text = lowered .compile ().as_text ()
48404845 if collective_name is not None and compiled_text is not None :
@@ -4971,7 +4976,7 @@ def f(x):
49714976 self .assertEqual (out .sharding , NamedSharding (mesh , out_spec ))
49724977
49734978 lowered = f .lower (arr )
4974- self .assertIn ( '@Sharding' , lowered .as_text ())
4979+ self .check_wsc_in_lowered ( lowered .as_text ())
49754980
49764981 compiled_text = lowered .compile ().as_text ()
49774982 if reduce and compiled_text is not None :
@@ -5002,7 +5007,7 @@ def f(x):
50025007 self .assertEqual (out .sharding , NamedSharding (mesh , out_spec ))
50035008
50045009 lowered = f .lower (arr )
5005- self .assertIn ( '@Sharding' , lowered .as_text ())
5010+ self .check_wsc_in_lowered ( lowered .as_text ())
50065011
50075012 compiled_text = lowered .compile ().as_text ()
50085013 if reduce and compiled_text is not None :
@@ -5044,7 +5049,7 @@ def f(x):
50445049 self .assertEqual (out .sharding , NamedSharding (mesh , out_spec ))
50455050
50465051 lowered_text = f .lower (arr ).as_text ()
5047- self .assertIn ( '@Sharding' , lowered_text )
5052+ self .check_wsc_in_lowered ( lowered_text )
50485053
50495054 @parameterized .named_parameters (
50505055 ('2' , 2 ),
@@ -5068,7 +5073,7 @@ def f(x):
50685073 self .assertArraysEqual (out , np_inp ** pow )
50695074
50705075 lowered_text = f .lower (arr ).as_text ()
5071- self .assertIn ( '@Sharding' , lowered_text )
5076+ self .check_wsc_in_lowered ( lowered_text )
50725077
50735078 @jtu .with_user_mesh ((1 ,), 'x' )
50745079 def test_broadcasting_nary_error (self , mesh ):
@@ -5102,7 +5107,7 @@ def f(x):
51025107 self .assertEqual (out .sharding , s )
51035108
51045109 lowered_text = f .lower (arr ).as_text ()
5105- self .assertIn ( '@Sharding' , lowered_text )
5110+ self .check_wsc_in_lowered ( lowered_text )
51065111
51075112 @jtu .with_user_mesh ((2 , 2 ), ('x' , 'y' ))
51085113 def test_jnp_array (self , mesh ):
@@ -5137,7 +5142,7 @@ def f(x):
51375142 self .assertEqual (out .sharding , NamedSharding (mesh , P ('y' , 'z' , 'x' )))
51385143
51395144 lowered_text = f .lower (arr ).as_text ()
5140- self .assertIn ( '@Sharding' , lowered_text )
5145+ self .check_wsc_in_lowered ( lowered_text )
51415146
51425147 @jtu .with_user_mesh ((2 , 2 ), ('x' , 'y' ))
51435148 def test_broadcasted_iota_with_sharding (self , mesh ):
@@ -5182,7 +5187,7 @@ def f(x, y):
51825187 self .assertEqual (out .sharding , NamedSharding (mesh , P ('x' , None )))
51835188
51845189 lowered_text = f .lower (arr1 , arr2 ).as_text ()
5185- self .assertIn ( '@Sharding' , lowered_text )
5190+ self .check_wsc_in_lowered ( lowered_text )
51865191
51875192 @jax .jit
51885193 def g (x , y ):
@@ -5228,7 +5233,7 @@ def h(x, y):
52285233 self .assertEqual (out .sharding , NamedSharding (mesh , P ('x' , None , 'y' , None )))
52295234
52305235 lowered_text = h .lower (arr1 , arr2 ).as_text ()
5231- self .assertIn ( '@Sharding' , lowered_text )
5236+ self .check_wsc_in_lowered ( lowered_text )
52325237
52335238 @jax .jit
52345239 def h2 (x , y ):
@@ -5268,7 +5273,7 @@ def f(x, new_sharding):
52685273 self .assertArraysEqual (out , np_inp .reshape (dst_shape ) * 2 )
52695274
52705275 lowered_text = f .lower (arr , new_s ).as_text ()
5271- self .assertIn ( '@Sharding' , lowered_text )
5276+ self .check_wsc_in_lowered ( lowered_text )
52725277
52735278 def g (x ):
52745279 out = f (x , new_s )
@@ -5295,7 +5300,7 @@ def f(pred, on_true, on_false):
52955300 self .assertArraysEqual (out , arr1 )
52965301
52975302 lowered_text = f .lower (arr1 == arr2 , arr1 , arr2 ).as_text ()
5298- self .assertIn ( '@Sharding' , lowered_text )
5303+ self .check_wsc_in_lowered ( lowered_text )
52995304
53005305 arr3 = jax .device_put (np_inp , NamedSharding (mesh , P ('y' , 'x' )))
53015306 with self .assertRaisesRegex (
@@ -5383,7 +5388,7 @@ def f(x):
53835388
53845389 out = f (arr )
53855390 self .assertEqual (out .sharding , NamedSharding (mesh , P ('x' , None )))
5386- self .assertIn ( '@Sharding' , f .lower (arr ).as_text ())
5391+ self .check_wsc_in_lowered ( f .lower (arr ).as_text ())
53875392
53885393 def g (x ):
53895394 out = f (x )
@@ -5414,7 +5419,7 @@ def f(x):
54145419
54155420 out = f (arr )
54165421 self .assertEqual (out .sharding , NamedSharding (mesh , P ('x' , None )))
5417- self .assertIn ( '@Sharding' , f .lower (arr ).as_text ())
5422+ self .check_wsc_in_lowered ( f .lower (arr ).as_text ())
54185423 self .assertArraysEqual (out , np .squeeze (np_inp , axis = 2 ))
54195424
54205425 def g (x ):
@@ -5441,7 +5446,7 @@ def f(x, padding_config, spec):
54415446 out = f (arr , ((2 , 2 , 0 ),), P ('x' ))
54425447 self .assertArraysEqual (out , np .pad (np_inp , 2 ))
54435448 self .assertEqual (out .sharding , NamedSharding (mesh , P ('x' )))
5444- self .assertIn ( '@Sharding' , f .lower (arr , ((2 , 2 , 0 ),), P ('x' )).as_text ())
5449+ self .check_wsc_in_lowered ( f .lower (arr , ((2 , 2 , 0 ),), P ('x' )).as_text ())
54455450
54465451 out = f (arr , ((0 , 0 , 0 ),), P ('x' ))
54475452 self .assertArraysEqual (out , np_inp )
@@ -5489,7 +5494,7 @@ def f(x, y, method='jnp'):
54895494 out = f (arr1 , arr2 )
54905495 self .assertEqual (out .sharding , s )
54915496 self .assertArraysEqual (out , np .concatenate ([arr1 , arr2 ], axis = 1 ))
5492- self .assertIn ( '@Sharding' , f .lower (arr1 , arr2 ).as_text ())
5497+ self .check_wsc_in_lowered ( f .lower (arr1 , arr2 ).as_text ())
54935498
54945499 out = f (arr1 , arr2 , method = 'lax' )
54955500 self .assertEqual (out .sharding , s )
@@ -5568,8 +5573,7 @@ def f(x):
55685573 self .assertEqual (out1 .sharding , NamedSharding (mesh , P ('y' )))
55695574 self .assertArraysEqual (out2 , np .argmin (np_inp , axis = 1 ))
55705575 self .assertEqual (out2 .sharding , NamedSharding (mesh , P ('x' )))
5571-
5572- self .assertIn ('@Sharding' , f .lower (arr ).as_text ())
5576+ self .check_wsc_in_lowered (f .lower (arr ).as_text ())
55735577
55745578 @jtu .with_user_mesh ((2 , 2 ), ('x' , 'y' ), {mesh_lib .AxisTypes .Auto : ('x' , 'y' )})
55755579 def test_only_auto (self , mesh ):
@@ -5618,7 +5622,10 @@ def f(x, x2):
56185622 out = f (arr , arr2 )
56195623 self .assertEqual (out .sharding , NamedSharding (mesh2 , P ('x' ,)))
56205624 lowered_text = f .lower (arr , arr2 ).as_text ()
5621- self .assertTrue (lowered_text .count ("unspecified_dims" ) == 5 )
5625+ if config .use_shardy_partitioner .value :
5626+ self .assertTrue (lowered_text .count ("{?}" ) == 5 )
5627+ else :
5628+ self .assertTrue (lowered_text .count ("unspecified_dims" ) == 5 )
56225629
56235630 mesh3 = jtu .create_mesh ((2 , 2 ), ('x' , 'y' ),
56245631 axis_types = {mesh_lib .AxisTypes .User : 'y' ,
@@ -5629,7 +5636,12 @@ def f(x, x2):
56295636 out = f (arr , arr2 )
56305637 self .assertEqual (out .sharding , NamedSharding (mesh3 , P ('x' ,)))
56315638 lowered_text = f .lower (arr , arr2 ).as_text ()
5632- self .assertTrue (lowered_text .count ("unspecified_dims" ) == 4 )
5639+ print (lowered_text )
5640+ if config .use_shardy_partitioner .value :
5641+ self .assertTrue (lowered_text .count ("{?}" ) == 5 )
5642+ self .assertIn ('replicated={"y"}' , lowered_text )
5643+ else :
5644+ self .assertTrue (lowered_text .count ("unspecified_dims" ) == 4 )
56335645
56345646 with self .assertRaisesRegex (
56355647 ValueError ,
@@ -5784,7 +5796,7 @@ def f(x, sizes=(4, 4), axis=0):
57845796 return ys
57855797
57865798 f (arr )
5787- self .assertIn ( '@Sharding' , f .lower (arr ).as_text ())
5799+ self .check_wsc_in_lowered ( f .lower (arr ).as_text ())
57885800
57895801 with self .assertRaisesRegex (NotImplementedError , "split on sharded dims" ):
57905802 f (arr , sizes = (1 , 1 ), axis = 1 )
@@ -5864,6 +5876,31 @@ def g(x, y):
58645876 ValueError , "PartitionSpec cannot contain axis names.*Auto" ):
58655877 g (arr1 , arr2 )
58665878
5879+ @jtu .with_user_mesh ((2 , 2 , 2 ), ('x' , 'y' , 'z' ),
5880+ axis_types = {AxisTypes .User : ('x' , 'y' ),
5881+ AxisTypes .Auto : 'z' })
5882+ def test_out_sharding_mix_axis_types (self , mesh ):
5883+ np_inp = np .arange (16 ).reshape (4 , 2 , 2 )
5884+ s = NamedSharding (mesh , P ('x' , None , None ))
5885+ arr = jax .device_put (np_inp , s )
5886+
5887+ @jax .jit
5888+ def f (x ):
5889+ y = x * 2
5890+ self .assertEqual (y .sharding .spec , P ('x' , None , None ))
5891+ return y
5892+
5893+ out = f (arr )
5894+ self .assertEqual (out .sharding , NamedSharding (mesh , P ('x' ,)))
5895+ self .assertArraysEqual (out , np_inp * 2 )
5896+
5897+ lowered_text = f .lower (arr ).as_text ()
5898+ if config .use_shardy_partitioner .value :
5899+ self .assertTrue (lowered_text .count (
5900+ '[{"x"}, {?}, {?}], replicated={"y"}' ) == 3 )
5901+ else :
5902+ self .assertTrue (lowered_text .count ("unspecified_dims=[1,2]" ) == 3 )
5903+
58675904
58685905@jtu .pytest_mark_if_available ('multiaccelerator' )
58695906class PJitErrorTest (jtu .JaxTestCase ):
0 commit comments