Skip to content

Commit 469dd45

Browse files
committed
Merge commit '570f24d016702cbfe1179beae3eb03d24e9c6b40'
2 parents 71cc33b + 570f24d commit 469dd45

File tree

18 files changed

+1009
-371
lines changed

18 files changed

+1009
-371
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2012,6 +2012,13 @@ SwizzledSharedEncodingAttr AMDMfmaEncodingAttr::composeSharedLayoutForOperand(
20122012
ArrayRef<unsigned> sharedOrder, unsigned vectorSize, unsigned elemBitWidth,
20132013
bool needTrans) const {
20142014
int kDimIndex = operandIdx == 0 ? 1 : 0;
2015+
2016+
// Disable swizzling for scales
2017+
if (operandIdx >= 2) {
2018+
return SwizzledSharedEncodingAttr::get(getContext(), 1, 1, 1, sharedOrder,
2019+
ctaLayout);
2020+
}
2021+
20152022
if (needTrans)
20162023
kDimIndex = 1 - kDimIndex;
20172024

lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,94 @@ class SinkTMEMLoad : public OpRewritePattern<TMEMTokenLoadOp> {
148148
}
149149
};
150150

151+
// Combine back TMEM alloc and store. This is equivalent but gives us a more
152+
// canonical form to do further optimizations.
153+
class CombineTMEMStoreAndAlloc : public OpRewritePattern<TMEMTokenStoreOp> {
154+
public:
155+
using OpRewritePattern::OpRewritePattern;
156+
157+
LogicalResult matchAndRewrite(TMEMTokenStoreOp store,
158+
PatternRewriter &rewriter) const override {
159+
if (!matchPattern(store.getPred(), m_One()))
160+
return failure();
161+
auto alloc = store.getDep().getDefiningOp<TMEMTokenAllocOp>();
162+
if (!alloc)
163+
return failure();
164+
if (alloc->getBlock() != store->getBlock())
165+
return failure();
166+
alloc.getSrcMutable().assign(store.getSrc());
167+
rewriter.replaceOp(store, alloc.getToken());
168+
return success();
169+
}
170+
};
171+
172+
// Hoists a tmem alloc outside an if op like this:
173+
// %0 = scf.if {
174+
// %1, %token0 = tmem.alloc %init
175+
// ...
176+
// %2 = tmem.load %1, %token1
177+
// scf.yield %2
178+
// } else {
179+
// scf.yield %init
180+
// }
181+
// ->
182+
// %a, %token0 = tmem.alloc %init
183+
// %token2 = scf.if {
184+
//
185+
// ...
186+
// scf.yield %token1
187+
// } else {
188+
// scf.yield %token0
189+
// }
190+
// %2 = tmem.load %a, %token2
191+
class HoistTMEMAllocOutOfIf : public OpRewritePattern<ttng::TMEMAllocOp> {
192+
public:
193+
using OpRewritePattern::OpRewritePattern;
194+
195+
LogicalResult matchAndRewrite(ttng::TMEMAllocOp alloc,
196+
PatternRewriter &rewriter) const override {
197+
if (!alloc.getToken())
198+
return failure();
199+
Value init = alloc.getSrc();
200+
if (!init)
201+
return failure();
202+
auto ifOp = dyn_cast<scf::IfOp>(alloc->getParentOp());
203+
if (!ifOp)
204+
return failure();
205+
auto thenOp = ifOp.thenBlock()->getTerminator();
206+
auto elseOp = ifOp.elseBlock()->getTerminator();
207+
SmallVector<int> yieldArgs;
208+
for (auto [thenOperand, elseOperand] :
209+
llvm::zip(thenOp->getOpOperands(), elseOp->getOpOperands())) {
210+
auto load = thenOperand.get().getDefiningOp<TMEMTokenLoadOp>();
211+
if (!load || load.getSrc() != alloc.getResult())
212+
continue;
213+
if (elseOperand.get() != init)
214+
continue;
215+
yieldArgs.push_back(thenOperand.getOperandNumber());
216+
}
217+
if (yieldArgs.empty())
218+
return failure();
219+
// Since init is used in the else terminator we know that it dominates the
220+
// if op.
221+
alloc->moveBefore(ifOp);
222+
rewriter.setInsertionPointAfter(ifOp);
223+
for (int argNo : yieldArgs) {
224+
auto load =
225+
cast<TMEMTokenLoadOp>(thenOp->getOperand(argNo).getDefiningOp());
226+
auto newLoad = cast<TMEMTokenLoadOp>(rewriter.clone(*load));
227+
rewriter.modifyOpInPlace(ifOp, [&] {
228+
ifOp->getResult(argNo).replaceAllUsesWith(newLoad.getResult());
229+
newLoad.getDepMutable().assign(ifOp->getResult(argNo));
230+
thenOp->setOperand(argNo, load.getToken());
231+
elseOp->setOperand(argNo, alloc.getToken());
232+
ifOp->getResult(argNo).setType(newLoad.getToken().getType());
233+
});
234+
}
235+
return success();
236+
}
237+
};
238+
151239
// Remove loop-carried tensor dependencies if they are fed immediately into a
152240
// TMEM store by pulling the store into the previous iteration.
153241
class RotateTMEMStoreInLoop : public OpRewritePattern<TMEMTokenStoreOp> {
@@ -412,11 +500,29 @@ struct HoistTMEMAlloc
412500
mlir::RewritePatternSet patterns(&getContext());
413501
patterns.add<RotateTMEMStoreInLoop, RotateTMEMLoadInLoop,
414502
CombineTMEMLoadAndStore, CombineTMEMStoreAndSelect,
415-
SinkTMEMLoad, RemoveUnusedTMEMLoad>(&getContext());
503+
SinkTMEMLoad, RemoveUnusedTMEMLoad, CombineTMEMStoreAndAlloc,
504+
HoistTMEMAllocOutOfIf>(&getContext());
416505
scf::ForOp::getCanonicalizationPatterns(patterns, &getContext());
417506
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
418507
llvm_unreachable("Failed to hoist tmem_store");
419508
}
509+
510+
// TODO: currently some code assumes that a mutable tmem alloc doesn't have
511+
// an initial value. As a workaround we break up the op in order to keep
512+
// this form for the downstream passes. We should remove this once the
513+
// downstread passes are fixed.
514+
m.walk([&](ttng::TMEMAllocOp alloc) {
515+
if (alloc.getType().getMutableMemory() && alloc.getSrc()) {
516+
OpBuilder builder(alloc);
517+
builder.setInsertionPointAfter(alloc);
518+
auto store = builder.create<ttng::TMEMStoreOp>(
519+
alloc.getLoc(), builder.getType<AsyncTokenType>(),
520+
alloc.getResult(), alloc.getToken(), alloc.getSrc(),
521+
builder.create<arith::ConstantIntOp>(alloc.getLoc(), 1, 1));
522+
alloc.getToken().replaceAllUsesExcept(store.getToken(), store);
523+
alloc.getSrcMutable().clear();
524+
}
525+
});
420526
}
421527
};
422528

