Skip to content

Commit 83f279b

Browse files
Merge commit 'c109dc79e57db5f505994324eb362c78004f1d38'
2 parents 712dec1 + c109dc7 commit 83f279b

File tree

24 files changed

+530
-94
lines changed

24 files changed

+530
-94
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,45 @@
11
#ifndef TRITON_IR_INTERFACES_H_
22
#define TRITON_IR_INTERFACES_H_
33

4+
#include "mlir/IR/DialectImplementation.h"
45
#include "mlir/IR/OpDefinition.h"
6+
#include "mlir/Transforms/InliningUtils.h"
57

68
#define GET_TYPEDEF_CLASSES
79
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
810

11+
namespace mlir::triton {
12+
13+
//===----------------------------------------------------------------------===//
14+
// TritonDialect Dialect Interfaces
15+
//===----------------------------------------------------------------------===//
16+
17+
struct TritonInlinerInterface : public DialectInlinerInterface {
18+
using DialectInlinerInterface::DialectInlinerInterface;
19+
20+
bool isLegalToInline(Operation *call, Operation *callable,
21+
bool wouldBeCloned) const final;
22+
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
23+
IRMapping &valueMapping) const final {
24+
return true;
25+
}
26+
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
27+
IRMapping &) const final {
28+
return true;
29+
}
30+
31+
//===--------------------------------------------------------------------===//
32+
// Transformation Hooks
33+
//===--------------------------------------------------------------------===//
34+
35+
/// Handle the given inlined terminator by replacing it with a new operation
36+
/// as necessary.
37+
void handleTerminator(Operation *op, Block *newDest) const final;
38+
/// Handle the given inlined terminator by replacing it with a new operation
39+
/// as necessary.
40+
void handleTerminator(Operation *op, ValueRange valuesToRepl) const final;
41+
};
42+
43+
} // namespace mlir::triton
44+
945
#endif // TRITON_IR_TYPES_H_

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,4 +360,17 @@ def TritonGPUCoalesceAsyncCopy: Pass<"tritongpu-coalesce-async-copy", "mlir::Mod
360360
"mlir::triton::TritonDialect"];
361361
}
362362

363+
def TritonGPUCanonicalize: Pass<"tritongpu-canonicalize"> {
364+
let summary = "reduced set of simplifications for TTGIR";
365+
366+
let description = [{
367+
The `tritongpu-canonicalize` pass applies a reduced set of simplification
368+
and canonicalization patterns to the module.
369+
}];
370+
let dependentDialects = [
371+
"mlir::arith::ArithDialect",
372+
"mlir::scf::SCFDialect",
373+
];
374+
}
375+
363376
#endif

lib/Dialect/Triton/IR/Dialect.cpp

Lines changed: 37 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
#include "triton/Dialect/Triton/IR/Dialect.h"
2+
#include "triton/Dialect/Triton/IR/Interfaces.h"
23
#include "triton/Dialect/Triton/IR/Types.h"
34

45
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
56
#include "mlir/Dialect/UB/IR/UBOps.h"
67
#include "llvm/ADT/StringSwitch.h"
78
#include "llvm/ADT/TypeSwitch.h"
8-
#include "llvm/Support/raw_ostream.h"
99

10-
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
11-
#include "mlir/IR/DialectImplementation.h"
12-
13-
#include "mlir/Transforms/InliningUtils.h"
1410
#include "triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc"
1511
#include "triton/Dialect/Triton/IR/Dialect.cpp.inc"
1612
#include "triton/Dialect/Triton/IR/OpInterfaces.cpp.inc"
@@ -22,62 +18,45 @@ using namespace mlir::triton;
2218
// TritonDialect Dialect Interfaces
2319
//===----------------------------------------------------------------------===//
2420

