Skip to content

Commit e56afa6

Browse files
authored
[Gluon][Tutorial] Optimize attention kernel (#7238)
* Pass `m_i` through tensor memory * Use `FMUL2` instructions in correction partition * For `DHEAD=128`, subtile `p` production in the softmax partition * This allows us to scrounge a few extra registers for the correction partition (200->192) * This allows ptxas to overlap the f32 to f16 downcast with the exp * Reduce the number of instructions needed to apply causal masking * Each element is produced with IADD3, conditional, and FSEL
1 parent 94ff6af commit e56afa6

File tree

3 files changed

+127
-45
lines changed

3 files changed

+127
-45
lines changed

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def shared_load(self, mem_desc, layout):
155155
return ttgl.tensor(handle, ret_ty)
156156

157157
def shared_store(self, mem_desc, value):
158+
assert value.shape == mem_desc.shape, f"source shape {value.shape} and destination shape {mem_desc.shape} must match"
159+
assert value.dtype == mem_desc.dtype, f"source dtype {value.dtype} and destination dtype {mem_desc.dtype} must match"
158160
self.builder.create_local_store(mem_desc.handle, value.handle)
159161

160162
def shared_dealloc(self, mem_desc):

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ def load(self, layout, _semantic: GluonSemantic) -> ttgl.tensor:
130130
def store(self, value, pred=True, _semantic: GluonSemantic = None) -> None:
131131
pred = _unwrap_if_constexpr(pred)
132132
pred = _semantic.to_tensor(pred)
133+
assert value.shape == self.shape, f"source shape {value.shape} does not match destination shape {self.shape}"
134+
assert value.dtype == self.dtype, f"source dtype {value.dtype} does not match destination dtype {self.dtype}"
133135
_semantic.builder.create_tmem_store(self.handle, value.handle, pred.handle)
134136

135137
@builtin

python/tutorials/gluon/01-attention-forward.py

Lines changed: 123 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -241,14 +241,14 @@ def __init__(self, channel, phase, index):
241241

242242
@gluon.jit
243243
def acquire(self):
244-
smem, ready_bar = self.channel.acquire_producer(self.index, self.phase)
244+
mem, ready_bar = self.channel.acquire_producer(self.index, self.phase)
245245
self.index, self.phase = self.channel.increment(self.index, self.phase)
246-
return smem, ready_bar, self
246+
return mem, ready_bar, self
247247

248248
@gluon.jit
249249
def emplace(self, value):
250-
smem, ready_bar, self = self.acquire()
251-
smem.store(value)
250+
mem, ready_bar, self = self.acquire()
251+
mem.store(value)
252252
mbarrier.arrive(ready_bar, count=1)
253253
return self
254254

@@ -265,14 +265,14 @@ def __init__(self, channel, phase, index):
265265

266266
@gluon.jit
267267
def acquire(self):
268-
smem, empty_bar = self.channel.acquire_consumer(self.index, self.phase)
268+
mem, empty_bar = self.channel.acquire_consumer(self.index, self.phase)
269269
self.index, self.phase = self.channel.increment(self.index, self.phase)
270-
return smem, empty_bar, self
270+
return mem, empty_bar, self
271271

272272
@gluon.jit
273273
def get(self, layout: gl.constexpr):
274-
smem, empty_bar, self = self.acquire()
275-
value = smem.load(layout)
274+
mem, empty_bar, self = self.acquire()
275+
value = mem.load(layout)
276276
mbarrier.arrive(empty_bar, count=1)
277277
return value, self
278278

@@ -399,9 +399,9 @@ class AttentionConfig:
399399
dtype: gl.constexpr
400400
num_warps: gl.constexpr
401401

402-
SPLIT_N_FACTOR: gl.constexpr
402+
SPLIT_D_FACTOR: gl.constexpr
403403
SPLIT_M: gl.constexpr
404-
SPLIT_N: gl.constexpr
404+
SPLIT_D: gl.constexpr
405405

406406
q_shape: gl.constexpr
407407
k_shape: gl.constexpr
@@ -416,8 +416,11 @@ class AttentionConfig:
416416
qk_layout: gl.constexpr
417417
o_layout: gl.constexpr
418418
o_splitn_layout: gl.constexpr
419+
mi_2d_layout: gl.constexpr
419420

420-
def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, dtype, num_warps, SPLIT_N_FACTOR):
421+
mi_use_tmem: gl.constexpr
422+
423+
def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, dtype, num_warps, SPLIT_D_FACTOR):
421424
self.qk_scale = qk_scale
422425
self.Z = Z
423426
self.H = H
@@ -428,9 +431,9 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, dtype, num
428431
self.dtype = gl.constexpr(dtype)
429432
self.num_warps = gl.constexpr(num_warps)
430433

