@@ -1040,7 +1040,8 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
10401040
10411041def dot_general (lhs : ArrayLike , rhs : ArrayLike , dimension_numbers : DotDimensionNumbers ,
10421042 precision : PrecisionLike = None ,
1043- preferred_element_type : DTypeLike | None = None ) -> Array :
1043+ preferred_element_type : DTypeLike | None = None ,
1044+ out_type = None ) -> Array :
10441045 """General dot product/contraction operator.
10451046
10461047 Wraps XLA's `DotGeneral
@@ -1086,6 +1087,10 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
10861087 by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs``
10871088 non-contracting/non-batch dimensions.
10881089 """
1090+ if out_type is not None and not isinstance (out_type , NamedSharding ):
1091+ raise NotImplementedError (
1092+ '`out_type` argument of `dot_general` only supports NamedSharding '
1093+ 'instances. Please file a bug if this is not enough for your use case.' )
10891094 (lhs_contract , rhs_contract ), (lhs_batch , rhs_batch ) = dimension_numbers
10901095 cdims = (api_util ._ensure_index_tuple (lhs_contract ),
10911096 api_util ._ensure_index_tuple (rhs_contract ))
@@ -1097,7 +1102,8 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
10971102 return dot_general_p .bind (lhs , rhs ,
10981103 dimension_numbers = (cdims , bdims ),
10991104 precision = canonicalize_precision (precision ),
1100- preferred_element_type = preferred_element_type )
1105+ preferred_element_type = preferred_element_type ,
1106+ out_type = out_type )
11011107
11021108
11031109def ragged_dot (
@@ -3002,7 +3008,11 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type,
30023008 not dtypes .issubdtype (new_dtype , np .complexfloating )):
30033009 operand = hlo .real (operand )
30043010 aval_in = aval_in .update (dtype = _real_dtype (aval_in .dtype ))
3005- return [mlir .convert_hlo (ctx , operand , aval_in , aval_out )]
3011+ out = mlir .convert_hlo (ctx , operand , aval_in , aval_out )
3012+ if config .sharding_in_types .value :
3013+ proto = aval_out .sharding ._to_xla_hlo_sharding (aval_out .ndim ).to_proto ()
3014+ return [mlir .wrap_with_sharding_op (ctx , out , aval_out , proto )]
3015+ return [out ]
30063016
30073017mlir .register_lowering (convert_element_type_p , _convert_element_type_lower )
30083018
@@ -3164,7 +3174,8 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type):
31643174
31653175
31663176def _dot_general_shape_rule (lhs , rhs , * , dimension_numbers , precision ,
3167- preferred_element_type : DTypeLike | None ):
3177+ preferred_element_type : DTypeLike | None ,
3178+ out_type ):
31683179 (lhs_contracting , rhs_contracting ), (lhs_batch , rhs_batch ) = dimension_numbers
31693180 if not all (np .all (np .greater_equal (d , 0 )) and np .all (np .less (d , lhs .ndim ))
31703181 for d in (lhs_contracting , lhs_batch )):
@@ -3241,24 +3252,28 @@ def _check_specs_match(lhs_spec, rhs_spec, msg):
32413252 raise TypeError (msg )
32423253
32433254def _dot_general_sharding_rule (lhs , rhs , * , dimension_numbers , precision ,
3244- preferred_element_type : DTypeLike | None ):
3255+ preferred_element_type : DTypeLike | None ,
3256+ out_type ):
32453257 if lhs .sharding .mesh != rhs .sharding .mesh :
32463258 raise ValueError (
32473259 'Mesh of both lhs and rhs should match. Got lhs:'
32483260 f' { lhs .sharding .mesh } and rhs: { rhs .sharding .mesh } ' )
32493261
3262+ if out_type is not None :
3263+ return out_type
3264+
32503265 (lhs_contracting , rhs_contracting ), (lhs_batch , rhs_batch ) = dimension_numbers
32513266 lhs_batch_spec = tuple (lhs .sharding .spec [i ] for i in lhs_batch )
32523267 rhs_batch_spec = tuple (rhs .sharding .spec [i ] for i in rhs_batch )
32533268 msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions "
3254- f"to have the consistent sharding, got { lhs_batch_spec } and "
3255- f"{ rhs_batch_spec } ." )
3269+ f"to have the consistent sharding, got { lhs_batch_spec } and "
3270+ f"{ rhs_batch_spec } ." )
32563271 _check_specs_match (lhs_batch_spec , rhs_batch_spec , msg )
32573272
32583273 lhs_contracting_spec = tuple (lhs .sharding .spec [i ] for i in lhs_contracting )
32593274 rhs_contracting_spec = tuple (rhs .sharding .spec [i ] for i in rhs_contracting )
32603275 msg = ("dot_general requires contracting dimensions to have consistent "
3261- f"sharding, got { lhs_contracting_spec } and { rhs_contracting_spec } ." )
3276+ f"sharding, got { lhs_contracting_spec } and { rhs_contracting_spec } ." )
32623277 _check_specs_match (lhs_contracting_spec , rhs_contracting_spec , msg )
32633278
32643279 return _dot_general_sharding_computation (
@@ -3280,7 +3295,8 @@ def tuple_delete(tup, idx):
32803295
32813296
32823297def _dot_general_dtype_rule (lhs , rhs , * , dimension_numbers , precision ,
3283- preferred_element_type : DTypeLike | None ):
3298+ preferred_element_type : DTypeLike | None ,
3299+ out_type ):
32843300 del dimension_numbers # unused
32853301 # We're mostly matching XLA's logic here, namely in shape_inference.cc and
32863302 # primitive_util.h's HigherPrecisionType, e.g.
@@ -3327,7 +3343,7 @@ def _maybe_upcast(result_dtype, preferred_element_type, check_bit_width):
33273343
33283344def _dot_general_transpose_lhs (g , x , y , * , dimension_numbers , precision ,
33293345 preferred_element_type : DTypeLike | None ,
3330- swap_ans = False ):
3346+ out_type , swap_ans = False ):
33313347 (x_contract , y_contract ), (x_batch , y_batch ) = dimension_numbers
33323348 x_ndim = x .aval .ndim
33333349 x_kept = remaining (range (x_ndim ), x_contract , x_batch )
@@ -3347,12 +3363,14 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
33473363 return x_bar
33483364
33493365def _dot_general_transpose_rhs (g , x , y , * , dimension_numbers , precision ,
3350- preferred_element_type : DTypeLike | None ):
3366+ preferred_element_type : DTypeLike | None ,
3367+ out_type ):
33513368 (x_contract , y_contract ), (x_batch , y_batch ) = dimension_numbers
33523369 swapped_dimension_numbers = ((y_contract , x_contract ), (y_batch , x_batch ))
33533370 y_bar = _dot_general_transpose_lhs (
33543371 g , y , x , dimension_numbers = swapped_dimension_numbers , precision = precision ,
3355- preferred_element_type = preferred_element_type , swap_ans = True )
3372+ preferred_element_type = preferred_element_type , out_type = out_type ,
3373+ swap_ans = True )
33563374 if y_bar .dtype != y .aval .dtype :
33573375 y_bar = _convert_element_type (y_bar , y .aval .dtype , y .aval .weak_type )
33583376 return y_bar
@@ -3366,6 +3384,7 @@ def _dot_batch_rule(
33663384 batch_dims ,
33673385 * ,
33683386 dimension_numbers ,
3387+ out_type ,
33693388 precision ,
33703389 preferred_element_type : DTypeLike | None ,
33713390 ** _ ,
@@ -3395,12 +3414,16 @@ def _dot_batch_rule(
33953414 rhs_shape = batching .bdim_as_shape (rbd , rhs .shape )
33963415 else :
33973416 rhs_shape = np .shape (rhs )
3417+ if out_type is not None :
3418+ raise NotImplementedError ("vmap with out_type is not supported. "
3419+ "Please open an issue." )
33983420 batched_out = invoke_prim (
33993421 lhs ,
34003422 rhs ,
34013423 new_dimension_numbers ,
34023424 precision = precision ,
34033425 preferred_element_type = preferred_element_type ,
3426+ out_type = out_type ,
34043427 )
34053428 result_batch_dim = batching .shape_as_bdim (
34063429 result_stack_dim ,
@@ -3570,7 +3593,7 @@ def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike,
35703593
35713594def _dot_general_lower (ctx , lhs , rhs , * , dimension_numbers ,
35723595 precision , preferred_element_type : np .dtype | None ,
3573- platform : str = "default" ):
3596+ out_type , platform : str = "default" ):
35743597 def _is_fp8_mixed_precision_matmul (_lhs_dtypes , _rhs_dtypes ):
35753598 fp8_dtypes = (dtypes .float8_e4m3fn , dtypes .float8_e5m2 ,
35763599 dtypes .float8_e5m2fnuz , dtypes .float8_e4m3fnuz )
@@ -3658,6 +3681,8 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype):
36583681 ** algorithm_kwarg ,
36593682 )
36603683 if config .sharding_in_types .value :
3684+ if out_type is not None :
3685+ assert aval_out .sharding == out_type
36613686 out_sp = aval_out .sharding ._to_xla_hlo_sharding (aval_out .ndim ).to_proto ()
36623687 result = mlir .wrap_with_sharding_op (ctx , result , aval_out , out_sp )
36633688 if accumulation_aval .dtype != aval_out .dtype :
@@ -3711,12 +3736,15 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S
37113736 return (m , n )
37123737
37133738def _ragged_dot_dtype_rule (lhs : Array , rhs : Array , group_sizes : Array ,
3714- precision , preferred_element_type : DTypeLike | None , ** _ ) -> np .dtype :
3739+ precision , preferred_element_type : DTypeLike | None ,
3740+ ** _ ) -> np .dtype :
37153741 if not dtypes .issubdtype (group_sizes .dtype , np .integer ):
37163742 raise TypeError ("ragged_dot requires that group_sizes.dtype is subtype of np.integer." )
37173743 # defer the output dtype to dot_general, which is part of the _ragged_dot_impl.
3718- return _dot_general_dtype_rule (lhs , rhs , dimension_numbers = _RAGGED_DOT_DOT_DIMENSION_NUMBERS ,
3719- precision = precision , preferred_element_type = preferred_element_type )
3744+ return _dot_general_dtype_rule (
3745+ lhs , rhs , dimension_numbers = _RAGGED_DOT_DOT_DIMENSION_NUMBERS ,
3746+ precision = precision , preferred_element_type = preferred_element_type ,
3747+ out_type = None )
37203748
37213749
37223750def _ragged_dot_jvp_rule (
@@ -3855,6 +3883,7 @@ def _ragged_dot_batch_rule(
38553883 * ,
38563884 precision ,
38573885 preferred_element_type : DTypeLike | None ,
3886+ out_type ,
38583887 ** _ ,
38593888):
38603889 invoke = functools .partial (_ragged_dot_invoke_prim , batched_args [2 ])
@@ -3868,6 +3897,7 @@ def _ragged_dot_batch_rule(
38683897 dimension_numbers = _RAGGED_DOT_DOT_DIMENSION_NUMBERS ,
38693898 precision = precision ,
38703899 preferred_element_type = preferred_element_type ,
3900+ out_type = out_type ,
38713901 )
38723902
38733903
0 commit comments