25-
namespace {
26-
struct TritonInlinerInterface : public DialectInlinerInterface {
27-
using DialectInlinerInterface::DialectInlinerInterface;
28-
29-
bool isLegalToInline(Operation *call, Operation *callable,
30-
bool wouldBeCloned) const final {
31-
auto funcOp = dyn_cast<triton::FuncOp>(callable);
32-
if (!funcOp)
33-
return true;
34-
if (funcOp->hasAttr("noinline"))
35-
return !funcOp->getAttrOfType<BoolAttr>("noinline").getValue();
36-
return true;
37-
}
38-
39-
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
40-
IRMapping &valueMapping) const final {
41-
return true;
42-
}
43-
44-
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
45-
IRMapping &) const final {
21+
bool TritonInlinerInterface::isLegalToInline(Operation *call,
22+
Operation *callable,
23+
bool wouldBeCloned) const {
24+
auto funcOp = dyn_cast<triton::FuncOp>(callable);
25+
if (!funcOp)
4626
return true;
47-
}
48-
//===--------------------------------------------------------------------===//
49-
// Transformation Hooks
50-
//===--------------------------------------------------------------------===//
51-
52-
/// Handle the given inlined terminator by replacing it with a new operation
53-
/// as necessary.
54-
void handleTerminator(Operation *op, Block *newDest) const final {
55-
// Only return needs to be handled here.
56-
auto returnOp = dyn_cast<triton::ReturnOp>(op);
57-
if (!returnOp)
58-
return;
59-
60-
// Replace the return with a branch to the dest.
61-
OpBuilder builder(op);
62-
builder.create<mlir::cf::BranchOp>(op->getLoc(), newDest,
63-
returnOp.getOperands());
64-
op->erase();
65-
}
66-
67-
/// Handle the given inlined terminator by replacing it with a new operation
68-
/// as necessary.
69-
void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
70-
// Only return needs to be handled here.
71-
auto returnOp = cast<triton::ReturnOp>(op);
27+
if (funcOp->hasAttr("noinline"))
28+
return !funcOp->getAttrOfType<BoolAttr>("noinline").getValue();
29+
return true;
30+
}
7231

73-
// Replace the values directly with the return operands.
74-
assert(returnOp.getNumOperands() == valuesToRepl.size());
75-
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
76-
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
77-
}
78-
};
32+
/// Handle the given inlined terminator by replacing it with a new operation
33+
/// as necessary.
34+
void TritonInlinerInterface::handleTerminator(Operation *op,
35+
Block *newDest) const {
36+
// Only return needs to be handled here.
37+
auto returnOp = dyn_cast<triton::ReturnOp>(op);
38+
if (!returnOp)
39+
return;
40+
41+
// Replace the return with a branch to the dest.
42+
OpBuilder builder(op);
43+
builder.create<mlir::cf::BranchOp>(op->getLoc(), newDest,
44+
returnOp.getOperands());
45+
op->erase();
46+
}
7947

80-
} // namespace
48+
/// Handle the given inlined terminator by replacing it with a new operation
49+
/// as necessary.
50+
void TritonInlinerInterface::handleTerminator(Operation *op,
51+
ValueRange valuesToRepl) const {
52+
// Only return needs to be handled here.
53+
auto returnOp = cast<triton::ReturnOp>(op);
54+
55+
// Replace the values directly with the return operands.
56+
assert(returnOp.getNumOperands() == valuesToRepl.size());
57+
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
58+
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
59+
}
8160

