Skip to content

Commit ea82c9f

Browse files
authored
[Gluon] Clarify o_init meaning in attention tutorial (#7563)
Setting `o_init = True` twice in the mma loop looks like a bug, but it's actually alright because `o0_tmem` is unconditionally initialized outside the loop. This makes it clearer than `o_init` refers specifically to `o1`.
1 parent 3bd3e32 commit ea82c9f

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

python/tutorials/gluon/01-attention-forward.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def _attn_fwd_mma(config, chnls, descs, M, STAGE: gl.constexpr):
568568
s0_tmem, s0_bar, s0_producer = s0_producer.acquire()
569569
p0_tmem = _borrow_s_as_p(config, s0_tmem)
570570
tcgen05_mma(p0_tmem, v_smem, o0_tmem, use_acc=False, mbarriers=[o0_bar])
571-
o_init = False
571+
o1_init = False
572572

573573
for _ in range(num_mmas - 1):
574574
k_smem, k_bar, kv_consumer = kv_consumer.acquire()
@@ -577,25 +577,24 @@ def _attn_fwd_mma(config, chnls, descs, M, STAGE: gl.constexpr):
577577
o1_tmem, o1_bar, o_producer = o_producer.acquire()
578578
s1_tmem, s1_bar, s1_producer = s1_producer.acquire()
579579
p1_tmem = _borrow_s_as_p(config, s1_tmem)
580-
tcgen05_mma(p1_tmem, v_smem, o1_tmem, use_acc=o_init, mbarriers=[o1_bar, v_bar])
581-
o_init = True
580+
tcgen05_mma(p1_tmem, v_smem, o1_tmem, use_acc=o1_init, mbarriers=[o1_bar, v_bar])
581+
o1_init = True
582582

583583
tcgen05_mma(q1_smem, k_smem.permute((1, 0)), s1_tmem, use_acc=False, mbarriers=[s1_bar, k_bar])
584584

585585
v_smem, v_bar, kv_consumer = kv_consumer.acquire()
586586
o0_tmem, o0_bar, o_producer = o_producer.acquire()
587587
s0_tmem, s0_bar, s0_producer = s0_producer.acquire()
588588
p0_tmem = _borrow_s_as_p(config, s0_tmem)
589-
tcgen05_mma(p0_tmem, v_smem, o0_tmem, use_acc=o_init, mbarriers=[o0_bar])
590-
o_init = True
589+
tcgen05_mma(p0_tmem, v_smem, o0_tmem, mbarriers=[o0_bar])
591590

592591
tcgen05_commit(q0_bar)
593592
tcgen05_commit(q1_bar)
594593

595594
o1_tmem, o1_bar, o_producer = o_producer.acquire()
596595
s1_tmem, s1_bar, s1_producer = s1_producer.acquire()
597596
p1_tmem = _borrow_s_as_p(config, s1_tmem)
598-
tcgen05_mma(p1_tmem, v_smem, o1_tmem, use_acc=o_init, mbarriers=[o1_bar, v_bar, s0_bar, s1_bar])
597+
tcgen05_mma(p1_tmem, v_smem, o1_tmem, use_acc=o1_init, mbarriers=[o1_bar, v_bar, s0_bar, s1_bar])
599598

600599

601600
@gluon.jit

0 commit comments

Comments
 (0)