Skip to content

Commit 877cf7a

Browse files
authored
Update getLaneAndWarpId and getThreadId (#4715)
Update `getLaneAndWarpId` and `getThreadId` and make `getLaneId` use `getLaneAndWarpId` --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent b95b34f commit 877cf7a

File tree

10 files changed

+119
-78
lines changed

10 files changed

+119
-78
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -298,42 +298,53 @@ Value getThreadId(OpBuilder &rewriter, Location loc) {
298298
rewriter.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x);
299299
tid = rewriter.create<arith::IndexCastOp>(loc, i32_ty, tid);
300300

301+
Operation *lookupPt = &rewriter.getInsertionBlock()->front();
302+
int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(rewriter);
303+
int numWarps = triton::gpu::lookupNumWarps(lookupPt);
304+
int upperBound = numWarps * threadsPerWarp;
305+
306+
TritonLLVMOpBuilder b(loc, rewriter);
307+
301308
// If this is being created inside a warp specialize op, compute the relative
302309
// thread ID within the warp group.
303310
if (std::optional<int> startId =
304311
getWarpGroupStartThreadId(rewriter.getInsertionBlock())) {
305-
TritonLLVMOpBuilder b(loc, rewriter);
306312
tid = rewriter.create<arith::SubIOp>(loc, tid, b.i32_val(*startId));
307313
}
308314

309-
return tid;
310-
}
315+
if (llvm::isPowerOf2_32(upperBound)) {
316+
// help LLVM's known bits analysis:
317+
tid = b.and_(tid, b.i32_val(upperBound - 1));
318+
}
311319

312-
Value getLaneId(OpBuilder &rewriter, Location loc) {
313-
TritonLLVMOpBuilder b(loc, rewriter);
314-
Value tid = getThreadId(rewriter, loc);
315-
int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(rewriter);
316-
return b.urem(tid, b.i32_val(threadsPerWarp));
320+
return tid;
317321
}
318322

319323
std::pair<Value, Value> getLaneAndWarpId(OpBuilder &rewriter, Location loc) {
320324
TritonLLVMOpBuilder b(loc, rewriter);
321325
Value tid = getThreadId(rewriter, loc);
322326
int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(rewriter);
323327
Value warpSizeVal = b.i32_val(threadsPerWarp);
324-
Value laneId = b.urem(tid, warpSizeVal);
325328

326329
// If there is only one warp, the warp ID is always 0.
327330
Operation *lookupPt = &rewriter.getInsertionBlock()->front();
331+
Value laneId;
328332
Value warpId;
329-
if (triton::gpu::lookupNumWarps(lookupPt) == 1)
333+
if (triton::gpu::lookupNumWarps(lookupPt) == 1) {
334+
laneId = tid;
330335
warpId = b.i32_val(0);
331-
else
336+
} else {
337+
laneId = b.urem(tid, warpSizeVal);
332338
warpId = b.udiv(tid, warpSizeVal);
339+
}
333340

334341
return {laneId, warpId};
335342
}
336343

