@@ -1700,6 +1700,8 @@ def zeros_like_shaped_array(aval: ShapedArray) -> Array:
17001700 scalar_zero = np .zeros ((), dtype = aval .dtype )
17011701 else :
17021702 scalar_zero = _convert_element_type (0 , aval .dtype , aval .weak_type )
1703+ if config .sharding_in_types .value :
1704+ return broadcast (scalar_zero , aval .shape , sharding = aval .sharding )
17031705 return broadcast (scalar_zero , aval .shape )
17041706
17051707ad_util .aval_zeros_likers [ShapedArray ] = zeros_like_shaped_array
@@ -4401,7 +4403,7 @@ def _concatenate_shape_rule(*operands, **kwargs):
44014403 raise TypeError (msg .format (dimension , ", " .join ([str (o .shape ) for o in operands ])))
44024404 shapes = [operand .shape [:dimension ] + operand .shape [dimension + 1 :]
44034405 for operand in operands ]
4404- if not shapes [:- 1 ] = = shapes [1 :]:
4406+ if shapes [:- 1 ] ! = shapes [1 :]:
44054407 msg = ("Cannot concatenate arrays with shapes that differ in dimensions "
44064408 "other than the one being concatenated: concatenating along "
44074409 "dimension {} for shapes {}." )
@@ -4412,6 +4414,13 @@ def _concatenate_shape_rule(*operands, **kwargs):
44124414 ex_shape = operands [0 ].shape
44134415 return ex_shape [:dimension ] + (concat_size ,) + ex_shape [dimension + 1 :]
44144416
4417+ def _concatenate_sharding_rule (* operands , ** kwargs ):
4418+ if not all (o .sharding == operands [0 ].sharding for o in operands ):
4419+ ss = ", " .join (str (o .sharding ) for o in operands )
4420+ raise TypeError (
4421+ f"All operands should have the same sharding. Got shardings { ss } " )
4422+ return operands [0 ].sharding
4423+
44154424def _concatenate_dtype_rule (* operands , ** kwargs ):
44164425 check_same_dtypes ('concatenate' , * operands )
44174426 return operands [0 ].dtype
@@ -4452,14 +4461,19 @@ def _concatenate_pad_rule(in_avals, out_avals, *operands, dimension):
44524461 raise NotImplementedError # TODO(mattjj)
44534462
44544463concatenate_p = standard_primitive (
4455- _concatenate_shape_rule , _concatenate_dtype_rule , 'concatenate' )
4464+ _concatenate_shape_rule , _concatenate_dtype_rule , 'concatenate' ,
4465+ sharding_rule = _concatenate_sharding_rule )
44564466ad .deflinear2 (concatenate_p , _concatenate_transpose_rule )
44574467ad .primitive_transposes [concatenate_p ] = _concatenate_transpose_rule
44584468batching .primitive_batchers [concatenate_p ] = _concatenate_batch_rule
44594469pe .padding_rules [concatenate_p ] = _concatenate_pad_rule
44604470
44614471def _concatenate_lower (ctx , * xs , dimension ):
4462- return [hlo .concatenate (xs , mlir .i64_attr (dimension ))]
4472+ aval_out , = ctx .avals_out
4473+ out = hlo .concatenate (xs , mlir .i64_attr (dimension ))
4474+ if config .sharding_in_types .value :
4475+ return [mlir .lower_sharding_under_shit (ctx , out , aval_out )]
4476+ return [out ]
44634477mlir .register_lowering (concatenate_p , _concatenate_lower )
44644478
44654479
0 commit comments