Skip to content

Commit e5aa2ab

Browse files
authored
[TritonGPU] Allow inlining ttng ops and actually run the canonicalizer (#7018)
1 parent b04547a commit e5aa2ab

File tree

5 files changed

+31
-1
lines changed

5 files changed

+31
-1
lines changed

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ def TritonGPUCanonicalize: Pass<"tritongpu-canonicalize"> {
369369
}];
370370
let dependentDialects = [
371371
"mlir::arith::ArithDialect",
372+
"mlir::cf::ControlFlowDialect",
372373
"mlir::scf::SCFDialect",
373374
];
374375
}

lib/Dialect/TritonGPU/Transforms/Canonicalize.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "mlir/Dialect/Arith/IR/Arith.h"
2+
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
23
#include "mlir/Dialect/SCF/IR/SCF.h"
34
#include "mlir/Pass/Pass.h"
45
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -32,6 +33,8 @@ void Canonicalize::runOnOperation() {
3233
patterns);
3334
ctx->getLoadedDialect<scf::SCFDialect>()->getCanonicalizationPatterns(
3435
patterns);
36+
ctx->getLoadedDialect<cf::ControlFlowDialect>()->getCanonicalizationPatterns(
37+
patterns);
3538
populateForOpDeadArgumentElimination(patterns);
3639

3740
// Populate select Triton canonicalization patterns. The important patterns to
@@ -43,4 +46,6 @@ void Canonicalize::runOnOperation() {
4346
ExpandDimsOp::getCanonicalizationPatterns(patterns, ctx);
4447
ttg::WarpSpecializeOp::getCanonicalizationPatterns(patterns, ctx);
4548
ttng::TensorDescToTMAPtrOp::getCanonicalizationPatterns(patterns, ctx);
49+
50+
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
4651
}

lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "mlir/IR/DialectImplementation.h"
3131
#include "mlir/IR/OpImplementation.h"
3232
#include "triton/Analysis/Utility.h"
33+
#include "triton/Dialect/Triton/IR/Interfaces.h"
3334
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
3435
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
3536
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
@@ -264,6 +265,7 @@ void TritonNvidiaGPUDialect::initialize() {
264265
#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc"
265266
>();
266267
addInterfaces<TritonGPUOpAsmInterface>();
268+
addInterfaces<TritonInlinerInterface>();
267269
}
268270

269271
// verify TritonNvidiaGPU ops

test/TritonGPU/inline.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -allow-unregistered-dialect -inline | FileCheck %s
1+
// RUN: triton-opt %s -inline | FileCheck %s
22

33
#smem = #ttg.shared_memory
44
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>

test/TritonNvidiaGPU/inline.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: triton-opt %s -inline | FileCheck %s
2+
3+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
4+
#smem = #ttg.shared_memory
5+
6+
module attributes {"ttg.num-warps" = 4 : i32} {
7+
8+
// CHECK-LABEL: @inline_ttng_ops
9+
tt.func public @inline_ttng_ops() {
10+
// CHECK-NEXT: ttg.local_alloc
11+
// CHECK-NEXT: ttng.init_barrier
12+
tt.call @function_with_ttng_ops() : () -> ()
13+
tt.return
14+
}
15+
16+
tt.func private @function_with_ttng_ops() {
17+
%0 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared, #smem, mutable>
18+
ttng.init_barrier %0, 1 : !ttg.memdesc<1xi64, #shared, #smem, mutable>
19+
tt.return
20+
}
21+
22+
}

0 commit comments

Comments
 (0)