Skip to content

Commit f10d3eb

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Allow contracting ops into FMAs
Using FMAs can significantly increase the ALU throughput and only increases the precision. We use this capability to reduce the number of operations needed to evaluate the softmax part of attention. PiperOrigin-RevId: 701226007
1 parent ab79066 commit f10d3eb

File tree

3 files changed

+51
-18
lines changed

3 files changed

+51
-18
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1199,6 +1199,13 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x):
11991199
return a.exp(approx=ctx.module_ctx.approx_math)
12001200

12011201

1202+
@register_lowering_rule(lax.exp2_p)
1203+
def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
1204+
[x_aval] = ctx.avals_in
1205+
a = _ensure_fa(x, x_aval.dtype)
1206+
return a.exp2(approx=ctx.module_ctx.approx_math)
1207+
1208+
12021209
@register_lowering_rule(lax.reduce_sum_p)
12031210
def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
12041211
[x_aval] = ctx.avals_in
@@ -1216,7 +1223,7 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
12161223
raise NotImplementedError
12171224
if not jnp.issubdtype(x_aval.dtype, jnp.floating):
12181225
raise NotImplementedError
1219-
return x.reduce(arith_dialect.addf, axes[0])
1226+
return x.reduce("add", axes[0])
12201227
case _:
12211228
raise NotImplementedError(f"Unsupported layout {x.layout}")
12221229

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ def __neg__(self):
700700

701701
def __add__(self, other):
702702
if ir.FloatType.isinstance(self.mlir_dtype):
703-
return self._pointwise(arith.addf, other)
703+
return self._pointwise(addf, other)
704704
elif ir.IntegerType.isinstance(self.mlir_dtype):
705705
return self._pointwise(arith.addi, other)
706706
else:
@@ -711,7 +711,7 @@ def __radd__(self, other):
711711

712712
def __mul__(self, other):
713713
if ir.FloatType.isinstance(self.mlir_dtype):
714-
return self._pointwise(arith.mulf, other)
714+
return self._pointwise(mulf, other)
715715
elif ir.IntegerType.isinstance(self.mlir_dtype):
716716
return self._pointwise(arith.muli, other)
717717
else:
@@ -722,15 +722,15 @@ def __rmul__(self, other):
722722

723723
def __sub__(self, other):
724724
if ir.FloatType.isinstance(self.mlir_dtype):
725-
return self._pointwise(arith.subf, other)
725+
return self._pointwise(subf, other)
726726
elif ir.IntegerType.isinstance(self.mlir_dtype):
727727
return self._pointwise(arith.subi, other)
728728
else:
729729
return NotImplemented
730730

731731
def __rsub__(self, other):
732732
if ir.FloatType.isinstance(self.mlir_dtype):
733-
return self._pointwise(lambda s, o: arith.subf(o, s), other)
733+
return self._pointwise(lambda s, o: subf(o, s), other)
734734
elif ir.IntegerType.isinstance(self.mlir_dtype):
735735
return self._pointwise(lambda s, o: arith.subi(o, s), other)
736736
else:
@@ -904,16 +904,20 @@ def exp(self, *, approx: bool = False):
904904
if not ir.FloatType.isinstance(self.mlir_dtype):
905905
raise NotImplementedError
906906
if approx:
907-
f32 = ir.F32Type.get()
908-
if self.mlir_dtype != f32:
909-
raise NotImplementedError
910-
log2e = arith.constant(f32, ir.FloatAttr.get(f32, 1.4426950408889634))
911-
def fast_exp(x):
912-
scaled = arith.mulf(x, log2e)
913-
return llvm.inline_asm(f32, [scaled], "ex2.approx.ftz.f32 $0, $1;", "=f,f")
914-
return self._pointwise(self._lift_fast_instr(fast_exp))
907+
dtype = self.mlir_dtype
908+
log2e = arith.constant(dtype, ir.FloatAttr.get(dtype, 1.4426950408889634))
909+
return (self * log2e).exp2()
915910
return self._pointwise(mlir_math.exp)
916911

