Skip to content

Commit 8f3b43e

Browse files
authored
[TritonArithToLinalg] Preserve function visibility during conversion. (#347)
For cases where functions are not inlined we want to preserve visibility information.
1 parent c1f338b commit 8f3b43e

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ class TritonArithToLinalgPass
216216
func.getAllResultAttrs(resAttrs);
217217

218218
auto funcFunc = builder.create<func::FuncOp>(func.getLoc(), name, type);
219+
// Preserve the visibility attribute
220+
funcFunc.setVisibility(func.getVisibility());
219221
funcFunc.setAllArgAttrs(argAttrs);
220222
funcFunc.setAllResultAttrs(resAttrs);
221223

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: triton-shared-opt --triton-to-linalg-experimental %s | FileCheck %s
2+
3+
module {
4+
tt.func public @kernel(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
5+
%c0_i32 = arith.constant 0 : i32
6+
%0 = tt.get_program_id x : i32
7+
%1 = tt.load %arg0 : !tt.ptr<i32>
8+
%2 = arith.cmpi eq, %0, %c0_i32 : i32
9+
%3 = scf.if %2 -> (i32) {
10+
%4 = tt.call @test_core.add_fn_return__i32_i32__(%1, %0) : (i32, i32) -> i32
11+
scf.yield %4 : i32
12+
} else {
13+
scf.yield %1 : i32
14+
}
15+
tt.store %arg0, %3 : !tt.ptr<i32>
16+
tt.return
17+
}
18+
tt.func private @test_core.add_fn_return__i32_i32__(%arg0: i32, %arg1: i32) -> i32 attributes {noinline = false} {
19+
%c2_i32 = arith.constant 2 : i32
20+
%c1_i32 = arith.constant 1 : i32
21+
%c0_i32 = arith.constant 0 : i32
22+
%0 = arith.cmpi eq, %arg1, %c0_i32 : i32
23+
cf.cond_br %0, ^bb1(%c1_i32 : i32), ^bb1(%c2_i32 : i32)
24+
^bb1(%1: i32): // 2 preds: ^bb0, ^bb0
25+
%2 = arith.addi %arg0, %1 : i32
26+
tt.return %2 : i32
27+
}
28+
}
29+
30+
// Public is implicit
31+
// CHECK-LABEL: func.func @kernel
32+
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xi32> {tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32)
33+
34+
// CHECK-LABEL: func.func private @test_core.add_fn_return__i32_i32__
35+
// CHECK-SAME: ([[PARAM_0_:%.+]]: i32, [[PARAM_1_:%.+]]: i32) -> i32

0 commit comments

Comments
 (0)