Skip to content

Commit 949256e

Browse files
committed
Refactor
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent d9de8e7 commit 949256e

File tree

1 file changed

+87
-123
lines changed

1 file changed

+87
-123
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp

Lines changed: 87 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
33
#include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h"
44
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
5+
#include "mlir/IR/Operation.h"
6+
#include "mlir/IR/Value.h"
57
#include "mlir/IR/Verifier.h"
68
#include "mlir/Support/LLVM.h"
79
#include "triton/Dialect/Triton/IR/Dialect.h"
@@ -11,6 +13,7 @@
1113
#include "triton/Tools/StrUtil.h"
1214
#include "llvm/Support/Debug.h"
1315
#include "llvm/Support/raw_ostream.h"
16+
#include <variant>
1417

1518
#define DEBUG_TYPE "tritonintelgpu-coalesce"
1619
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -150,7 +153,7 @@ struct CoalescePass
150153

151154
static bool filterUser(Operation *op) {
152155
// Yield operations trigger updating the layout of the containing loop
153-
// results, so don't skip them.
156+
// results, don't skip them.
154157
if (isa<scf::YieldOp>(op))
155158
return false;
156159

@@ -168,154 +171,123 @@ struct CoalescePass
168171
return false;
169172
}
170173

171-
// Propagate the \p root block argument operation output layout along the
172-
// def-use chain.
173-
static void propagateLayout(BlockArgument arg, Attribute layout,
174-
IRRewriter &rewriter) {
175-
llvm::errs() << "arg: " << arg << "\n";
174+
// Propagate the layout to \p root operation's result to the \p forOp loop
175+
// init argument that uses it, and transitively to the operations in the loop
176+
// body that use that argument.
177+
static void propagate(scf::ForOp forOp, Operation *root, Attribute layout,
178+
IRRewriter &rewriter) {
179+
assert(llvm::any_of(root->getUsers(),
180+
[&](Operation *user) { return user == forOp; }) &&
181+
"Expecting the loop to be a user of the root operation");
182+
183+
for (BlockArgument arg : forOp.getRegionIterArgs()) {
184+
Value loopArg = forOp.getInitArgs()[arg.getArgNumber() - 1];
185+
for (OpResult res : root->getResults()) {
186+
if (res != loopArg || !tt::isTensorPointerType(res.getType()))
187+
continue;
176188

177-
auto users = arg.getUsers();
178-
if (users.empty()) {
179-
llvm::errs() << "arg has no users\n";
180-
return;
189+
LDBG("loopArg: " << loopArg);
190+
191+
// Modify the layout of the loop init argument...
192+
tt::PointerType ptrType = cast<tt::PointerType>(arg.getType());
193+
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
194+
arg.setType(tt::PointerType::get(getNewType(tensorType, layout),
195+
ptrType.getAddressSpace()));
196+
197+
// ... and then propagate it to the operations in the loop.
198+
propagateLayout(arg, layout, rewriter);
199+
}
181200
}
201+
}
202+
203+
// Modify the given loop \p forOp and propagate the result of the enclosing
204+
// loop.
205+
static void propagate(scf::ForOp forOp, Attribute layout,
206+
IRRewriter &rewriter) {
207+
Operation *yieldOp = forOp.getBody()->getTerminator();
208+
209+
rewriter.modifyOpInPlace(forOp, [&]() {
210+
for (auto [opType, res] :
211+
llvm::zip(yieldOp->getOperandTypes(), forOp.getResults())) {
212+
if (opType == res.getType())
213+
continue;
214+
215+
assert(tt::isTensorPointerType(res.getType()) &&
216+
tt::isTensorPointerType(opType) && "Expecting blocked pointers");
217+
assert(cast<RankedTensorType>(
218+
cast<tt::PointerType>(opType).getPointeeType())
219+
.getEncoding() == layout &&
220+
"Unexpected layout");
182221

183-
for (Operation *user : users) {
184-
llvm::errs() << "arg's user: " << *user << "\n\n";
222+
auto resType = cast<tt::PointerType>(res.getType());
223+
RankedTensorType tensorType = getRankedTensorType(resType);
224+
res.setType(tt::PointerType::get(getNewType(tensorType, layout),
225+
resType.getAddressSpace()));
226+
}
227+
});
228+
229+
propagateLayout(forOp, layout, rewriter);
230+
}
231+
232+
static void propagateLayout(BlockArgument arg, Attribute layout,
233+
IRRewriter &rewriter) {
234+
LDBG("arg: " << arg);
235+
for (Operation *user : arg.getUsers()) {
236+
LDBG("arg's user: " << *user << "\n");
185237
if (filterUser(user)) {
186-
llvm::errs() << "SKIP\n";
187238
continue;
188239
}
189-
190240
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
191-
// Modify and propagate the result of the enclosing loop.
192241
auto forOp = yieldOp->getParentOfType<scf::ForOp>();
193-
194-
rewriter.modifyOpInPlace(forOp, [&]() {
195-
for (auto [opType, res] :
196-
llvm::zip(yieldOp->getOperandTypes(), forOp.getResults())) {
197-
if (opType == res.getType())
198-
continue;
199-
200-
assert(tt::isTensorPointerType(res.getType()) &&
201-
tt::isTensorPointerType(opType) &&
202-
"Expecting blocked pointers");
203-
assert(cast<RankedTensorType>(
204-
cast<tt::PointerType>(opType).getPointeeType())
205-
.getEncoding() == layout &&
206-
"Unexpected layout");
207-
208-
auto resType = cast<tt::PointerType>(res.getType());
209-
RankedTensorType tensorType = getRankedTensorType(resType);
210-
res.setType(tt::PointerType::get(getNewType(tensorType, layout),
211-
resType.getAddressSpace()));
212-
}
213-
});
214-
215-
propagateLayout(forOp, layout, rewriter);
242+
propagate(forOp, layout, rewriter);
216243
continue;
217244
}
218-
219245
changeAndPropagateLayout(user, layout, rewriter);
220246
}
221247
}
222248

223249
static void propagateLayout(Operation *root, Attribute layout,
224250
IRRewriter &rewriter) {
225-
assert(root && root->getNumResults() != 0 &&
251+
assert(root->getNumResults() != 0 &&
226252
"Expecting an operation yielding a result");
227253

228-
llvm::errs() << "root: " << *root << "\n";
229-
auto users = root->getUsers();
230-
if (users.empty()) {
231-
llvm::errs() << "root has no users\n";
232-
return;
233-
}
234-
235-
for (Operation *user : users) {
236-
llvm::errs() << "root's user: " << *user << "\n\n";
254+
LDBG("root: " << *root);
255+
for (Operation *user : root->getUsers()) {
256+
LDBG("root's user: " << *user << "\n");
237257
if (filterUser(user)) {
238-
llvm::errs() << "SKIP\n";
239258
continue;
240259
}
241-
242-
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
243-
// Modify and propagate the result of the enclosing loop.
244-
auto forOp = yieldOp->getParentOfType<scf::ForOp>();
245-
246-
rewriter.modifyOpInPlace(forOp, [&]() {
247-
for (auto [opType, res] :
248-
llvm::zip(yieldOp->getOperandTypes(), forOp.getResults())) {
249-
if (opType == res.getType())
250-
continue;
251-
252-
assert(tt::isTensorPointerType(res.getType()) &&
253-
tt::isTensorPointerType(opType) &&
254-
"Expecting blocked pointers");
255-
assert(cast<RankedTensorType>(
256-
cast<tt::PointerType>(opType).getPointeeType())
257-
.getEncoding() == layout &&
258-
"Unexpected layout");
259-
260-
auto resType = cast<tt::PointerType>(res.getType());
261-
RankedTensorType tensorType = getRankedTensorType(resType);
262-
res.setType(tt::PointerType::get(getNewType(tensorType, layout),
263-
resType.getAddressSpace()));
264-
}
265-
});
266-
267-
propagateLayout(forOp, layout, rewriter);
260+
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
261+
propagate(forOp, root, layout, rewriter);
268262
continue;
269263
}
270-
271-
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
272-
for (BlockArgument arg : forOp.getRegionIterArgs()) {
273-
Value loopArg = forOp.getInitArgs()[arg.getArgNumber() - 1];
274-
for (OpResult res : root->getResults()) {
275-
if (res == loopArg && tt::isTensorPointerType(res.getType())) {
276-
llvm::errs() << "arg: " << arg << "\n";
277-
llvm::errs() << "loopArg: " << loopArg << "\n";
278-
279-
// Modify the layout of the loop init argument...
280-
tt::PointerType ptrType = cast<tt::PointerType>(arg.getType());
281-
auto tensorType =
282-
cast<RankedTensorType>(ptrType.getPointeeType());
283-
arg.setType(tt::PointerType::get(getNewType(tensorType, layout),
284-
ptrType.getAddressSpace()));
285-
286-
// ... and then propagate it to the operations in the loop.
287-
propagateLayout(arg, layout, rewriter);
288-
}
289-
}
290-
}
264+
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
265+
auto forOp = yieldOp->getParentOfType<scf::ForOp>();
266+
propagate(forOp, layout, rewriter);
291267
continue;
292268
}
293-
294269
changeAndPropagateLayout(user, layout, rewriter);
295270
}
296271
}
297272

298-
// TODO: change the implementation to handle only operation yielding one
299-
// result?
300-
// Change the \p layout of the \p op result(s) and propagate the new
301-
// result type to its users.
273+
// Change the \p layout of the \p op result and propagate the new result type
274+
// to its users.
302275
static void changeAndPropagateLayout(Operation *op, Attribute layout,
303276
IRRewriter &rewriter) {
304-
assert(op && op->getNumResults() != 0 &&
277+
assert(op && op->getNumResults() == 1 &&
305278
"Expecting operation yielding a result");
306279

307280
rewriter.modifyOpInPlace(op, [&]() {
308-
for (Value res : op->getResults()) {
309-
if (!tt::isTensorPointerType(res.getType()))
310-
continue;
311-
312-
auto ptrType = cast<tt::PointerType>(res.getType());
313-
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
314-
res.setType(tt::PointerType::get(getNewType(tensorType, layout),
315-
ptrType.getAddressSpace()));
316-
}
281+
Value res = op->getOpResult(0);
282+
assert(tt::isTensorPointerType(res.getType()) &&
283+
"Expecting a block pointer");
284+
285+
auto ptrType = cast<tt::PointerType>(res.getType());
286+
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
287+
res.setType(tt::PointerType::get(getNewType(tensorType, layout),
288+
ptrType.getAddressSpace()));
317289
});
318-
llvm::errs() << "Coalesced op: " << *op << "\n";
290+
LDBG("Coalesced op: " << *op);
319291

320292
propagateLayout(op, layout, rewriter);
321293
}
@@ -400,22 +372,14 @@ struct CoalescePass
400372
if (!refTensorType || !refTensorType.getEncoding())
401373
return;
402374

403-
// static int n = 0;
404-
// if (tt::isTensorPointerType(ptr.getType()))
405-
// n++;
406-
407-
// if (n != 2)
408-
// return;
409-
410375
int numWarps = ttg::TritonGPUDialect::getNumWarps(moduleOp);
411376
int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(moduleOp);
412377
setCoalescedEncoding(axisInfoAnalysis, curr, numWarps, threadsPerWarp,
413378
layoutMap);
414379
});
415380

416381
LLVM_DEBUG({
417-
DBGS() << "layoutMap:"
418-
<< "\n";
382+
DBGS() << "layoutMap:" << "\n";
419383
for (auto [op, encoding] : layoutMap) {
420384
DBGS() << "op: " << *op << "\n";
421385
DBGS() << "encoding: " << encoding << "\n";

0 commit comments

Comments
 (0)