Skip to content

Commit e7cef2e

Browse files
authored
[BACKEND] hint to LLVM that we can bound threadIdx.x (#7249)
1 parent dcf41f6 commit e7cef2e

File tree

5 files changed

+36
-16
lines changed

5 files changed

+36
-16
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -250,42 +250,53 @@ Value getThreadId(OpBuilder &rewriter, Location loc) {
250250
rewriter.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x);
251251
tid = rewriter.create<arith::IndexCastOp>(loc, i32_ty, tid);
252252

253+
Operation *lookupPt = &rewriter.getInsertionBlock()->front();
254+
int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(rewriter);
255+
int numWarps = triton::gpu::lookupNumWarps(lookupPt);
256+
int upperBound = numWarps * threadsPerWarp;
257+
258+
TritonLLVMOpBuilder b(loc, rewriter);
259+
253260
// If this is being created inside a warp specialize op, compute the relative
254261
// thread ID within the warp group.
255262
if (std::optional<int> startId =
256263
getWarpGroupStartThreadId(rewriter.getInsertionBlock())) {
257-
TritonLLVMOpBuilder b(loc, rewriter);
258264
tid = rewriter.create<arith::SubIOp>(loc, tid, b.i32_val(*startId));
259265
}
260266

261-
return tid;
262-
}
267+
if (llvm::isPowerOf2_32(upperBound)) {
268+
// help LLVM's known bits analysis:
269+
tid = b.and_(tid, b.i32_val(upperBound - 1));
270+
}
263271

264-
Value getLaneId(OpBuilder &rewriter, Location loc) {
265-
TritonLLVMOpBuilder b(loc, rewriter);
266-
Value tid = getThreadId(rewriter, loc);
267-
int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(rewriter);
268-
return b.urem(tid, b.i32_val(threadsPerWarp));
272+
return tid;
269273
}
270274

271275
std::pair<Value, Value> getLaneAndWarpId(OpBuilder &rewriter, Location loc) {
272276
TritonLLVMOpBuilder b(loc, rewriter);
273277
Value tid = getThreadId(rewriter, loc);
274278
int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(rewriter);
275279
Value warpSizeVal = b.i32_val(threadsPerWarp);
276-
Value laneId = b.urem(tid, warpSizeVal);
277280

278281
// If there is only one warp, the warp ID is always 0.
279282
Operation *lookupPt = &rewriter.getInsertionBlock()->front();
283+
Value laneId;
280284
Value warpId;
281-
if (triton::gpu::lookupNumWarps(lookupPt) == 1)
285+
if (triton::gpu::lookupNumWarps(lookupPt) == 1) {
286+
laneId = tid;
282287
warpId = b.i32_val(0);
283-
else
288+
} else {
289+
laneId = b.urem(tid, warpSizeVal);
284290
warpId = b.udiv(tid, warpSizeVal);
291+
}
285292

286293
return {laneId, warpId};
287294
}
288295

