Skip to content

Commit 11090be

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add an optimization barrier
The barrier is a no-op at runtime, but appears as a side-effecting op to LLVM which prevents it from moving the (even pure) computations that involve the supplied arrays past the barrier. PiperOrigin-RevId: 702709125
1 parent 3895e03 commit 11090be

File tree

7 files changed

+147
-56
lines changed

7 files changed

+147
-56
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,6 +1681,12 @@ def _bitcast_convert_type_lowering_rule(
16811681
)
16821682

16831683

1684+
@register_lowering_rule(lax.optimization_barrier_p)
1685+
def _optimization_barrier_lowering(ctx: LoweringRuleContext, *args):
1686+
args = (_ensure_fa(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in))
1687+
return mgpu.optimization_barrier(*args)
1688+
1689+
16841690
def _bcast(
16851691
x: ir.Value,
16861692
y: ir.Value,

jax/experimental/mosaic/gpu/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
WGMMA_ROW_LAYOUT as WGMMA_ROW_LAYOUT,
4444
WGSplatFragLayout as WGSplatFragLayout,
4545
WGStridedFragLayout as WGStridedFragLayout,
46+
optimization_barrier as optimization_barrier,
4647
)
4748
from .utils import (
4849
BarrierRef as BarrierRef,

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1866,3 +1866,108 @@ def subf(a: ir.Value, b: ir.Value):
18661866

18671867
def mulf(a: ir.Value, b: ir.Value):
18681868
return arith.mulf(a, b, fastmath=arith.FastMathFlags.contract)
1869+
1870+
1871+
def optimization_barrier(*arrays: mgpu.FragmentedArray):
1872+
"""Acts as an optimization barrier for LLVM.
1873+
1874+
Passing arrays through this function will make sure that they are computed
1875+
before any side-effecting operations that follow this barrier.
1876+
"""
1877+
index = ir.IndexType.get()
1878+
i32 = ir.IntegerType.get_signless(32)
1879+
1880+
regs = []
1881+
reg_dtypes = []
1882+
reg_constraints = []
1883+
ptx_lines = ["// Optimization barrier"]
1884+
repack_fns = []
1885+
# We unpack each array into a flat list of registers, and prepare the
1886+
# functions that invert the transform in repack_fns.
1887+
for array in arrays:
1888+
ptx_lines.append("// Next array")
1889+
reg_ty = array.registers.flat[0].type
1890+
dtype = array.mlir_dtype
1891+
num_prev_cstr = len(reg_constraints)
1892+
if ir.F32Type.isinstance(dtype):
1893+
if ir.VectorType.isinstance(reg_ty):
1894+
[vec_len] = ir.VectorType(reg_ty).shape
1895+
array_regs = [ # pylint: disable=g-complex-comprehension
1896+
vector.extractelement(reg, position=c(pos, index))
1897+
for reg in array.registers.flat
1898+
for pos in range(vec_len)
1899+
]
1900+
def _repack(regs, reg_ty=reg_ty):
1901+
reg = llvm.mlir_undef(reg_ty)
1902+
[vec_len] = ir.VectorType(reg_ty).shape
1903+
for i_elem in range(vec_len):
1904+
reg = llvm.insertelement(
1905+
reg, next(regs), arith.constant(i32, i_elem)
1906+
)
1907+
return reg
1908+
repack_fns.append(_repack)
1909+
else:
1910+
array_regs = list(array.registers.flat)
1911+
repack_fns.append(lambda regs: next(regs))
1912+
reg_constraint = "f"
1913+
elif ir.BF16Type.isinstance(dtype) or ir.F16Type.isinstance(dtype):
1914+
if not ir.VectorType.isinstance(reg_ty):
1915+
raise NotImplementedError(array.mlir_dtype)
1916+
[vec_len] = ir.VectorType(reg_ty).shape
1917+
if vec_len != 2:
1918+
raise NotImplementedError(vec_len)
1919+
i32_reg_ty = ir.VectorType.get((1,), i32)
1920+
array_regs = [
1921+
vector.extractelement(
1922+
vector.bitcast(i32_reg_ty, reg), position=c(0, index)
1923+
)
1924+
for reg in array.registers.flat
1925+
]
1926+
reg_constraint = "r"
1927+
def _repack(regs, reg_ty=reg_ty, i32_reg_ty=i32_reg_ty):
1928+
return vector.bitcast(reg_ty, vector.splat(i32_reg_ty, next(regs)))
1929+
repack_fns.append(_repack)
1930+
else:
1931+
raise NotImplementedError(array.mlir_dtype)
1932+
regs += array_regs
1933+
reg_dtypes += [array_regs[0].type] * len(array_regs)
1934+
reg_constraints += [f"={reg_constraint}"] * len(array_regs)
1935+
reg_constraints += [reg_constraint] * len(array_regs)
1936+
ptx_lines += [
1937+
f"mov.b32 ${i}, ${len(array_regs)+i}"
1938+
for i in range(num_prev_cstr, num_prev_cstr + len(array_regs))
1939+
]
1940+
reg_constraints = ",".join(reg_constraints)
1941+
ptx = ";\n\t".join(ptx_lines) + ";"
1942+
struct_ty = ir.Type.parse(
1943+
f"!llvm.struct<({','.join(map(str, reg_dtypes))})>"
1944+
)
1945+
result_struct = llvm.inline_asm(
1946+
struct_ty, regs, ptx, reg_constraints,
1947+
asm_dialect=0, has_side_effects=True,
1948+
)
1949+
regs = [
1950+
llvm.extractvalue(dtype, result_struct, [i])
1951+
for i, dtype in enumerate(reg_dtypes)
1952+
]
1953+
i32 = ir.IntegerType.get_signless(32)
1954+
results = []
1955+
regs_it = iter(regs)
1956+
for array, repack_fn in zip(arrays, repack_fns, strict=True):
1957+
num_regs = array.registers.size
1958+
reg_ty = array.registers.flat[0].type
1959+
if ir.VectorType.isinstance(reg_ty):
1960+
reg_ty = ir.VectorType(reg_ty)
1961+
new_registers = np.empty((num_regs,), dtype=object)
1962+
for i_vreg in range(num_regs):
1963+
reg = repack_fn(regs_it)
1964+
assert reg.type == reg_ty, (reg.type, reg_ty)
1965+
new_registers[i_vreg] = reg
1966+
results.append(
1967+
FragmentedArray(
1968+
_registers=new_registers.reshape(array.registers.shape),
1969+
_layout=array.layout,
1970+
_is_signed=array.is_signed,
1971+
)
1972+
)
1973+
return results[0] if len(arrays) == 1 else results

jax/experimental/mosaic/gpu/wgmma.py

Lines changed: 7 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from jaxlib.mlir.dialects import arith
2424
from jaxlib.mlir.dialects import llvm
2525
from jaxlib.mlir.dialects import vector
26+
from jaxlib.mlir.dialects import nvvm
2627
import numpy as np
2728

2829
import jax.experimental.mosaic.gpu as mgpu
@@ -445,58 +446,13 @@ def wgmma(
445446
def wgmma_fence(array: mgpu.FragmentedArray):
446447
"""Fences the array construction from WGMMA instructions.
447448
448-
This is a little workaround to force LLVM to initialize the PTX registers
449-
before the wgmma.fence.sync.aligned instruction. Otherwise, LLVM treats
450-
in-register computation as pure and can move it after the fence, which is
451-
explicitly disallowed by the PTX programming model.
449+
LLVM treats in-register computation as pure and can move it after the fence,
450+
which is explicitly disallowed by the PTX programming model. For that reason,
451+
we insert an LLVM optimization barrier before the fence.
452452
"""
453-
i32 = ir.IntegerType.get_signless(32)
454-
index = ir.IndexType.get()
455-
dtype = array.mlir_dtype
456-
src_vec_ty = ir.VectorType(array.registers.flat[0].type)
457-
assert src_vec_ty.shape == [2]
458-
459-
if dtype == ir.F32Type.get():
460-
regs = [ # pylint: disable=g-complex-comprehension
461-
vector.extractelement(reg, position=c(pos, index))
462-
for reg in array.registers.flat
463-
for pos in range(2)
464-
]
465-
reg_dtype = dtype
466-
reg_constraints_list = ["=f"] * len(regs) + ["f"] * len(regs)
467-
ptx_lines = [f"mov.f32 ${i}, ${len(regs)+i}" for i in range(len(regs))]
468-
elif dtype == ir.F16Type.get() or dtype == ir.BF16Type.get():
469-
regs = [_as_i32_reg(reg) for reg in array.registers.flat]
470-
reg_dtype = i32
471-
reg_constraints_list = ["=r"] * len(regs) + ["r"] * len(regs)
472-
ptx_lines = [f"mov.b32 ${i}, ${len(regs)+i}" for i in range(len(regs))]
473-
else:
474-
raise NotImplementedError(dtype)
475-
reg_constraints = ",".join(reg_constraints_list)
476-
# Copy over the registers. ptxas should be able to remove the moves.
477-
ptx_lines.append("wgmma.fence.sync.aligned")
478-
ptx = ";\n".join(ptx_lines) + ";\n"
479-
dtype_str = str(reg_dtype)
480-
struct_ty = ir.Type.parse(
481-
f"!llvm.struct<({','.join(dtype_str for _ in regs)})>"
482-
)
483-
acc_struct = llvm.inline_asm(
484-
struct_ty, regs, ptx, reg_constraints,
485-
asm_dialect=0, has_side_effects=True,
486-
)
487-
regs = [
488-
llvm.extractvalue(reg_dtype, acc_struct, [i]) for i in range(len(regs))
489-
]
490-
if dtype == ir.F32Type.get():
491-
registers = _as_fragmented_reg_ndarray(
492-
regs, array.mlir_dtype, array.registers.shape
493-
)
494-
elif dtype == ir.F16Type.get() or dtype == ir.BF16Type.get():
495-
regs = [_unpack_i32(src_vec_ty, r) for r in regs]
496-
registers = np.asarray(regs, dtype=object).reshape(array.registers.shape)
497-
else:
498-
raise NotImplementedError(dtype)
499-
return mgpu.FragmentedArray(_registers=registers, _layout=array.layout, _is_signed=array.is_signed)
453+
array = mgpu.optimization_barrier(array)
454+
nvvm.wgmma_fence_aligned()
455+
return array
500456

501457

502458
def _as_fragmented_reg_ndarray(flat_regs, dtype: ir.Type, shape: tuple[int, ...]):

jax/experimental/pallas/mosaic_gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@
3131
from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline as emit_pipeline
3232
from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive
3333
from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait
34+
from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota
3435
from jax._src.pallas.mosaic_gpu.primitives import commit_smem as commit_smem
3536
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem
3637
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem
3738
from jax._src.pallas.mosaic_gpu.primitives import Layout as Layout
3839
from jax._src.pallas.mosaic_gpu.primitives import layout_cast as layout_cast
39-
from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota
4040
from jax._src.pallas.mosaic_gpu.primitives import set_max_registers as set_max_registers
4141
from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem
4242
from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma

jax/experimental/pallas/ops/gpu/attention_mgpu.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,19 @@ def compute_qk(acc_ref):
129129
l_i *= alpha
130130
p16 = p.astype(dtype)
131131

132-
plgpu.barrier_wait(v_barriers.at[slot])
133-
perform_schedule_barrier()
134-
135-
l_i += p.sum(axis=1)
132+
def end_softmax_barriers():
133+
plgpu.barrier_arrive(schedule_barrier) # Done with softmax!
134+
plgpu.barrier_wait(v_barriers.at[slot])
135+
plgpu.barrier_wait(schedule_barrier) # Wait until TensorCore is free.
136+
# Can't fully explain why, but empirically the ordering here influences
137+
# the performance of the final kernel quite significantly.
138+
if head_dim <= 128:
139+
l_i += p.sum(axis=1)
140+
acc, l_i, m_i, p16 = lax.optimization_barrier((acc, l_i, m_i, p16))
141+
end_softmax_barriers()
142+
else:
143+
end_softmax_barriers()
144+
l_i += p.sum(axis=1)
136145

137146
# PV
138147
def compute_pv(acc_ref):

tests/mosaic/gpu_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,6 +1674,20 @@ def kernel(ctx, inp, out, smem):
16741674
)(x)
16751675
np.testing.assert_array_equal(result, reference)
16761676

1677+
@parameterized.parameters(jnp.float32, jnp.float16, jnp.bfloat16)
1678+
def test_optimization_barrier(self, dtype):
1679+
def kernel(ctx, inp, out, smem):
1680+
del ctx, smem
1681+
arr = mgpu.FragmentedArray.load_strided(inp)
1682+
arr2 = arr * 2
1683+
arr, arr2 = mgpu.optimization_barrier(arr, arr2)
1684+
(arr + arr2).store_untiled(out)
1685+
1686+
x = jnp.arange(256, dtype=dtype)
1687+
1688+
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, None)
1689+
np.testing.assert_array_equal(f(x), x * 3)
1690+
16771691

16781692
class ProfilerTest(TestCase):
16791693

0 commit comments

Comments
 (0)