Skip to content

Commit 7635605

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Use with_spec where possible to clean up the code a bit
PiperOrigin-RevId: 699226058
1 parent c0811c9 commit 7635605

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

jax/_src/lax/lax.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3426,7 +3426,7 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
34263426
if config.sharding_in_types.value:
34273427
xs = x.aval.sharding
34283428
inverse_spec = tuple(xs.spec[o] for o in unsorted_axes)
3429-
ds = NamedSharding(xs.mesh, P(*inverse_spec))
3429+
ds = xs.with_spec(inverse_spec)
34303430
else:
34313431
ds = None
34323432
dot_general_out = dot_general(g, y, dims, precision=precision,
@@ -4116,7 +4116,7 @@ def _broadcast_in_dim_sharding_rule(operand, *, shape, broadcast_dimensions,
41164116
orig_spec = iter(operand.sharding.spec)
41174117
new_spec = [next(orig_spec) if i in bds else None for i in range(len(shape))]
41184118
assert next(orig_spec, None) is None
4119-
return NamedSharding(operand.sharding.mesh, P(*new_spec))
4119+
return operand.sharding.with_spec(new_spec)
41204120

41214121
def _broadcast_in_dim_typecheck_rule(
41224122
_, operand, *dyn_shape, shape, broadcast_dimensions, sharding):
@@ -4593,7 +4593,7 @@ def _squeeze_sharding_rule(operand, *, dimensions):
45934593
dims_set = set(dimensions)
45944594
new_spec = tuple(s for i, s in enumerate(operand.sharding.spec)
45954595
if i not in dims_set)
4596-
return NamedSharding(operand.sharding.mesh, P(*new_spec))
4596+
return operand.sharding.with_spec(new_spec)
45974597

45984598
def _compute_squeeze_shape(shape, dimensions):
45994599
dims_set = set(dimensions)
@@ -4688,7 +4688,7 @@ def _reshape_sharding_rule(operand, *, new_sizes, dimensions):
46884688
if n != sh:
46894689
raise NotImplementedError
46904690
new_spec.append(sp)
4691-
return NamedSharding(operand.sharding.mesh, P(*new_spec))
4691+
return operand.sharding.with_spec(new_spec)
46924692

46934693
def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions):
46944694
if not dyn_shape:
@@ -4791,7 +4791,7 @@ def _transpose_shape_rule(operand, *, permutation):
47914791
def _transpose_sharding_rule(operand, *, permutation):
47924792
o_spec = operand.sharding.spec
47934793
new_spec = [o_spec[old_idx] for old_idx in permutation]
4794-
return NamedSharding(operand.sharding.mesh, P(*new_spec))
4794+
return operand.sharding.with_spec(new_spec)
47954795

47964796
def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
47974797
operand, = batched_args
@@ -5165,7 +5165,7 @@ def _reduce_op_sharding_rule(operand, *, axes):
51655165
axes = frozenset(axes)
51665166
new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec)
51675167
if i not in axes))
5168-
return NamedSharding(operand.sharding.mesh, new_spec)
5168+
return operand.sharding.with_spec(new_spec)
51695169

51705170
reduce_sum_p = standard_primitive(
51715171
_reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'),
@@ -6237,7 +6237,7 @@ def _const(example, val):
62376237
def _one(x):
62386238
if config.sharding_in_types.value:
62396239
return full_like(x, shape=(), fill_value=1,
6240-
sharding=NamedSharding(x.sharding.mesh, P()))
6240+
sharding=x.sharding.with_spec(P()))
62416241
return full_like(x, shape=(), fill_value=1)
62426242

62436243
_twos: Callable = partial(full_like, fill_value=2)

0 commit comments

Comments
 (0)