296+
Value getLaneId(OpBuilder &rewriter, Location loc) {
297+
return getLaneAndWarpId(rewriter, loc).first;
298+
}
299+
289300
SmallVector<SmallVector<Value>>
290301
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
291302
Attribute layout, RankedTensorType type, bool withCTAOffset) {

test/Conversion/amd/buffer_load_to_local_to_llvm.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
147147
// The first constant 0 skips the LDS offset which is also 0
148148
// COMMON: llvm.getelementptr
149149
// COMMON: llvm.mlir.constant(0 : i32) : i32
150+
// COMMON: llvm.mlir.constant(0 : i32) : i32
150151
// COMMON: %[[aux_ca:.*]] = llvm.mlir.constant(0 : i32) : i32
151152
// COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]]
152153
%1 = amdgpu.buffer_load_to_local %arg0[%0] cacheModifier = ca into %arg2: <f32>[tensor<64xi32, #blocked>] -> <64xf32, #shared, #smem, mutable>

test/Conversion/amd/mfma-shortcut.mlir

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
4747
// GFX942-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32)
4848

4949
// GFX942: [[threadId:%.*]] = rocdl.workitem.id.x
50-
// GFX942: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]]
50+
// GFX942: [[c255:%.*]] = llvm.mlir.constant(255 : i32)
51+
// GFX942: [[RTID:%.*]] = llvm.and [[threadId]], [[c255]]
52+
// GFX942: [[laneId:%.*]] = llvm.urem [[RTID]], [[c64]]
5153
// GFX942: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]]
5254

5355
// GFX942: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]]
@@ -128,7 +130,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
128130
// GFX942-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32)
129131

130132
// GFX942: [[threadId:%.*]] = rocdl.workitem.id.x
131-
// GFX942: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]]
133+
// GFX942: [[c255:%.*]] = llvm.mlir.constant(255 : i32)
134+
// GFX942: [[RTID:%.*]] = llvm.and [[threadId]], [[c255]]
135+
// GFX942: [[laneId:%.*]] = llvm.urem [[RTID]], [[c64]]
132136
// GFX942: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]]
133137

134138
// GFX942: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]]

test/Conversion/reduce_to_llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ tt.func private @reduce_linear_layout(%arg0: tensor<32x16xi32, #linear>) -> tens
2525
// is not needed.
2626

2727
// Reduce within threads
28-
// CHECK-NEXT: [[SUM0:%.*]] = add i32 [[SRC0]], [[SRC2]]
28+
// CHECK: [[SUM0:%.*]] = add i32 [[SRC0]], [[SRC2]]
2929
// CHECK-NEXT: [[SUM1:%.*]] = add i32 [[SRC1]], [[SRC3]]
3030

3131
// Reduce within warp.

test/Conversion/tritonnvidiagpu_to_llvm.mlir

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
2828
// CHECK-LABEL: arrive_barrier
2929
tt.func @arrive_barrier(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>) {
3030
// CHECK-NEXT: [[TID:%.*]] = nvvm.read.ptx.sreg.tid.x
31+
// CHECK-NEXT: [[C127:%.*]] = llvm.mlir.constant(127 : i32)
32+
// CHECK-NEXT: [[RTID:%.*]] = llvm.and [[TID]], [[C127]]
3133
// CHECK-NEXT: [[C0:%.*]] = llvm.mlir.constant(0 : i32)
32-
// CHECK-NEXT: [[IS_ZERO:%.*]] = llvm.icmp "eq" [[TID]], [[C0]]
34+
// CHECK-NEXT: [[IS_ZERO:%.*]] = llvm.icmp "eq" [[RTID]], [[C0]]
3335
// CHECK-NEXT: "@$0 mbarrier.arrive.shared::cta.b64 _, [$1], 2;", "b,r" [[IS_ZERO]], %arg0
3436
ttng.arrive_barrier %alloc, 2 : !ttg.memdesc<1xi64, #shared0, #smem>
3537
tt.return
@@ -38,8 +40,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
3840
// CHECK-LABEL: arrive_barrier_pred
3941
tt.func @arrive_barrier_pred(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>, %pred: i1) {
4042
// CHECK-NEXT: [[TID:%.*]] = nvvm.read.ptx.sreg.tid.x
43+
// CHECK-NEXT: [[C127:%.*]] = llvm.mlir.constant(127 : i32)
44+
// CHECK-NEXT: [[RTID:%.*]] = llvm.and [[TID]], [[C127]]
4145
// CHECK-NEXT: [[C0:%.*]] = llvm.mlir.constant(0 : i32)
42-
// CHECK-NEXT: [[IS_ZERO:%.*]] = llvm.icmp "eq" [[TID]], [[C0]]
46+
// CHECK-NEXT: [[IS_ZERO:%.*]] = llvm.icmp "eq" [[RTID]], [[C0]]
4347
// CHECK-NEXT: [[PRED:%.*]] = llvm.and [[IS_ZERO]], %arg1
4448
// CHECK-NEXT: "@$0 mbarrier.arrive.shared::cta.b64 _, [$1], 2;", "b,r" [[PRED]], %arg0
4549
ttng.arrive_barrier %alloc, 2, %pred : !ttg.memdesc<1xi64, #shared0, #smem>

0 commit comments

Comments
 (0)