@@ -5007,6 +5007,11 @@ def _reduce_shape_rule(*avals, computation, jaxpr, dimensions):
50075007 raise ValueError (f'reduce found non-scalar initial value: { init_val_shapes } ' )
50085008 return [tuple (np .delete (op .shape , dimensions )) for op in operand_avals ]
50095009
5010+ def _reduce_sharding_rule (* avals , computation , jaxpr , dimensions ):
5011+ operand_avals , _ = split_list (avals , [len (avals ) // 2 ])
5012+ return [op .sharding .with_spec (tuple_delete (op .sharding .spec , dimensions ))
5013+ for op in operand_avals ]
5014+
50105015def _reduce_dtype_rule (* avals , computation , jaxpr , dimensions ):
50115016 operand_avals , init_val_avals = split_list (avals , [len (avals ) // 2 ])
50125017 operand_dtypes = [dtypes .canonicalize_dtype (op .dtype ) for op in operand_avals ]
@@ -5093,7 +5098,7 @@ def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr, dimensions):
50935098reduce_p .def_impl (partial (dispatch .apply_primitive , reduce_p ))
50945099reduce_p .def_abstract_eval (
50955100 partial (standard_multi_result_abstract_eval , reduce_p , _reduce_shape_rule ,
5096- _reduce_dtype_rule , _reduce_weak_type_rule ))
5101+ _reduce_dtype_rule , _reduce_weak_type_rule , _reduce_sharding_rule ))
50975102batching .primitive_batchers [reduce_p ] = _reduce_batch_rule
50985103ad .primitive_jvps [reduce_p ] = _reduce_jvp_rule
50995104
@@ -5115,6 +5120,9 @@ def _reduce_lower(ctx, *values, computation, jaxpr, dimensions):
51155120 * reducer .arguments ,
51165121 dim_var_values = ctx .dim_var_values )
51175122 hlo .return_ (mlir .flatten_ir_values (out_nodes ))
5123+ if config .sharding_in_types .value :
5124+ return [mlir .lower_sharding_under_shit (ctx , r , aval )
5125+ for r , aval in safe_zip (op .results , ctx .avals_out )]
51185126 return op .results
51195127
51205128mlir .register_lowering (reduce_p , _reduce_lower )
@@ -5227,7 +5235,12 @@ def _argminmax_shape_rule(operand, *, axes, index_dtype):
52275235 if operand .shape [axis ] < 1 :
52285236 raise ValueError ("argmin and argmax require non-empty reduced dimension. "
52295237 f"operand.shape={ operand .shape } { axis = } " )
5230- return tuple (np .delete (operand .shape , axis ))
5238+ return util .tuple_delete (operand .shape , axis )
5239+
5240+ def _argminmax_sharding_rule (operand , * , axes , index_dtype ):
5241+ axis , = axes
5242+ return operand .sharding .with_spec (
5243+ util .tuple_delete (operand .sharding .spec , axis ))
52315244
52325245def _argminmax_dtype_rule (operand , * , axes , index_dtype ):
52335246 if not dtypes .issubdtype (index_dtype , np .integer ):
@@ -5264,30 +5277,34 @@ def _compute_argminmax(value_comparator, get_identity,
52645277 # value_comparator is either lax.lt (for argmin) or lax.gt
52655278 # get_identity(operand.dtype) is inf for argmin or -inf for argmax
52665279 axis , = axes
5267- indices = broadcasted_iota (index_dtype , np .shape (operand ), axis )
5280+ indices = broadcasted_iota (
5281+ index_dtype , np .shape (operand ), axis ,
5282+ _sharding = operand .sharding if config .sharding_in_types .value else None )
52685283 res = reduce ([operand , indices ],
52695284 [get_identity (operand .dtype ), np .array (0 , index_dtype )],
52705285 _ArgMinMaxReducer (value_comparator ),
52715286 axes )
52725287 return res [1 ]
52735288
52745289argmin_p = standard_primitive (_argminmax_shape_rule , _argminmax_dtype_rule ,
5275- 'argmin' , weak_type_rule = _strip_weak_type )
5290+ 'argmin' , weak_type_rule = _strip_weak_type ,
5291+ sharding_rule = _argminmax_sharding_rule )
52765292batching .defreducer (argmin_p , _get_min_identity )
52775293ad .defjvp_zero (argmin_p )
52785294
52795295argmax_p = standard_primitive (_argminmax_shape_rule , _argminmax_dtype_rule ,
5280- 'argmax' , weak_type_rule = _strip_weak_type )
5296+ 'argmax' , weak_type_rule = _strip_weak_type ,
5297+ sharding_rule = _argminmax_sharding_rule )
52815298batching .defreducer (argmax_p , _get_max_identity )
52825299ad .defjvp_zero (argmax_p )
52835300
5284- mlir .register_lowering (argmin_p , mlir .cache_lowering (mlir . lower_fun (
5285- partial (_compute_argminmax , lt , _get_min_identity ),
5286- multiple_results = False )))
5301+ mlir .register_lowering (argmin_p , mlir .cache_lowering (
5302+ mlir . lower_fun ( partial (_compute_argminmax , lt , _get_min_identity ),
5303+ multiple_results = False )))
52875304
5288- mlir .register_lowering (argmax_p , mlir .cache_lowering (mlir . lower_fun (
5289- partial (_compute_argminmax , gt , _get_max_identity ),
5290- multiple_results = False )))
5305+ mlir .register_lowering (argmax_p , mlir .cache_lowering (
5306+ mlir . lower_fun ( partial (_compute_argminmax , gt , _get_max_identity ),
5307+ multiple_results = False )))
52915308
52925309
52935310def _reduce_logical_shape_rule (operand , * , axes ):
@@ -5882,7 +5899,7 @@ def _rng_bit_generator_lowering(
58825899rng_bit_generator_p .def_abstract_eval (
58835900 partial (standard_multi_result_abstract_eval , rng_bit_generator_p ,
58845901 _rng_bit_generator_shape_rule , _rng_bit_generator_dtype_rule ,
5885- _rng_bit_generator_weak_type_rule ))
5902+ _rng_bit_generator_weak_type_rule , None ))
58865903mlir .register_lowering (rng_bit_generator_p ,
58875904 _rng_bit_generator_lowering )
58885905
0 commit comments