11#include " triton/Dialect/Triton/IR/Dialect.h"
2+ #include " triton/Dialect/Triton/IR/Interfaces.h"
23#include " triton/Dialect/Triton/IR/Types.h"
34
45#include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
56#include " mlir/Dialect/UB/IR/UBOps.h"
67#include " llvm/ADT/StringSwitch.h"
78#include " llvm/ADT/TypeSwitch.h"
8- #include " llvm/Support/raw_ostream.h"
99
10- #include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
11- #include " mlir/IR/DialectImplementation.h"
12-
13- #include " mlir/Transforms/InliningUtils.h"
1410#include " triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc"
1511#include " triton/Dialect/Triton/IR/Dialect.cpp.inc"
1612#include " triton/Dialect/Triton/IR/OpInterfaces.cpp.inc"
@@ -22,62 +18,45 @@ using namespace mlir::triton;
2218// TritonDialect Dialect Interfaces
2319// ===----------------------------------------------------------------------===//
2420
25- namespace {
26- struct TritonInlinerInterface : public DialectInlinerInterface {
27- using DialectInlinerInterface::DialectInlinerInterface;
28-
29- bool isLegalToInline (Operation *call, Operation *callable,
30- bool wouldBeCloned) const final {
31- auto funcOp = dyn_cast<triton::FuncOp>(callable);
32- if (!funcOp)
33- return true ;
34- if (funcOp->hasAttr (" noinline" ))
35- return !funcOp->getAttrOfType <BoolAttr>(" noinline" ).getValue ();
36- return true ;
37- }
38-
39- bool isLegalToInline (Region *dest, Region *src, bool wouldBeCloned,
40- IRMapping &valueMapping) const final {
41- return true ;
42- }
43-
44- bool isLegalToInline (Operation *, Region *, bool wouldBeCloned,
45- IRMapping &) const final {
21+ bool TritonInlinerInterface::isLegalToInline (Operation *call,
22+ Operation *callable,
23+ bool wouldBeCloned) const {
24+ auto funcOp = dyn_cast<triton::FuncOp>(callable);
25+ if (!funcOp)
4626 return true ;
47- }
48- // ===--------------------------------------------------------------------===//
49- // Transformation Hooks
50- // ===--------------------------------------------------------------------===//
51-
52- // / Handle the given inlined terminator by replacing it with a new operation
53- // / as necessary.
54- void handleTerminator (Operation *op, Block *newDest) const final {
55- // Only return needs to be handled here.
56- auto returnOp = dyn_cast<triton::ReturnOp>(op);
57- if (!returnOp)
58- return ;
59-
60- // Replace the return with a branch to the dest.
61- OpBuilder builder (op);
62- builder.create <mlir::cf::BranchOp>(op->getLoc (), newDest,
63- returnOp.getOperands ());
64- op->erase ();
65- }
66-
67- // / Handle the given inlined terminator by replacing it with a new operation
68- // / as necessary.
69- void handleTerminator (Operation *op, ValueRange valuesToRepl) const final {
70- // Only return needs to be handled here.
71- auto returnOp = cast<triton::ReturnOp>(op);
27+ if (funcOp->hasAttr (" noinline" ))
28+ return !funcOp->getAttrOfType <BoolAttr>(" noinline" ).getValue ();
29+ return true ;
30+ }
7231
73- // Replace the values directly with the return operands.
74- assert (returnOp.getNumOperands () == valuesToRepl.size ());
75- for (const auto &it : llvm::enumerate (returnOp.getOperands ()))
76- valuesToRepl[it.index ()].replaceAllUsesWith (it.value ());
77- }
78- };
32+ // / Handle the given inlined terminator by replacing it with a new operation
33+ // / as necessary.
34+ void TritonInlinerInterface::handleTerminator (Operation *op,
35+ Block *newDest) const {
36+ // Only return needs to be handled here.
37+ auto returnOp = dyn_cast<triton::ReturnOp>(op);
38+ if (!returnOp)
39+ return ;
40+
41+ // Replace the return with a branch to the dest.
42+ OpBuilder builder (op);
43+ builder.create <mlir::cf::BranchOp>(op->getLoc (), newDest,
44+ returnOp.getOperands ());
45+ op->erase ();
46+ }
7947
80- } // namespace
48+ // / Handle the given inlined terminator by replacing it with a new operation
49+ // / as necessary.
50+ void TritonInlinerInterface::handleTerminator (Operation *op,
51+ ValueRange valuesToRepl) const {
52+ // Only return needs to be handled here.
53+ auto returnOp = cast<triton::ReturnOp>(op);
54+
55+ // Replace the values directly with the return operands.
56+ assert (returnOp.getNumOperands () == valuesToRepl.size ());
57+ for (const auto &it : llvm::enumerate (returnOp.getOperands ()))
58+ valuesToRepl[it.index ()].replaceAllUsesWith (it.value ());
59+ }
8160
8261void TritonDialect::initialize () {
8362 registerTypes ();
0 commit comments