@@ -2149,7 +2149,7 @@ def full_like(x: ArrayLike | DuckTypedArray,
21492149 if dtypes .issubdtype (dtype , dtypes .extended ):
21502150 return dtype ._rules .full (fill_shape , fill_value , dtype ) # type: ignore[union-attr]
21512151
2152- if (config .sharding_in_types .value and sharding is None and
2152+ if (config .sharding_in_types .value and sharding is None and shape is None and
21532153 isinstance (x , Array )):
21542154 sharding = x .aval .sharding
21552155 else :
@@ -4577,6 +4577,9 @@ def _clamp_shape_rule(min, operand, max):
45774577 f"(), got max.shape={ max .shape } , { operand .shape = } ." )
45784578 return operand .shape
45794579
4580+ def _clamp_sharding_rule (min , operand , max ):
4581+ return operand .sharding
4582+
45804583_clamp_dtype_rule = partial (naryop_dtype_rule , _input_dtype , [_any , _any , _any ],
45814584 'clamp' )
45824585
@@ -4617,7 +4620,8 @@ def _clamp_batch_rule(batched_args, batch_dims, **params):
46174620 x = broadcast (x , min .shape )
46184621 return clamp_p .bind (min , x , max ), 0
46194622
4620- clamp_p = standard_primitive (_clamp_shape_rule , _clamp_dtype_rule , 'clamp' )
4623+ clamp_p = standard_primitive (_clamp_shape_rule , _clamp_dtype_rule , 'clamp' ,
4624+ sharding_rule = _clamp_sharding_rule )
46214625ad .defjvp (clamp_p ,
46224626 lambda g , min , operand , max :
46234627 select (bitwise_and (gt (min , operand ), lt (min , max )),
@@ -5165,18 +5169,28 @@ def _rev_shape_rule(operand, *, dimensions):
51655169 raise TypeError (msg .format (dimensions , operand .ndim ))
51665170 return operand .shape
51675171
5172+ def _rev_sharding_rule (operand , * , dimensions ):
5173+ # TODO(yashkatariya): Will lead to data movement. Maybe just error out and
5174+ # require the operand to be unsharded?
5175+ return operand .sharding
5176+
51685177def _rev_batch_rule (batched_args , batch_dims , * , dimensions ):
51695178 operand , = batched_args
51705179 bdim , = batch_dims
51715180 new_dimensions = [i + 1 if i >= bdim else i for i in dimensions ]
51725181 return rev (operand , new_dimensions ), bdim
51735182
5174- rev_p = standard_primitive (_rev_shape_rule , _input_dtype , 'rev' )
5183+ rev_p = standard_primitive (_rev_shape_rule , _input_dtype , 'rev' ,
5184+ sharding_rule = _rev_sharding_rule )
51755185ad .deflinear2 (rev_p , lambda t , _ , dimensions : [rev (t , dimensions )])
51765186batching .primitive_batchers [rev_p ] = _rev_batch_rule
51775187
51785188def _rev_lower (ctx , x , * , dimensions ):
5179- return [hlo .reverse (x , mlir .dense_int_array (dimensions ))]
5189+ aval_out , = ctx .avals_out
5190+ out = hlo .reverse (x , mlir .dense_int_array (dimensions ))
5191+ if config .sharding_in_types .value :
5192+ return [mlir .lower_sharding_under_shit (ctx , out , aval_out )]
5193+ return [out ]
51805194mlir .register_lowering (rev_p , _rev_lower )
51815195
51825196
@@ -5932,7 +5946,10 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
59325946 mlir .flatten_ir_values (operands ),
59335947 dimension = mlir .i64_attr (dimension ),
59345948 is_stable = ir .BoolAttr .get (is_stable ))
5935- scalar_avals = [aval .update (shape = ()) for aval in ctx .avals_in ]
5949+ scalar_s = (lambda a : a .sharding .with_spec (P ())
5950+ if config .sharding_in_types .value else lambda _ : None )
5951+ scalar_avals = [aval .update (shape = (), sharding = scalar_s (aval ))
5952+ for aval in ctx .avals_in ]
59365953 scalar_types = safe_map (mlir .aval_to_ir_type , scalar_avals )
59375954 comparator = sort .comparator .blocks .append (
59385955 * util .flatten (zip (scalar_types , scalar_types )))
0 commit comments