@@ -3426,7 +3426,7 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
34263426 if config .sharding_in_types .value :
34273427 xs = x .aval .sharding
34283428 inverse_spec = tuple (xs .spec [o ] for o in unsorted_axes )
3429- ds = NamedSharding ( xs .mesh , P ( * inverse_spec ) )
3429+ ds = xs .with_spec ( inverse_spec )
34303430 else :
34313431 ds = None
34323432 dot_general_out = dot_general (g , y , dims , precision = precision ,
@@ -4116,7 +4116,7 @@ def _broadcast_in_dim_sharding_rule(operand, *, shape, broadcast_dimensions,
41164116 orig_spec = iter (operand .sharding .spec )
41174117 new_spec = [next (orig_spec ) if i in bds else None for i in range (len (shape ))]
41184118 assert next (orig_spec , None ) is None
4119- return NamedSharding ( operand .sharding .mesh , P ( * new_spec ) )
4119+ return operand .sharding .with_spec ( new_spec )
41204120
41214121def _broadcast_in_dim_typecheck_rule (
41224122 _ , operand , * dyn_shape , shape , broadcast_dimensions , sharding ):
@@ -4593,7 +4593,7 @@ def _squeeze_sharding_rule(operand, *, dimensions):
45934593 dims_set = set (dimensions )
45944594 new_spec = tuple (s for i , s in enumerate (operand .sharding .spec )
45954595 if i not in dims_set )
4596- return NamedSharding ( operand .sharding .mesh , P ( * new_spec ) )
4596+ return operand .sharding .with_spec ( new_spec )
45974597
45984598def _compute_squeeze_shape (shape , dimensions ):
45994599 dims_set = set (dimensions )
@@ -4688,7 +4688,7 @@ def _reshape_sharding_rule(operand, *, new_sizes, dimensions):
46884688 if n != sh :
46894689 raise NotImplementedError
46904690 new_spec .append (sp )
4691- return NamedSharding ( operand .sharding .mesh , P ( * new_spec ) )
4691+ return operand .sharding .with_spec ( new_spec )
46924692
46934693def _reshape_typecheck_rule (_ , operand , * dyn_shape , new_sizes , dimensions ):
46944694 if not dyn_shape :
@@ -4791,7 +4791,7 @@ def _transpose_shape_rule(operand, *, permutation):
47914791def _transpose_sharding_rule (operand , * , permutation ):
47924792 o_spec = operand .sharding .spec
47934793 new_spec = [o_spec [old_idx ] for old_idx in permutation ]
4794- return NamedSharding ( operand .sharding .mesh , P ( * new_spec ) )
4794+ return operand .sharding .with_spec ( new_spec )
47954795
47964796def _transpose_batch_rule (batched_args , batch_dims , * , permutation ):
47974797 operand , = batched_args
@@ -5165,7 +5165,7 @@ def _reduce_op_sharding_rule(operand, *, axes):
51655165 axes = frozenset (axes )
51665166 new_spec = P (* tuple (s for i , s in enumerate (operand .sharding .spec )
51675167 if i not in axes ))
5168- return NamedSharding ( operand .sharding .mesh , new_spec )
5168+ return operand .sharding .with_spec ( new_spec )
51695169
51705170reduce_sum_p = standard_primitive (
51715171 _reduce_op_shape_rule , partial (_reduce_number_dtype_rule , 'reduce_sum' ),
@@ -6237,7 +6237,7 @@ def _const(example, val):
62376237def _one (x ):
62386238 if config .sharding_in_types .value :
62396239 return full_like (x , shape = (), fill_value = 1 ,
6240- sharding = NamedSharding ( x .sharding .mesh , P ()))
6240+ sharding = x .sharding .with_spec ( P ()))
62416241 return full_like (x , shape = (), fill_value = 1 )
62426242
62436243_twos : Callable = partial (full_like , fill_value = 2 )
0 commit comments