@@ -1040,8 +1040,7 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
10401040
10411041def dot_general (lhs : ArrayLike , rhs : ArrayLike , dimension_numbers : DotDimensionNumbers ,
10421042 precision : PrecisionLike = None ,
1043- preferred_element_type : DTypeLike | None = None ,
1044- out_type = None ) -> Array :
1043+ preferred_element_type : DTypeLike | None = None ) -> Array :
10451044 """General dot product/contraction operator.
10461045
10471046 Wraps XLA's `DotGeneral
@@ -1087,10 +1086,6 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
10871086 by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs``
10881087 non-contracting/non-batch dimensions.
10891088 """
1090- if out_type is not None and not isinstance (out_type , NamedSharding ):
1091- raise NotImplementedError (
1092- '`out_type` argument of `dot_general` only supports NamedSharding '
1093- 'instances. Please file a bug if this is not enough for your use case.' )
10941089 (lhs_contract , rhs_contract ), (lhs_batch , rhs_batch ) = dimension_numbers
10951090 cdims = (api_util ._ensure_index_tuple (lhs_contract ),
10961091 api_util ._ensure_index_tuple (rhs_contract ))
@@ -1102,8 +1097,7 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
11021097 return dot_general_p .bind (lhs , rhs ,
11031098 dimension_numbers = (cdims , bdims ),
11041099 precision = canonicalize_precision (precision ),
1105- preferred_element_type = preferred_element_type ,
1106- out_type = out_type )
1100+ preferred_element_type = preferred_element_type )
11071101
11081102
11091103def ragged_dot (
@@ -3008,11 +3002,7 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type,
30083002 not dtypes .issubdtype (new_dtype , np .complexfloating )):
30093003 operand = hlo .real (operand )
30103004 aval_in = aval_in .update (dtype = _real_dtype (aval_in .dtype ))
3011- out = mlir .convert_hlo (ctx , operand , aval_in , aval_out )
3012- if config .sharding_in_types .value :
3013- proto = aval_out .sharding ._to_xla_hlo_sharding (aval_out .ndim ).to_proto ()
3014- return [mlir .wrap_with_sharding_op (ctx , out , aval_out , proto )]
3015- return [out ]
3005+ return [mlir .convert_hlo (ctx , operand , aval_in , aval_out )]
30163006
30173007mlir .register_lowering (convert_element_type_p , _convert_element_type_lower )
30183008
@@ -3174,8 +3164,7 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type):
31743164
31753165
31763166def _dot_general_shape_rule (lhs , rhs , * , dimension_numbers , precision ,
3177- preferred_element_type : DTypeLike | None ,
3178- out_type ):
3167+ preferred_element_type : DTypeLike | None ):
31793168 (lhs_contracting , rhs_contracting ), (lhs_batch , rhs_batch ) = dimension_numbers
31803169 if not all (np .all (np .greater_equal (d , 0 )) and np .all (np .less (d , lhs .ndim ))
31813170 for d in (lhs_contracting , lhs_batch )):
@@ -3252,28 +3241,24 @@ def _check_specs_match(lhs_spec, rhs_spec, msg):
32523241 raise TypeError (msg )
32533242
32543243def _dot_general_sharding_rule (lhs , rhs , * , dimension_numbers , precision ,
3255- preferred_element_type : DTypeLike | None ,
3256- out_type ):
3244+ preferred_element_type : DTypeLike | None ):
32573245 if lhs .sharding .mesh != rhs .sharding .mesh :
32583246 raise ValueError (
32593247 'Mesh of both lhs and rhs should match. Got lhs:'
32603248 f' { lhs .sharding .mesh } and rhs: { rhs .sharding .mesh } ' )
32613249
3262- if out_type is not None :
3263- return out_type
3264-
32653250 (lhs_contracting , rhs_contracting ), (lhs_batch , rhs_batch ) = dimension_numbers
32663251 lhs_batch_spec = tuple (lhs .sharding .spec [i ] for i in lhs_batch )
32673252 rhs_batch_spec = tuple (rhs .sharding .spec [i ] for i in rhs_batch )
32683253 msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions "
3269- f"to have the consistent sharding, got { lhs_batch_spec } and "
3270- f"{ rhs_batch_spec } ." )
3254+ f"to have the consistent sharding, got { lhs_batch_spec } and "
3255+ f"{ rhs_batch_spec } ." )
32713256 _check_specs_match (lhs_batch_spec , rhs_batch_spec , msg )
32723257
32733258 lhs_contracting_spec = tuple (lhs .sharding .spec [i ] for i in lhs_contracting )
32743259 rhs_contracting_spec = tuple (rhs .sharding .spec [i ] for i in rhs_contracting )
32753260 msg = ("dot_general requires contracting dimensions to have consistent "
3276- f"sharding, got { lhs_contracting_spec } and { rhs_contracting_spec } ." )
3261+ f"sharding, got { lhs_contracting_spec } and { rhs_contracting_spec } ." )
32773262 _check_specs_match (lhs_contracting_spec , rhs_contracting_spec , msg )
32783263
32793264 return _dot_general_sharding_computation (
@@ -3295,8 +3280,7 @@ def tuple_delete(tup, idx):
32953280
32963281
32973282def _dot_general_dtype_rule (lhs , rhs , * , dimension_numbers , precision ,
3298- preferred_element_type : DTypeLike | None ,
3299- out_type ):
3283+ preferred_element_type : DTypeLike | None ):
33003284 del dimension_numbers # unused
33013285 # We're mostly matching XLA's logic here, namely in shape_inference.cc and
33023286 # primitive_util.h's HigherPrecisionType, e.g.
@@ -3343,7 +3327,7 @@ def _maybe_upcast(result_dtype, preferred_element_type, check_bit_width):
33433327
33443328def _dot_general_transpose_lhs (g , x , y , * , dimension_numbers , precision ,
33453329 preferred_element_type : DTypeLike | None ,
3346- out_type , swap_ans = False ):
3330+ swap_ans = False ):
33473331 (x_contract , y_contract ), (x_batch , y_batch ) = dimension_numbers
33483332 x_ndim = x .aval .ndim
33493333 x_kept = remaining (range (x_ndim ), x_contract , x_batch )
@@ -3363,14 +3347,12 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
33633347 return x_bar
33643348
33653349def _dot_general_transpose_rhs (g , x , y , * , dimension_numbers , precision ,
3366- preferred_element_type : DTypeLike | None ,
3367- out_type ):
3350+ preferred_element_type : DTypeLike | None ):
33683351 (x_contract , y_contract ), (x_batch , y_batch ) = dimension_numbers
33693352 swapped_dimension_numbers = ((y_contract , x_contract ), (y_batch , x_batch ))
33703353 y_bar = _dot_general_transpose_lhs (
33713354 g , y , x , dimension_numbers = swapped_dimension_numbers , precision = precision ,
3372- preferred_element_type = preferred_element_type , out_type = out_type ,
3373- swap_ans = True )
3355+ preferred_element_type = preferred_element_type , swap_ans = True )
33743356 if y_bar .dtype != y .aval .dtype :
33753357 y_bar = _convert_element_type (y_bar , y .aval .dtype , y .aval .weak_type )
33763358 return y_bar
@@ -3384,7 +3366,6 @@ def _dot_batch_rule(
33843366 batch_dims ,
33853367 * ,
33863368 dimension_numbers ,
3387- out_type ,
33883369 precision ,
33893370 preferred_element_type : DTypeLike | None ,
33903371 ** _ ,
@@ -3414,16 +3395,12 @@ def _dot_batch_rule(
34143395 rhs_shape = batching .bdim_as_shape (rbd , rhs .shape )
34153396 else :
34163397 rhs_shape = np .shape (rhs )
3417- if out_type is not None :
3418- raise NotImplementedError ("vmap with out_type is not supported. "
3419- "Please open an issue." )
34203398 batched_out = invoke_prim (
34213399 lhs ,
34223400 rhs ,
34233401 new_dimension_numbers ,
34243402 precision = precision ,
34253403 preferred_element_type = preferred_element_type ,
3426- out_type = out_type ,
34273404 )
34283405 result_batch_dim = batching .shape_as_bdim (
34293406 result_stack_dim ,
@@ -3593,7 +3570,7 @@ def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike,
35933570
35943571def _dot_general_lower (ctx , lhs , rhs , * , dimension_numbers ,
35953572 precision , preferred_element_type : np .dtype | None ,
3596- out_type , platform : str = "default" ):
3573+ platform : str = "default" ):
35973574 def _is_fp8_mixed_precision_matmul (_lhs_dtypes , _rhs_dtypes ):
35983575 fp8_dtypes = (dtypes .float8_e4m3fn , dtypes .float8_e5m2 ,
35993576 dtypes .float8_e5m2fnuz , dtypes .float8_e4m3fnuz )
@@ -3681,8 +3658,6 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype):
36813658 ** algorithm_kwarg ,
36823659 )
36833660 if config .sharding_in_types .value :
3684- if out_type is not None :
3685- assert aval_out .sharding == out_type
36863661 out_sp = aval_out .sharding ._to_xla_hlo_sharding (aval_out .ndim ).to_proto ()
36873662 result = mlir .wrap_with_sharding_op (ctx , result , aval_out , out_sp )
36883663 if accumulation_aval .dtype != aval_out .dtype :
@@ -3736,15 +3711,12 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S
37363711 return (m , n )
37373712
37383713def _ragged_dot_dtype_rule (lhs : Array , rhs : Array , group_sizes : Array ,
3739- precision , preferred_element_type : DTypeLike | None ,
3740- ** _ ) -> np .dtype :
3714+ precision , preferred_element_type : DTypeLike | None , ** _ ) -> np .dtype :
37413715 if not dtypes .issubdtype (group_sizes .dtype , np .integer ):
37423716 raise TypeError ("ragged_dot requires that group_sizes.dtype is subtype of np.integer." )
37433717 # defer the output dtype to dot_general, which is part of the _ragged_dot_impl.
3744- return _dot_general_dtype_rule (
3745- lhs , rhs , dimension_numbers = _RAGGED_DOT_DOT_DIMENSION_NUMBERS ,
3746- precision = precision , preferred_element_type = preferred_element_type ,
3747- out_type = None )
3718+ return _dot_general_dtype_rule (lhs , rhs , dimension_numbers = _RAGGED_DOT_DOT_DIMENSION_NUMBERS ,
3719+ precision = precision , preferred_element_type = preferred_element_type )
37483720
37493721
37503722def _ragged_dot_jvp_rule (
@@ -3883,7 +3855,6 @@ def _ragged_dot_batch_rule(
38833855 * ,
38843856 precision ,
38853857 preferred_element_type : DTypeLike | None ,
3886- out_type ,
38873858 ** _ ,
38883859):
38893860 invoke = functools .partial (_ragged_dot_invoke_prim , batched_args [2 ])
@@ -3897,7 +3868,6 @@ def _ragged_dot_batch_rule(
38973868 dimension_numbers = _RAGGED_DOT_DOT_DIMENSION_NUMBERS ,
38983869 precision = precision ,
38993870 preferred_element_type = preferred_element_type ,
3900- out_type = out_type ,
39013871 )
39023872
39033873
0 commit comments