Skip to content

Commit 86dffe2

Browse files
Ensure CallOp with SPIR_FUNC cconv (#3692)
This PR ensures `CallOp` has `SPIR_FUNC` cconv on SPIR target. By adding a pattern to fix-up calling convention, more common patterns can be reused. This PR also fixes cases of `CallOp` without `SPIR_FUNC` calling convention. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent cbb1cce commit 86dffe2

File tree

11 files changed

+118
-139
lines changed

11 files changed

+118
-139
lines changed

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1985,8 +1985,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
19851985
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
19861986
// CHECK-LABEL: print_ptr
19871987
tt.func @print_ptr(%arg0 : tensor<256x!tt.ptr<i32>, #blocked0>) {
1988-
// CHECK: llvm.call @_Z18__spirv_ocl_printf(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} %{{.*}}) vararg(!llvm.func<i32 (ptr<2>, ...)>) : (!llvm.ptr<2>, i32, i32, i32, i32, !llvm.ptr<1>) -> i32
1989-
// CHECK-NEXT: llvm.call @_Z18__spirv_ocl_printf(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) vararg(!llvm.func<i32 (ptr<2>, ...)>) : (!llvm.ptr<2>, i32, i32, i32, i32, !llvm.ptr<1>) -> i32
1988+
// CHECK: llvm.call spir_funccc @_Z18__spirv_ocl_printf(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} %{{.*}}) vararg(!llvm.func<i32 (ptr<2>, ...)>) : (!llvm.ptr<2>, i32, i32, i32, i32, !llvm.ptr<1>) -> i32
1989+
// CHECK-NEXT: llvm.call spir_funccc @_Z18__spirv_ocl_printf(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) vararg(!llvm.func<i32 (ptr<2>, ...)>) : (!llvm.ptr<2>, i32, i32, i32, i32, !llvm.ptr<1>) -> i32
19901990
tt.print "ptr: " {hex = false, isSigned = array<i32: 0>} : %arg0 : tensor<256x!tt.ptr<i32>, #blocked0>
19911991
tt.return
19921992
}
@@ -1998,7 +1998,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
19981998
// Test that %u format specifier is used if isSigned is false
19991999
// CHECK: llvm.mlir.global internal constant @printfFormat_("pid (%u, %u, %u) idx ()int32 tensor: %u\0A\00") {addr_space = 2 : i32}
20002000
// CHECK-LABEL: print_int32_tensor_issigned_off
2001-
// CHECK: llvm.call @_Z18__spirv_ocl_printf(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) vararg(!llvm.func<i32 (ptr<2>, ...)>) : (!llvm.ptr<2>, i32, i32, i32, i32) -> i32
2001+
// CHECK: llvm.call spir_funccc @_Z18__spirv_ocl_printf(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) vararg(!llvm.func<i32 (ptr<2>, ...)>) : (!llvm.ptr<2>, i32, i32, i32, i32) -> i32
20022002
tt.func @print_int32_tensor_issigned_off(%arg0 : i32) {
20032003
tt.print "int32 tensor: " {hex = false, isSigned = array<i32: 0>} : %arg0 : i32
20042004
tt.return
@@ -2011,7 +2011,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
20112011
// Test that %i format specifier is used if isSigned is true
20122012
// CHECK: llvm.mlir.global internal constant @printfFormat_("pid (%u, %u, %u) idx ()int32 tensor: %i\0A\00") {addr_space = 2 : i32}
20132013
// CHECK-LABEL: print_int32_tensor_issigned_on
2014-
// CHECK: llvm.call @_Z18__spirv_ocl_printf(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) vararg(!llvm.func<i32 (ptr<2>, ...)>) : (!llvm.ptr<2>, i32, i32, i32, i32) -> i32
2014+
// CHECK: llvm.call spir_funccc @_Z18__spirv_ocl_printf(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) vararg(!llvm.func<i32 (ptr<2>, ...)>) : (!llvm.ptr<2>, i32, i32, i32, i32) -> i32
20152015
tt.func @print_int32_tensor_issigned_on(%arg0 : i32) {
20162016
tt.print "int32 tensor: " {hex = false, isSigned = array<i32: 1>} : %arg0 : i32
20172017
tt.return

test/TritonIntelGPU/tritonintelgpu-rewrite-stack-ptr.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.sup
77
%0 = tt.load %arg0 : !tt.ptr<f32>
88
%1 = tt.load %arg1 : !tt.ptr<f32>
99
// CHECK: llvm.mlir.poison : !llvm.ptr<3>
10-
// CHECK: llvm.call @noinline_simple_fn__fp32_fp32_Pfp32__(%8, %17, %arg2, %18, %arg2)
10+
// CHECK: llvm.call spir_funccc @noinline_simple_fn__fp32_fp32_Pfp32__(%8, %17, %arg2, %18, %arg2)
1111
tt.call @noinline_simple_fn__fp32_fp32_Pfp32__(%0, %1, %arg2) : (f32, f32, !tt.ptr<f32>) -> ()
1212
tt.return
1313
}
@@ -31,7 +31,7 @@ module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.sup
3131
tt.func public @kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
3232
%0 = tt.load %arg0 : !tt.ptr<f32>
3333
%1 = tt.load %arg1 : !tt.ptr<f32>
34-
// CHECK: llvm.call @noinline_shared_fn__fp32_fp32_Pfp32__(%8, %17, %arg2, %arg3, %arg2)
34+
// CHECK: llvm.call spir_funccc @noinline_shared_fn__fp32_fp32_Pfp32__(%8, %17, %arg2, %arg3, %arg2)
3535
tt.call @noinline_shared_fn__fp32_fp32_Pfp32__(%0, %1, %arg2) {allocation.offset = 0 : i32} : (f32, f32, !tt.ptr<f32>) -> ()
3636
tt.return
3737
}

test/TritonIntelGPU/tritonintlgpu-nested-layout.mlir

Lines changed: 64 additions & 64 deletions
Large diffs are not rendered by default.

third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ inline bool hasSpirvTargetArch(Operation *op) {
5757
triton::gpu::intel::TritonIntelGPUDialect::getTargetArchAttrName());
5858
return !arch || arch.str().substr(0, 4) == "spir";
5959
}
60+
61+
inline LLVM::cconv::CConv getRequiredCConv(Operation *op) {
62+
if (hasSpirvTargetArch(op))
63+
return LLVM::cconv::CConv::SPIR_FUNC;
64+
llvm_unreachable("Unexpected target architecture");
65+
}
6066
} // namespace mlir::triton::gpu::intel
6167

