Skip to content

Commit 3ec55c7

Browse files
chr1sj0nesGoogle-ML-Automation
authored andcommitted
[pallas:triton] Add support for DotAlgorithmPreset precision arguments to dot.
PiperOrigin-RevId: 704208558
1 parent 7062325 commit 3ec55c7

File tree

2 files changed

+97
-95
lines changed

2 files changed

+97
-95
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 57 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,81 +1983,6 @@ def _transpose_lowering(ctx: LoweringRuleContext, x, *, permutation):
19831983
return tt_dialect.trans(x, permutation)
19841984

19851985

1986-
def _check_dot_operands(
1987-
x_type: ir.RankedTensorType, y_type: ir.RankedTensorType, options: Any
1988-
):
1989-
# TODO(slebedev): Ensure that the dtypes are supported by CUDA.
1990-
return
1991-
1992-
1993-
def _dot(
1994-
x: ir.Value,
1995-
y: ir.Value,
1996-
acc: ir.Value | None = None,
1997-
*,
1998-
allow_tf32: bool = True,
1999-
max_num_imprecise_acc: int | None = None,
2000-
out_type: ir.Type | None = None,
2001-
) -> ir.Value:
2002-
if out_type is None:
2003-
out_type = ir.F32Type.get()
2004-
elif isinstance(out_type, ir.BF16Type):
2005-
raise NotImplementedError(f"unsupported output type: {out_type}")
2006-
2007-
x_type = ir.RankedTensorType(x.type)
2008-
y_type = ir.RankedTensorType(y.type)
2009-
if min(*x_type.shape, *y_type.shape) < 16:
2010-
raise ValueError("all dimensions of x and y must be >= 16 ")
2011-
if x_type.element_type != y_type.element_type:
2012-
raise ValueError(
2013-
"x and y must have the same element type, but got:"
2014-
f" {x_type.element_type} and {y_type.element_type}"
2015-
)
2016-
2017-
_check_dot_operands(x_type, y_type, object())
2018-
2019-
element_type = x_type.element_type
2020-
if isinstance(element_type, ir.IntegerType):
2021-
if element_type.width != 8:
2022-
raise TypeError(f"unsupported element type: {element_type}")
2023-
element_type = ir.IntegerType.get_signless(32)
2024-
elif isinstance(element_type, (ir.F32Type, ir.BF16Type)):
2025-
element_type = ir.F32Type.get()
2026-
else:
2027-
element_type = out_type
2028-
2029-
if element_type != out_type:
2030-
raise TypeError(
2031-
f"output type {out_type} does not match element type {element_type}"
2032-
)
2033-
2034-
m, _ = x_type.shape
2035-
_, n = y_type.shape
2036-
2037-
if acc is None:
2038-
acc = _full(ir.RankedTensorType.get([m, n], element_type), 0)
2039-
2040-
if max_num_imprecise_acc is None:
2041-
if isinstance(element_type, ir.FloatType) and element_type.width == 8:
2042-
# TODO(slebedev): Fill in from options.
2043-
raise NotImplementedError
2044-
else:
2045-
max_num_imprecise_acc = 0
2046-
2047-
# Ideally, replace all allow_tf32 usages with InputPrecision directly.
2048-
input_precision = tt_dialect.InputPrecision.IEEE
2049-
if allow_tf32:
2050-
input_precision = tt_dialect.InputPrecision.TF32
2051-
2052-
return tt_dialect.dot(
2053-
x,
2054-
y,
2055-
acc,
2056-
max_num_imprecise_acc=max_num_imprecise_acc,
2057-
input_precision=input_precision
2058-
)
2059-
2060-
20611986
_TF32_PRECISIONS = (lax.Precision.HIGH, lax.Precision.DEFAULT)
20621987

20631988

