Skip to content

Commit c109dc7

Browse files
authored
[Gluon] Add an opt pass pipeline for gluon (#6992)
This PR adds a separate TTGIR optimization pass for code parsed from Gluon directly. Most important is the inliner and basic running TTGIR-level optimizations. I added passes that I thought were obviously good to have.
1 parent fed9ac4 commit c109dc7

File tree

15 files changed

+244
-79
lines changed

15 files changed

+244
-79
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,45 @@
11
#ifndef TRITON_IR_INTERFACES_H_
22
#define TRITON_IR_INTERFACES_H_
33

4+
#include "mlir/IR/DialectImplementation.h"
45
#include "mlir/IR/OpDefinition.h"
6+
#include "mlir/Transforms/InliningUtils.h"
57

68
#define GET_TYPEDEF_CLASSES
79
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
810

11+
namespace mlir::triton {
12+
13+
//===----------------------------------------------------------------------===//
14+
// TritonDialect Dialect Interfaces
15+
//===----------------------------------------------------------------------===//
16+
17+
struct TritonInlinerInterface : public DialectInlinerInterface {
18+
using DialectInlinerInterface::DialectInlinerInterface;
19+
20+
bool isLegalToInline(Operation *call, Operation *callable,
21+
bool wouldBeCloned) const final;
22+
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
23+
IRMapping &valueMapping) const final {
24+
return true;
25+
}
26+
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
27+
IRMapping &) const final {
28+
return true;
29+
}
30+
31+
//===--------------------------------------------------------------------===//
32+
// Transformation Hooks
33+
//===--------------------------------------------------------------------===//
34+
35+
/// Handle the given inlined terminator by replacing it with a new operation
36+
/// as necessary.
37+
void handleTerminator(Operation *op, Block *newDest) const final;
38+
/// Handle the given inlined terminator by replacing it with a new operation
39+
/// as necessary.
40+
void handleTerminator(Operation *op, ValueRange valuesToRepl) const final;
41+
};
42+
43+
} // namespace mlir::triton
44+
945
#endif // TRITON_IR_TYPES_H_

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,4 +360,17 @@ def TritonGPUCoalesceAsyncCopy: Pass<"tritongpu-coalesce-async-copy", "mlir::Mod
360360
"mlir::triton::TritonDialect"];
361361
}
362362

363+
def TritonGPUCanonicalize: Pass<"tritongpu-canonicalize"> {
364+
let summary = "reduced set of simplifications for TTGIR";
365+
366+
let description = [{
367+
The `tritongpu-canonicalize` pass applies a reduced set of simplification
368+
and canonicalization patterns to the module.
369+
}];
370+
let dependentDialects = [
371+
"mlir::arith::ArithDialect",
372+
"mlir::scf::SCFDialect",
373+
];
374+
}
375+
363376
#endif

lib/Dialect/Triton/IR/Dialect.cpp

Lines changed: 37 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
#include "triton/Dialect/Triton/IR/Dialect.h"
2+
#include "triton/Dialect/Triton/IR/Interfaces.h"
23
#include "triton/Dialect/Triton/IR/Types.h"
34

45
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
56
#include "mlir/Dialect/UB/IR/UBOps.h"
67
#include "llvm/ADT/StringSwitch.h"
78
#include "llvm/ADT/TypeSwitch.h"
8-
#include "llvm/Support/raw_ostream.h"
99

10-
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
11-
#include "mlir/IR/DialectImplementation.h"
12-
13-
#include "mlir/Transforms/InliningUtils.h"
1410
#include "triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc"
1511
#include "triton/Dialect/Triton/IR/Dialect.cpp.inc"
1612
#include "triton/Dialect/Triton/IR/OpInterfaces.cpp.inc"
@@ -22,62 +18,45 @@ using namespace mlir::triton;
2218
// TritonDialect Dialect Interfaces
2319
//===----------------------------------------------------------------------===//
2420