6268
#endif // TRITON_DIALECT_TRITON_INTEL_GPU_IR_UTILS_H

third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_triton_library(TritonIntelGPUToLLVM
22
ArithOpsToLLVM.cpp
33
BF16Casts.cpp
4+
ControlFlowOpToLLVM.cpp
45
ConvertLayoutOpToLLVM.cpp
56
DecomposeUnsupportedConversions.cpp
67
DotOpToLLVM/DPAS.cpp
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include "PatternTritonGPUOpToLLVM.h"
2+
#include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h"
3+
4+
namespace {
5+
6+
struct FixCallCConv : public ConvertOpToLLVMPattern<LLVM::CallOp> {
7+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
8+
9+
LogicalResult
10+
matchAndRewrite(LLVM::CallOp op, LLVM::CallOp::Adaptor adaptor,
11+
ConversionPatternRewriter &rewriter) const override {
12+
rewriter.startOpModification(op);
13+
op.setCConv(triton::gpu::intel::getRequiredCConv(op));
14+
rewriter.finalizeOpModification(op);
15+
return success();
16+
}
17+
};
18+
19+
} // namespace
20+
21+
void mlir::triton::intel::populateControlFlowOpToLLVMPattern(
22+
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
23+
const TargetInfoBase &targetInfo, PatternBenefit benefit) {
24+
patterns.add<FixCallCConv>(typeConverter);
25+
mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns,
26+
targetInfo, benefit);
27+
}

third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,34 +1085,6 @@ struct FpToFpOpConversion
10851085
}
10861086
};
10871087

1088-
struct ExternElementwiseOpConversion
1089-
: public ElementwiseOpConversionBase<ExternElementwiseOp,
1090-
ExternElementwiseOpConversion> {
1091-
using Base = ElementwiseOpConversionBase<ExternElementwiseOp,
1092-
ExternElementwiseOpConversion>;
1093-
using Base::Base;
1094-
using Adaptor = typename Base::OpAdaptor;
1095-
typedef typename Base::OpAdaptor OpAdaptor;
1096-
1097-
SmallVector<Value> createDestOps(ExternElementwiseOp op, OpAdaptor adaptor,
1098-
ConversionPatternRewriter &rewriter,
1099-
Type elemTy, MultipleOperandsRange operands,
1100-
Location loc) const {
1101-
StringRef funcName = op.getSymbol();
1102-
if (funcName.empty())
1103-
llvm::errs() << "ExternElementwiseOpConversion";
1104-
1105-
Type funcType = getFunctionType(elemTy, operands[0]);
1106-
LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp(
1107-
rewriter, op, funcName, funcType, op.getLibname(), op.getLibpath());
1108-
1109-
auto callOp = LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]);
1110-
callOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1111-
1112-
return {callOp.getResult()};
1113-
}
1114-
};
1115-
11161088
template <typename SourceOp, typename DestOp>
11171089
struct ElementwiseOpConversion
11181090
: ElementwiseOpConversionBase<SourceOp,
@@ -1292,38 +1264,6 @@ struct AbsFOpConversion
12921264
}
12931265
};
12941266

