Skip to content

Commit c104666

Browse files
authored
Fix issue in prefetching column major matrix. (#4611)
The prefetching lowering uses the incorrect shape sizes to get the tiling shape for column major matrix. --------- Signed-off-by: Lu,Chengjun <[email protected]>
1 parent eb7d154 commit c104666

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

test/TritonIntelGPU/prefetch-to-llvm.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
7373
// CHECK: %[[SUB_GROUP_ID_RAW:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32
7474
// CHECK: %[[SUB_GROUP_ID_EXT:.*]] = llvm.zext %[[SUB_GROUP_ID_RAW]] : i32 to i64
7575
// CHECK: %[[SUB_GROUP_ID:.*]] = llvm.trunc %[[SUB_GROUP_ID_EXT]] : i64 to i32
76-
// CHECK: %[[VAL_18:.*]] = llvm.mlir.constant(2 : i32) : i32
76+
// CHECK: %[[VAL_18:.*]] = llvm.mlir.constant(1 : i32) : i32
7777
// CHECK: %[[VAL_19:.*]] = llvm.urem %[[SUB_GROUP_ID]], %[[VAL_18]] : i32
7878
// CHECK: %[[VAL_20:.*]] = llvm.udiv %[[SUB_GROUP_ID]], %[[VAL_18]] : i32
79-
// CHECK: %[[CST_8:.*]] = llvm.mlir.constant(4 : i32) : i32
79+
// CHECK: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32
8080
// CHECK: %[[VAL_22:.*]] = llvm.urem %[[VAL_20]], %[[CST_8]] : i32
8181
// CHECK: %[[VAL_23:.*]] = llvm.udiv %[[VAL_20]], %[[CST_8]] : i32
8282
// CHECK: %[[OFFSET_0:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
@@ -94,20 +94,20 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
9494
// CHECK: %[[VAL_24:.*]] = llvm.mul %[[COL_STRIDE_i64]], %[[CST_2]] : i64
9595
// CHECK: %[[COL_MAJOR_PITCH:.*]] = llvm.trunc %[[VAL_24]] : i64 to i32
9696
// CHECK: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
97-
// CHECK: %[[CST_32:.*]] = llvm.mlir.constant(16 : i32) : i32
97+
// CHECK: %[[CST_32:.*]] = llvm.mlir.constant(32 : i32) : i32
9898
// CHECK: %[[VAL_26:.*]] = llvm.mul %[[VAL_19]], %[[CST_32]] : i32
9999
// CHECK: %[[VAL_27:.*]] = llvm.add %[[VAL_26]], %[[CST_0]] : i32
100100
// CHECK: %[[CST_32:.*]] = llvm.mlir.constant(32 : i32) : i32
101101
// CHECK: %[[VAL_28:.*]] = llvm.urem %[[VAL_27]], %[[CST_32]] : i32
102102
// CHECK: %[[COL_MAJOR_OFFSET_X:.*]] = llvm.add %[[VAL_28]], %[[OFFSET_1]] : i32
103103
// CHECK: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
104-
// CHECK: %[[CST_2:.*]] = llvm.mlir.constant(4 : i32) : i32
104+
// CHECK: %[[CST_2:.*]] = llvm.mlir.constant(2 : i32) : i32
105105
// CHECK: %[[VAL_30:.*]] = llvm.mul %[[VAL_22]], %[[CST_2]] : i32
106106
// CHECK: %[[VAL_31:.*]] = llvm.add %[[VAL_30]], %[[CST_0]] : i32
107107
// CHECK: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32
108108
// CHECK: %[[VAL_32:.*]] = llvm.urem %[[VAL_31]], %[[CST_16]] : i32
109109
// CHECK: %[[COL_MAJOR_OFFSET_Y:.*]] = llvm.add %[[VAL_32]], %[[OFFSET_0]] : i32
110-
// CHECK: triton_gen.2Dblockprefetch %[[BASE_]], %[[COL_MAJOR_BASE_WIDTH]], %[[COL_MAJOR_BASE_HEIGHT]], %[[COL_MAJOR_PITCH]], %[[COL_MAJOR_OFFSET_X]], %[[COL_MAJOR_OFFSET_Y]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 4, v_blocks = 1, cache_control = L1C_L3C}
110+
// CHECK: triton_gen.2Dblockprefetch %[[BASE_]], %[[COL_MAJOR_BASE_WIDTH]], %[[COL_MAJOR_BASE_HEIGHT]], %[[COL_MAJOR_PITCH]], %[[COL_MAJOR_OFFSET_X]], %[[COL_MAJOR_OFFSET_Y]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 2, v_blocks = 2, cache_control = L1C_L3C}
111111
%columnMajorPtr = tt.make_tensor_ptr %arg0, [%arg4, %arg2], [%c1_i64, %arg5], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<32x16xf16>>
112112
ttig.prefetch %columnMajorPtr {cache = 1 : i32, evict = 1 : i32, isVolatile = false, ttig.block_io = "column_major"} : !tt.ptr<tensor<32x16xf16>>
113113

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,10 @@ struct PrefetchOpConversion
559559
// Swap the shape to make it row major and then get the tiling
560560
// size base on row major shape.
561561
std::swap(tensorShape[0], tensorShape[1]);
562+
563+
// Create the new tensor type with swapped row and col.
564+
tensorType = RankedTensorType::get(
565+
tensorShape, tensorType.getElementType(), tensorType.getEncoding());
562566
}
563567

564568
unsigned numWarps = triton::gpu::lookupNumWarps(op);

0 commit comments

Comments
 (0)