Skip to content

Commit 98d23fe

Browse files
authored
[BACKEND] Rewrite tablegen combine into cpp patterns (#7709)
<!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `test/Triton/combine.mlir already checks` (NFC change). - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent a325e03 commit 98d23fe

File tree

2 files changed

+37
-34
lines changed

2 files changed

+37
-34
lines changed

lib/Dialect/Triton/Transforms/Combine.cpp

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,43 @@ class RankedReduceDescriptorLoads : public mlir::OpRewritePattern<ReshapeOp> {
231231
}
232232
};
233233

234+
template <typename OpTy>
235+
class CombineDotAddPattern : public mlir::OpRewritePattern<OpTy> {
236+
public:
237+
using OpRewritePattern<OpTy>::OpRewritePattern;
238+
239+
mlir::LogicalResult
240+
matchAndRewrite(OpTy addOp, mlir::PatternRewriter &rewriter) const override {
241+
auto dotOp = addOp.getRhs().template getDefiningOp<DotOp>();
242+
bool isDotLHS = false;
243+
if (!dotOp) {
244+
dotOp = addOp.getLhs().template getDefiningOp<DotOp>();
245+
if (!dotOp) {
246+
return failure();
247+
}
248+
isDotLHS = true;
249+
}
250+
if (!dotOp->hasOneUse()) {
251+
return failure();
252+
}
253+
if (!isZero(dotOp.getC()))
254+
return failure();
255+
rewriter.modifyOpInPlace(dotOp, [&] {
256+
dotOp.getCMutable().assign(isDotLHS ? addOp.getRhs() : addOp.getLhs());
257+
dotOp->moveBefore(addOp);
258+
});
259+
rewriter.replaceAllUsesWith(addOp, dotOp.getResult());
260+
return success();
261+
}
262+
};
263+
264+
// AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d)
265+
// AddFOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d)
266+
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
267+
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
268+
using CombineDotAddIPattern = CombineDotAddPattern<arith::AddIOp>;
269+
using CombineDotAddFPattern = CombineDotAddPattern<arith::AddFOp>;
270+
234271
} // anonymous namespace
235272

236273
class CombineOpsPass : public impl::TritonCombineOpsBase<CombineOpsPass> {
@@ -240,12 +277,8 @@ class CombineOpsPass : public impl::TritonCombineOpsBase<CombineOpsPass> {
240277
RewritePatternSet patterns(context);
241278
ModuleOp m = getOperation();
242279

243-
// Dot Add %{
244280
patterns.add<CombineDotAddIPattern>(context);
245281
patterns.add<CombineDotAddFPattern>(context);
246-
patterns.add<CombineDotAddIRevPattern>(context);
247-
patterns.add<CombineDotAddFRevPattern>(context);
248-
// %}
249282
patterns.add<CombineSelectMaskedLoadPattern>(context);
250283
patterns.add<CombineAddPtrPattern>(context);
251284
patterns.add<CombineBroadcastMulReducePattern>(context);

lib/Dialect/Triton/Transforms/Combine.td

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,6 @@ include "mlir/Dialect/Arith/IR/ArithOps.td"
55
include "triton/Dialect/Triton/IR/TritonOps.td"
66
include "mlir/IR/PatternBase.td"
77

8-
9-
// AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d)
10-
// AddFOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d)
11-
12-
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
13-
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
14-
def CombineDotAddIPattern : Pat<
15-
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $overflow),
16-
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
17-
[(Constraint<CPred<"isZero($0)">> $c),
18-
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
19-
def CombineDotAddFPattern : Pat<
20-
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath),
21-
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
22-
[(Constraint<CPred<"isZero($0)">> $c),
23-
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),
24-
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
25-
26-
def CombineDotAddIRevPattern : Pat<
27-
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $overflow),
28-
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
29-
[(Constraint<CPred<"isZero($0)">> $c),
30-
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
31-
def CombineDotAddFRevPattern : Pat<
32-
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath),
33-
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
34-
[(Constraint<CPred<"isZero($0)">> $c),
35-
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),
36-
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
37-
388
// addptr(addptr(%ptr, %idx0), %idx1) => addptr(%ptr, AddI(%idx0, %idx1))
399
// Note: leave (sub %c0, %c0) canceling to ArithDialect
4010
// (ref: ArithCanonicalization.td)

0 commit comments

Comments
 (0)