8261
void TritonDialect::initialize() {
8362
registerTypes();

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/Support/LLVM.h"
1212
#include "triton/Analysis/Utility.h"
13+
#include "triton/Dialect/Triton/IR/Interfaces.h"
1314
#include "triton/Dialect/Triton/IR/Utility.h"
1415
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
1516
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -3119,6 +3120,7 @@ void TritonGPUDialect::initialize() {
31193120
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
31203121
#include "triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc"
31213122
>();
3123+
addInterfaces<TritonInlinerInterface>();
31223124
addInterfaces<TritonGPUOpAsmInterface>();
31233125
addInterfaces<TritonGPUInferLayoutInterface>();
31243126
addInterfaces<TritonGPUVerifyTensorLayoutInterface>();

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_triton_library(TritonGPUTransforms
22
AccelerateMatmul.cpp
3+
Canonicalize.cpp
34
Coalesce.cpp
45
F32DotTC.cpp
56
FuseNestedLoops.cpp
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#include "mlir/Dialect/Arith/IR/Arith.h"
2+
#include "mlir/Dialect/SCF/IR/SCF.h"
3+
#include "mlir/Pass/Pass.h"
4+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
5+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
6+
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
7+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
8+
9+
using namespace mlir;
10+
using namespace triton;
11+
namespace ttg = triton::gpu;
12+
namespace ttng = triton::nvidia_gpu;
13+
14+
namespace mlir::triton::gpu {
15+
#define GEN_PASS_DEF_TRITONGPUCANONICALIZE
16+
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
17+
} // namespace mlir::triton::gpu
18+
19+
namespace {
20+
struct Canonicalize
21+
: public ttg::impl::TritonGPUCanonicalizeBase<Canonicalize> {
22+
void runOnOperation() override;
23+
};
24+
} // namespace
25+
26+
void Canonicalize::runOnOperation() {
27+
MLIRContext *ctx = &getContext();
28+
RewritePatternSet patterns(&getContext());
29+
30+
// Populate `arith` and `scf` canonicalizers.
31+
ctx->getLoadedDialect<arith::ArithDialect>()->getCanonicalizationPatterns(
32+
patterns);
33+
ctx->getLoadedDialect<scf::SCFDialect>()->getCanonicalizationPatterns(
34+
patterns);
35+
populateForOpDeadArgumentElimination(patterns);
36+
37+
// Populate select Triton canonicalization patterns. The important patterns to
38+
// EXCLUDE are those that modify layouts, especially `ConvertLayoutOp`
39+
// patterns.
40+
LoadOp::getCanonicalizationPatterns(patterns, ctx);
41+
StoreOp::getCanonicalizationPatterns(patterns, ctx);
42+
BroadcastOp::getCanonicalizationPatterns(patterns, ctx);
43+
ExpandDimsOp::getCanonicalizationPatterns(patterns, ctx);
44+
ttg::WarpSpecializeOp::getCanonicalizationPatterns(patterns, ctx);
45+
ttng::TensorDescToTMAPtrOp::getCanonicalizationPatterns(patterns, ctx);
46+
}

python/src/gluon_ir.cc

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "mlir/IR/Types.h"
77
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
88
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
9+
#include "triton/Dialect/TritonGPU/IR/Types.h"
910
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1011

1112
using namespace mlir;
@@ -80,6 +81,17 @@ void init_gluon_ir(py::module &&m) {
8081
ctx, swizzleByteWidth, transposed, elementBitwidth, fp4Padded,
8182
ctaLayout);
8283
})
84+
.def("get_swizzled_shared_layout",
85+
[](GluonOpBuilder &self, int vec, int perPhase, int maxPhase,
86+
std::vector<unsigned> &order, std::vector<unsigned> &ctasPerCga,
87+
std::vector<unsigned> &ctaSplitNum,
88+
std::vector<unsigned> &ctaOrder) -> Attribute {
89+
auto ctx = self.getContext();
90+
auto ctaLayout = ttg::CTALayoutAttr::get(ctx, ctasPerCga,
91+
ctaSplitNum, ctaOrder);
92+
return ttg::SwizzledSharedEncodingAttr::get(
93+
ctx, vec, perPhase, maxPhase, order, ctaLayout);
94+
})
8395
.def("get_tensor_memory_layout",
8496
[](GluonOpBuilder &self, std::vector<unsigned> &block, bool unpacked,
8597
std::vector<unsigned> &ctaSplitNum) -> Attribute {
@@ -94,6 +106,10 @@ void init_gluon_ir(py::module &&m) {
94106
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
95107
return self.create<ttg::ConvertLayoutOp>(resultTy, value);
96108
})
109+
.def("create_local_alloc",
110+
[](GluonOpBuilder &self, Type resultTy) -> Value {
111+
return self.create<ttg::LocalAllocOp>(resultTy);
112+
})
97113
.def("create_local_alloc",
98114
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
99115
return self.create<ttg::LocalAllocOp>(resultTy, value);
@@ -106,10 +122,19 @@ void init_gluon_ir(py::module &&m) {
106122
[](GluonOpBuilder &self, Type resultTy, Value memDesc) -> Value {
107123
return self.create<ttg::LocalLoadOp>(resultTy, memDesc);
108124
})
125+
.def("create_local_dealloc",
126+
[](GluonOpBuilder &self, Value memDesc) -> Operation * {
127+
return self.create<ttg::LocalDeallocOp>(memDesc);
128+
})
129+
109130
.def("create_tmem_alloc",
110131
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
111132
return self.create<ttng::TMEMAllocOp>(resultTy, value);
112133
})
134+
.def("create_tmem_alloc",
135+
[](GluonOpBuilder &self, Type resultTy, py::none value) -> Value {
136+
return self.create<ttng::TMEMAllocOp>(resultTy, Value{});
137+
})
113138
.def("create_tmem_store",
114139
[](GluonOpBuilder &self, Value memDesc, Value value, Value pred) {
115140
self.create<ttng::TMEMStoreOp>(memDesc, value, pred);
@@ -123,6 +148,38 @@ void init_gluon_ir(py::module &&m) {
123148
int N) -> Value {
124149
return self.create<ttng::TMEMSubSliceOp>(resultTy, memDesc, N);
125150
})
151+
.def("create_mbarrier_init",
152+
[](GluonOpBuilder &self, Value memDesc, int count) {
153+
self.create<ttng::InitBarrierOp>(memDesc, count);
154+
})
155+
.def("create_mbarrier_inval",
156+
[](GluonOpBuilder &self, Value memDesc) {
157+
self.create<ttng::InvalBarrierOp>(memDesc);
158+
})
159+
.def("create_mbarrier_expect",
160+
[](GluonOpBuilder &self, Value memDesc, int bytes, Value pred) {
161+
self.create<ttng::BarrierExpectOp>(memDesc, bytes, pred);
162+
})
163+
.def("create_mbarrier_wait",
164+
[](GluonOpBuilder &self, Value memDesc, Value phase, Value pred,
165+
std::vector<Value> &deps) {
166+
self.create<ttng::WaitBarrierOp>(memDesc, phase, pred, deps);
167+
})
168+
.def("create_mbarrier_arrive",
169+
[](GluonOpBuilder &self, Value memDesc, int count, Value pred) {
170+
self.create<ttng::ArriveBarrierOp>(memDesc, count, pred);
171+
})
172+
.def("create_tcgen05_mma",
173+
[](GluonOpBuilder &self, Value a, Value b, Value acc, Value useAcc,
174+
Value pred, std::vector<Value> &mbarriers,
175+
std::vector<Value> &mbarrier_preds) {
176+
Value accDep;
177+
bool two_ctas = false;
178+
auto tokType = self.getBuilder().getType<ttg::AsyncTokenType>();
179+
self.create<ttng::TCGen5MMAOp>(tokType, a, b, acc, accDep, useAcc,
180+
pred, two_ctas, mbarriers,
181+
mbarrier_preds);
182+
})
126183
.def("create_warp_return",
127184
[](GluonOpBuilder &self) -> Operation * {
128185
return self.create<ttg::WarpReturnOp>();

python/src/passes.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ void init_triton_passes_ttir(py::module &&m) {
5151
}
5252

5353
void init_triton_passes_ttgpuir(py::module &&m) {
54+
using namespace mlir;
5455
using namespace mlir::triton::gpu;
5556
ADD_PASS_WRAPPER_0("add_coalesce", createTritonGPUCoalesce);
5657
ADD_PASS_WRAPPER_0("add_optimize_thread_locality",
@@ -85,6 +86,12 @@ void init_triton_passes_ttgpuir(py::module &&m) {
8586
ADD_PASS_WRAPPER_0("add_fuse_nested_loops", createTritonGPUFuseNestedLoops);
8687
ADD_PASS_WRAPPER_0("add_coalesce_async_copy",
8788
createTritonGPUCoalesceAsyncCopy);
89+
ADD_PASS_WRAPPER_0("add_canonicalizer", createTritonGPUCanonicalize);
90+
ADD_PASS_WRAPPER_0("add_inliner", [] {
91+
return createInlinerPass(/*opPipelines=*/{}, [](OpPassManager &pm) {
92+
pm.addPass(createTritonGPUCanonicalize());
93+
});
94+
});
8895
}
8996

9097
void init_triton_passes_convert(py::module &&m) {

python/test/backend/test_device_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(self, device_type: str) -> None:
9494
self.driver = ExtensionDriver()
9595
self.version_key = None
9696

97-
def add_stages(self, arch, extern_libs, stages):
97+
def add_stages(self, stages, options, language):
9898
filter_in_stages = ["ast", "ttir", "ttgir"]
9999
filter_out_stages = []
100100
for key, _ in stages.items():

0 commit comments

Comments
 (0)