@@ -2081,27 +2006,63 @@ def _dot_general_lowering(
20812006
if b_contract_dim == 1:
20822007
b = tt_dialect.trans(b, (1, 0))
20832008

2084-
if precision is None:
2085-
allow_tf32 = True
2009+
a_aval, b_aval = ctx.avals_in
2010+
[out_aval] = ctx.avals_out
2011+
2012+
if precision is None or (precision == lax.DotAlgorithmPreset.DEFAULT):
2013+
precision = (lax.Precision.DEFAULT, lax.Precision.DEFAULT)
2014+
2015+
if isinstance(precision, lax.DotAlgorithmPreset):
2016+
match precision:
2017+
case lax.DotAlgorithmPreset.TF32_TF32_F32:
2018+
input_precision = tt_dialect.InputPrecision.TF32
2019+
case lax.DotAlgorithmPreset.TF32_TF32_F32_X3:
2020+
input_precision = tt_dialect.InputPrecision.TF32x3
2021+
case lax.DotAlgorithmPreset.F32_F32_F32:
2022+
input_precision = tt_dialect.InputPrecision.IEEE
2023+
case (
2024+
lax.DotAlgorithmPreset.F16_F16_F16
2025+
| lax.DotAlgorithmPreset.F16_F16_F32
2026+
| lax.DotAlgorithmPreset.BF16_BF16_BF16
2027+
| lax.DotAlgorithmPreset.BF16_BF16_F32
2028+
):
2029+
input_precision = None
2030+
case _:
2031+
raise NotImplementedError(f"Unsupported dot algorithm: {precision}.")
2032+
2033+
a = _cast(a, a_aval.dtype, precision.supported_lhs_types[0])
2034+
b = _cast(b, b_aval.dtype, precision.supported_rhs_types[0])
2035+
acc_dtype = precision.accumulation_type
2036+
elif isinstance(precision, tuple):
2037+
a_precision, b_precision = precision
2038+
if a_precision in _TF32_PRECISIONS or b_precision in _TF32_PRECISIONS:
2039+
input_precision = tt_dialect.InputPrecision.TF32
2040+
elif a_aval.dtype == jnp.float32:
2041+
input_precision = tt_dialect.InputPrecision.IEEE
2042+
else:
2043+
input_precision = None
2044+
2045+
acc_dtype = out_aval.dtype
2046+
if acc_dtype != jnp.int32 and acc_dtype != jnp.float16:
2047+
acc_dtype = jnp.float32
20862048
else:
2087-
prec_a, prec_b = precision
2088-
allow_tf32 = prec_a in _TF32_PRECISIONS or prec_b in _TF32_PRECISIONS
2049+
raise NotImplementedError(f"Unsupported dot precision: {precision}.")
20892050

2090-
[out_aval] = ctx.avals_out
2091-
out_dtype = acc_dtype = out_aval.dtype
2092-
if acc_dtype != jnp.int32 and acc_dtype != jnp.float16:
2093-
acc_dtype = jnp.dtype(jnp.float32)
2094-
2095-
return _cast(
2096-
_dot(
2097-
a,
2098-
b,
2099-
allow_tf32=allow_tf32,
2100-
out_type=_dtype_to_ir_type(acc_dtype),
2101-
),
2102-
acc_dtype,
2103-
out_dtype,
2104-
)
2051+
a_type = ir.RankedTensorType(a.type)
2052+
b_type = ir.RankedTensorType(b.type)
2053+
if min(*a_type.shape, *b_type.shape) < 16:
2054+
raise ValueError("all dimensions of a and b must be >= 16 ")
2055+
if a_type.element_type != b_type.element_type:
2056+
raise ValueError(
2057+
"a and b must have the same element type, but got:"
2058+
f" {a_type.element_type} and {b_type.element_type}"
2059+
)
2060+
2061+
m, _ = a_type.shape
2062+
_, n = b_type.shape
2063+
acc = _full(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype)), 0)
2064+
acc = tt_dialect.dot(a, b, acc, input_precision=input_precision)
2065+
return _cast(acc, acc_dtype, out_aval.dtype)
21052066

21062067

21072068
def _reduction_lowering(body, ctx: LoweringRuleContext, a, axes):
@@ -2623,7 +2584,8 @@ def _i64_constant(v: int) -> ir.Value:
26232584
return arith_dialect.constant(ir.IntegerType.get_signless(64), v)
26242585

26252586

2626-
def _dtype_to_ir_type(dtype: jnp.dtype) -> ir.Type:
2587+
def _dtype_to_ir_type(dtype: jax.typing.DTypeLike) -> ir.Type:
2588+
dtype = jnp.dtype(dtype)
26272589
if jnp.issubdtype(dtype, np.integer):
26282590
# All integer types in Triton are signless.
26292591
return ir.IntegerType.get_signless(dtype.itemsize * 8)

tests/pallas/pallas_test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,46 @@ def f(x):
687687
self.assertEqual(f(x), 2.)
688688
self.assertEqual(trace_count, 1)
689689

690+
@parameterized.parameters(
691+
("float32", None),
692+
("float32", jax.lax.Precision.DEFAULT),
693+
("float32", jax.lax.Precision.HIGH),
694+
("float32", jax.lax.Precision.HIGHEST),
695+
("float32", jax.lax.DotAlgorithmPreset.DEFAULT),
696+
("float32", jax.lax.DotAlgorithmPreset.F16_F16_F32),
697+
("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32),
698+
("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32),
699+
("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32_X3),
700+
("float32", jax.lax.DotAlgorithmPreset.F32_F32_F32),
701+
("bfloat16", None),
702+
("bfloat16", jax.lax.Precision.DEFAULT),
703+
("bfloat16", jax.lax.Precision.HIGHEST),
704+
("bfloat16", jax.lax.DotAlgorithmPreset.DEFAULT),
705+
("bfloat16", jax.lax.DotAlgorithmPreset.BF16_BF16_F32),
706+
)
707+
def test_dot_precision(self, dtype, precision):
708+
if not jtu.test_device_matches(["gpu"]):
709+
self.skipTest("`DotAlgorithmPreset` only supported on GPU.")
710+
711+
@functools.partial(
712+
self.pallas_call,
713+
out_shape=jax.ShapeDtypeStruct((32, 64), jnp.float32),
714+
grid=1,
715+
)
716+
def dot_kernel(x_ref, y_ref, o_ref):
717+
o_ref[()] = pl.dot(x_ref[()], y_ref[()], precision=precision)
718+
719+
key0, key1 = random.split(random.key(0))
720+
x = random.normal(key0, (32, 16), dtype=dtype)
721+
y = random.normal(key1, (16, 64), dtype=dtype)
722+
expected = jnp.dot(
723+
x,
724+
y,
725+
precision=jax.lax.Precision.HIGHEST,
726+
preferred_element_type=jnp.float32,
727+
)
728+
self.assertAllClose(dot_kernel(x, y), expected, atol=5e-2, rtol=5e-3)
729+
690730

691731
class PallasCallInterpretTest(PallasCallTest):
692732
INTERPRET = True

0 commit comments

Comments
 (0)