@@ -1983,81 +1983,6 @@ def _transpose_lowering(ctx: LoweringRuleContext, x, *, permutation):
19831983 return tt_dialect .trans (x , permutation )
19841984
19851985
1986- def _check_dot_operands (
1987- x_type : ir .RankedTensorType , y_type : ir .RankedTensorType , options : Any
1988- ):
1989- # TODO(slebedev): Ensure that the dtypes are supported by CUDA.
1990- return
1991-
1992-
1993- def _dot (
1994- x : ir .Value ,
1995- y : ir .Value ,
1996- acc : ir .Value | None = None ,
1997- * ,
1998- allow_tf32 : bool = True ,
1999- max_num_imprecise_acc : int | None = None ,
2000- out_type : ir .Type | None = None ,
2001- ) -> ir .Value :
2002- if out_type is None :
2003- out_type = ir .F32Type .get ()
2004- elif isinstance (out_type , ir .BF16Type ):
2005- raise NotImplementedError (f"unsupported output type: { out_type } " )
2006-
2007- x_type = ir .RankedTensorType (x .type )
2008- y_type = ir .RankedTensorType (y .type )
2009- if min (* x_type .shape , * y_type .shape ) < 16 :
2010- raise ValueError ("all dimensions of x and y must be >= 16 " )
2011- if x_type .element_type != y_type .element_type :
2012- raise ValueError (
2013- "x and y must have the same element type, but got:"
2014- f" { x_type .element_type } and { y_type .element_type } "
2015- )
2016-
2017- _check_dot_operands (x_type , y_type , object ())
2018-
2019- element_type = x_type .element_type
2020- if isinstance (element_type , ir .IntegerType ):
2021- if element_type .width != 8 :
2022- raise TypeError (f"unsupported element type: { element_type } " )
2023- element_type = ir .IntegerType .get_signless (32 )
2024- elif isinstance (element_type , (ir .F32Type , ir .BF16Type )):
2025- element_type = ir .F32Type .get ()
2026- else :
2027- element_type = out_type
2028-
2029- if element_type != out_type :
2030- raise TypeError (
2031- f"output type { out_type } does not match element type { element_type } "
2032- )
2033-
2034- m , _ = x_type .shape
2035- _ , n = y_type .shape
2036-
2037- if acc is None :
2038- acc = _full (ir .RankedTensorType .get ([m , n ], element_type ), 0 )
2039-
2040- if max_num_imprecise_acc is None :
2041- if isinstance (element_type , ir .FloatType ) and element_type .width == 8 :
2042- # TODO(slebedev): Fill in from options.
2043- raise NotImplementedError
2044- else :
2045- max_num_imprecise_acc = 0
2046-
2047- # Ideally, replace all allow_tf32 usages with InputPrecision directly.
2048- input_precision = tt_dialect .InputPrecision .IEEE
2049- if allow_tf32 :
2050- input_precision = tt_dialect .InputPrecision .TF32
2051-
2052- return tt_dialect .dot (
2053- x ,
2054- y ,
2055- acc ,
2056- max_num_imprecise_acc = max_num_imprecise_acc ,
2057- input_precision = input_precision
2058- )
2059-
2060-
20611986_TF32_PRECISIONS = (lax .Precision .HIGH , lax .Precision .DEFAULT )
20621987
20631988
@@ -2081,27 +2006,63 @@ def _dot_general_lowering(
20812006 if b_contract_dim == 1 :
20822007 b = tt_dialect .trans (b , (1 , 0 ))
20832008
2084- if precision is None :
2085- allow_tf32 = True
2009+ a_aval , b_aval = ctx .avals_in
2010+ [out_aval ] = ctx .avals_out
2011+
2012+ if precision is None or (precision == lax .DotAlgorithmPreset .DEFAULT ):
2013+ precision = (lax .Precision .DEFAULT , lax .Precision .DEFAULT )
2014+
2015+ if isinstance (precision , lax .DotAlgorithmPreset ):
2016+ match precision :
2017+ case lax .DotAlgorithmPreset .TF32_TF32_F32 :
2018+ input_precision = tt_dialect .InputPrecision .TF32
2019+ case lax .DotAlgorithmPreset .TF32_TF32_F32_X3 :
2020+ input_precision = tt_dialect .InputPrecision .TF32x3
2021+ case lax .DotAlgorithmPreset .F32_F32_F32 :
2022+ input_precision = tt_dialect .InputPrecision .IEEE
2023+ case (
2024+ lax .DotAlgorithmPreset .F16_F16_F16
2025+ | lax .DotAlgorithmPreset .F16_F16_F32
2026+ | lax .DotAlgorithmPreset .BF16_BF16_BF16
2027+ | lax .DotAlgorithmPreset .BF16_BF16_F32
2028+ ):
2029+ input_precision = None
2030+ case _:
2031+ raise NotImplementedError (f"Unsupported dot algorithm: { precision } ." )
2032+
2033+ a = _cast (a , a_aval .dtype , precision .supported_lhs_types [0 ])
2034+ b = _cast (b , b_aval .dtype , precision .supported_rhs_types [0 ])
2035+ acc_dtype = precision .accumulation_type
2036+ elif isinstance (precision , tuple ):
2037+ a_precision , b_precision = precision
2038+ if a_precision in _TF32_PRECISIONS or b_precision in _TF32_PRECISIONS :
2039+ input_precision = tt_dialect .InputPrecision .TF32
2040+ elif a_aval .dtype == jnp .float32 :
2041+ input_precision = tt_dialect .InputPrecision .IEEE
2042+ else :
2043+ input_precision = None
2044+
2045+ acc_dtype = out_aval .dtype
2046+ if acc_dtype != jnp .int32 and acc_dtype != jnp .float16 :
2047+ acc_dtype = jnp .float32
20862048 else :
2087- prec_a , prec_b = precision
2088- allow_tf32 = prec_a in _TF32_PRECISIONS or prec_b in _TF32_PRECISIONS
2049+ raise NotImplementedError (f"Unsupported dot precision: { precision } ." )
20892050
2090- [ out_aval ] = ctx . avals_out
2091- out_dtype = acc_dtype = out_aval . dtype
2092- if acc_dtype != jnp . int32 and acc_dtype != jnp . float16 :
2093- acc_dtype = jnp . dtype ( jnp . float32 )
2094-
2095- return _cast (
2096- _dot (
2097- a ,
2098- b ,
2099- allow_tf32 = allow_tf32 ,
2100- out_type = _dtype_to_ir_type ( acc_dtype ),
2101- ),
2102- acc_dtype ,
2103- out_dtype ,
2104- )
2051+ a_type = ir . RankedTensorType ( a . type )
2052+ b_type = ir . RankedTensorType ( b . type )
2053+ if min ( * a_type . shape , * b_type . shape ) < 16 :
2054+ raise ValueError ( "all dimensions of a and b must be >= 16 " )
2055+ if a_type . element_type != b_type . element_type :
2056+ raise ValueError (
2057+ "a and b must have the same element type, but got:"
2058+ f" { a_type . element_type } and { b_type . element_type } "
2059+ )
2060+
2061+ m , _ = a_type . shape
2062+ _ , n = b_type . shape
2063+ acc = _full ( ir . RankedTensorType . get ([ m , n ], _dtype_to_ir_type ( acc_dtype )), 0 )
2064+ acc = tt_dialect . dot ( a , b , acc , input_precision = input_precision )
2065+ return _cast ( acc , acc_dtype , out_aval . dtype )
21052066
21062067
21072068def _reduction_lowering (body , ctx : LoweringRuleContext , a , axes ):
@@ -2623,7 +2584,8 @@ def _i64_constant(v: int) -> ir.Value:
26232584 return arith_dialect .constant (ir .IntegerType .get_signless (64 ), v )
26242585
26252586
2626- def _dtype_to_ir_type (dtype : jnp .dtype ) -> ir .Type :
2587+ def _dtype_to_ir_type (dtype : jax .typing .DTypeLike ) -> ir .Type :
2588+ dtype = jnp .dtype (dtype )
26272589 if jnp .issubdtype (dtype , np .integer ):
26282590 # All integer types in Triton are signless.
26292591 return ir .IntegerType .get_signless (dtype .itemsize * 8 )
0 commit comments