Skip to content

Commit ce091da

Browse files
authored
[AMDGPU] Mark WMMA machine instructions as convergent (#165602)
The WMMA MI(s) are missing the isConvergent flag. This causes incorrect behavior in passes like machine-sink, where WMMA instructions get sunk into divergent branches. This patch fixes the issue by setting the isConvergent flag to 1 in the VOP3PInstructions.td file.
1 parent 1c85981 commit ce091da

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

llvm/lib/Target/AMDGPU/VOP3PInstructions.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,7 +1707,7 @@ multiclass WMMAInstGFX12<string Instr, VOP3PWMMA_Profile WMMAProfile, string Pse
17071707
defvar WMMAConstraints2Addr = !if(DiffVdstSrc2, "@earlyclobber $vdst", "@earlyclobber $vdst,$vdst = $src2");
17081708
defvar WMMAConstraints3Addr = "@earlyclobber $vdst";
17091709

1710-
let Mnemonic = Instr, mayRaiseFPException = 0, ReadsModeReg = 0 in {
1710+
let Mnemonic = Instr, mayRaiseFPException = 0, ReadsModeReg = 0, isConvergent = 1 in {
17111711
let Constraints = WMMAConstraints2Addr, isConvertibleToThreeAddress = 1 in
17121712
def _twoaddr : VOP3P_Pseudo<Instr, WMMAProfile>, WMMAInstInfo {
17131713
let PseudoInstr = Instr#PseudoInstrSuffix;
@@ -1734,7 +1734,7 @@ multiclass SWMMACInstGFX12<string Instr, VOP3PWMMA_Profile WMMAProfile, string P
17341734
let mayRaiseFPException = 0;
17351735
let ReadsModeReg = 0;
17361736
let AsmMatchConverter = "cvtSWMMAC";
1737-
1737+
let isConvergent = 1;
17381738
let Constraints = "@earlyclobber $vdst,$vdst = $srcTiedDef";
17391739
}
17401740
}
@@ -1906,8 +1906,10 @@ defm V_WMMA_SCALE_F32_32X16X128_F4_w32 : WMMAInstGFX12<"v_wmma_scale_f32_32x16
19061906
defm V_WMMA_SCALE16_F32_32X16X128_F4_w32 : WMMAInstGFX12<"v_wmma_scale16_f32_32x16x128_f4", F32_32X16X128_F4_SCALE16_w32, "_w32">;
19071907
} // End is_wmma_xdl = 1.
19081908

1909-
defm V_WMMA_LD_SCALE_PAIRED_B32 : VOP3PInst<"v_wmma_ld_scale_paired_b32", VOP_WMMA_LD_SCALE<i32, VCSrc_b32_Lo256>>;
1910-
defm V_WMMA_LD_SCALE16_PAIRED_B64 : VOP3PInst<"v_wmma_ld_scale16_paired_b64", VOP_WMMA_LD_SCALE<i64, VCSrc_b64_Lo256>>;
1909+
let isConvergent = 1 in {
1910+
defm V_WMMA_LD_SCALE_PAIRED_B32 : VOP3PInst<"v_wmma_ld_scale_paired_b32", VOP_WMMA_LD_SCALE<i32, VCSrc_b32_Lo256>>;
1911+
defm V_WMMA_LD_SCALE16_PAIRED_B64 : VOP3PInst<"v_wmma_ld_scale16_paired_b64", VOP_WMMA_LD_SCALE<i64, VCSrc_b64_Lo256>>;
1912+
}
19111913
} // End SubtargetPredicate = isGFX125xOnly
19121914
} // End WaveSizePredicate = isWave32
19131915

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 6
2+
# RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx12-generic -run-pass=machine-sink %s -o - | FileCheck %s
3+
4+
---
5+
name: wmma_test
6+
tracksRegLiveness: true
7+
body: |
8+
; CHECK-LABEL: name: wmma_test
9+
; CHECK: bb.0:
10+
; CHECK-NEXT: successors: %bb.2(0x40000000), %bb.1(0x40000000)
11+
; CHECK-NEXT: {{ $}}
12+
; CHECK-NEXT: [[DEF:%[0-9]+]]:vreg_128 = IMPLICIT_DEF
13+
; CHECK-NEXT: [[DEF1:%[0-9]+]]:vreg_128 = IMPLICIT_DEF
14+
; CHECK-NEXT: [[DEF2:%[0-9]+]]:sreg_32 = IMPLICIT_DEF
15+
; CHECK-NEXT: early-clobber %3:vreg_256 = V_WMMA_F32_16X16X16_F16_w32_threeaddr 8, [[DEF]], 8, [[DEF1]], 8, 0, 0, 0, implicit $exec
16+
; CHECK-NEXT: [[SI_IF:%[0-9]+]]:sreg_32 = SI_IF [[DEF2]], %bb.2, implicit-def dead $exec, implicit-def dead $scc, implicit $exec
17+
; CHECK-NEXT: S_BRANCH %bb.1
18+
; CHECK-NEXT: {{ $}}
19+
; CHECK-NEXT: bb.1:
20+
; CHECK-NEXT: successors: %bb.2(0x80000000)
21+
; CHECK-NEXT: {{ $}}
22+
; CHECK-NEXT: [[COPY:%[0-9]+]]:vreg_256 = COPY %3.sub1
23+
; CHECK-NEXT: {{ $}}
24+
; CHECK-NEXT: bb.2:
25+
; CHECK-NEXT: SI_END_CF [[SI_IF]], implicit-def dead $exec, implicit-def dead $scc, implicit $exec
26+
; CHECK-NEXT: S_ENDPGM 0
27+
bb.0:
28+
%0:vreg_128 = IMPLICIT_DEF
29+
%1:vreg_128 = IMPLICIT_DEF
30+
%2:sreg_32 = IMPLICIT_DEF
31+
early-clobber %3:vreg_256 = V_WMMA_F32_16X16X16_F16_w32_threeaddr 8, %0:vreg_128, 8, %1:vreg_128, 8, 0, 0, 0, implicit $exec
32+
%4:sreg_32 = SI_IF %2:sreg_32, %bb.2, implicit-def dead $exec, implicit-def dead $scc, implicit $exec
33+
S_BRANCH %bb.1
34+
35+
bb.1:
36+
%5:vreg_256 = COPY %3.sub1:vreg_256
37+
38+
bb.2:
39+
SI_END_CF %4:sreg_32, implicit-def dead $exec, implicit-def dead $scc, implicit $exec
40+
S_ENDPGM 0
41+
42+
...

0 commit comments

Comments
 (0)