25-
namespace {
26-
struct TritonInlinerInterface : public DialectInlinerInterface {
27-
using DialectInlinerInterface::DialectInlinerInterface;
28-
29-
bool isLegalToInline(Operation *call, Operation *callable,
30-
bool wouldBeCloned) const final {
31-
auto funcOp = dyn_cast<triton::FuncOp>(callable);
32-
if (!funcOp)
33-
return true;
34-
if (funcOp->hasAttr("noinline"))
35-
return !funcOp->getAttrOfType<BoolAttr>("noinline").getValue();
36-
return true;
37-
}
38-
39-
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
40-
IRMapping &valueMapping) const final {
41-
return true;
42-
}
43-
44-
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
45-
IRMapping &) const final {
21+
bool TritonInlinerInterface::isLegalToInline(Operation *call,
22+
Operation *callable,
23+
bool wouldBeCloned) const {
24+
auto funcOp = dyn_cast<triton::FuncOp>(callable);
25+
if (!funcOp)
4626
return true;
47-
}
48-
//===--------------------------------------------------------------------===//
49-
// Transformation Hooks
50-
//===--------------------------------------------------------------------===//
51-
52-
/// Handle the given inlined terminator by replacing it with a new operation
53-
/// as necessary.
54-
void handleTerminator(Operation *op, Block *newDest) const final {
55-
// Only return needs to be handled here.
56-
auto returnOp = dyn_cast<triton::ReturnOp>(op);
57-
if (!returnOp)
58-
return;
59-
60-
// Replace the return with a branch to the dest.
61-
OpBuilder builder(op);
62-
builder.create<mlir::cf::BranchOp>(op->getLoc(), newDest,
63-
returnOp.getOperands());
64-
op->erase();
65-
}
66-
67-
/// Handle the given inlined terminator by replacing it with a new operation
68-
/// as necessary.
69-
void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
70-
// Only return needs to be handled here.
71-
auto returnOp = cast<triton::ReturnOp>(op);
27+
if (funcOp->hasAttr("noinline"))
28+
return !funcOp->getAttrOfType<BoolAttr>("noinline").getValue();
29+
return true;
30+
}
7231

73-
// Replace the values directly with the return operands.
74-
assert(returnOp.getNumOperands() == valuesToRepl.size());
75-
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
76-
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
77-
}
78-
};
32+
/// Handle the given inlined terminator by replacing it with a new operation
33+
/// as necessary.
34+
void TritonInlinerInterface::handleTerminator(Operation *op,
35+
Block *newDest) const {
36+
// Only return needs to be handled here.
37+
auto returnOp = dyn_cast<triton::ReturnOp>(op);
38+
if (!returnOp)
39+
return;
40+
41+
// Replace the return with a branch to the dest.
42+
OpBuilder builder(op);
43+
builder.create<mlir::cf::BranchOp>(op->getLoc(), newDest,
44+
returnOp.getOperands());
45+
op->erase();
46+
}
7947

80-
} // namespace
48+
/// Handle the given inlined terminator by replacing it with a new operation
49+
/// as necessary.
50+
void TritonInlinerInterface::handleTerminator(Operation *op,
51+
ValueRange valuesToRepl) const {
52+
// Only return needs to be handled here.
53+
auto returnOp = cast<triton::ReturnOp>(op);
54+
55+
// Replace the values directly with the return operands.
56+
assert(returnOp.getNumOperands() == valuesToRepl.size());
57+
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
58+
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
59+
}
8160

