@@ -724,7 +724,6 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
724724 aval_in , aval_out , x ):
725725 if prod ([size for n , size in mesh .shape .items () if n not in auto ]) == 1 :
726726 return x
727- manual_proto = pxla .manual_proto (aval_in , frozenset (mesh .axis_names ) - auto , mesh )
728727 axes = {name : i for i , ns in names .items () for name in ns }
729728 ns = _make_scoped_manual_sharding (ctx , mesh , axes )
730729 if dtypes .issubdtype (aval_in .dtype , dtypes .extended ):
@@ -734,6 +733,7 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
734733 unspecified = set (range (aval_in .ndim )) if auto else set ()
735734 sx = mlir .wrap_with_sharding_op (ctx , x , aval_in , shard_proto ,
736735 unspecified_dims = unspecified )
736+ manual_proto = pxla .manual_proto (aval_in , frozenset (mesh .axis_names ) - auto , mesh )
737737 return mlir .wrap_with_full_to_shard_op (ctx , sx , aval_out , manual_proto , unspecified )
738738
739739def _xla_unshard (ctx : mlir .LoweringRuleContext , mesh , auto , names ,
@@ -746,6 +746,8 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
746746 ns = sharding_impls .physical_sharding (aval_out , ns )
747747 aval_out = core .physical_aval (aval_out )
748748 unspecified = set (range (aval_out .ndim )) if auto else set ()
749+ if dtypes .issubdtype (aval_in .dtype , dtypes .extended ):
750+ aval_in = core .physical_aval (aval_in )
749751 manual_proto = pxla .manual_proto (aval_in , frozenset (mesh .axis_names ) - auto , mesh )
750752 sx = mlir .wrap_with_sharding_op (ctx , x , aval_in , manual_proto , unspecified_dims = unspecified )
751753 shard_proto = ns ._to_xla_hlo_sharding (aval_out .ndim ).to_proto ()
0 commit comments