Skip to content

Commit be3b9ad

Browse files
Merge commit 'fdac59428cd08d7d7438f330db7224de454c6d52'
2 parents 4c9df48 + fdac594 commit be3b9ad

File tree

19 files changed

+161
-83
lines changed

19 files changed

+161
-83
lines changed

cmake/llvm-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
df0864e761107b07e38f5503e0cbee0cebb4c5e8
1+
61f8a7f618901797ee8663389a29722f29216a96

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ using namespace mlir::triton;
101101
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
102102
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
103103
#define null(...) rewriter.create<LLVM::ZeroOp>(loc, __VA_ARGS__)
104-
#define call(...) rewriter.create<LLVM::CallOp>(loc, __VA_ARGS__)
104+
#define call(...) LLVM::createLLVMCallOp(rewriter, loc, __VA_ARGS__)
105105

106106
// Types
107107
#define int_ty(width) rewriter.getIntegerType(width)
@@ -228,6 +228,12 @@ Value createIndexConstant(OpBuilder &builder, Location loc,
228228
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
229229
int64_t value);
230230

231+
LLVM::CallOp createLLVMCallOp(OpBuilder &builder, Location loc,
232+
LLVMFuncOp funcOp, ValueRange args);
233+
LLVM::CallIntrinsicOp
234+
createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic,
235+
TypeRange types, ValueRange args);
236+
231237
// Is v an integer or floating-point scalar constant equal to 0?
232238
bool isConstantZero(Value v);
233239

lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
109109
auto newCallOp = rewriter.create<LLVM::CallOp>(
110110
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
111111
promotedOperands, callOp->getAttrs());
112+
newCallOp.getProperties().setOpBundleSizes(
113+
rewriter.getDenseI32ArrayAttr({}));
114+
newCallOp.getProperties().setOperandSegmentSizes(
115+
{static_cast<int>(promotedOperands.size()), 0});
112116
return newCallOp;
113117
}
114118

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ struct MulhiUIOpConversion
299299
LLVM::LLVMFuncOp funcOp =
300300
appendOrGetExternFuncOp(rewriter, op, funcName, funcType);
301301
return {
302-
rewriter.create<LLVM::CallOp>(loc, funcOp, operands[0]).getResult()};
302+
LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()};
303303
}
304304

305305
protected:
@@ -327,7 +327,7 @@ struct ExternElementwiseOpConversion
327327
LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp(
328328
rewriter, op, funcName, funcType, op.getLibname(), op.getLibpath());
329329
return {
330-
rewriter.create<LLVM::CallOp>(loc, funcOp, operands[0]).getResult()};
330+
LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()};
331331
}
332332
};
333333

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
22
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
3-
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
3+
#include "mlir/IR/Attributes.h"
44
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
5-
#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"
65
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
76
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
87
#include "llvm/ADT/STLExtras.h"
@@ -518,6 +517,24 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
518517
builder.getIntegerAttr(ty, value));
519518
}
520519

