Skip to content

Commit ab774ad

Browse files
[LoadStoreOpToLLVM] Remove unnecessary trunc (#4643)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 32c4e43 commit ab774ad

File tree

4 files changed

+15
-25
lines changed

4 files changed

+15
-25
lines changed

test/TritonIntelGPU/blockptr_load.mlir

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,9 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
151151
// CHECK: %[[VAL_41:.*]] = llvm.mlir.constant(0 : i32) : i32
152152
// CHECK: %[[offsetX_:.*]] = llvm.add %[[VAL_41]], %[[OFFSET_1]] : i32
153153
// CHECK: %[[offsetY_:.*]] = llvm.add %[[VAL_40]], %[[OFFSET_0]] : i32
154-
// CHECK: %[[VAL_44:.*]] = llvm.trunc %[[offsetY_]] : i32 to i32
155-
// CHECK: %[[VAL_45:.*]] = llvm.trunc %[[offsetX_]] : i32 to i32
156154
// CHECK: %[[ROW_STRIDE_IN_BYTES:.*]] = llvm.mul %[[ROW_STRIDE_i32]], %[[ELEM_SIZE_IN_BYTES]] : i32
157155
// CHECK: %[[HEIGHT:.*]] = llvm.mul %[[HEIGHT_i32]], %[[ELEM_SIZE_IN_BYTES]] : i32
158-
// CHECK: triton_gen.2Dblockload %[[BASE]], %[[HEIGHT]], %[[WIDTH_i32]], %[[ROW_STRIDE_IN_BYTES]], %[[VAL_45]], %[[VAL_44]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 32, v_blocks = 2, transpose = false, vnni_transform = false, cache_control = Default}
156+
// CHECK: triton_gen.2Dblockload %[[BASE]], %[[HEIGHT]], %[[WIDTH_i32]], %[[ROW_STRIDE_IN_BYTES]], %[[offsetX_]], %[[offsetY_]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 32, v_blocks = 2, transpose = false, vnni_transform = false, cache_control = Default}
159157
%ptrA = tt.make_tensor_ptr %arg0, [%arg2, %arg4], [%arg5, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x32xf16, #dot0>>
160158
%A = tt.load %ptrA {boundaryCheck = array<i32: 1>, padding = 1 : i32, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x32xf16, #dot0>>
161159
%B = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #dot1>
@@ -216,11 +214,9 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
216214
// CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(0 : i32) : i32
217215
// CHECK: %[[offsetX_:.*]] = llvm.add %[[VAL_39]], %[[OFFSET_1]] : i32
218216
// CHECK: %[[offsetY_:.*]] = llvm.add %[[VAL_40]], %[[OFFSET_0]] : i32
219-
// CHECK: %[[VAL_43:.*]] = llvm.trunc %[[offsetY_]] : i32 to i32
220-
// CHECK: %[[VAL_44:.*]] = llvm.trunc %[[offsetX_]] : i32 to i32
221217
// CHECK: %[[ROW_STRIDE_IN_BYTES:.*]] = llvm.mul %[[ROW_STRIDE_i32]], %[[ELEM_SIZE_IN_BYTES]] : i32
222218
// CHECK: %[[HEIGHT:.*]] = llvm.mul %[[HEIGHT_i32]], %[[ELEM_SIZE_IN_BYTES]] : i32
223-
// CHECK: triton_gen.2Dblockload %[[BASE]], %[[HEIGHT]], %[[WIDTH_i32]], %[[ROW_STRIDE_IN_BYTES]], %[[VAL_44]], %[[VAL_43]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 32, v_blocks = 2, transpose = false, vnni_transform = true, cache_control = Default}
219+
// CHECK: triton_gen.2Dblockload %[[BASE]], %[[HEIGHT]], %[[WIDTH_i32]], %[[ROW_STRIDE_IN_BYTES]], %[[offsetX_]], %[[offsetY_]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 32, v_blocks = 2, transpose = false, vnni_transform = true, cache_control = Default}
224220
%ptrB = tt.make_tensor_ptr %arg1, [%arg4, %arg3], [%arg7, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x32xf16, #dot1>>
225221
%B = tt.load %ptrB {boundaryCheck = array<i32: 0>, padding = 1 : i32, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x32xf16, #dot1>>
226222
%A = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #dot0>

test/TritonIntelGPU/blockptr_store.mlir

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
202202
// CHECK: %[[VAL_194:.*]] = llvm.insertelement %[[VAL_103]], %[[VAL_192]]{{\[}}{{.*}} : i32] : vector<8xf16>
203203
// CHECK: %[[VAL_196:.*]] = llvm.insertelement %[[VAL_104]], %[[VAL_194]]{{\[}}{{.*}} : i32] : vector<8xf16>
204204
// CHECK: %[[VAL_197:.*]] = llvm.bitcast %[[VAL_196]] : vector<8xf16> to vector<8xi16>
205-
// CHECK: %[[VAL_198:.*]] = llvm.trunc %[[offsetY]] : i32 to i32
206-
// CHECK: %[[VAL_199:.*]] = llvm.trunc %[[offsetX]] : i32 to i32
207-
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[baseWidth]], %[[baseHeight]], %[[basePitch]], %[[VAL_199]], %[[VAL_198]], %[[VAL_197]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
205+
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[baseWidth]], %[[baseHeight]], %[[basePitch]], %[[offsetX]], %[[offsetY]], %[[VAL_197]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
208206

209207
// COM: replica [0, 1]
210208
// CHECK: %[[VAL_207:.*]] = llvm.mlir.constant(16 : i32) : i32

test/TritonIntelGPU/prefetch-to-llvm.mlir

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,14 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
5050
// CHECK: %[[VAL_27:.*]] = llvm.add %[[VAL_26]], %[[CST_0]] : i32
5151
// CHECK: %[[CST_32:.*]] = llvm.mlir.constant(32 : i32) : i32
5252
// CHECK: %[[VAL_28:.*]] = llvm.urem %[[VAL_27]], %[[CST_32]] : i32
53-
// CHECK: %[[VAL_29:.*]] = llvm.add %[[VAL_28]], %[[OFFSET_1]] : i32
53+
// CHECK: %[[ROW_MAJOR_OFFSET_X:.*]] = llvm.add %[[VAL_28]], %[[OFFSET_1]] : i32
5454
// CHECK: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
5555
// CHECK: %[[CST_2:.*]] = llvm.mlir.constant(2 : i32) : i32
5656
// CHECK: %[[VAL_30:.*]] = llvm.mul %[[VAL_22]], %[[CST_2]] : i32
5757
// CHECK: %[[VAL_31:.*]] = llvm.add %[[VAL_30]], %[[CST_0]] : i32
5858
// CHECK: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32
5959
// CHECK: %[[VAL_32:.*]] = llvm.urem %[[VAL_31]], %[[CST_16]] : i32
60-
// CHECK: %[[VAL_33:.*]] = llvm.add %[[VAL_32]], %[[OFFSET_0]] : i32
61-
// CHECK: %[[ROW_MAJOR_OFFSET_Y:.*]] = llvm.trunc %[[VAL_33]] : i32 to i32
62-
// CHECK: %[[ROW_MAJOR_OFFSET_X:.*]] = llvm.trunc %[[VAL_29]] : i32 to i32
60+
// CHECK: %[[ROW_MAJOR_OFFSET_Y:.*]] = llvm.add %[[VAL_32]], %[[OFFSET_0]] : i32
6361
// CHECK: triton_gen.2Dblockprefetch %[[BASE_]], %[[ROW_MAJOR_BASE_WIDTH]], %[[ROW_MAJOR_BASE_HEIGHT]], %[[ROW_MAJOR_PITCH]], %[[ROW_MAJOR_OFFSET_X]], %[[ROW_MAJOR_OFFSET_Y]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 2, v_blocks = 2, cache_control = L1C_L3C}
6462
%rowMajorPtr = tt.make_tensor_ptr %arg0, [%arg2, %arg4], [%arg5, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x32xf16>>
6563
ttig.prefetch %rowMajorPtr {cache = 1 : i32, evict = 1 : i32, isVolatile = false, ttig.block_io = "row_major"} : !tt.ptr<tensor<16x32xf16>>
@@ -101,16 +99,14 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
10199
// CHECK: %[[VAL_27:.*]] = llvm.add %[[VAL_26]], %[[CST_0]] : i32
102100
// CHECK: %[[CST_32:.*]] = llvm.mlir.constant(32 : i32) : i32
103101
// CHECK: %[[VAL_28:.*]] = llvm.urem %[[VAL_27]], %[[CST_32]] : i32
104-
// CHECK: %[[VAL_29:.*]] = llvm.add %[[VAL_28]], %[[OFFSET_1]] : i32
102+
// CHECK: %[[COL_MAJOR_OFFSET_X:.*]] = llvm.add %[[VAL_28]], %[[OFFSET_1]] : i32
105103
// CHECK: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
106104
// CHECK: %[[CST_2:.*]] = llvm.mlir.constant(4 : i32) : i32
107105
// CHECK: %[[VAL_30:.*]] = llvm.mul %[[VAL_22]], %[[CST_2]] : i32
108106
// CHECK: %[[VAL_31:.*]] = llvm.add %[[VAL_30]], %[[CST_0]] : i32
109107
// CHECK: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32
110108
// CHECK: %[[VAL_32:.*]] = llvm.urem %[[VAL_31]], %[[CST_16]] : i32
111-
// CHECK: %[[VAL_33:.*]] = llvm.add %[[VAL_32]], %[[OFFSET_0]] : i32
112-
// CHECK: %[[COL_MAJOR_OFFSET_Y:.*]] = llvm.trunc %[[VAL_33]] : i32 to i32
113-
// CHECK: %[[COL_MAJOR_OFFSET_X:.*]] = llvm.trunc %[[VAL_29]] : i32 to i32
109+
// CHECK: %[[COL_MAJOR_OFFSET_Y:.*]] = llvm.add %[[VAL_32]], %[[OFFSET_0]] : i32
114110
// CHECK: triton_gen.2Dblockprefetch %[[BASE_]], %[[COL_MAJOR_BASE_WIDTH]], %[[COL_MAJOR_BASE_HEIGHT]], %[[COL_MAJOR_PITCH]], %[[COL_MAJOR_OFFSET_X]], %[[COL_MAJOR_OFFSET_Y]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 4, v_blocks = 1, cache_control = L1C_L3C}
115111
%columnMajorPtr = tt.make_tensor_ptr %arg0, [%arg4, %arg2], [%c1_i64, %arg5], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<32x16xf16>>
116112
ttig.prefetch %columnMajorPtr {cache = 1 : i32, evict = 1 : i32, isVolatile = false, ttig.block_io = "column_major"} : !tt.ptr<tensor<32x16xf16>>

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -533,8 +533,8 @@ struct PrefetchOpConversion
533533
/*base_width*/ baseWidth,
534534
/*base_height*/ baseHeight,
535535
/*base_pitch*/ rowStrideInBytes,
536-
/*x*/ b.trunc(i32_ty, offsetX),
537-
/*y*/ b.trunc(i32_ty, offsetY),
536+
/*x*/ offsetX,
537+
/*y*/ offsetY,
538538
/*elem_size_in_bits*/ elemSizeInBits,
539539
/*tile_width*/ tileWidthInElem,
540540
/*tile_height*/ tileHeightInElem,
@@ -927,8 +927,8 @@ struct LoadOpToBlockIOConversion
927927
/*base_width*/ b.mul(baseWidth, elemSizeInBytes),
928928
/*base_height*/ baseHeight,
929929
/*base_pitch*/ b.mul(pitch, elemSizeInBytes),
930-
/*x*/ b.trunc(i32_ty, offsetX),
931-
/*y*/ b.trunc(i32_ty, offsetY),
930+
/*x*/ offsetX,
931+
/*y*/ offsetY,
932932
/*elem_size_in_bits*/ elemSizeInBits,
933933
/*tile_width*/ tileWidth,
934934
/*tile_height*/ tileHeight,
@@ -1448,8 +1448,8 @@ struct LoadOpToBlockIOConversion
14481448
/*base_width*/ b.mul(baseWidth, elemSizeInBytes),
14491449
/*base_height*/ baseHeight,
14501450
/*base_pitch*/ b.mul(pitch, elemSizeInBytes),
1451-
/*x*/ b.trunc(i32_ty, offsetX),
1452-
/*y*/ b.trunc(i32_ty, offsetY),
1451+
/*x*/ offsetX,
1452+
/*y*/ offsetY,
14531453
/*elem_size_in_bits*/ elemSizeInBits,
14541454
/*tile_width*/ tileWidth,
14551455
/*tile_height*/ tileHeight,
@@ -2559,8 +2559,8 @@ struct StoreOpToBlockIOConversion
25592559
/*base_width*/ baseWidth,
25602560
/*base_height*/ height,
25612561
/*base_pitch*/ basePitch,
2562-
/*x*/ b.trunc(i32_ty, offsetX),
2563-
/*y*/ b.trunc(i32_ty, offsetY),
2562+
/*x*/ offsetX,
2563+
/*y*/ offsetY,
25642564
/*elem_size_in_bits*/ elemSizeInBits,
25652565
/*tile_width*/ elemsPerInstr[1],
25662566
/*tile_height*/ elemsPerInstr[0],

0 commit comments

Comments
 (0)