Skip to content

Commit db158e6

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Improve the implementation of max and exp
Both are very important for FlashAttention and both were poorly mapped to PTX. For exp, we really do not care about denormals when running in approximate mode, since they would produce results so close to 1 that it really doesn't matter. For max, LLVM ended up generating a while bunch of comparisons and selects and failed to take advantage of the max instructions present in GPUs. Both of those changes _significantly_ improve the performance of Mosaic attention kernels for heads smaller than 256 (when the pointwise part dominates the execution time). In one example I looked at, the utilization jumps from 55% to 64%. PiperOrigin-RevId: 701042779
1 parent b801539 commit db158e6

File tree

2 files changed

+42
-21
lines changed

2 files changed

+42
-21
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1230,7 +1230,7 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
12301230
raise NotImplementedError
12311231
if not jnp.issubdtype(x_aval.dtype, jnp.floating):
12321232
raise NotImplementedError
1233-
return x.reduce(arith_dialect.maxnumf, axes[0])
1233+
return x.reduce("max", axes[0])
12341234
case _:
12351235
raise NotImplementedError(f"Unsupported layout {x.layout}")
12361236

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -879,7 +879,10 @@ def _compare(self, other, *, f_pred, si_pred, ui_pred):
879879

880880
def max(self, other):
881881
if ir.FloatType.isinstance(self.mlir_dtype):
882-
return self._pointwise(arith.maximumf, other)
882+
maximumf = arith.maximumf
883+
if ir.F32Type.isinstance(self.mlir_dtype):
884+
maximumf = self._lift_fast_instr("max.NaN.f32")
885+
return self._pointwise(maximumf, other)
883886
elif ir.IntegerType.isinstance(self.mlir_dtype):
884887
return self._pointwise(
885888
arith.maxsi if self.is_signed else arith.maxui, other
@@ -907,8 +910,8 @@ def exp(self, *, approx: bool = False):
907910
log2e = arith.constant(f32, ir.FloatAttr.get(f32, 1.4426950408889634))
908911
def fast_exp(x):
909912
scaled = arith.mulf(x, log2e)
910-
return llvm.inline_asm(f32, [scaled], "ex2.approx.f32 $0, $1;", "=f,f")
911-
return self._pointwise(self._lift_fast_unary(fast_exp))
913+
return llvm.inline_asm(f32, [scaled], "ex2.approx.ftz.f32 $0, $1;", "=f,f")
914+
return self._pointwise(self._lift_fast_instr(fast_exp))
912915
return self._pointwise(mlir_math.exp)
913916

914917
def sin(self, *, approx: bool = False):
@@ -917,7 +920,7 @@ def sin(self, *, approx: bool = False):
917920
if approx and self.mlir_dtype != ir.F32Type.get():
918921
raise NotImplementedError
919922
return self._pointwise(
920-
self._lift_fast_unary("sin.approx.f32") if approx else mlir_math.sin
923+
self._lift_fast_instr("sin.approx.f32") if approx else mlir_math.sin
921924
)
922925

923926
def cos(self, *, approx: bool = False):
@@ -926,7 +929,7 @@ def cos(self, *, approx: bool = False):
926929
if approx and self.mlir_dtype != ir.F32Type.get():
927930
raise NotImplementedError
928931
return self._pointwise(
929-
self._lift_fast_unary("cos.approx.f32") if approx else mlir_math.cos
932+
self._lift_fast_instr("cos.approx.f32") if approx else mlir_math.cos
930933
)
931934

932935
def tanh(self, *, approx: bool = False):
@@ -935,7 +938,7 @@ def tanh(self, *, approx: bool = False):
935938
if approx and self.mlir_dtype != ir.F32Type.get():
936939
raise NotImplementedError
937940
return self._pointwise(
938-
self._lift_fast_unary("tanh.approx.f32") if approx else mlir_math.tanh
941+
self._lift_fast_instr("tanh.approx.f32") if approx else mlir_math.tanh
939942
)
940943

941944
def rsqrt(self, *, approx: bool = False):
@@ -944,31 +947,36 @@ def rsqrt(self, *, approx: bool = False):
944947
if approx and self.mlir_dtype != ir.F32Type.get():
945948
raise NotImplementedError
946949
return self._pointwise(
947-
self._lift_fast_unary("rsqrt.approx.f32") if approx else mlir_math.rsqrt
950+
self._lift_fast_instr("rsqrt.approx.f32") if approx else mlir_math.rsqrt
948951
)
949952

950953
@staticmethod
951-
def _lift_fast_unary(
954+
def _lift_fast_instr(
952955
instr: str | Callable[[ir.Value], ir.Value],
953956
) -> Callable[[ir.Value], ir.Value]:
954-
def fast_instr(x):
957+
def fast_instr(*args):
955958
f32 = ir.F32Type.get()
956-
if x.type == f32:
959+
arg_ty = args[0].type
960+
assert all(a.type == arg_ty for a in args)
961+
if arg_ty == f32:
957962
if isinstance(instr, str):
958-
return llvm.inline_asm(f32, [x], instr + " $0, $1;", "=f,f")
963+
args_ptx = ", ".join(f"${i}" for i in range(len(args) + 1))
964+
return llvm.inline_asm(
965+
f32, args, f"{instr} {args_ptx};", "=f" + ",f" * len(args)
966+
)
959967
else:
960-
return instr(x)
961-
elif ir.VectorType.isinstance(x.type):
968+
return instr(*args)
969+
elif ir.VectorType.isinstance(arg_ty):
962970
index = ir.IndexType.get()
963-
result = llvm.mlir_undef(x.type)
964-
[vec_len] = ir.VectorType(x.type).shape
971+
result = llvm.mlir_undef(arg_ty)
972+
[vec_len] = ir.VectorType(arg_ty).shape
965973
for i in range(vec_len):
966-
v = vector.extractelement(x, position=c(i, index))
967-
vr = fast_instr(v)
974+
vs = [vector.extractelement(a, position=c(i, index)) for a in args]
975+
vr = fast_instr(*vs)
968976
result = vector.insertelement(vr, result, position=c(i, index))
969977
return result
970978
else:
971-
raise NotImplementedError(x.type)
979+
raise NotImplementedError(arg_ty)
972980
return fast_instr
973981

974982
def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None):
@@ -1156,7 +1164,20 @@ def reduce_sum(self, scratch) -> ir.Value:
11561164
utils.warpgroup_barrier() # Make sure everyone is done using scratch.
11571165
return result
11581166

1159-
def reduce(self, op, axis):
1167+
def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis):
1168+
if isinstance(op, str):
1169+
match op:
1170+
case "max":
1171+
if ir.F32Type.isinstance(self.mlir_dtype):
1172+
op = self._lift_fast_instr("max.NaN.f32")
1173+
elif ir.FloatType.isinstance(self.mlir_dtype):
1174+
op = arith.maximumf
1175+
elif ir.IntegerType.isinstance(self.mlir_dtype):
1176+
op = arith.maxsi if self.is_signed else arith.maxui
1177+
else:
1178+
raise NotImplementedError(self.mlir_dtype)
1179+
case _:
1180+
raise ValueError(f"Unrecognized reduction operator: {op}")
11601181
if self.layout != WGMMA_LAYOUT:
11611182
raise NotImplementedError(self.layout)
11621183
if axis != 1:
@@ -1421,7 +1442,7 @@ def load_tiled(
14211442
tiled_shape = ref_ty.shape
14221443
if len(tiled_shape) % 2:
14231444
raise ValueError("Tiled reference must have even rank")
1424-
tiling = Tiling((tiled_shape[len(tiled_shape) // 2:],))
1445+
tiling = Tiling((tiled_shape[len(tiled_shape) // 2 :],))
14251446
shape = tiling.untile_shape(tiled_shape)
14261447
registers = np.full(layout.registers_shape(shape), None, dtype=object)
14271448
reg_ty = ir.VectorType.get((layout.vector_length,), ref_ty.element_type)

0 commit comments

Comments
 (0)