Skip to content

Commit ecee648

Browse files
[LoadStoreOpToLLVM] Minor changes (#4047)
1. In `get2DPrefetchShapePerWarp`, similar to calculating `numRows`, when calculating `numCols`, it should take tensor shape into consideration. 2. `triton::getPointeeBitWidth(ptr.getType())` is the same as `tensorTy.getElementType().getIntOrFloatBitWidth()` when `isTensorPointerType(ptr.getType())`. 3. `rewriteTensorPointerLoad` should only accept `isTensorPointerType(ptr.getType())`. Signed-off-by: Whitney Tsang <[email protected]>
1 parent f3009ec commit ecee648

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

test/TritonIntelGPU/prefetch-to-llvm.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
22

3+
// CHECK-DAG: llvm.func spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_4r16x1cPU3AS1viiiDv2_i(!llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
34
// CHECK-DAG: llvm.func spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_4r16x2cPU3AS1viiiDv2_i(!llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
45
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} {
56
tt.func public @matmul_with_prefetch(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) {
@@ -36,7 +37,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
3637
// CHECK: %[[VAL_57:.*]] = llvm.mlir.constant(0 : i32) : i32
3738
// CHECK: %[[VAL_59:.*]] = llvm.insertelement %[[COLUMN_MAJOR_WARP_OFF_X]], {{.*}}{{\[}}%[[VAL_57]] : i32] : vector<2xi32>
3839
// CHECK: %[[ROW_MAJOR_COORD:.*]] = llvm.insertelement %[[COLUMN_MAJOR_WARP_OFF_Y]], {{.*}}{{\[}}%[[VAL_56]] : i32] : vector<2xi32>
39-
// CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_4r16x2cPU3AS1viiiDv2_i(%[[ROW_MAJOR_BASE]], %[[ROW_MAJOR_WIDTH]], %[[ROW_MAJOR_HEIGHT]], %[[ROW_MAJOR_STRIDE]], %[[ROW_MAJOR_COORD]]) {{.*}} : (!llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>) -> ()
40+
// CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_4r16x1cPU3AS1viiiDv2_i(%[[ROW_MAJOR_BASE]], %[[ROW_MAJOR_WIDTH]], %[[ROW_MAJOR_HEIGHT]], %[[ROW_MAJOR_STRIDE]], %[[ROW_MAJOR_COORD]]) {{.*}} : (!llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>) -> ()
4041
%rowMajorPtr = tt.make_tensor_ptr %arg0, [%arg2, %arg4], [%arg5, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x16xf16>>
4142
triton_intel_gpu.prefetch %rowMajorPtr {cache = 1 : i32, evict = 1 : i32, isVolatile = false, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x16xf16>>
4243

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ SmallVector<unsigned, 2> get2DPrefetchShapePerWarp(RankedTensorType tensorTy) {
141141
unsigned elemSizeInBytes = elemSizeInBits / 8;
142142
unsigned maxBytesPerCol = 64;
143143
unsigned numRows = std::min<unsigned>(tensorShape[0], 32);
144-
unsigned numCols = maxBytesPerCol / elemSizeInBytes;
144+
unsigned numCols =
145+
std::min<unsigned>(tensorShape[1], maxBytesPerCol / elemSizeInBytes);
145146
return {numRows, numCols};
146147
}
147148

@@ -173,15 +174,11 @@ struct LoadStoreConversionBase {
173174
}
174175

175176
unsigned getVectorSize(Value ptr) const {
176-
auto tensorTy = getRankedTensorType(ptr.getType());
177-
if (!tensorTy)
177+
if (!isTensorOrTensorPointerType(ptr.getType()))
178178
return 1;
179179

180180
unsigned contiguity = getContiguity(ptr);
181-
unsigned pointeeBitWidth =
182-
isTensorPointerType(ptr.getType())
183-
? tensorTy.getElementType().getIntOrFloatBitWidth()
184-
: triton::getPointeeBitWidth(tensorTy);
181+
unsigned pointeeBitWidth = triton::getPointeeBitWidth(ptr.getType());
185182
// The maximum vector size is 128 bits.
186183
return std::min<unsigned>(128 / pointeeBitWidth, contiguity);
187184
}
@@ -1005,9 +1002,12 @@ struct LoadOpConversion
10051002
LogicalResult
10061003
rewriteTensorPointerLoad(triton::LoadOp op, OpAdaptor adaptor,
10071004
ConversionPatternRewriter &rewriter) const {
1005+
Value ptr = op.getPtr();
1006+
assert(isTensorPointerType(ptr.getType()) &&
1007+
"Expecting tensor of pointer type");
1008+
10081009
Location loc = op.getLoc();
10091010
auto b = TritonLLVMOpBuilder(loc, rewriter);
1010-
Value ptr = op.getPtr();
10111011
Value mask = op.getMask();
10121012
Value other = op.getOther();
10131013
Type resultType = op.getType();

0 commit comments

Comments
 (0)