344+
Value getLaneId(OpBuilder &rewriter, Location loc) {
345+
return getLaneAndWarpId(rewriter, loc).first;
346+
}
347+
337348
// Helper function: applies linear layout vectorized over register indices
338349
SmallVector<SmallVector<std::pair<StringAttr, Value>>>
339350
applyLinearLayoutVec(Location loc, RewriterBase &rewriter,

test/Conversion/amd/buffer_load_to_local_to_llvm.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
147147
// The first constant 0 skips the LDS offset which is also 0
148148
// COMMON: llvm.getelementptr
149149
// COMMON: llvm.mlir.constant(0 : i32) : i32
150+
// COMMON: llvm.mlir.constant(0 : i32) : i32
150151
// COMMON: %[[aux_ca:.*]] = llvm.mlir.constant(0 : i32) : i32
151152
// COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]]
152153
%1 = amdgpu.buffer_load_to_local %arg0[%0] cacheModifier = ca into %arg2: <f32>[tensor<64xi32, #blocked>] -> <64xf32, #shared, #smem, mutable>

test/Conversion/amd/mfma-shortcut.mlir

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
4747
// GFX942-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32)
4848

4949
// GFX942: [[threadId:%.*]] = rocdl.workitem.id.x
50-
// GFX942: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]]
50+
// GFX942: [[c255:%.*]] = llvm.mlir.constant(255 : i32)
51+
// GFX942: [[RTID:%.*]] = llvm.and [[threadId]], [[c255]]
52+
// GFX942: [[laneId:%.*]] = llvm.urem [[RTID]], [[c64]]
5153
// GFX942: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]]
5254

5355
// GFX942: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]]
@@ -128,7 +130,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
128130
// GFX942-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32)
129131

130132
// GFX942: [[threadId:%.*]] = rocdl.workitem.id.x
131-
// GFX942: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]]
133+
// GFX942: [[c255:%.*]] = llvm.mlir.constant(255 : i32)
134+
// GFX942: [[RTID:%.*]] = llvm.and [[threadId]], [[c255]]
135+
// GFX942: [[laneId:%.*]] = llvm.urem [[RTID]], [[c64]]
132136
// GFX942: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]]
133137

134138
// GFX942: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]]

test/Conversion/intel/dot_layout_offset.mlir

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32}
1111
// COM: Base index of the dot layout.
1212
// CHECK: %[[THREAD_ID_I64:.*]] = llvm.call spir_funccc @_Z12get_local_idj
1313
// CHECK: %[[THREAD_ID_I32:.*]] = llvm.trunc %[[THREAD_ID_I64]] : i64 to i32
14+
// CHECK: %[[CST_63:.*]] = llvm.mlir.constant(63 : i32) : i32
15+
// CHECK: %[[RTID:.*]] = llvm.and %[[THREAD_ID_I32]], %[[CST_63]] : i32
1416
// CHECK: %[[VAL_145:.*]] = llvm.mlir.constant(16 : i32) : i32
15-
// CHECK: %[[LANE_ID:.*]] = llvm.urem %[[THREAD_ID_I32]], %[[VAL_145]] : i32
16-
// CHECK: %[[WARP_ID:.*]] = llvm.udiv %[[THREAD_ID_I32]], %[[VAL_145]] : i32
17+
// CHECK: %[[LANE_ID:.*]] = llvm.urem %[[RTID]], %[[VAL_145]] : i32
18+
// CHECK: %[[WARP_ID:.*]] = llvm.udiv %[[RTID]], %[[VAL_145]] : i32
1719
// CHECK-COUNT-4: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
1820
// CHECK: %[[VAL_149:.*]] = llvm.mlir.constant(1 : i32) : i32
1921
// CHECK: %[[VAL_150:.*]] = llvm.and %[[LANE_ID]], %[[VAL_149]] : i32
@@ -333,9 +335,11 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.thr
333335
// COM: Base index of the dot layout.
334336
// CHECK: %[[THREAD_ID_I64:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[VAL_142]])
335337
// CHECK: %[[THREAD_ID_I32:.*]] = llvm.trunc %[[THREAD_ID_I64]] : i64 to i32
336-
// CHECK: %[[VAL_145:.*]] = llvm.mlir.constant(16 : i32) : i32
337-
// CHECK: %[[LANE_ID:.*]] = llvm.urem %[[THREAD_ID_I32]], %[[VAL_145]] : i32
338-
// CHECK: %[[WARP_ID:.*]] = llvm.udiv %[[THREAD_ID_I32]], %[[VAL_145]] : i32
338+
// CHECK-DAG: %[[CST_63:.*]] = llvm.mlir.constant(63 : i32) : i32
339+
// CHECK-DAG: %[[RTID:.*]] = llvm.and %[[THREAD_ID_32:.*]], %[[CST_63]] : i32
340+
// CHECK-DAG: %[[VAL_145:.*]] = llvm.mlir.constant(16 : i32) : i32
341+
// CHECK-DAG: %[[LANE_ID:.*]] = llvm.urem %[[RTID]], %[[VAL_145]] : i32
342+
// CHECK-DAG: %[[WARP_ID:.*]] = llvm.udiv %[[RTID]], %[[VAL_145]] : i32
339343
// CHECK-COUNT-4: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
340344
// CHECK: %[[VAL_149:.*]] = llvm.mlir.constant(1 : i32) : i32
341345
// CHECK: %[[VAL_150:.*]] = llvm.and %[[LANE_ID]], %[[VAL_149]] : i32

test/Conversion/intel/dpas_to_block_layout_convert.mlir

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,26 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha
1010
tt.func public @convert_dpas(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
1111
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf16, #mma>
1212

13-
// CHECK-DAG: %[[CST_3:.*]] = llvm.mlir.constant(3 : i32) : i32
14-
// CHECK-DAG: %[[CST_16384:.*]] = llvm.mlir.constant(16384 : i32) : i32
15-
// CHECK-DAG: %[[CST_8192:.*]] = llvm.mlir.constant(8192 : i32) : i32
16-
// CHECK-DAG: %[[CST_128:.*]] = llvm.mlir.constant(128 : i32) : i32
17-
// CHECK-DAG: %[[CST_64:.*]] = llvm.mlir.constant(64 : i32) : i32
18-
// CHECK-DAG: %[[CST_32:.*]] = llvm.mlir.constant(32 : i32) : i32
19-
// CHECK-DAG: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32
20-
// CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32
21-
// CHECK-DAG: %[[CST_2:.*]] = llvm.mlir.constant(2 : i32) : i32
22-
// CHECK-DAG: %[[CST_1:.*]] = llvm.mlir.constant(1 : i32) : i32
23-
// CHECK-DAG: %[[SMEM:.*]] = llvm.mlir.addressof @global_smem : !llvm.ptr<3>
24-
// CHECK-DAG: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32
25-
// CHECK-DAG: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
13+
// CHECK-DAG: %[[CST_3:.*]] = llvm.mlir.constant(3 : i32) : i32
14+
// CHECK-DAG: %[[CST_16384:.*]] = llvm.mlir.constant(16384 : i32) : i32
15+
// CHECK-DAG: %[[CST_8192:.*]] = llvm.mlir.constant(8192 : i32) : i32
16+
// CHECK-DAG: %[[CST_128:.*]] = llvm.mlir.constant(128 : i32) : i32
17+
// CHECK-DAG: %[[CST_64:.*]] = llvm.mlir.constant(64 : i32) : i32
18+
// CHECK-DAG: %[[CST_32:.*]] = llvm.mlir.constant(32 : i32) : i32
19+
// CHECK-DAG: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32
20+
// CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32
21+
// CHECK-DAG: %[[CST_2:.*]] = llvm.mlir.constant(2 : i32) : i32
22+
// CHECK-DAG: %[[CST_1:.*]] = llvm.mlir.constant(1 : i32) : i32
23+
// CHECK-DAG: %[[SMEM:.*]] = llvm.mlir.addressof @global_smem : !llvm.ptr<3>
24+
// CHECK-DAG: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32
25+
// CHECK-DAG: %[[CST_511:.*]] = llvm.mlir.constant(511 : i32) : i32
26+
// CHECK-DAG: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
2627
// COM: The following operations is generated for the conversion of DPAS layout to blocked layout. The conversion replica size is 128*256. So there is 1 round of load/store with synchronization.
2728
// CHECK: %[[threadId_64:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[CST_0]]) {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return} : (i32) -> i64
2829
// CHECK: %[[threadId:.*]] = llvm.trunc %[[threadId_64]] : i64 to i32
29-
// CHECK: %[[laneId:.*]] = llvm.urem %[[threadId]], %[[CST_16]] : i32
30-
// CHECK: %[[warpId:.*]] = llvm.udiv %[[threadId]], %[[CST_16]] : i32
30+
// CHECK: %[[rtid:.*]] = llvm.and %[[threadId:.*]], %[[CST_511]] : i32
31+
// CHECK: %[[laneId:.*]] = llvm.urem %[[rtid]], %[[CST_16]] : i32
32+
// CHECK: %[[warpId:.*]] = llvm.udiv %[[rtid]], %[[CST_16]] : i32
3133
// CHECK: %[[VAL_25:.*]] = llvm.and %[[laneId]], %[[CST_1]] : i32
3234
// CHECK: %[[VAL_26:.*]] = llvm.icmp "eq" %[[VAL_25]], %[[CST_0]] : i32
3335
// CHECK: %[[VAL_27:.*]] = llvm.select %[[VAL_26]], %[[CST_0]], %[[CST_1]] : i1, i32
@@ -115,12 +117,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha
115117
// CHECK-DAG: %[[SMEM:.*]] = llvm.mlir.addressof @global_smem : !llvm.ptr<3>
116118
// CHECK-DAG: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32
117119
// CHECK-DAG: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
120+
// CHECK-DAG: %[[CST_511:.*]] = llvm.mlir.constant(511 : i32) : i32
118121

119122
// COM: The following operations is generated for the conversion of DPAS layout to blocked layout. The conversion replica size is 64*256. So there are 2 round of load/store with synchronization.
120123
// CHECK: %[[threadId_64:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[CST_0]]) {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return} : (i32) -> i64
121124
// CHECK: %[[threadId:.*]] = llvm.trunc %[[threadId_64]] : i64 to i32
122-
// CHECK: %[[laneId:.*]] = llvm.urem %[[threadId]], %[[CST_16]] : i32
123-
// CHECK: %[[warpId:.*]] = llvm.udiv %[[threadId]], %[[CST_16]] : i32
125+
// CHECK: %[[rtid:.*]] = llvm.and %[[threadId]], %[[CST_511]] : i32
126+
// CHECK: %[[laneId:.*]] = llvm.urem %[[rtid]], %[[CST_16]] : i32
127+
// CHECK: %[[warpId:.*]] = llvm.udiv %[[rtid]], %[[CST_16]] : i32
124128
// CHECK: %[[VAL_25:.*]] = llvm.and %[[laneId]], %[[CST_1]] : i32
125129
// CHECK: %[[VAL_26:.*]] = llvm.icmp "eq" %[[VAL_25]], %[[CST_0]] : i32
126130
// CHECK: %[[VAL_27:.*]] = llvm.select %[[VAL_26]], %[[CST_0]], %[[CST_1]] : i1, i32

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,7 +1176,8 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp
11761176
// CHECK-LABEL: atomic_add_f32_scalar_no_store
11771177
tt.func @atomic_add_f32_scalar_no_store(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
11781178
// CHECK: [[ZERO0:%.*]] = llvm.mlir.constant(0 : i32) : i32
1179-
// CHECK: [[MASKLANE:%.*]] = llvm.and
1179+
// CHECK: [[CST_NEG_ONE:%.*]] = llvm.mlir.constant(-1 : i32) : i32
1180+
// CHECK-NEXT: [[MASKLANE:%.*]] = llvm.and
11801181
// CHECK-NEXT: [[CMPLANE:%.*]] = llvm.icmp "eq" [[MASKLANE]], [[ZERO0]]
11811182
// CHECK: [[MASKWARP:%.*]] = llvm.and
11821183
// CHECK-NEXT: [[CMPWARP:%.*]] = llvm.icmp "eq" [[MASKWARP]], [[ZERO0]]
@@ -1212,7 +1213,8 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp
12121213
// CHECK-LABEL: atomic_add_f32_scalar
12131214
tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
12141215
// CHECK: [[ZERO0:%.*]] = llvm.mlir.constant(0 : i32) : i32
1215-
// CHECK: [[MASKLANE:%.*]] = llvm.and
1216+
// CHECK: [[CST_NEG_ONE:%.*]] = llvm.mlir.constant(-1 : i32) : i32
1217+
// CHECK-NEXT: [[MASKLANE:%.*]] = llvm.and
12161218
// CHECK-NEXT: [[CMPLANE:%.*]] = llvm.icmp "eq" [[MASKLANE]], [[ZERO0]]
12171219
// CHECK: [[MASKWARP:%.*]] = llvm.and
12181220
// CHECK-NEXT: [[CMPWARP:%.*]] = llvm.icmp "eq" [[MASKWARP]], [[ZERO0]]

test/Conversion/tritonnvidiagpu_to_llvm.mlir

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
2828
// CHECK-LABEL: arrive_barrier
2929
tt.func @arrive_barrier(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>) {
3030
// CHECK-NEXT: [[TID:%.*]] = nvvm.read.ptx.sreg.tid.x
31+
// CHECK-NEXT: [[C127:%.*]] = llvm.mlir.constant(127 : i32)
32+
// CHECK-NEXT: [[RTID:%.*]] = llvm.and [[TID]], [[C127]]
3133
// CHECK-NEXT: [[C0:%.*]] = llvm.mlir.constant(0 : i32)
32-
// CHECK-NEXT: [[IS_ZERO:%.*]] = llvm.icmp "eq" [[TID]], [[C0]]
34+
// CHECK-NEXT: [[IS_ZERO:%.*]] = llvm.icmp "eq" [[RTID]], [[C0]]
3335
// CHECK-NEXT: "@$0 mbarrier.arrive.shared::cta.b64 _, [$1], 2;", "b,r" [[IS_ZERO]], %arg0
3436
ttng.arrive_barrier %alloc, 2 : !ttg.memdesc<1xi64, #shared0, #smem>
3537
tt.return
@@ -38,8 +40,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
3840
// CHECK-LABEL: arrive_barrier_pred
3941
tt.func @arrive_barrier_pred(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>, %pred: i1) {
4042
// CHECK-NEXT: [[TID:%.*]] = nvvm.read.ptx.sreg.tid.x
43+
// CHECK-NEXT: [[C127:%.*]] = llvm.mlir.constant(127 : i32)
44+
// CHECK-NEXT: [[RTID:%.*]] = llvm.and [[TID]], [[C127]]
4145
// CHECK-NEXT: [[C0:%.*]] = llvm.mlir.constant(0 : i32)
42-
// CHECK-NEXT: [[IS_ZERO:%.*]] = llvm.icmp "eq" [[TID]], [[C0]]
46+
// CHECK-NEXT: [[IS_ZERO:%.*]] = llvm.icmp "eq" [[RTID]], [[C0]]
4347
// CHECK-NEXT: [[PRED:%.*]] = llvm.and [[IS_ZERO]], %arg1
4448
// CHECK-NEXT: "@$0 mbarrier.arrive.shared::cta.b64 _, [$1], 2;", "b,r" [[PRED]], %arg0
4549
ttng.arrive_barrier %alloc, 2, %pred : !ttg.memdesc<1xi64, #shared0, #smem>

test/TritonIntelGPU/blockptr_store.mlir

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,12 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
244244
%0 = tt.make_tensor_ptr %arg0, [%c64_i64, %c64_i64], [%c1_i64, %col_stride], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x16xf16, #blocked>>
245245
// CHECK: llvm.call spir_funccc @_Z12get_local_idj
246246
// CHECK-NOT: llvm.icmp "slt"
247-
// CHECK: %[[threadID:.*]] = llvm.call spir_funccc @_Z12get_local_idj
248-
// CHECK: %[[VAL_583:.*]] = llvm.trunc %[[threadID]] : i64 to i32
249-
// CHECK: %[[VAL_584:.*]] = llvm.mlir.constant(16 : i32) : i32
250-
// CHECK: %[[VAL_586:.*]] = llvm.udiv %[[VAL_583]], %[[VAL_584]] : i32
247+
// CHECK: %[[THREAD_ID:.*]] = llvm.call spir_funccc @_Z12get_local_idj
248+
// CHECK: %[[THREAD_ID_32:.*]] = llvm.trunc %[[THREAD_ID]] : i64 to i32
249+
// CHECK-DAG: %[[CST_127:.*]] = llvm.mlir.constant(127 : i32) : i32
250+
// CHECK-DAG: %[[RTID:.*]] = llvm.and %[[THREAD_ID_32:.*]], %[[CST_127]] : i32
251+
// CHECK-DAG: %[[VAL_584:.*]] = llvm.mlir.constant(16 : i32) : i32
252+
// CHECK: %[[VAL_586:.*]] = llvm.udiv %[[RTID]], %[[VAL_584]] : i32
251253
// CHECK: %[[VAL_587:.*]] = llvm.mlir.constant(3 : i32) : i32
252254
// CHECK: %[[VAL_588:.*]] = llvm.and %[[VAL_586]], %[[VAL_587]] : i32
253255
// CHECK: %[[threadPred:.*]] = llvm.icmp "eq" %[[VAL_588]], {{.*}} : i32

test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr
323323
// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32
324324
// CHECK: [[THREADID_i64:%.*]] = llvm.call spir_funccc @_Z12get_local_idj([[C0]])
325325
// CHECK: [[THREADID:%.*]] = llvm.trunc [[THREADID_i64]] : i64 to i32
326+
// CHECK: [[C127:%.*]] = llvm.mlir.constant(127 : i32) : i32
327+
// CHECK: [[RTID:%.*]] = llvm.and [[THREADID]], [[C127]] : i32
326328
// CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32
327-
// CHECK: [[REM:%.*]] = llvm.urem [[THREADID]], [[C8]] : i32
329+
// CHECK: [[REM:%.*]] = llvm.urem [[RTID]], [[C8]] : i32
330+
328331
// CHECK: [[NEWVAL:%.*]] = llvm.call spir_funccc @_Z17sub_group_shuffleij([[OLDVAL]], [[REM]])
329332
// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32
330333
// CHECK: [[VEC1:%.*]] = llvm.insertelement [[NEWVAL]], [[VEC]][[[C0]] : i32] : vector<2xi32>
@@ -334,8 +337,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr
334337
// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32
335338
// CHECK: [[THREADID_i64:%.*]] = llvm.call spir_funccc @_Z12get_local_idj([[C0]])
336339
// CHECK: [[THREADID:%.*]] = llvm.trunc [[THREADID_i64]] : i64 to i32
340+
// CHECK: [[C127:%.*]] = llvm.mlir.constant(127 : i32) : i32
341+
// CHECK: [[RTID:%.*]] = llvm.and [[THREADID]], [[C127]] : i32
337342
// CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32
338-
// CHECK: [[REM:%.*]] = llvm.urem [[THREADID]], [[C8]] : i32
343+
// CHECK: [[REM:%.*]] = llvm.urem [[RTID]], [[C8]] : i32
339344
// CHECK: [[NEWVAL:%.*]] = llvm.call spir_funccc @_Z17sub_group_shuffleij([[OLDVAL]], [[REM]])
340345
// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32
341346
// CHECK: [[VEC2:%.*]] = llvm.insertelement [[NEWVAL]], [[VEC1]][[[C1]] : i32] : vector<2xi32>

0 commit comments

Comments
 (0)