431-
self.SPLIT_N_FACTOR = SPLIT_N_FACTOR
434+
self.SPLIT_D_FACTOR = SPLIT_D_FACTOR
432435
self.SPLIT_M = self.BLOCK_M // 2
433-
self.SPLIT_N = self.BLOCK_N // self.SPLIT_N_FACTOR
436+
self.SPLIT_D = self.HEAD_DIM // self.SPLIT_D_FACTOR
434437

435438
self.q_shape = gl.constexpr([self.SPLIT_M, self.HEAD_DIM])
436439
self.k_shape = gl.constexpr([self.BLOCK_N, self.HEAD_DIM])
@@ -447,8 +450,11 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, dtype, num
447450
self.qk_layout = gl.constexpr(get_tmem_32x32b_reg_layout(qk_instr_shape, self.qk_shape, self.num_warps))
448451
self.o_layout = gl.constexpr(get_tmem_32x32b_reg_layout(o_instr_shape, self.o_shape, self.num_warps))
449452
self.o_splitn_layout = gl.constexpr(
450-
get_tmem_32x32b_reg_layout((o_instr_shape[0], o_instr_shape[1] // self.SPLIT_N_FACTOR, o_instr_shape[2]),
451-
(self.o_shape[0], self.o_shape[1] // self.SPLIT_N_FACTOR), self.num_warps))
453+
get_tmem_32x32b_reg_layout((o_instr_shape[0], o_instr_shape[1] // self.SPLIT_D_FACTOR, o_instr_shape[2]),
454+
(self.o_shape[0], self.o_shape[1] // self.SPLIT_D_FACTOR), self.num_warps))
455+
self.mi_2d_layout = gl.constexpr(gl.BlockedLayout([1, 1], [32, 1], [4, 1], [0, 1]))
456+
457+
self.mi_use_tmem = gl.constexpr(True)
452458

453459
@gluon.jit
454460
def get_program(self):
@@ -539,7 +545,7 @@ class InnerLoopInfo:
539545
qk_mma_ctx: MMAContext
540546
o_mma_ctx: MMAContext
541547
p_chnl: TensorMemoryChannel
542-
mi_chnl: SharedMemoryChannel
548+
mi_chnl: TensorMemoryChannel
543549
li_smem: gl.shared_memory_descriptor
544550
q_smem: gl.shared_memory_descriptor
545551

@@ -552,14 +558,25 @@ def create(config, tile):
552558
o_mma_ctx.channel.initialize_for_consumer()
553559
o_mma_ctx.channel.mem.index(0).store(tile.acc)
554560

555-
p_chnl = TensorMemoryChannel._borrow(qk_mma_ctx.channel.mem, config.qk_shape, config.dtype,
556-
config.p_tmem_layout, num_buffers=1, num_consumers=1)
561+
# QK and PV MMAs are serialized, which enables borrowing QK's memory.
562+
borrow_tmem = qk_mma_ctx.channel.mem.index(0)
563+
p_tmem = borrow_tmem.slice(0, config.BLOCK_N // 2)
564+
mi_tmem = borrow_tmem.slice(config.BLOCK_N // 2, 1)
565+
mi_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=False)
566+
567+
p_chnl = TensorMemoryChannel._borrow(p_tmem, config.qk_shape, config.dtype, config.p_tmem_layout, num_buffers=1,
568+
num_consumers=1)
557569
p_chnl.initialize_for_producer()
558570

559-
mi_chnl = SharedMemoryChannel.create([config.SPLIT_M], gl.float32, gl.constexpr(mbarrier.MBarrierLayout()),
560-
num_buffers=1)
571+
if config.mi_use_tmem:
572+
mi_chnl = TensorMemoryChannel._borrow(mi_tmem, [config.SPLIT_M, 1], gl.float32, mi_layout, num_buffers=1)
573+
m_i = gl.convert_layout(tile.m_i.expand_dims(1), config.mi_2d_layout)
574+
else:
575+
mi_chnl = SharedMemoryChannel.create([config.SPLIT_M], gl.float32, gl.constexpr(mbarrier.MBarrierLayout()),
576+
num_buffers=1)
577+
m_i = tile.m_i
578+
mi_chnl.mem.index(0).store(m_i)
561579
mi_chnl.initialize_for_producer()
562-
mi_chnl.mem.index(0).store(tile.m_i)
563580

564581
li_smem = gl.allocate_shared_memory(gl.float32, [config.SPLIT_M], gl.constexpr(mbarrier.MBarrierLayout()))
565582
li_smem.store(tile.l_i)
@@ -662,21 +679,66 @@ def _attn_fwd_mma(config, #
662679
mbarrier.invalidate(qk_p_bar)
663680

664681

682+
@gluon.jit
683+
def _add_f32x2(a, b):
684+
return gl.inline_asm_elementwise(
685+
"""
686+
{
687+
.reg .b64 ra, rb, rc;
688+
mov.b64 ra, { $2, $3 };
689+
mov.b64 rb, { $4, $5 };
690+
add.f32x2 rc, ra, rb;
691+
mov.b64 { $0, $1 }, rc;
692+
}
693+
""",
694+
"=r,=r,r,r,r,r",
695+
[a, b],
696+
dtype=gl.float32,
697+
is_pure=True,
698+
pack=2,
699+
)
700+
701+
702+
@gluon.jit
703+
def _mul_f32x2(a, b):
704+
return gl.inline_asm_elementwise(
705+
"""
706+
{
707+
.reg .b64 ra, rb, rc;
708+
mov.b64 ra, { $2, $3 };
709+
mov.b64 rb, { $4, $5 };
710+
mul.f32x2 rc, ra, rb;
711+
mov.b64 { $0, $1 }, rc;
712+
}
713+
""",
714+
"=r,=r,r,r,r,r",
715+
[a, b],
716+
dtype=gl.float32,
717+
is_pure=True,
718+
pack=2,
719+
)
720+
721+
665722
@gluon.jit
666723
def _attn_fwd_correction_compute(config, mi_consumer, o_consumer, m_i):
667-
m_ij, mi_consumer = mi_consumer.get(gl.constexpr(gl.SliceLayout(1, config.o_splitn_layout)))
724+
mi_layout: gl.constexpr = gl.SliceLayout(1, config.o_splitn_layout)
725+
if config.mi_use_tmem:
726+
m_ij, mi_consumer = mi_consumer.get(config.mi_2d_layout)
727+
m_ij = gl.convert_layout(m_ij.reshape([config.SPLIT_M]), mi_layout)
728+
else:
729+
m_ij, mi_consumer = mi_consumer.get(mi_layout)
668730
alpha = gl.exp2(m_i - m_ij)
669731

670732
o_tmem, o_bar, o_consumer = o_consumer.acquire()
671-
if config.SPLIT_N_FACTOR == 1:
733+
if config.SPLIT_D_FACTOR == 1:
672734
o = o_tmem.load(config.o_layout)
673-
o = o * alpha[:, None]
735+
o = _mul_f32x2(o, alpha[:, None])
674736
o_tmem.store(o)
675737
else:
676-
for i in tl.static_range(config.SPLIT_N_FACTOR):
677-
o_ref = o_tmem.slice(i * config.SPLIT_N, config.SPLIT_N)
738+
for i in tl.static_range(config.SPLIT_D_FACTOR):
739+
o_ref = o_tmem.slice(i * config.SPLIT_D, config.SPLIT_D)
678740
o = o_ref.load(config.o_splitn_layout)
679-
o = o * alpha[:, None]
741+
o = _mul_f32x2(o, alpha[:, None])
680742
o_ref.store(o)
681743
mbarrier.arrive(o_bar, count=1)
682744
return mi_consumer, o_consumer, m_ij
@@ -723,31 +785,48 @@ def _softmax_tile(tile_id: gl.constexpr, config, info, STAGE: gl.constexpr):
723785
p_producer = info.p_chnl.create_producer()
724786
mi_producer = info.mi_chnl.create_producer()
725787

726-
m_i = info.mi_chnl.mem.index(0).load(qk_slice_dim1)
788+
if config.mi_use_tmem:
789+
m_i = info.mi_chnl.mem.index(0).load(config.mi_2d_layout)
790+
m_i = gl.convert_layout(m_i.reshape([config.SPLIT_M]), qk_slice_dim1)
791+
else:
792+
m_i = info.mi_chnl.mem.index(0).load(qk_slice_dim1)
727793
l_i = info.li_smem.load(qk_slice_dim1)
728794

729795
for start_n in range(lo, hi, config.BLOCK_N):
730796
qk, qk_consumer = qk_consumer.get(config.qk_layout)
797+
if config.HEAD_DIM == 128:
798+
p_tmem, p_bar, p_producer = p_producer.acquire()
799+
731800
if STAGE == 2:
732801
# Prevent LLVM from hoisting the partial sums, which triggers spilling.
733802
offs_n = gl.inline_asm_elementwise("mov.b32 $0, $0;", "=r,r", [offs_n], dtype=gl.int32, is_pure=True,
734803
pack=1)
735804
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
736-
qk = qk * config.qk_scale + gl.where(mask, 0, -1.0e6)
737-
m_ij = gl.maximum(m_i, gl.max(qk, 1))
738-
mi_producer = mi_producer.emplace(m_ij)
739-
qk -= m_ij[:, None]
805+
qk = gl.where(mask, qk, -1.0e8)
806+
m_ij = gl.maximum(m_i, gl.max(qk, 1) * config.qk_scale)
807+
if config.mi_use_tmem:
808+
mi_producer = mi_producer.emplace(gl.convert_layout(m_ij.expand_dims(1), config.mi_2d_layout))
740809
else:
741-
m_ij = gl.maximum(m_i, gl.max(qk, 1) * config.qk_scale)
742810
mi_producer = mi_producer.emplace(m_ij)
743-
qk = qk * config.qk_scale - m_ij[:, None]
744-
745-
p = gl.exp2(qk)
811+
qk = qk * config.qk_scale - m_ij[:, None]
746812

747-
l_ij = gl.sum(p, 1)
748-
alpha = gl.exp2(m_i - m_ij)
749-
750-
p_producer = p_producer.emplace(p.to(config.dtype))
813+
if config.HEAD_DIM == 64:
814+
p = gl.exp2(qk)
815+
l_ij = gl.sum(p, 1)
816+
alpha = gl.exp2(m_i - m_ij)
817+
p_producer = p_producer.emplace(p.to(config.dtype))
818+
else:
819+
qk0, qk1, = qk.reshape([config.SPLIT_M, 2, config.BLOCK_N // 2]).permute(0, 2, 1).split()
820+
p0 = gl.exp2(qk0)
821+
p_tmem.slice(0, config.BLOCK_N // 2).store(p0.to(config.dtype))
822+
p1 = gl.exp2(qk1)
823+
p_tmem.slice(config.BLOCK_N // 2, config.BLOCK_N // 2).store(p1.to(config.dtype))
824+
mbarrier.arrive(p_bar, count=1)
825+
p = gl.join(p0, p1).permute(0, 2, 1).reshape([config.SPLIT_M, config.BLOCK_N])
826+
p = gl.convert_layout(p, config.qk_layout)
827+
828+
l_ij = gl.sum(p, 1)
829+
alpha = gl.exp2(m_i - m_ij)
751830

752831
l_i = l_i * alpha + l_ij
753832
m_i = m_ij
@@ -773,7 +852,7 @@ def _attn_fwd_softmax1(config, #
773852
def _attn_fwd_inner(config, info0, info1, m_i0, m_i1, #
774853
desc_k, desc_v, #
775854
STAGE: gl.constexpr):
776-
num_buffers: gl.constexpr = 2 if config.HEAD_DIM >= 128 else 3
855+
num_buffers: gl.constexpr = 2 if config.HEAD_DIM == 128 else 3
777856
k_load_ctx = LoadContext.create(desc_k, num_buffers=num_buffers, num_consumers=2)
778857
v_load_ctx = LoadContext.create(desc_v, num_buffers=num_buffers, num_consumers=2)
779858

@@ -793,7 +872,7 @@ def _attn_fwd_inner(config, info0, info1, m_i0, m_i1, #
793872
_attn_fwd_softmax1,
794873
_attn_fwd_mma,
795874
_attn_fwd_load,
796-
], [4, 4, 1, 1], [192, 200, 32, 32])
875+
], [4, 4, 1, 1], [192, 192, 32, 32])
797876

798877
k_load_ctx.release()
799878
v_load_ctx.release()
@@ -809,8 +888,7 @@ def _gluon_attn(sm_scale, M, Z, H, N_CTX, #
809888
num_warps: gl.constexpr):
810889
qk_scale = sm_scale
811890
qk_scale *= 1.44269504
812-
config = AttentionConfig(qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, dtype, num_warps,
813-
SPLIT_N_FACTOR=triton.cdiv(HEAD_DIM, 64))
891+
config = AttentionConfig(qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, dtype, num_warps, SPLIT_D_FACTOR=2)
814892

815893
prog = config.get_program()
816894

@@ -909,7 +987,7 @@ def is_blackwell():
909987
@pytest.mark.parametrize("H", [2, 48])
910988
@pytest.mark.parametrize("N_CTX", [256, 1024, 4 * 1024])
911989
@pytest.mark.parametrize("HEAD_DIM", [64, 128])
912-
@pytest.mark.parametrize("causal", [True])
990+
@pytest.mark.parametrize("causal", [False, True])
913991
@pytest.mark.parametrize("dtype", [torch.float16])
914992
@pytest.mark.skipif(not is_blackwell(), reason="Gluon attention is only supported on Blackwell GPUs")
915993
def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype):

0 commit comments

Comments
 (0)