Skip to content

Commit 5cdcb89

Browse files
committed
Add call arg test + fix crash
1 parent 362ebb6 commit 5cdcb89

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

lib/Conversion/TritonPtrToMemref/TritonPtrToMemrefPass.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//===----------------------------------------------------------------------===//
77

88
#include "mlir/Dialect/Arith/IR/Arith.h"
9+
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
910
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1011
#include "mlir/Dialect/SCF/IR/SCF.h"
1112
#include "mlir/IR/Builders.h"
@@ -102,8 +103,7 @@ class TritonPtrToMemrefPass
102103
patterns, typeConverter);
103104
populateFunctionOpInterfaceTypeConversionPattern<triton::FuncOp>(
104105
patterns, typeConverter);
105-
populateFunctionOpInterfaceTypeConversionPattern<func::CallOp>(
106-
patterns, typeConverter);
106+
populateCallOpTypeConversionPattern(patterns, typeConverter);
107107

108108
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
109109
signalPassFailure();
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: triton-shared-opt --triton-arith-to-linalg --triton-ptr-to-memref %s | FileCheck %s
2+
3+
module {
4+
tt.func @_sum_combine__fp32(%arg0: !tt.ptr<f32>) -> f32{
5+
%0 = arith.constant 42.0 : f32
6+
tt.return %0 : f32
7+
}
8+
tt.func @test(%arg0: !tt.ptr<f32>) -> f32{
9+
%0 = tt.call @_sum_combine__fp32(%arg0) : (!tt.ptr<f32>) -> f32
10+
tt.return %0 : f32
11+
}
12+
}
13+
14+
// CHECK: module {
15+
// CHECK: func.func @_sum_combine__fp32(%arg0: memref<*xf32>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) -> f32 {
16+
// CHECK: %cst = arith.constant 4.200000e+01 : f32
17+
// CHECK: return %cst : f32
18+
// CHECK: }
19+
// CHECK: func.func @test(%arg0: memref<*xf32>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) -> f32 {
20+
// CHECK: %0 = call @_sum_combine__fp32(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (memref<*xf32>, i32, i32, i32, i32, i32, i32) -> f32
21+
// CHECK: return %0 : f32
22+
// CHECK: }
23+
// CHECK: }

0 commit comments

Comments
 (0)