Skip to content

Commit 430f8b2

Browse files
authored
[BACKEND] Fix unnecessary cvt caused by wgmma wait op (#8579)
Fixes #8578 We're using the wrong output constraint which leads llvm to extend the fp16 value to 32-bits. Fixing the constraint removes the conversion. Note that we still end up with a no-op sequence like: ```ptx mov.b32 {%rs1, %rs2}, %r1 mov.b32 %r2, {%rs1, %rs2} ``` However, `ptxas` is able to optimize these out.
1 parent a295e60 commit 430f8b2

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

test/Conversion/nvgpu_to_llvm.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,19 @@ llvm.func @wgmma(%desc: i64, %in: !struct_64xf32) {
5858

5959
// -----
6060

61+
!struct = !llvm.struct<(f32, f32, i32, i32, f16, f16)>
62+
63+
// CHECK-LABEL: @wgmma_wait
64+
llvm.func @wgmma_wait(%in: !struct) {
65+
// CHECK: // wait for regs: $0,$1,$2,$3,$4,$5
66+
// CHECK: wgmma.wait_group.sync.aligned 0;
67+
// CHECK: "=f,=f,=r,=r,=h,=h,0,1,2,3,4,5"
68+
%out = nvgpu.wgmma_wait_group %in {pendings = 0 : i32} : !struct
69+
llvm.return
70+
}
71+
72+
// -----
73+
6174
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} {
6275
// CHECK-LABEL: @tensor_memory_base_lowering
6376
// CHECK: %[[TID:.+]] = nvvm.read.ptx.sreg.tid.x : i32

third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,29 @@ class WGMMAWaitGroupOpPattern : public OpRewritePattern<ttn::WGMMAWaitGroupOp> {
303303
Constraints getOutputConstraints(ttn::WGMMAWaitGroupOp op) const {
304304
auto outputStructType = cast<LLVM::LLVMStructType>(op.getType());
305305
uint32_t numOutputRegs = outputStructType.getBody().size();
306-
std::string output =
307-
outputStructType.getBody().front().isF32() ? "=f" : "=r";
308-
return Constraints(numOutputRegs, output);
306+
Constraints constraints;
307+
constraints.reserve(numOutputRegs);
308+
mlir::DataLayout dl(op->getParentOfType<mlir::ModuleOp>());
309+
for (auto ty : outputStructType.getBody()) {
310+
auto bitwidth = dl.getTypeSizeInBits(ty);
311+
std::string c;
312+
switch (bitwidth) {
313+
case 64:
314+
c = "=l";
315+
break;
316+
case 32:
317+
c = ty.isF32() ? "=f" : "=r";
318+
break;
319+
case 16:
320+
c = "=h";
321+
break;
322+
default:
323+
llvm::report_fatal_error("Unexpected bitwidth in WGMMAWaitGroupOp: " +
324+
Twine(bitwidth));
325+
}
326+
constraints.push_back(c);
327+
}
328+
return constraints;
309329
}
310330

311331
OperandsAndConstraints

0 commit comments

Comments
 (0)