Skip to content

Commit 7ed7e0b

Browse files
yashk2810Google-ML-Automation
authored andcommitted
[sharding_in_types] Add clamp_p sharding rule.
PiperOrigin-RevId: 720428881
1 parent ae705fe commit 7ed7e0b

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

jax/_src/lax/lax.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2149,7 +2149,7 @@ def full_like(x: ArrayLike | DuckTypedArray,
21492149
if dtypes.issubdtype(dtype, dtypes.extended):
21502150
return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr]
21512151

2152-
if (config.sharding_in_types.value and sharding is None and
2152+
if (config.sharding_in_types.value and sharding is None and shape is None and
21532153
isinstance(x, Array)):
21542154
sharding = x.aval.sharding
21552155
else:
@@ -4577,6 +4577,9 @@ def _clamp_shape_rule(min, operand, max):
45774577
f"(), got max.shape={max.shape}, {operand.shape=}.")
45784578
return operand.shape
45794579

4580+
def _clamp_sharding_rule(min, operand, max):
4581+
return operand.sharding
4582+
45804583
_clamp_dtype_rule = partial(naryop_dtype_rule, _input_dtype, [_any, _any, _any],
45814584
'clamp')
45824585

@@ -4617,7 +4620,8 @@ def _clamp_batch_rule(batched_args, batch_dims, **params):
46174620
x = broadcast(x, min.shape)
46184621
return clamp_p.bind(min, x, max), 0
46194622

4620-
clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp')
4623+
clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp',
4624+
sharding_rule=_clamp_sharding_rule)
46214625
ad.defjvp(clamp_p,
46224626
lambda g, min, operand, max:
46234627
select(bitwise_and(gt(min, operand), lt(min, max)),

0 commit comments

Comments
 (0)