Skip to content

Commit cd6d25f

Browse files
authored
[Gluon] Fix inlining functions with gluon.set_auto_layout op (#7553)
The gluon dialect is missing an inliner interface implementation, without which the inliner defaults to blocking all inlining.
1 parent f7d4eba commit cd6d25f

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

lib/Dialect/Gluon/IR/Dialect.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include "triton/Dialect/Gluon/IR/Dialect.h"
22

3-
#include "mlir/IR/DialectImplementation.h"
43
#include "mlir/Support/LLVM.h"
4+
#include "triton/Dialect/Triton/IR/Interfaces.h"
55
#include "llvm/ADT/TypeSwitch.h"
66

77
using namespace mlir;
@@ -111,6 +111,7 @@ void GluonDialect::initialize() {
111111
#define GET_OP_LIST
112112
#include "triton/Dialect/Gluon/IR/Ops.cpp.inc"
113113
>();
114+
addInterfaces<TritonInlinerInterface>();
114115
addInterfaces<GluonInferLayoutInterface>();
115116
}
116117

test/Gluon/inlining.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: triton-opt %s --gluon-inline | FileCheck %s
2+
3+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
4+
5+
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
6+
tt.func private @set_encoding(%arg0 : tensor<16xi32, #gluon.auto_encoding>) -> tensor<16xi32, #blocked> {
7+
%cvt = gluon.set_auto_layout %arg0 : tensor<16xi32, #gluon.auto_encoding> -> tensor<16xi32, #blocked>
8+
tt.return %cvt : tensor<16xi32, #blocked>
9+
}
10+
11+
tt.func public @infer_make_range() -> tensor<16xi32, #blocked> {
12+
// CHECK-DAG: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
13+
// CHECK: [[CST:%.*]] = arith.constant dense<0> : tensor<16xi32, #gluon.auto_encoding>
14+
// CHECK: [[SET:%.*]] = gluon.set_auto_layout [[CST]] : tensor<16xi32, #gluon.auto_encoding> -> tensor<16xi32, [[BLOCKED]]>
15+
// CHECK: tt.return [[SET]] : tensor<16xi32, [[BLOCKED]]>
16+
%cst = arith.constant dense<0> : tensor<16xi32, #gluon.auto_encoding>
17+
%0 = tt.call @"set_encoding"(%cst) : (tensor<16xi32, #gluon.auto_encoding>) -> tensor<16xi32, #blocked>
18+
tt.return %0 : tensor<16xi32, #blocked>
19+
}
20+
}

0 commit comments

Comments
 (0)