Skip to content

Commit b61c0b1

Browse files
Merge commit '40dd0c41758fc52ac1423c2e55332088fc865702'
2 parents 2ae6dd8 + 40dd0c4 commit b61c0b1

File tree

49 files changed

+1106
-820
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1106
-820
lines changed

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,13 @@ getLastUseOfPipelinedOp(ArrayRef<Operation *> ops, scf::ForOp forOp,
184184

185185
// Clean up attributes passing over schedules across stages in pipelining
186186
void removePipeliningAttributes(ModuleOp moduleOp);
187+
188+
// For LoadOp, DescriptorLoad, and DescriptorGather ops, determine if
189+
// they should be pipelined.
190+
bool isPipeliningBeneficial(Operation *op,
191+
triton::ModuleAxisInfoAnalysis &axisInfoAnalysis,
192+
bool filterSmall = true);
193+
187194
} // namespace triton
188195
} // namespace mlir
189196

lib/Analysis/AxisInfo.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl<ub::PoisonOp> {
276276
getAxisInfo(ub::PoisonOp op,
277277
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
278278
unsigned rank = 1;
279-
if (auto shape = dyn_cast<mlir::ShapedType>(op.getType()))
279+
if (auto shape = dyn_cast<RankedTensorType>(op.getType()))
280280
rank = shape.getRank();
281281

282282
// Poison values are never accessed, thus assume optimistic values.
@@ -1227,6 +1227,7 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
12271227
return rhs;
12281228
if (rhs.getRank() == 0)
12291229
return lhs;
1230+
assert(lhs.getRank() == rhs.getRank() && "Mismatched ranks");
12301231
DimVectorT contiguity;
12311232
DimVectorT divisibility;
12321233
DimVectorT constancy;

lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp

Lines changed: 1 addition & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -88,64 +88,6 @@ class AssignLoadLatencies {
8888
scf::ForOp forOp;
8989
int numStages;
9090
DenseMap<Operation *, int> &opLatency;
91-
92-
public:
93-
static bool canHaveSharedEncoding(tt::LoadOp op) {
94-
// If used by an user with DotOp encoding, all the uses must be compatible.
95-
bool incompatible = false;
96-
getSharedEncIfAllUsersAreDotEnc(op.getResult(), incompatible);
97-
return !incompatible;
98-
}
99-
100-
static bool
101-
isPipeliningBeneficial(Operation *op, Operation *finalUser,
102-
tt::ModuleAxisInfoAnalysis &axisInfoAnalysis,
103-
bool filterSmall) {
104-
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
105-
if (filterSmall && !canBeConvertedToAsyncLoad(loadOp, axisInfoAnalysis)) {
106-
LDBG("Load " << *loadOp << " is too small for pipelining");
107-
return false;
108-
}
109-
}
110-
if (isa<tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op))
111-
return true;
112-
if (!canHaveSharedEncoding(cast<tt::LoadOp>(op))) {
113-
LDBG("Load " << *op << " cannot have shared encoding");
114-
return false;
115-
}
116-
117-
ttg::SharedEncodingTrait localAllocEnc;
118-
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
119-
return isa<ttg::LocalAllocOp>(user);
120-
})) {
121-
for (auto user : op->getUsers()) {
122-
auto localAlloc = dyn_cast<ttg::LocalAllocOp>(user);
123-
if (!localAlloc)
124-
continue;
125-
auto enc = mlir::cast<ttg::SharedEncodingTrait>(
126-
localAlloc.getType().getEncoding());
127-
if (!localAllocEnc) {
128-
localAllocEnc = enc;
129-
}
130-
if (enc != localAllocEnc) {
131-
// If the load is used by a LocalAllocOp, all the users need to have
132-
// the same encoding.
133-
return false;
134-
}
135-
}
136-
}
137-
138-
if (localAllocEnc) {
139-
auto registerTy = cast<RankedTensorType>(op->getResultTypes()[0]);
140-
auto vecBytes = getCopyVecBytes(registerTy, localAllocEnc);
141-
if (filterSmall && vecBytes < 4) {
142-
// At least 4 bytes need to be consecutive for cp.async
143-
return false;
144-
}
145-
}
146-
147-
return true;
148-
}
14991
};
15092