python/src/passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,8 @@
3636
#define ADD_PASS_OPTION_WRAPPER_4(name, builder, ty0, ty1, ty2, ty3) \
3737
m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \
3838
ty3 val3) { pm.addPass(builder({val0, val1, val2, val3})); })
39+
40+
#define ADD_PASS_OPTION_WRAPPER_5(name, builder, ty0, ty1, ty2, ty3, ty4) \
41+
m.def(name, \
42+
[](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, ty3 val3, \
43+
ty4 val4) { pm.addPass(builder({val0, val1, val2, val3, val4})); })

python/tutorials/gluon/01-attention-forward.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,18 @@ def issue_async_tma_load(smem, bar, desc, offset):
173173
tma.async_copy_global_to_shared(desc, [offset, 0], bar, smem)
174174

175175

176+
@gluon.jit
177+
def _interleave_n(a, b, size: gl.constexpr, f: gl.constexpr, i: gl.constexpr = 0):
178+
if a.shape[1] == size:
179+
return f(a, b, i)
180+
else:
181+
a0, a1 = a.reshape([a.shape[0], 2, a.shape[1] // 2]).permute(0, 2, 1).split()
182+
b0, b1 = b.reshape([b.shape[0], 2, b.shape[1] // 2]).permute(0, 2, 1).split()
183+
c0 = _interleave_n(a0, b0, size, f, i)
184+
c1 = _interleave_n(a1, b1, size, f, i + a.shape[1] // 2)
185+
return gl.convert_layout(gl.join(c0, c1).permute(0, 2, 1).reshape(a.shape), a.type.layout)
186+
187+
176188
# ===-----------------------------------------------------------------------===#
177189
# Gluon Attention
178190
# ===-----------------------------------------------------------------------===#
@@ -556,7 +568,7 @@ def _attn_fwd_mma(config, chnls, descs, M, STAGE: gl.constexpr):
556568
s0_tmem, s0_bar, s0_producer = s0_producer.acquire()
557569
p0_tmem = _borrow_s_as_p(config, s0_tmem)
558570
tcgen05_mma(p0_tmem, v_smem, o0_tmem, use_acc=False, mbarriers=[o0_bar])
559-
o_init = False
571+
o1_init = False
560572

561573
for _ in range(num_mmas - 1):
562574
k_smem, k_bar, kv_consumer = kv_consumer.acquire()
@@ -565,43 +577,69 @@ def _attn_fwd_mma(config, chnls, descs, M, STAGE: gl.constexpr):
565577
o1_tmem, o1_bar, o_producer = o_producer.acquire()
566578
s1_tmem, s1_bar, s1_producer = s1_producer.acquire()
567579
p1_tmem = _borrow_s_as_p(config, s1_tmem)
568-
tcgen05_mma(p1_tmem, v_smem, o1_tmem, use_acc=o_init, mbarriers=[o1_bar, v_bar])
569-
o_init = True
580+
tcgen05_mma(p1_tmem, v_smem, o1_tmem, use_acc=o1_init, mbarriers=[o1_bar, v_bar])
581+
o1_init = True
570582

571583
tcgen05_mma(q1_smem, k_smem.permute((1, 0)), s1_tmem, use_acc=False, mbarriers=[s1_bar, k_bar])
572584

573585
v_smem, v_bar, kv_consumer = kv_consumer.acquire()
574586
o0_tmem, o0_bar, o_producer = o_producer.acquire()
575587
s0_tmem, s0_bar, s0_producer = s0_producer.acquire()
576588
p0_tmem = _borrow_s_as_p(config, s0_tmem)
577-
tcgen05_mma(p0_tmem, v_smem, o0_tmem, use_acc=o_init, mbarriers=[o0_bar])
578-
o_init = True
589+
tcgen05_mma(p0_tmem, v_smem, o0_tmem, mbarriers=[o0_bar])
579590

580591
tcgen05_commit(q0_bar)
581592
tcgen05_commit(q1_bar)
582593

583594
o1_tmem, o1_bar, o_producer = o_producer.acquire()
584595
s1_tmem, s1_bar, s1_producer = s1_producer.acquire()
585596
p1_tmem = _borrow_s_as_p(config, s1_tmem)
586-
tcgen05_mma(p1_tmem, v_smem, o1_tmem, use_acc=o_init, mbarriers=[o1_bar, v_bar, s0_bar, s1_bar])
597+
tcgen05_mma(p1_tmem, v_smem, o1_tmem, use_acc=o1_init, mbarriers=[o1_bar, v_bar, s0_bar, s1_bar])
598+
599+
600+
@gluon.jit
601+
def _mask_inner(qk, mask, i: gl.constexpr):
602+
mask_i_bit = mask & (1 << i) == 0
603+
return gl.where(mask_i_bit, qk, -float("inf"))
604+
605+
606+
@gluon.jit
607+
def _mask_frag(qk, col_limit_right, s: gl.constexpr):
608+
col_limit_right_s = col_limit_right - s
609+
col_limit_right_cur = max(col_limit_right_s, 0)
610+
mask = -1 << col_limit_right_cur
611+
return _interleave_n(qk, mask, 1, _mask_inner)
612+
613+
614+
@gluon.jit
615+
def _mask_bits(qk, col_limit_right):
616+
# FIXME: This is a more concise implementation (which compiles faster) but
617+
# it results in slightly slower code due to the lack of interleaving.
618+
offs_n = gl.arange(0, qk.shape[1], layout=gl.SliceLayout(0, qk.type.layout))[None, :]
619+
s = offs_n & ~0xf
620+
i = offs_n & 0xf
621+
622+
col_lim_right_s = col_limit_right - s
623+
col_lim_right_cur = max(col_lim_right_s, 0)
624+
mask = -1 << col_lim_right_cur
625+
mask_i_bit = (mask & (1 << i)) == 0
626+
return gl.where(mask_i_bit, qk, -float("inf"))
587627

588628

589629
@gluon.jit
590630
def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, #
591631
s_consumer, corr_producer, exp_turnstile, corr_bar, #
592-
offs_m, offs_n, m_i, l_i0, l_i1, STAGE: gl.constexpr):
632+
offs_m, m_i, l_i0, l_i1, STAGE: gl.constexpr):
593633
lo, hi = prog.get_loop_bounds(STAGE)
594634

595635
for start_n in range(lo, hi, config.BLOCK_N):
596636
s_tmem, s_bar, s_consumer = s_consumer.acquire()
597637
qk = s_tmem.load(config.qk_layout)
598638

599639
if STAGE == 2:
600-
# Prevent LLVM from hoisting the partial sums, which triggers spilling.
601-
offs_n = gl.inline_asm_elementwise("mov.b32 $0, $0;", "=r,r", [offs_n], dtype=gl.int32, is_pure=True,
602-
pack=1)
603-
mask = offs_m[:, None] < (start_n + offs_n[None, :])
604-
qk = gl.where(mask, -1.0e8, qk)
640+
col_limit_right = (offs_m - start_n + 1)[:, None].broadcast_to(qk.shape)
641+
qk = _interleave_n(qk, col_limit_right, 16, _mask_frag)
642+
605643
m_ij = gl.maximum(m_i, gl.max(qk, 1) * config.qk_scale)
606644
alpha = gl.exp2(m_i - m_ij)
607645

@@ -682,11 +720,8 @@ def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, #
682720
@gluon.jit
683721
def _softmax_tile(tile_id: gl.constexpr, config, M, desc_o, STAGE: gl.constexpr, #
684722
s_chnl, corr_chnl, exp_turnstile):
685-
qk_slice_dim0: gl.constexpr = gl.SliceLayout(0, config.qk_layout)
686723
qk_slice_dim1: gl.constexpr = gl.SliceLayout(1, config.qk_layout)
687724

688-
offs_n = gl.arange(0, config.BLOCK_N, qk_slice_dim0)
689-
690725
s_consumer = s_chnl.create_consumer()
691726
corr_producer = corr_chnl.create_producer()
692727
_, corr_bar, corr_producer = corr_producer.acquire()
@@ -709,11 +744,11 @@ def _softmax_tile(tile_id: gl.constexpr, config, M, desc_o, STAGE: gl.constexpr,
709744
if STAGE & 1:
710745
m_i, l_i0, l_i1, corr_bar, s_consumer, corr_producer, exp_turnstile = _softmax_inner_loop( #
711746
tile_id, config, prog, s_consumer, corr_producer, exp_turnstile, corr_bar, #
712-
offs_m, offs_n, m_i, l_i0, l_i1, STAGE=4 - STAGE)
747+
offs_m, m_i, l_i0, l_i1, STAGE=4 - STAGE)
713748
if STAGE & 2:
714749
m_i, l_i0, l_i1, corr_bar, s_consumer, corr_producer, exp_turnstile = _softmax_inner_loop( #
715750
tile_id, config, prog, s_consumer, corr_producer, exp_turnstile, corr_bar, #
716-
offs_m, offs_n, m_i, l_i0, l_i1, STAGE=2)
751+
offs_m, m_i, l_i0, l_i1, STAGE=2)
717752

718753
if config.use_fadd2_reduce:
719754
l_i = l_i0 + l_i1

0 commit comments

Comments
 (0)