8261
void TritonDialect::initialize() {
8362
registerTypes();

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "mlir/IR/OpImplementation.h"
88
#include "mlir/Support/LLVM.h"
99
#include "triton/Analysis/Utility.h"
10+
#include "triton/Dialect/Triton/IR/Interfaces.h"
1011
#include "triton/Dialect/Triton/IR/Utility.h"
1112
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
1213
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -3083,6 +3084,7 @@ void TritonGPUDialect::initialize() {
30833084
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
30843085
#include "triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc"
30853086
>();
3087+
addInterfaces<TritonInlinerInterface>();
30863088
addInterfaces<TritonGPUOpAsmInterface>();
30873089
addInterfaces<TritonGPUInferLayoutInterface>();
30883090
addInterfaces<TritonGPUVerifyTensorLayoutInterface>();

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_triton_library(TritonGPUTransforms
22
AccelerateMatmul.cpp
3+
Canonicalize.cpp
34
Coalesce.cpp
45
F32DotTC.cpp
56
FuseNestedLoops.cpp
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#include "mlir/Dialect/Arith/IR/Arith.h"
2+
#include "mlir/Dialect/SCF/IR/SCF.h"
3+
#include "mlir/Pass/Pass.h"
4+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
5+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
6+
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
7+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
8+
9+
using namespace mlir;
10+
using namespace triton;
11+
namespace ttg = triton::gpu;
12+
namespace ttng = triton::nvidia_gpu;
13+
14+
namespace mlir::triton::gpu {
15+
#define GEN_PASS_DEF_TRITONGPUCANONICALIZE
16+
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
17+
} // namespace mlir::triton::gpu
18+
19+
namespace {
20+
struct Canonicalize
21+
: public ttg::impl::TritonGPUCanonicalizeBase<Canonicalize> {
22+
void runOnOperation() override;
23+
};
24+
} // namespace
25+
26+
void Canonicalize::runOnOperation() {
27+
MLIRContext *ctx = &getContext();
28+
RewritePatternSet patterns(&getContext());
29+
30+
// Populate `arith` and `scf` canonicalizers.
31+
ctx->getLoadedDialect<arith::ArithDialect>()->getCanonicalizationPatterns(
32+
patterns);
33+
ctx->getLoadedDialect<scf::SCFDialect>()->getCanonicalizationPatterns(
34+
patterns);
35+
populateForOpDeadArgumentElimination(patterns);
36+
37+
// Populate select Triton canonicalization patterns. The important patterns to
38+
// EXCLUDE are those that modify layouts, especially `ConvertLayoutOp`
39+
// patterns.
40+
LoadOp::getCanonicalizationPatterns(patterns, ctx);
41+
StoreOp::getCanonicalizationPatterns(patterns, ctx);
42+
BroadcastOp::getCanonicalizationPatterns(patterns, ctx);
43+
ExpandDimsOp::getCanonicalizationPatterns(patterns, ctx);
44+
ttg::WarpSpecializeOp::getCanonicalizationPatterns(patterns, ctx);
45+
ttng::TensorDescToTMAPtrOp::getCanonicalizationPatterns(patterns, ctx);
46+
}

python/src/passes.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ void init_triton_passes_ttir(py::module &&m) {
5151
}
5252

5353
void init_triton_passes_ttgpuir(py::module &&m) {
54+
using namespace mlir;
5455
using namespace mlir::triton::gpu;
5556
ADD_PASS_WRAPPER_0("add_coalesce", createTritonGPUCoalesce);
5657
ADD_PASS_WRAPPER_0("add_optimize_thread_locality",
@@ -85,6 +86,12 @@ void init_triton_passes_ttgpuir(py::module &&m) {
8586
ADD_PASS_WRAPPER_0("add_fuse_nested_loops", createTritonGPUFuseNestedLoops);
8687
ADD_PASS_WRAPPER_0("add_coalesce_async_copy",
8788
createTritonGPUCoalesceAsyncCopy);
89+
ADD_PASS_WRAPPER_0("add_canonicalizer", createTritonGPUCanonicalize);
90+
ADD_PASS_WRAPPER_0("add_inliner", [] {
91+
return createInlinerPass(/*opPipelines=*/{}, [](OpPassManager &pm) {
92+
pm.addPass(createTritonGPUCanonicalize());
93+
});
94+
});
8895
}
8996

9097
void init_triton_passes_convert(py::module &&m) {

python/test/backend/test_device_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(self, device_type: str) -> None:
9494
self.driver = ExtensionDriver()
9595
self.version_key = None
9696

97-
def add_stages(self, arch, extern_libs, stages):
97+
def add_stages(self, stages, options, language):
9898
filter_in_stages = ["ast", "ttir", "ttgir"]
9999
filter_out_stages = []
100100
for key, _ in stages.items():

python/test/gluon/test_frontend.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_convert_layout(fresh_knobs):
2626
1, ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[1, 32], warps_per_cta=[1, 4], order=[1, 0]))
2727
h = convert_layout_kernel.warmup(128, layout_a, layout_b, num_warps=layout_a.warps_per_cta[0], grid=(1, ))
2828
expecttest.assert_expected_inline(
29-
h.asm["ttgir"], """\
29+
h.asm["source"], """\
3030
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
3131
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
3232
module attributes {"ttg.num-warps" = 4 : i32} {
@@ -37,6 +37,15 @@ def test_convert_layout(fresh_knobs):
3737
} loc(#loc)
3838
} loc(#loc)
3939
#loc = loc(unknown)
40+
""")
41+
expecttest.assert_expected_inline(
42+
h.asm["ttgir"], """\
43+
module attributes {"ttg.num-warps" = 4 : i32} {
44+
tt.func public @convert_layout_kernel() attributes {noinline = false} {
45+
tt.return loc(#loc)
46+
} loc(#loc)
47+
} loc(#loc)
48+
#loc = loc(unknown)
4049
""")
4150

4251

@@ -60,7 +69,7 @@ def test_shared_memory(fresh_knobs):
6069
h = shared_memory_kernel.warmup(8, 32, layout_a, layout_b, smem_layout, num_warps=layout_a.warps_per_cta[0],
6170
grid=(1, ))
6271
expecttest.assert_expected_inline(
63-
h.asm["ttgir"], """\
72+
h.asm["source"], """\
6473
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
6574
#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
6675
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
@@ -103,7 +112,7 @@ def test_tensor_memory(fresh_knobs):
103112
tmem_layout = ttgl.nvidia.blackwell.TensorMemoryLayout(block=[128, 128], unpacked=True)
104113
h = tensor_memory_kernel.warmup(layout, tmem_layout, num_warps=4, grid=(1, ))
105114
expecttest.assert_expected_inline(
106-
h.asm["ttgir"], """\
115+
h.asm["source"], """\
107116
#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
108117
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
109118
module attributes {"ttg.num-warps" = 4 : i32} {
@@ -200,7 +209,7 @@ def test_mbarrier(fresh_knobs):
200209

201210
h = mbarrier_kernel.warmup(grid=(1, ))
202211
expecttest.assert_expected_inline(
203-
h.asm["ttgir"], """\
212+
h.asm["source"], """\
204213
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
205214
#smem = #ttg.shared_memory
206215
module attributes {"ttg.num-warps" = 4 : i32} {
@@ -240,7 +249,7 @@ def test_tcgen05_mma(fresh_knobs):
240249

241250
h = tcgen05_mma_kernel.warmup(nvmma_layout, acc_layout, grid=(1, ))
242251
expecttest.assert_expected_inline(
243-
h.asm["ttgir"], """\
252+
h.asm["source"], """\
244253
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
245254
#smem = #ttg.shared_memory
246255
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>

python/triton/backends/compiler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from abc import ABCMeta, abstractmethod
22
from dataclasses import dataclass
3+
from enum import Enum
34
from typing import Dict, Union
45
from types import ModuleType
56

@@ -13,6 +14,12 @@ class GPUTarget(object):
1314
warp_size: int
1415

1516

17+
class Language(Enum):
18+
"""The input language being compiled by the backend."""
19+
TRITON = 0
20+
GLUON = 1
21+
22+
1623
class BaseBackend(metaclass=ABCMeta):
1724

1825
def __init__(self, target: GPUTarget) -> None:

0 commit comments

Comments
 (0)