15193
class AssignMMALatencies {
@@ -280,8 +222,7 @@ loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot,
280222
if (!seen.insert(op).second || excluded.count(op))
281223
return;
282224
if (isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op)) {
283-
if (!AssignLoadLatencies::isPipeliningBeneficial(
284-
op, finalUser, axisInfoAnalysis, filterSmall))
225+
if (!isPipeliningBeneficial(op, axisInfoAnalysis, filterSmall))
285226
return;
286227
if (loadOpToIndLevel.count(op)) {
287228
int level = loadOpToIndLevel[op].first;

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -453,26 +453,17 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule,
453453
continue;
454454
}
455455
SharedEncodingTrait sharedEncoding;
456-
bool canUseAsyncCp = false;
457-
if (!isa<RankedTensorType>(op.getResultTypes()[0])) {
458-
canUseAsyncCp = op.getResultTypes()[0].getIntOrFloatBitWidth() >= 32;
459-
sharedEncoding = ttg::SwizzledSharedEncodingAttr::get(
460-
forOp.getContext(), 1, 1, 1, {0},
461-
ttg::CTALayoutAttr::get(forOp.getContext(), {1}, {1}, {0}));
462-
if (canUseAsyncCp) {
456+
bool canUseAsyncCp =
457+
triton::isPipeliningBeneficial(&op, axisInfoAnalysis);
458+
if (canUseAsyncCp) {
459+
if (!isa<RankedTensorType>(op.getResultTypes()[0])) {
460+
sharedEncoding = ttg::SwizzledSharedEncodingAttr::get(
461+
forOp.getContext(), 1, 1, 1, {0},
462+
ttg::CTALayoutAttr::get(forOp.getContext(), {1}, {1}, {0}));
463463
scalarLoads.push_back(&op);
464+
} else {
465+
sharedEncoding = getSharedEncoding(&op);
464466
}
465-
} else {
466-
sharedEncoding = getSharedEncoding(&op);
467-
// Do not create async loads for small loads (cp.async requires at least
468-
// 4 bytes)
469-
canUseAsyncCp =
470-
isa<tt::LoadOp>(op) &&
471-
canBeConvertedToAsyncLoad(cast<tt::LoadOp>(op), axisInfoAnalysis);
472-
int copyVecBytes = getCopyVecBytes(
473-
cast<RankedTensorType>(op.getResultTypes()[0]), sharedEncoding);
474-
475-
canUseAsyncCp &= copyVecBytes >= 4;
476467
}
477468
if (canUseAsyncCp || isTMALoad(&op)) {
478469
if (loadRequiresAdditionalBuffer(&op)) {

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,10 @@ ttg::SharedEncodingTrait mlir::triton::getSharedEncoding(RankedTensorType ty) {
603603
}
604604

605605
ttg::SharedEncodingTrait mlir::triton::getSharedEncoding(Operation *op) {
606+
if (!isa<RankedTensorType>(op->getResultTypes()[0])) {
607+
return nullptr;
608+
}
609+
606610
// Try to use local alloc encoding if possible.
607611
ttg::SharedEncodingTrait localAllocEnc;
608612
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
@@ -933,3 +937,38 @@ void triton::removePipeliningAttributes(ModuleOp moduleOp) {
933937
op->removeAttr(mlir::triton::kScheduledMaxStageAttrName);
934938
});
935939
}
940+
941+
static bool canHaveSharedEncoding(tt::LoadOp op) {
942+
// If used by an user with DotOp encoding, all the uses must be compatible.
943+
bool incompatible = false;
944+
getSharedEncIfAllUsersAreDotEnc(op.getResult(), incompatible);
945+
return !incompatible;
946+
}
947+
948+
bool triton::isPipeliningBeneficial(
949+
Operation *op, tt::ModuleAxisInfoAnalysis &axisInfoAnalysis,
950+
bool filterSmall) {
951+
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
952+
if (filterSmall && !canBeConvertedToAsyncLoad(loadOp, axisInfoAnalysis)) {
953+
LDBG("Load " << *loadOp << " is too small for pipelining");
954+
return false;
955+
}
956+
}
957+
if (isa<tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op))
958+
return true;
959+
if (!canHaveSharedEncoding(cast<tt::LoadOp>(op))) {
960+
LDBG("Load " << *op << " cannot have shared encoding");
961+
return false;
962+
}
963+
964+
if (auto localAllocEnc = getSharedEncoding(op)) {
965+
auto registerTy = cast<RankedTensorType>(op->getResultTypes()[0]);
966+
auto vecBytes = mlir::triton::getCopyVecBytes(registerTy, localAllocEnc);
967+
if (filterSmall && vecBytes < 4) {
968+
// At least 4 bytes need to be consecutive for cp.async
969+
return false;
970+
}
971+
}
972+
973+
return true;
974+
}

python/examples/gluon/01-attention-forward.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,11 @@ class BarrierCounter:
5252
phase: gl.tensor
5353
num_barriers: gl.constexpr
5454

55+
@gluon.constexpr_function
5556
def __init__(self, index, phase, num_barriers):
5657
self.index = index
5758
self.phase = phase
58-
self.num_barriers = num_barriers
59+
self.num_barriers = gl.constexpr(num_barriers)
5960

6061
@gluon.must_use_result
6162
@gluon.jit
@@ -79,6 +80,7 @@ class ChannelType:
7980
num_buffers: gl.constexpr
8081
num_consumers: gl.constexpr
8182

83+
@gluon.constexpr_function
8284
def __init__(self, mem, ready_bars, empty_bars, num_buffers, num_consumers):
8385
self.mem = mem
8486
self.ready_bars = ready_bars
@@ -143,6 +145,7 @@ class Producer:
143145
channel: ChannelType
144146
counter: BarrierCounter
145147

148+
@gluon.constexpr_function
146149
def __init__(self, channel, counter):
147150
self.channel = channel
148151
self.counter = counter
@@ -158,6 +161,7 @@ class Consumer:
158161
channel: ChannelType
159162
counter: BarrierCounter
160163

164+
@gluon.constexpr_function
161165
def __init__(self, channel, counter):
162166
self.channel = channel
163167
self.counter = counter
@@ -234,6 +238,7 @@ class AttentionConfig:
234238
num_kv_buffers: gl.constexpr
235239
use_exp2_turnstile: gl.constexpr
236240

241+
@gluon.constexpr_function
237242
def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE_N, NUM_SMS, STAGE, dtype,
238243
num_warps):
239244
self.qk_scale = qk_scale
@@ -250,7 +255,7 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE
250255
self.num_warps = gl.constexpr(num_warps)
251256

252257
self.SPLIT_D_FACTOR = gl.constexpr(2)
253-
self.SPLIT_EXP_FACTOR = 256 // HEAD_DIM
258+
self.SPLIT_EXP_FACTOR = gl.constexpr(256 // HEAD_DIM)
254259
self.SPLIT_QK_LOAD_FACTOR = gl.constexpr(2 if STAGE == 1 else 1)
255260
self.SPLIT_M = gl.constexpr(self.BLOCK_M // 2)
256261
self.SPLIT_D = gl.constexpr(self.HEAD_DIM // self.SPLIT_D_FACTOR)
@@ -305,6 +310,7 @@ class ProgramScheduler:
305310
num_pid_in_group: gl.tensor
306311
num_tiles: gl.tensor
307312

313+
@gluon.constexpr_function
308314
def __init__(self, config, start_pid, num_pid_n, num_pid_in_group, num_tiles):
309315
self.config = config
310316
self.start_pid = start_pid
@@ -339,6 +345,7 @@ class AttentionProgram:
339345
offset_y: gl.tensor
340346
qo_offset_y: gl.tensor
341347

348+
@gluon.constexpr_function
342349
def __init__(self, config, start_m, off_hz, offset_y, qo_offset_y):
343350
self.config = config
344351
self.start_m = start_m
@@ -840,12 +847,13 @@ def attention_kernel( #
840847

841848
chnls = (q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile)
842849
descs = (desc_q, desc_k, desc_v, desc_o)
843-
gl.warp_specialize((config, chnls, descs, M, STAGE), _attn_fwd_correction, (config, chnls, descs, M, STAGE), [
844-
_attn_fwd_softmax0,
845-
_attn_fwd_softmax1,
846-
_attn_fwd_mma,
847-
_attn_fwd_load,
848-
_attn_fwd_epilogue,
850+
gl.warp_specialize([
851+
(_attn_fwd_correction, (config, chnls, descs, M, STAGE)),
852+
(_attn_fwd_softmax0, (config, chnls, descs, M, STAGE)),
853+
(_attn_fwd_softmax1, (config, chnls, descs, M, STAGE)),
854+
(_attn_fwd_mma, (config, chnls, descs, M, STAGE)),
855+
(_attn_fwd_load, (config, chnls, descs, M, STAGE)),
856+
(_attn_fwd_epilogue, (config, chnls, descs, M, STAGE)),
849857
], [4, 4, 1, 1, 1], [192, 192, 24, 24, 24])
850858

851859
q_chnl.release()

python/src/gluon_ir.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,16 @@ void init_gluon_ir(py::module &&m) {
384384
ctx, version, warpsPerCta, instrShape, transposed, ctaLayout,
385385
tilesPerWarp, elementBitWidth);
386386
})
387+
.def("get_amd_mfma_scale_layout",
388+
[](GluonOpBuilder &self, unsigned opIdx, std::vector<int64_t> &shape,
389+
unsigned mfmaMDim, std::vector<unsigned> &tilesPerWarp,
390+
std::vector<unsigned> &warpsPerCTA) -> py::object {
391+
auto ctx = self.getContext();
392+
auto ll = ttg::chooseScaledMfmaScaleLayout(
393+
ctx, opIdx, shape, mfmaMDim, tilesPerWarp, warpsPerCTA);
394+
auto attr = ttg::LinearEncodingAttr::get(ctx, ll);
395+
return layoutToGluon(attr);
396+
})
387397
.def("get_amd_wmma_layout",
388398
[](GluonOpBuilder &self, unsigned version, bool transposed,
389399
std::vector<unsigned> &warpsPerCta,
@@ -397,6 +407,15 @@ void init_gluon_ir(py::module &&m) {
397407
return ttg::AMDWmmaEncodingAttr::get(
398408
ctx, version, transposed, warpsPerCta, ctaLayout, instrShape);
399409
})
410+
.def("get_amd_wmma_scale_layout",
411+
[](GluonOpBuilder &self, unsigned opIdx, std::vector<int64_t> &shape,
412+
std::vector<unsigned> &warpsPerCTA) -> py::object {
413+
auto ctx = self.getContext();
414+
auto ll = ttg::chooseScaledWmmaScaleLayout(ctx, opIdx, warpsPerCTA,
415+
shape);
416+
auto attr = ttg::LinearEncodingAttr::get(ctx, ll);
417+
return layoutToGluon(attr);
418+
})
400419
.def("get_intel_dpas_layout",
401420
[](GluonOpBuilder &self, unsigned repeatCount,
402421
unsigned systolicDepth, unsigned executionSize,

python/src/ir.h

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ class TritonOpBuilder {
3535
if (!block.empty())
3636
setLastLoc(block.begin()->getLoc());
3737
else
38-
setLastLoc(builder->getUnknownLoc());
38+
setLastLoc(getLocForBlock(&block));
3939
builder->setInsertionPointToStart(&block);
4040
}
4141

4242
void setInsertionPointToEnd(mlir::Block &block) {
4343
if (!block.empty())
4444
setLastLoc(block.back().getLoc());
4545
else
46-
setLastLoc(builder->getUnknownLoc());
46+
setLastLoc(getLocForBlock(&block));
4747
builder->setInsertionPointToEnd(&block);
4848
}
4949

@@ -53,10 +53,14 @@ class TritonOpBuilder {
5353
}
5454

5555
void restoreInsertionPoint(mlir::OpBuilder::InsertPoint pt) {
56-
if (pt.isSet() && pt.getPoint() != pt.getBlock()->end())
57-
setLastLoc(pt.getPoint()->getLoc());
58-
else
59-
setLastLoc(builder->getUnknownLoc());
56+
setLastLoc(builder->getUnknownLoc());
57+
if (pt.isSet()) {
58+
if (pt.getPoint() != pt.getBlock()->end())
59+
setLastLoc(pt.getPoint()->getLoc());
60+
else
61+
setLastLoc(getLocForBlock(pt.getBlock()));
62+
}
63+
6064
builder->restoreInsertionPoint(pt);
6165
}
6266

@@ -87,4 +91,10 @@ class TritonOpBuilder {
8791
std::unique_ptr<mlir::Location> lastLoc;
8892
bool lineInfoEnabled =
8993
!mlir::triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO");
94+
95+
mlir::Location getLocForBlock(mlir::Block *block) {
96+
if (auto parentOp = block->getParentOp())
97+
return parentOp->getLoc();
98+
return builder->getUnknownLoc();
99+
}
90100
};

0 commit comments

Comments
 (0)