-
Notifications
You must be signed in to change notification settings - Fork 15.3k
Description
A matrix multiplication kernel that supports arbitrary input sizes (and which uses LDS to do an inter-wave reduction) begins producing incorrect results when MMRA annotations are added to the fences around a pair of barriers on gfx1201. The incorrect values occur even at -O0 but there, the failure is nondeterministic.
(See iree-org/iree#22278 for context and a bit of debugging history)
Sadly, I can't think of a good way to extract this issue from the context of a matrix multiply.
iree-22278-reproducer.zip contains the input IR, a C++ test program, and a Makefile to compile and test the IR at issue.
The differences between the IR are
--- global-fence.ll 2025-10-14 15:12:53.302090833 -0700
+++ lds-fence.ll 2025-10-14 15:12:14.559936178 -0700
@@ -208,15 +208,15 @@ define amdgpu_kernel void @matmul_accumu
%162 = fadd float %160, %161
%163 = tail call float @llvm.amdgcn.readlane.f32(float %162, i32 31)
%164 = insertelement <1 x float> poison, float %163, i64 0
- fence syncscope("workgroup") release
+ fence syncscope("workgroup") release, !mmra !12
tail call void @llvm.amdgcn.s.barrier.signal(i32 -1)
tail call void @llvm.amdgcn.s.barrier.wait(i16 -1)
- fence syncscope("workgroup") acquire
+ fence syncscope("workgroup") acquire, !mmra !12
store <1 x float> %164, ptr addrspace(3) %59, align 4
- fence syncscope("workgroup") release
+ fence syncscope("workgroup") release, !mmra !12
tail call void @llvm.amdgcn.s.barrier.signal(i32 -1)
tail call void @llvm.amdgcn.s.barrier.wait(i16 -1)
- fence syncscope("workgroup") acquire
+ fence syncscope("workgroup") acquire, !mmra !12
%165 = load <1 x float>, ptr addrspace(3) %60, align 4
%166 = extractelement <1 x float> %165, i64 0
%167 = fadd float %166, 0.000000e+00
For context, %59 is the workitem ID / 32 (aka workitem ID >> 5)th float in LDS and %60isworkitem ID % 32(akaworkitem ID & 0x1f)`the element. That is, after doing a reduction in each wave, the wave writes that readlane'd result to LDS at its wave ID and then each lane reads the lane'th element of LDS so that the rest of the reduction can be computed. (All of this is in the reproducer IR).
This should, as far as I'm aware, work with an LDS barrier only - no global memory needs to be synchronized on. However, it doesn't.
The difference between the working and failing assembly at -O3 is
--- global-fence-O3.s 2025-10-15 10:27:08.525073548 -0700
+++ lds-fence-O3.s 2025-10-15 10:27:08.324047999 -0700
@@ -98,33 +98,31 @@ matmul_accumulate_DYNxDYNxf16_times_DYNx
v_cmp_lt_i64_e64 s31, s[34:35], s[22:23]
global_load_b32 v2, v1, s[4:5]
v_add_f32_e32 v7, 0, v28
- s_wait_loadcnt 0x0
s_barrier_signal -1
s_barrier_wait -1
- global_inv scope:SCOPE_SE
+ s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_1)
v_add_f32_e32 v7, v7, v31
s_and_b32 vcc_lo, exec_lo, s31
- s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
v_add_f32_e32 v7, v7, v30
- v_add_f32_e32 v7, v7, v29
s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+ v_add_f32_e32 v7, v7, v29
v_add_f32_dpp v7, v7, v7 quad_perm:[1,0,3,2] row_mask:0xf bank_mask:0xf bound_ctrl:1
- v_add_f32_dpp v7, v7, v7 quad_perm:[2,3,0,1] row_mask:0xf bank_mask:0xf bound_ctrl:1
s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+ v_add_f32_dpp v7, v7, v7 quad_perm:[2,3,0,1] row_mask:0xf bank_mask:0xf bound_ctrl:1
v_add_f32_dpp v7, v7, v7 row_half_mirror row_mask:0xf bank_mask:0xf bound_ctrl:1
- v_add_f32_dpp v7, v7, v7 row_mirror row_mask:0xf bank_mask:0xf bound_ctrl:1
s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+ v_add_f32_dpp v7, v7, v7 row_mirror row_mask:0xf bank_mask:0xf bound_ctrl:1
v_permlanex16_b32 v8, v7, -1, -1 op_sel:[1,0]
+ s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
v_add_f32_e32 v7, v7, v8
- s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_1)
v_readlane_b32 s6, v7, 31
s_wait_alu 0xf1ff
+ s_delay_alu instid0(VALU_DEP_1)
v_mov_b32_e32 v7, s6
ds_store_b32 v23, v7
- s_wait_loadcnt_dscnt 0x0
+ s_wait_dscnt 0x0
s_barrier_signal -1
s_barrier_wait -1
- global_inv scope:SCOPE_SE
ds_load_b32 v7, v0
s_wait_dscnt 0x0
v_add_f32_e32 v7, 0, v7
@@ -137,8 +135,9 @@ matmul_accumulate_DYNxDYNxf16_times_DYNx
s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
v_permlanex16_b32 v8, v7, -1, -1 op_sel:[1,0]
v_add_f32_e32 v7, v8, v7
- s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_1)
+ s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_2) | instid1(VALU_DEP_1)
v_readlane_b32 s6, v7, 31
+ s_wait_loadcnt 0x0
s_wait_alu 0xf1ff
v_add_f32_e32 v2, s6, v2
global_store_b32 v1, v2, s[4:5]