@@ -1231,7 +1231,8 @@ def broadcast_to_rank(x: ArrayLike, rank: int) -> Array:
12311231 return broadcast (x , (1 ,) * (rank - ndim ))
12321232
12331233def reshape (operand : ArrayLike , new_sizes : Shape ,
1234- dimensions : Sequence [int ] | None = None ) -> Array :
1234+ dimensions : Sequence [int ] | None = None ,
1235+ sharding : NamedSharding | None = None ) -> Array :
12351236 """Wraps XLA's `Reshape
12361237 <https://www.tensorflow.org/xla/operation_semantics#reshape>`_
12371238 operator.
@@ -1285,7 +1286,8 @@ def reshape(operand: ArrayLike, new_sizes: Shape,
12851286
12861287 return reshape_p .bind (
12871288 operand , * dyn_shape , new_sizes = tuple (static_new_sizes ),
1288- dimensions = None if dims is None or same_dims else dims )
1289+ dimensions = None if dims is None or same_dims else dims ,
1290+ sharding = sharding )
12891291
12901292def pad (operand : ArrayLike , padding_value : ArrayLike ,
12911293 padding_config : Sequence [tuple [int , int , int ]]) -> Array :
@@ -4654,7 +4656,7 @@ def shape_as_value(shape: core.Shape):
46544656 ]
46554657 return concatenate (dims , dimension = 0 )
46564658
4657- def _reshape_shape_rule (operand , * , new_sizes , dimensions ):
4659+ def _reshape_shape_rule (operand , * , new_sizes , dimensions , sharding ):
46584660 if not all (d >= 0 for d in new_sizes ):
46594661 msg = 'reshape new_sizes must all be positive, got {}.'
46604662 raise TypeError (msg .format (new_sizes ))
@@ -4674,7 +4676,9 @@ def _reshape_shape_rule(operand, *, new_sizes, dimensions):
46744676 raise TypeError (msg .format (dimensions , np .shape (operand )))
46754677 return tuple (new_sizes )
46764678
4677- def _reshape_sharding_rule (operand , * , new_sizes , dimensions ):
4679+ def _reshape_sharding_rule (operand , * , new_sizes , dimensions , sharding ):
4680+ if sharding is not None :
4681+ return sharding
46784682 filtered_spec = [
46794683 (sh , sp ) for sh , sp in zip (operand .shape , operand .sharding .spec )
46804684 if sh != 1
@@ -4687,14 +4691,18 @@ def _reshape_sharding_rule(operand, *, new_sizes, dimensions):
46874691 else :
46884692 sh , sp = next (fs )
46894693 if n != sh :
4690- raise NotImplementedError
4694+ raise ValueError (
4695+ 'This reshape is not supported. Please specify the sharding of the'
4696+ ' output via the `sharding` argument of reshape.' )
46914697 new_spec .append (sp )
46924698 return operand .sharding .with_spec (new_spec )
46934699
4694- def _reshape_typecheck_rule (_ , operand , * dyn_shape , new_sizes , dimensions ):
4700+ def _reshape_typecheck_rule (_ , operand , * dyn_shape , new_sizes , dimensions ,
4701+ sharding ):
46954702 if not dyn_shape :
46964703 out_aval , effects = reshape_p .abstract_eval (
4697- operand .aval , new_sizes = new_sizes , dimensions = dimensions )
4704+ operand .aval , new_sizes = new_sizes , dimensions = dimensions ,
4705+ sharding = sharding )
46984706 return [out_aval ], effects
46994707 else :
47004708 # TODO(mattjj, necula): perform more checks like _reshape_shape_rule
@@ -4705,18 +4713,29 @@ def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions):
47054713 return [out_aval ], core .no_effects
47064714
47074715
4708- def _reshape_dtype_rule (operand , * , new_sizes , dimensions ):
4716+ def _reshape_dtype_rule (operand , * , new_sizes , dimensions , sharding ):
47094717 return operand .dtype
47104718
4711- def _reshape_transpose_rule (t , operand , * , new_sizes , dimensions ):
4719+ def _reshape_transpose_rule (t , operand , * , new_sizes , dimensions , sharding ):
47124720 assert ad .is_undefined_primal (operand )
47134721 if dimensions is None :
4722+ if config .sharding_in_types .value :
4723+ return [reshape (t , operand .aval .shape , sharding = operand .aval .sharding )]
47144724 return [reshape (t , operand .aval .shape )]
47154725 else :
4716- return [transpose (reshape (t , np .take (operand .aval .shape , dimensions )),
4726+ if config .sharding_in_types .value :
4727+ t_s = operand .sharding .with_spec (
4728+ tuple (map (str , np .take (operand .aval .sharding .spec , dimensions ))))
4729+ else :
4730+ t_s = None
4731+ return [transpose (reshape (t , np .take (operand .aval .shape , dimensions ),
4732+ sharding = t_s ),
47174733 np .argsort (dimensions ))]
47184734
4719- def _reshape_batch_rule (batched_args , batch_dims , * , new_sizes , dimensions ):
4735+ def _reshape_batch_rule (batched_args , batch_dims , * , new_sizes , dimensions ,
4736+ sharding ):
4737+ if sharding is not None :
4738+ raise NotImplementedError
47204739 operand , = batched_args
47214740 bdim , = batch_dims
47224741 operand = batching .moveaxis (operand , bdim , 0 )
@@ -4725,20 +4744,22 @@ def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions):
47254744 return reshape (operand , operand .shape [:1 ] + new_sizes , dimensions ), 0
47264745
47274746
4728- def _reshape_lower (ctx , x , * dyn_shape , new_sizes , dimensions ):
4747+ def _reshape_lower (ctx , x , * dyn_shape , new_sizes , dimensions , sharding ):
47294748 aval_out , = ctx .avals_out
47304749 if dimensions is not None :
47314750 x = hlo .transpose (x , mlir .dense_int_array (dimensions ))
47324751 if dyn_shape :
47334752 aval_out = aval_out .update (shape = _merge_dyn_shape (new_sizes , dyn_shape ))
47344753 out = mlir .reshape (ctx , x , aval_out )
47354754 if config .sharding_in_types .value :
4755+ if sharding is not None :
4756+ assert sharding == aval_out .sharding
47364757 return [mlir .lower_sharding_under_shit (ctx , out , aval_out )]
47374758 return [out ]
47384759
47394760def _reshape_staging_rule (
4740- trace , x , * dyn , new_sizes , dimensions ):
4741- params = dict (new_sizes = new_sizes , dimensions = dimensions )
4761+ trace , x , * dyn , new_sizes , dimensions , sharding ):
4762+ params = dict (new_sizes = new_sizes , dimensions = dimensions , sharding = sharding )
47424763 if not dyn :
47434764 return trace .default_process_primitive (reshape_p , (x ,), params )
47444765 av = core .DShapedArray (_merge_dyn_shape (new_sizes , dyn ), x .dtype , x .weak_type )
0 commit comments