520+
LLVM::CallOp createLLVMCallOp(OpBuilder &builder, Location loc,
521+
LLVMFuncOp funcOp, ValueRange args) {
522+
auto op = builder.create<LLVM::CallOp>(loc, funcOp, args);
523+
op.getProperties().setOpBundleSizes(builder.getDenseI32ArrayAttr({}));
524+
op.getProperties().setOperandSegmentSizes({static_cast<int>(args.size()), 0});
525+
return op;
526+
}
527+
528+
LLVM::CallIntrinsicOp
529+
createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic,
530+
TypeRange types, ValueRange args) {
531+
auto op = builder.create<LLVM::CallIntrinsicOp>(loc, types, args);
532+
op.getProperties().setIntrin(builder.getStringAttr(intrinsic));
533+
op.getProperties().setOpBundleSizes(builder.getDenseI32ArrayAttr({}));
534+
op.getProperties().setOperandSegmentSizes({static_cast<int>(args.size()), 0});
535+
return op;
536+
}
537+
521538
bool isConstantZero(Value v) {
522539
if (auto constantOp = v.getDefiningOp<arith::ConstantOp>()) {
523540
if (auto attr = dyn_cast<IntegerAttr>(constantOp.getValue())) {

test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -336,10 +336,10 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
336336
// CHECK-LABEL: llvm.func spir_kernelcc @test(
337337
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<3>,
338338
// CHECK-SAME: %[[VAL_1:.*]]: vector<16xf32>) -> vector<16xf32>
339-
// CHECK: %[[VAL_2:.*]] = llvm.call spir_funccc @_Z16get_sub_group_idv() {{{.*}}} : () -> i32
340-
// CHECK: %[[VAL_3:.*]] = llvm.sext %[[VAL_2]] : i32 to i64
341-
// CHECK: %[[VAL_4:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_idv() {{{.*}}} : () -> i32
342-
// CHECK: %[[VAL_5:.*]] = llvm.sext %[[VAL_4]] : i32 to i64
339+
// CHECK: %[[VAL_2:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {{{.*}}} : () -> i32
340+
// CHECK: %[[VAL_3:.*]] = llvm.zext %[[VAL_2]] : i32 to i64
341+
// CHECK: %[[VAL_4:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {{{.*}}} : () -> i32
342+
// CHECK: %[[VAL_5:.*]] = llvm.zext %[[VAL_4]] : i32 to i64
343343
// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(16 : i64) : i64
344344
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(256 : i64) : i64
345345
// CHECK: %[[VAL_8:.*]] = llvm.mul %[[VAL_7]], %[[VAL_3]] : i64

test/TritonIntelGPU/blockptr_load.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
7070
// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_12]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
7171
// CHECK: %[[VAL_14:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_13]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
7272
// CHECK: %[[BLOCK_POINTER:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_14]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
73-
// CHECK: %[[SUB_GROUP_ID_RAW:.*]] = llvm.call spir_funccc @_Z16get_sub_group_idv()
74-
// CHECK: %[[SUB_GROUP_ID_EXT:.*]] = llvm.sext %[[SUB_GROUP_ID_RAW]] : i32 to i64
73+
// CHECK: %[[SUB_GROUP_ID_RAW:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
74+
// CHECK: %[[SUB_GROUP_ID_EXT:.*]] = llvm.zext %[[SUB_GROUP_ID_RAW]] : i32 to i64
7575
// CHECK: %[[SUB_GROUP_ID:.*]] = llvm.trunc %[[SUB_GROUP_ID_EXT]] : i64 to i32
7676
// CHECK: %[[VAL_17:.*]] = llvm.mlir.constant(1 : i32) : i32
7777
// CHECK: %[[SUB_GROUP_ID_N:.*]] = llvm.urem %[[SUB_GROUP_ID]], %[[VAL_17]] : i32
@@ -142,8 +142,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
142142
// CHECK: %[[VAL_12:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_11]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
143143
// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_12]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
144144
// CHECK: %[[BLOCK_POINTER:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_13]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
145-
// CHECK: %[[SUB_GROUP_ID_RAW:.*]] = llvm.call spir_funccc @_Z16get_sub_group_idv()
146-
// CHECK: %[[SUB_GROUP_ID_EXT:.*]] = llvm.sext %[[SUB_GROUP_ID_RAW]] : i32 to i64
145+
// CHECK: %[[SUB_GROUP_ID_RAW:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
146+
// CHECK: %[[SUB_GROUP_ID_EXT:.*]] = llvm.zext %[[SUB_GROUP_ID_RAW]] : i32 to i64
147147
// CHECK: %[[SUB_GROUP_ID:.*]] = llvm.trunc %[[SUB_GROUP_ID_EXT]] : i64 to i32
148148
// CHECK: %[[VAL_16:.*]] = llvm.mlir.constant(1 : i32) : i32
149149
// CHECK: %[[VAL_17:.*]] = llvm.urem %[[SUB_GROUP_ID]], %[[VAL_16]] : i32

test/TritonIntelGPU/blockptr_store.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
2323
%12 = arith.truncf %11#0 : tensor<64x64xf32, #dpas> to tensor<64x64xf16, #dpas>
2424
%13 = tt.make_tensor_ptr %arg2, [%arg3, %arg5], [%arg6, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #dpas>>
2525
// The next two lines is used to start checking constant related to the BlockStore.
26-
// CHECK-COUNT-3: llvm.call spir_funccc @_Z16get_sub_group_idv
26+
// CHECK-COUNT-3: llvm.call spir_funccc @_Z16get_sub_group_id
2727
// CHECK-COUNT-39: llvm.extractvalue
2828
// Next constant must be equal to warpsPerCTA[0]
2929
// CHECK: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32
@@ -83,8 +83,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
8383
// CHECK: %[[VAL_80:.*]] = llvm.insertvalue %[[CST_1]], %[[VAL_79]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
8484
// CHECK: %[[BLOCK_PTR:.*]] = llvm.insertvalue %[[base]], %[[VAL_80]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
8585
// CHECK: %[[CST_2:.*]] = llvm.mlir.constant(2 : i32) : i32
86-
// CHECK: %[[SUB_GROUP_ID_RAW:.*]] = llvm.call spir_funccc @_Z16get_sub_group_idv()
87-
// CHECK: %[[SUB_GROUP_ID_EXT:.*]] = llvm.sext %[[SUB_GROUP_ID_RAW]] : i32 to i64
86+
// CHECK: %[[SUB_GROUP_ID_RAW:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
87+
// CHECK: %[[SUB_GROUP_ID_EXT:.*]] = llvm.zext %[[SUB_GROUP_ID_RAW]] : i32 to i64
8888
// CHECK: %[[SUB_GROUP_ID:.*]] = llvm.trunc %[[SUB_GROUP_ID_EXT]] : i64 to i32
8989
// CHECK: %[[CST_1:.*]] = llvm.mlir.constant(1 : i32) : i32
9090
// CHECK: %[[SUB_GROUP_ID_N:.*]] = llvm.urem %[[SUB_GROUP_ID]], %[[CST_1]] : i32

test/TritonIntelGPU/prefetch-to-llvm.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
88
%c1_i64 = arith.constant 1 : i64
99

1010
// CHECK: %[[ROW_MAJOR_BLOCK_PTR:.*]] = llvm.insertvalue %arg0, {{.*}}[6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
11-
// CHECK: %[[VAL_17:.*]] = llvm.call spir_funccc @_Z16get_sub_group_idv()
12-
// CHECK: %[[VAL_18:.*]] = llvm.sext %[[VAL_17]] : i32 to i64
11+
// CHECK: %[[VAL_17:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
12+
// CHECK: %[[VAL_18:.*]] = llvm.zext %[[VAL_17]] : i32 to i64
1313
// CHECK: %[[VAL_19:.*]] = llvm.trunc %[[VAL_18]] : i64 to i32
1414
// CHECK: %[[VAL_20:.*]] = llvm.mlir.constant(1 : i32) : i32
1515
// CHECK: %[[VAL_21:.*]] = llvm.urem %[[VAL_19]], %[[VAL_20]] : i32

third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
55
#include "mlir/Pass/Pass.h"
66
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
7+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
78

89
namespace mlir {
910
namespace triton {
@@ -187,11 +188,11 @@ class CallOpConversion : public mlir::RewritePattern {
187188
rewriter.create<LLVM::FPToSIOp>(loc, returnType, op->getResult(0));
188189
} else if (calleeName == "__triton_hip_fast_fdividef") {
189190
assert(operands.size() == 2);
190-
auto name = StringAttr::get(callOp.getContext(), "llvm.amdgcn.rcp.f32");
191-
LLVM::FastmathFlagsAttr defaultFlags{};
192-
auto rcpOp = rewriter.create<LLVM::CallIntrinsicOp>(
193-
loc, returnType, name, operands[1], defaultFlags);
191+
const char *intrinsic = "llvm.amdgcn.rcp.f32";
192+
auto rcpOp = LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic,
193+
returnType, operands[1]);
194194

195+
LLVM::FastmathFlagsAttr defaultFlags{};
195196
replacementOp = rewriter.create<LLVM::FMulOp>(
196197
loc, returnType, operands[0], rcpOp->getResult(0), defaultFlags);
197198
}

0 commit comments

Comments
 (0)