912+
def exp2(self, *, approx: bool = False):
913+
if not ir.FloatType.isinstance(self.mlir_dtype):
914+
raise NotImplementedError
915+
if approx:
916+
if not ir.F32Type.isinstance(self.mlir_dtype):
917+
raise NotImplementedError(self.mlir_dtype)
918+
return self._pointwise(self._lift_fast_instr("ex2.approx.ftz.f32"))
919+
return self._pointwise(mlir_math.exp2)
920+
917921
def sin(self, *, approx: bool = False):
918922
if not ir.FloatType.isinstance(self.mlir_dtype):
919923
raise NotImplementedError
@@ -1125,7 +1129,7 @@ def upcast_to_bf16(reg, high):
11251129
# NOTE: scratch can be reused immediately once this function returns.
11261130
def reduce_sum(self, scratch) -> ir.Value:
11271131
if ir.FloatType.isinstance(self.mlir_dtype):
1128-
op = arith.addf
1132+
op = addf
11291133
elif ir.IntegerType.isinstance(self.mlir_dtype):
11301134
op = arith.addi
11311135
else:
@@ -1167,6 +1171,13 @@ def reduce_sum(self, scratch) -> ir.Value:
11671171
def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis):
11681172
if isinstance(op, str):
11691173
match op:
1174+
case "add":
1175+
if ir.FloatType.isinstance(self.mlir_dtype):
1176+
op = addf
1177+
elif ir.IntegerType.isinstance(self.mlir_dtype):
1178+
op = arith.addi
1179+
else:
1180+
raise NotImplementedError(self.mlir_dtype)
11701181
case "max":
11711182
if ir.F32Type.isinstance(self.mlir_dtype):
11721183
op = self._lift_fast_instr("max.NaN.f32")
@@ -1653,3 +1664,15 @@ def tree_unflatten(cls, aux, flat_registers):
16531664
layout, reg_shape, is_signed = aux
16541665
registers = np.asarray(flat_registers, dtype=object).reshape(reg_shape)
16551666
return cls(_registers=registers, _layout=layout, _is_signed=is_signed)
1667+
1668+
1669+
# We allow contractions, to potentially take advantage of FMA instructions.
1670+
# They can change the results, but the precision should only increase.
1671+
def addf(a: ir.Value, b: ir.Value):
1672+
return arith.addf(a, b, fastmath=arith.FastMathFlags.contract)
1673+
1674+
def subf(a: ir.Value, b: ir.Value):
1675+
return arith.subf(a, b, fastmath=arith.FastMathFlags.contract)
1676+
1677+
def mulf(a: ir.Value, b: ir.Value):
1678+
return arith.mulf(a, b, fastmath=arith.FastMathFlags.contract)

jax/experimental/pallas/ops/gpu/attention_mgpu.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import dataclasses
1717
import functools
1818
import itertools
19+
import math
1920
import jax
2021
from jax import lax
2122
from jax._src import test_util as jtu # noqa: F401
@@ -118,11 +119,13 @@ def compute_qk(acc_ref):
118119
plgpu.barrier_arrive(k_consumed_barrier)
119120

120121
# Softmax
121-
m_ij = jnp.maximum(m_i, qk.max(axis=1))
122-
alpha = jnp.exp(m_i - m_ij)
122+
# We keep m scaled by log2e to use FMA instructions when computing p.
123+
log2e = math.log2(math.e)
124+
m_ij = jnp.maximum(m_i, qk.max(axis=1) * log2e)
125+
alpha = jnp.exp2(m_i - m_ij)
123126
m_i = m_ij
124-
p = jnp.exp(qk - lax.broadcast_in_dim(m_ij, (block_q, block_kv), [0]))
125-
acc *= lax.broadcast_in_dim(alpha, (block_q, head_dim), [0])
127+
p = jnp.exp2(qk * log2e - lax.broadcast_in_dim(m_ij, qk.shape, [0]))
128+
acc *= lax.broadcast_in_dim(alpha, acc.shape, [0])
126129
l_i *= alpha
127130
p16 = p.astype(dtype)
128131

0 commit comments

Comments
 (0)