1295-
struct MulhiUIOpConversion
1296-
: public ElementwiseOpConversionBase<MulhiUIOp, MulhiUIOpConversion> {
1297-
using Base = ElementwiseOpConversionBase<MulhiUIOp, MulhiUIOpConversion>;
1298-
using Base::Base;
1299-
using Adaptor = typename Base::OpAdaptor;
1300-
explicit MulhiUIOpConversion(LLVMTypeConverter &typeConverter,
1301-
ModuleAxisInfoAnalysis &axisAnalysisPass,
1302-
const TargetInfoBase &targetInfo,
1303-
PatternBenefit benefit = 1)
1304-
: ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit),
1305-
targetInfo(targetInfo) {}
1306-
SmallVector<Value> createDestOps(MulhiUIOp op, Adaptor adaptor,
1307-
ConversionPatternRewriter &rewriter,
1308-
Type elemTy, MultipleOperandsRange operands,
1309-
Location loc) const {
1310-
1311-
Type resultElementTy = getElementTypeOrSelf(op.getResult().getType());
1312-
assert(resultElementTy.isInteger(32) || resultElementTy.isInteger(64));
1313-
1314-
std::string funcName = targetInfo.getMulhiFuncName(resultElementTy);
1315-
Type funcType = getFunctionType(elemTy, operands[0]);
1316-
LLVM::LLVMFuncOp funcOp =
1317-
appendOrGetExternFuncOp(rewriter, op, funcName, funcType);
1318-
auto callOp = LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]);
1319-
callOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1320-
return {callOp.getResult()};
1321-
}
1322-
1323-
protected:
1324-
const TargetInfoBase &targetInfo;
1325-
};
1326-
13271267
struct PreciseSqrtOpConversion
13281268
: ElementwiseOpConversionBase<PreciseSqrtOp, PreciseSqrtOpConversion> {
13291269
using Base =
@@ -1401,10 +1341,6 @@ void populateElementwiseOpToLLVMPatterns(
14011341

14021342
mlir::triton::populateElementwiseOpToLLVMPatterns(
14031343
typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit);
1404-
patterns.add<MulhiUIOpConversion>(typeConverter, axisInfoAnalysis, targetInfo,
1405-
benefit);
1406-
patterns.add<ExternElementwiseOpConversion>(typeConverter, axisInfoAnalysis,
1407-
benefit);
14081344

14091345
patterns.add<AbsFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
14101346
patterns.add<ElementwiseOpConversion<arith::DivFOp, LLVM::FDivOp>>(

third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
4949
RewritePatternSet &patterns,
5050
PatternBenefit benefit);
5151

52+
void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter,
53+
RewritePatternSet &patterns,
54+
const TargetInfoBase &targetInfo,
55+
PatternBenefit benefit);
56+
5257
void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter,
5358
RewritePatternSet &patterns,
5459
const TargetInfoBase &targetInfo,

third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,8 @@ class TritonGPUToLLVMPipelineManager {
240240
if (isAdvancedPathEnabled) {
241241
intel::populateArithOpsToLLVMPatterns(typeConverter, patterns, benefit);
242242
intel::populateBF16CastsLLVMPatterns(typeConverter, patterns, benefit);
243-
mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns,
244-
targetInfo, benefit);
243+
intel::populateControlFlowOpToLLVMPattern(typeConverter, patterns,
244+
targetInfo, benefit);
245245
intel::populateTritonOpsToLLVMPatterns(typeConverter, patterns, benefit);
246246
} else {
247247
intel::populateConvertLayoutOpToLLVMPatterns(typeConverter, targetInfo,
@@ -272,8 +272,8 @@ class TritonGPUToLLVMPipelineManager {
272272
benefit);
273273
mlir::triton::populateMemoryOpToLLVMPatterns(typeConverter, targetInfo,
274274
patterns, benefit);
275-
mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns,
276-
targetInfo, benefit);
275+
intel::populateControlFlowOpToLLVMPattern(typeConverter, patterns,
276+
targetInfo, benefit);
277277
mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo,
278278
patterns, benefit);
279279
intel::populateFp4ToFpToLLVMPatterns(typeConverter, patterns, benefit);

third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,8 @@ void TargetInfo::printf(RewriterBase &rewriter, Value formatStrStart,
221221
operands.push_back(printfPromoteValue(
222222
rewriter, arg, isSigned.empty() ? true : isSigned[i]));
223223
}
224-
b.call(funcOp, operands);
224+
auto callOp = b.call(funcOp, operands);
225+
callOp.setCConv(triton::gpu::intel::getRequiredCConv(callOp));
225226
}
226227

227228
void TargetInfo::printf(RewriterBase &rewriter, StringRef msg, ValueRange args,

0 commit comments

Comments
 (0)