@@ -2149,7 +2149,7 @@ def full_like(x: ArrayLike | DuckTypedArray,
21492149 if dtypes .issubdtype (dtype , dtypes .extended ):
21502150 return dtype ._rules .full (fill_shape , fill_value , dtype ) # type: ignore[union-attr]
21512151
2152- if (config .sharding_in_types .value and sharding is None and
2152+ if (config .sharding_in_types .value and sharding is None and shape is None and
21532153 isinstance (x , Array )):
21542154 sharding = x .aval .sharding
21552155 else :
@@ -4577,6 +4577,9 @@ def _clamp_shape_rule(min, operand, max):
45774577 f"(), got max.shape={ max .shape } , { operand .shape = } ." )
45784578 return operand .shape
45794579
4580+ def _clamp_sharding_rule (min , operand , max ):
4581+ return operand .sharding
4582+
45804583_clamp_dtype_rule = partial (naryop_dtype_rule , _input_dtype , [_any , _any , _any ],
45814584 'clamp' )
45824585
@@ -4617,7 +4620,8 @@ def _clamp_batch_rule(batched_args, batch_dims, **params):
46174620 x = broadcast (x , min .shape )
46184621 return clamp_p .bind (min , x , max ), 0
46194622
4620- clamp_p = standard_primitive (_clamp_shape_rule , _clamp_dtype_rule , 'clamp' )
4623+ clamp_p = standard_primitive (_clamp_shape_rule , _clamp_dtype_rule , 'clamp' ,
4624+ sharding_rule = _clamp_sharding_rule )
46214625ad .defjvp (clamp_p ,
46224626 lambda g , min , operand , max :
46234627 select (bitwise_and (gt (min , operand ), lt (min , max )),
0 commit comments