Skip to content

Commit 1f5dc71

Browse files
tfruan2000Jokeren
andauthored
[BACKEND] Optimize code style in rewrite-tensor-pointer and add more tests (triton-lang#4724)
The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --- Hello, maintainers and reviewers! While reading the [RewriteTensorPointer.cpp](https://github.com/triton-lang/triton/blob/main/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp) pass, I noticed that the current implementation to be somewhat redundant and the test is hard to understand, so I submitted this PR. PR description: - Use `llvm::make_early_inc_range` to ensure no issues arise during visiting ops, instead of making a copy ```mlir for (auto &region : op->getRegions()) { for (auto &block : region) { SmallVector<Operation *> blockCopy; for (auto &nestedOp : block) blockCopy.push_back(&nestedOp); for (auto &nestedOp : blockCopy) { if (auto newOp = rewriteOp(nestedOp, eraser)) -> for (Region &region : op->getRegions()) { for (Block &block : region) { for (Operation &nestedOp : llvm::make_early_inc_range(block)) { if (auto newOp = rewriteOp(&nestedOp, eraser)) { visitOperation(newOp, eraser); } ``` - Return directly from the parameter instead of constructing a new SmallVector. ```mlir static SmallVector<Value> generateNewOperands(const SmallVector<Value> &oldOperands, unsigned index, const SmallVector<Value> &newValues) { -> static void generateNewOperands(SmallVector<Value> &oldOperands, unsigned index, ArrayRef<Value> newValues) { ``` - delete some dead code - add detailed tests. see test/Triton/rewrite-tensor-pointer.mlir Co-authored-by: Keren Zhou <[email protected]>
1 parent 5000e32 commit 1f5dc71

File tree

2 files changed

+215
-117
lines changed

2 files changed

+215
-117
lines changed

lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#include <memory>
22
#include <stack>
33

4+
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
45
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
6+
#include "mlir/Dialect/SCF/IR/SCF.h"
57
#include "mlir/Pass/Pass.h"
68
#include "mlir/Support/LLVM.h"
79
#include "triton/Analysis/Utility.h"
@@ -171,10 +173,7 @@ struct RewritedInfo {
171173
auto otherTensorType = RankedTensorType::get(tensorShape, elementType);
172174

173175
// Set zero padding value
174-
TypedAttr attr =
175-
elementType.isIntOrIndex()
176-
? cast<TypedAttr>(builder.getIntegerAttr(elementType, 0))
177-
: cast<TypedAttr>(builder.getFloatAttr(elementType, 0));
176+
TypedAttr attr = builder.getZeroAttr(elementType);
178177

179178
// Float NaN padding case
180179
if (padding.value() == triton::PaddingOption::PAD_NAN) {
@@ -209,18 +208,20 @@ class RewriteTensorPointerPass
209208
});
210209
}
211210

212-
static SmallVector<Value>
213-
generateNewOperands(const SmallVector<Value> &oldOperands, unsigned index,
214-
const SmallVector<Value> &newValues) {
215-
assert(index < oldOperands.size());
216-
SmallVector<Value> newOperands;
217-
for (int i = 0; i < index; ++i)
218-
newOperands.push_back(oldOperands[i]);
219-
for (auto value : newValues)
220-
newOperands.push_back(value);
221-
for (auto i = index + 1; i < oldOperands.size(); ++i)
222-
newOperands.push_back(oldOperands[i]);
223-
return newOperands;
211+
static void generateNewOperands(SmallVector<Value> &oldOperands,
212+
unsigned index, ArrayRef<Value> newValues) {
213+
size_t size = oldOperands.size();
214+
assert(index < size);
215+
SmallVector<Value> operands = oldOperands;
216+
oldOperands.reserve(size - 1 + newValues.size());
217+
oldOperands.clear();
218+
if (index != 0) {
219+
oldOperands.append(operands.begin(), operands.begin() + index);
220+
}
221+
oldOperands.append(newValues.begin(), newValues.end());
222+
if (index != size - 1) {
223+
oldOperands.append(operands.begin() + index + 1, operands.end());
224+
}
224225
}
225226

226227
Operation *rewriteMakeTensorPtrOp(OpBuilder &builder,
@@ -358,7 +359,7 @@ class RewriteTensorPointerPass
358359
}
359360
auto rematerialize = [&](Block *block) {
360361
for (Operation &opInIf : block->getOperations()) {
361-
auto newOp = builder.clone(opInIf, mapping);
362+
builder.clone(opInIf, mapping);
362363
}
363364
};
364365
builder.setInsertionPointToStart(newOp.thenBlock());
@@ -403,8 +404,7 @@ class RewriteTensorPointerPass
403404
// Expand the tensor pointer into offsets
404405
assert(rewritedInfo.count(newIterOperands[i]));
405406
auto info = rewritedInfo[newIterOperands[i]];
406-
newIterOperands =
407-
generateNewOperands(newIterOperands, i, info.getOffsets());
407+
generateNewOperands(newIterOperands, i, info.getOffsets());
408408
i += info.length() - 1;
409409
size += info.length() - 1;
410410
}
@@ -439,9 +439,7 @@ class RewriteTensorPointerPass
439439
// Clone body
440440
builder.setInsertionPointToStart(newForOp.getBody());
441441
for (auto &opInFor : *op.getBody()) {
442-
auto *newOp = builder.clone(opInFor, mapping);
443-
for (unsigned i = 0; i < opInFor.getNumResults(); ++i)
444-
mapping.map(opInFor.getResult(i), newOp->getResult(i));
442+
builder.clone(opInFor, mapping);
445443
}
446444

447445
// Replace later usages
@@ -476,7 +474,7 @@ class RewriteTensorPointerPass
476474

477475
assert(rewritedInfo.count(newOperands[i]));
478476
auto info = rewritedInfo[newOperands[i]];
479-
newOperands = generateNewOperands(newOperands, i, info.getOffsets());
477+
generateNewOperands(newOperands, i, info.getOffsets());
480478
i += info.length() - 1;
481479
size += info.length() - 1;
482480
}
@@ -492,15 +490,13 @@ class RewriteTensorPointerPass
492490
// Rewrite `make_tensor_ptr` and `advance` and make a tensor of pointers
493491
// Rewriting functions return the next operation to visit, if there is no
494492
// next one, simply return `nullptr`
495-
std::pair<Value, RewritedInfo> rewrited;
496493
if (auto makeTensorPtrOp = dyn_cast<triton::MakeTensorPtrOp>(op)) {
497494
return rewriteMakeTensorPtrOp(builder, makeTensorPtrOp, eraser);
498495
} else if (auto advanceOp = dyn_cast<triton::AdvanceOp>(op)) {
499496
return rewriteAdvanceOp(builder, advanceOp, eraser);
500497
} else if (isa<triton::LoadOp>(op) || isa<triton::StoreOp>(op)) {
501498
return rewriteLoadStoreOp(builder, op, eraser);
502-
} else if (op->getDialect()->getNamespace() == "scf" ||
503-
op->getDialect()->getNamespace() == "cf") {
499+
} else if (isa<scf::SCFDialect, cf::ControlFlowDialect>(op->getDialect())) {
504500
if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
505501
return rewriteIfOp(builder, ifOp, eraser);
506502
}
@@ -524,18 +520,12 @@ class RewriteTensorPointerPass
524520
}
525521

526522
void visitOperation(Operation *op, std::stack<Operation *> &eraser) {
527-
for (auto &region : op->getRegions()) {
528-
for (auto &block : region) {
529-
// We need an extra copy because erasing operations may break the
530-
// iterator behavior
531-
SmallVector<Operation *> blockCopy;
532-
for (auto &nestedOp : block)
533-
blockCopy.push_back(&nestedOp);
534-
535-
// Rewrite and recursively visit
536-
for (auto &nestedOp : blockCopy) {
537-
if (auto newOp = rewriteOp(nestedOp, eraser))
523+
for (Region &region : op->getRegions()) {
524+
for (Block &block : region) {
525+
for (Operation &nestedOp : llvm::make_early_inc_range(block)) {
526+
if (auto newOp = rewriteOp(&nestedOp, eraser)) {
538527
visitOperation(newOp, eraser);
528+
}
539529
}
540530
}
541531
}

0 commit comments

Comments
 (0)