@@ -861,9 +861,7 @@ def test_shmap_abstract_mesh_errors(self):
861861 @jtu .thread_unsafe_test ()
862862 def test_debug_print_jit (self , jit ):
863863 if config .use_shardy_partitioner .value :
864- self .skipTest (
865- 'TODO(b/364547005): debug prints not supported by Shardy yet'
866- )
864+ self .skipTest ('TODO(b/384938613): Failing under shardy' )
867865 mesh = Mesh (jax .devices (), ('i' ,))
868866
869867 @partial (shard_map , mesh = mesh , in_specs = P ('i' ), out_specs = P ('i' ))
@@ -2151,9 +2149,6 @@ def f():
21512149 self .assertAllClose (f (), np .arange (4 , dtype = np .int32 ).reshape (- 1 , 1 ))
21522150
21532151 def test_partial_auto_axis_index_degenerated_axis (self ):
2154- if config .use_shardy_partitioner .value :
2155- self .skipTest ('Shardy does not support full-to-shard.' )
2156-
21572152 mesh = jtu .create_mesh ((1 , 2 ), ('i' , 'j' ))
21582153 out_sharding = NamedSharding (mesh , P ('i' , None ))
21592154
@@ -2166,9 +2161,6 @@ def f():
21662161 self .assertAllClose (f (), np .arange (1 , dtype = np .int32 ).reshape (- 1 , 1 ))
21672162
21682163 def test_partial_auto_ppermute (self ):
2169- if config .use_shardy_partitioner .value :
2170- self .skipTest ('Shardy does not support full-to-shard.' )
2171-
21722164 mesh = jtu .create_mesh ((4 , 2 ), ('i' , 'j' ))
21732165 x = jnp .arange (8. )
21742166
@@ -2188,8 +2180,6 @@ def f(x):
21882180
21892181 # TODO(parkers,mattjj): get XLA to support this too
21902182 # def test_partial_auto_all_to_all(self):
2191- # if config.use_shardy_partitioner.value:
2192- # self.skipTest('Shardy does not support anything.')
21932183 #
21942184 # mesh = jtu.create_mesh((4, 2), ('i', 'j'))
21952185 # x = jnp.arange(128.).reshape(16, 8)
0 commit comments