Skip to content

Commit 81f3ddf

Browse files
authored
AMDGPU: Rewrite VGPR MFMAs to AGPR when directly copied to AGPR class (#152480)
1 parent ed6cd8f commit 81f3ddf

File tree

2 files changed

+202
-6
lines changed

2 files changed

+202
-6
lines changed

llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,17 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
109109

110110
// Find AV_* registers assigned to AGPRs.
111111
const TargetRegisterClass *VirtRegRC = MRI.getRegClass(VReg);
112-
if (!TRI.isVectorSuperClass(VirtRegRC))
112+
if (!TRI.hasAGPRs(VirtRegRC))
113113
continue;
114114

115-
const TargetRegisterClass *AssignedRC = TRI.getPhysRegBaseClass(PhysReg);
116-
if (!TRI.isAGPRClass(AssignedRC))
117-
continue;
115+
const TargetRegisterClass *AssignedRC = VirtRegRC;
116+
if (TRI.hasVGPRs(VirtRegRC)) {
117+
// If this is an AV register, we have to check if the actual assignment is
118+
// to an AGPR
119+
AssignedRC = TRI.getPhysRegBaseClass(PhysReg);
120+
if (!TRI.isAGPRClass(AssignedRC))
121+
continue;
122+
}
118123

119124
LiveInterval &LI = LIS.getInterval(VReg);
120125

llvm/test/CodeGen/AMDGPU/rewrite-vgpr-mfma-to-agpr.ll

Lines changed: 193 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,199 @@ bb:
200200
ret void
201201
}
202202

203-
declare <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float, float, <32 x float>, i32 immarg, i32 immarg, i32 immarg) #1
204-
declare noundef i32 @llvm.amdgcn.workitem.id.x() #2
203+
; The inline asm requires the value be copied to an AGPR class, not
204+
; the AV_* pseudo we usually expect for register allocator live range
205+
; splits.
206+
define amdgpu_kernel void @test_rewrite_mfma_direct_copy_to_agpr_class(ptr addrspace(1) %arg) #0 {
207+
; CHECK-LABEL: test_rewrite_mfma_direct_copy_to_agpr_class:
208+
; CHECK: ; %bb.0: ; %bb
209+
; CHECK-NEXT: s_load_dwordx2 s[0:1], s[4:5], 0x0
210+
; CHECK-NEXT: v_and_b32_e32 v0, 0x3ff, v0
211+
; CHECK-NEXT: v_lshlrev_b32_e32 v0, 7, v0
212+
; CHECK-NEXT: v_mov_b32_e32 v32, 2.0
213+
; CHECK-NEXT: v_mov_b32_e32 v33, 4.0
214+
; CHECK-NEXT: s_waitcnt lgkmcnt(0)
215+
; CHECK-NEXT: global_load_dwordx4 a[28:31], v0, s[0:1] offset:112
216+
; CHECK-NEXT: global_load_dwordx4 a[24:27], v0, s[0:1] offset:96
217+
; CHECK-NEXT: global_load_dwordx4 a[20:23], v0, s[0:1] offset:80
218+
; CHECK-NEXT: global_load_dwordx4 a[16:19], v0, s[0:1] offset:64
219+
; CHECK-NEXT: global_load_dwordx4 a[12:15], v0, s[0:1] offset:48
220+
; CHECK-NEXT: global_load_dwordx4 a[8:11], v0, s[0:1] offset:32
221+
; CHECK-NEXT: global_load_dwordx4 a[4:7], v0, s[0:1] offset:16
222+
; CHECK-NEXT: global_load_dwordx4 a[0:3], v0, s[0:1]
223+
; CHECK-NEXT: s_waitcnt vmcnt(0)
224+
; CHECK-NEXT: v_mfma_f32_32x32x1_2b_f32 a[0:31], v32, v33, a[0:31]
225+
; CHECK-NEXT: ;;#ASMSTART
226+
; CHECK-NEXT: ; use a[0:31]
227+
; CHECK-NEXT: ;;#ASMEND
228+
; CHECK-NEXT: s_endpgm
229+
bb:
230+
%id = call i32 @llvm.amdgcn.workitem.id.x()
231+
%gep = getelementptr <32 x float>, ptr addrspace(1) %arg, i32 %id
232+
%in = load <32 x float>, ptr addrspace(1) %gep, align 128
233+
%mai = call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float 2.0, float 4.0, <32 x float> %in, i32 0, i32 0, i32 0)
234+
call void asm sideeffect "; use $0", "a"(<32 x float> %mai)
235+
ret void
236+
}
237+
238+
; TODO: Handle rewriting this case
239+
define void @test_rewrite_mfma_imm_src2(float %arg0, float %arg1) #0 {
240+
; CHECK-LABEL: test_rewrite_mfma_imm_src2:
241+
; CHECK: ; %bb.0: ; %bb
242+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
243+
; CHECK-NEXT: v_mfma_f32_32x32x1_2b_f32 v[0:31], v0, v1, 2.0
244+
; CHECK-NEXT: s_nop 7
245+
; CHECK-NEXT: s_nop 7
246+
; CHECK-NEXT: s_nop 1
247+
; CHECK-NEXT: v_accvgpr_write_b32 a0, v0
248+
; CHECK-NEXT: v_accvgpr_write_b32 a1, v1
249+
; CHECK-NEXT: v_accvgpr_write_b32 a2, v2
250+
; CHECK-NEXT: v_accvgpr_write_b32 a3, v3
251+
; CHECK-NEXT: v_accvgpr_write_b32 a4, v4
252+
; CHECK-NEXT: v_accvgpr_write_b32 a5, v5
253+
; CHECK-NEXT: v_accvgpr_write_b32 a6, v6
254+
; CHECK-NEXT: v_accvgpr_write_b32 a7, v7
255+
; CHECK-NEXT: v_accvgpr_write_b32 a8, v8
256+
; CHECK-NEXT: v_accvgpr_write_b32 a9, v9
257+
; CHECK-NEXT: v_accvgpr_write_b32 a10, v10
258+
; CHECK-NEXT: v_accvgpr_write_b32 a11, v11
259+
; CHECK-NEXT: v_accvgpr_write_b32 a12, v12
260+
; CHECK-NEXT: v_accvgpr_write_b32 a13, v13
261+
; CHECK-NEXT: v_accvgpr_write_b32 a14, v14
262+
; CHECK-NEXT: v_accvgpr_write_b32 a15, v15
263+
; CHECK-NEXT: v_accvgpr_write_b32 a16, v16
264+
; CHECK-NEXT: v_accvgpr_write_b32 a17, v17
265+
; CHECK-NEXT: v_accvgpr_write_b32 a18, v18
266+
; CHECK-NEXT: v_accvgpr_write_b32 a19, v19
267+
; CHECK-NEXT: v_accvgpr_write_b32 a20, v20
268+
; CHECK-NEXT: v_accvgpr_write_b32 a21, v21
269+
; CHECK-NEXT: v_accvgpr_write_b32 a22, v22
270+
; CHECK-NEXT: v_accvgpr_write_b32 a23, v23
271+
; CHECK-NEXT: v_accvgpr_write_b32 a24, v24
272+
; CHECK-NEXT: v_accvgpr_write_b32 a25, v25
273+
; CHECK-NEXT: v_accvgpr_write_b32 a26, v26
274+
; CHECK-NEXT: v_accvgpr_write_b32 a27, v27
275+
; CHECK-NEXT: v_accvgpr_write_b32 a28, v28
276+
; CHECK-NEXT: v_accvgpr_write_b32 a29, v29
277+
; CHECK-NEXT: v_accvgpr_write_b32 a30, v30
278+
; CHECK-NEXT: v_accvgpr_write_b32 a31, v31
279+
; CHECK-NEXT: ;;#ASMSTART
280+
; CHECK-NEXT: ; use a[0:31]
281+
; CHECK-NEXT: ;;#ASMEND
282+
; CHECK-NEXT: s_setpc_b64 s[30:31]
283+
bb:
284+
%mai = call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float %arg0, float %arg1, <32 x float> splat (float 2.0), i32 0, i32 0, i32 0)
285+
call void asm sideeffect "; use $0", "a"(<32 x float> %mai)
286+
ret void
287+
}
288+
289+
; TODO: Handle rewriting this case
290+
define void @test_rewrite_mfma_subreg_extract0(float %arg0, float %arg1, ptr addrspace(1) %ptr) #0 {
291+
; CHECK-LABEL: test_rewrite_mfma_subreg_extract0:
292+
; CHECK: ; %bb.0: ; %bb
293+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
294+
; CHECK-NEXT: global_load_dwordx4 v[30:33], v[2:3], off offset:112
295+
; CHECK-NEXT: global_load_dwordx4 v[26:29], v[2:3], off offset:96
296+
; CHECK-NEXT: global_load_dwordx4 v[22:25], v[2:3], off offset:80
297+
; CHECK-NEXT: global_load_dwordx4 v[18:21], v[2:3], off offset:64
298+
; CHECK-NEXT: global_load_dwordx4 v[14:17], v[2:3], off offset:48
299+
; CHECK-NEXT: global_load_dwordx4 v[10:13], v[2:3], off offset:32
300+
; CHECK-NEXT: global_load_dwordx4 v[6:9], v[2:3], off offset:16
301+
; CHECK-NEXT: s_nop 0
302+
; CHECK-NEXT: global_load_dwordx4 v[2:5], v[2:3], off
303+
; CHECK-NEXT: s_waitcnt vmcnt(0)
304+
; CHECK-NEXT: v_mfma_f32_32x32x1_2b_f32 v[2:33], v0, v1, v[2:33]
305+
; CHECK-NEXT: s_nop 7
306+
; CHECK-NEXT: s_nop 7
307+
; CHECK-NEXT: s_nop 1
308+
; CHECK-NEXT: v_accvgpr_write_b32 a0, v2
309+
; CHECK-NEXT: v_accvgpr_write_b32 a1, v3
310+
; CHECK-NEXT: v_accvgpr_write_b32 a2, v4
311+
; CHECK-NEXT: v_accvgpr_write_b32 a3, v5
312+
; CHECK-NEXT: ;;#ASMSTART
313+
; CHECK-NEXT: ; use a[0:3]
314+
; CHECK-NEXT: ;;#ASMEND
315+
; CHECK-NEXT: s_setpc_b64 s[30:31]
316+
bb:
317+
%src2 = load <32 x float>, ptr addrspace(1) %ptr
318+
%mai = call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float %arg0, float %arg1, <32 x float> %src2, i32 0, i32 0, i32 0)
319+
%extract.sub4 = shufflevector <32 x float> %mai, <32 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
320+
call void asm sideeffect "; use $0", "a"(<4 x float> %extract.sub4)
321+
ret void
322+
}
323+
324+
define void @test_rewrite_mfma_subreg_extract1(float %arg0, float %arg1, ptr addrspace(1) %ptr) #0 {
325+
; CHECK-LABEL: test_rewrite_mfma_subreg_extract1:
326+
; CHECK: ; %bb.0: ; %bb
327+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
328+
; CHECK-NEXT: global_load_dwordx4 v[30:33], v[2:3], off offset:112
329+
; CHECK-NEXT: global_load_dwordx4 v[26:29], v[2:3], off offset:96
330+
; CHECK-NEXT: global_load_dwordx4 v[22:25], v[2:3], off offset:80
331+
; CHECK-NEXT: global_load_dwordx4 v[18:21], v[2:3], off offset:64
332+
; CHECK-NEXT: global_load_dwordx4 v[14:17], v[2:3], off offset:48
333+
; CHECK-NEXT: global_load_dwordx4 v[10:13], v[2:3], off offset:32
334+
; CHECK-NEXT: global_load_dwordx4 v[6:9], v[2:3], off offset:16
335+
; CHECK-NEXT: s_nop 0
336+
; CHECK-NEXT: global_load_dwordx4 v[2:5], v[2:3], off
337+
; CHECK-NEXT: s_waitcnt vmcnt(0)
338+
; CHECK-NEXT: v_mfma_f32_32x32x1_2b_f32 v[2:33], v0, v1, v[2:33]
339+
; CHECK-NEXT: s_nop 7
340+
; CHECK-NEXT: s_nop 7
341+
; CHECK-NEXT: s_nop 1
342+
; CHECK-NEXT: v_accvgpr_write_b32 a0, v6
343+
; CHECK-NEXT: v_accvgpr_write_b32 a1, v7
344+
; CHECK-NEXT: v_accvgpr_write_b32 a2, v8
345+
; CHECK-NEXT: v_accvgpr_write_b32 a3, v9
346+
; CHECK-NEXT: ;;#ASMSTART
347+
; CHECK-NEXT: ; use a[0:3]
348+
; CHECK-NEXT: ;;#ASMEND
349+
; CHECK-NEXT: s_setpc_b64 s[30:31]
350+
bb:
351+
%src2 = load <32 x float>, ptr addrspace(1) %ptr
352+
%mai = call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float %arg0, float %arg1, <32 x float> %src2, i32 0, i32 0, i32 0)
353+
%extract.sub4 = shufflevector <32 x float> %mai, <32 x float> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
354+
call void asm sideeffect "; use $0", "a"(<4 x float> %extract.sub4)
355+
ret void
356+
}
357+
358+
; odd offset
359+
define void @test_rewrite_mfma_subreg_extract2(float %arg0, float %arg1, ptr addrspace(1) %ptr) #0 {
360+
; CHECK-LABEL: test_rewrite_mfma_subreg_extract2:
361+
; CHECK: ; %bb.0: ; %bb
362+
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
363+
; CHECK-NEXT: global_load_dwordx4 v[30:33], v[2:3], off offset:112
364+
; CHECK-NEXT: global_load_dwordx4 v[26:29], v[2:3], off offset:96
365+
; CHECK-NEXT: global_load_dwordx4 v[22:25], v[2:3], off offset:80
366+
; CHECK-NEXT: global_load_dwordx4 v[18:21], v[2:3], off offset:64
367+
; CHECK-NEXT: global_load_dwordx4 v[14:17], v[2:3], off offset:48
368+
; CHECK-NEXT: global_load_dwordx4 v[10:13], v[2:3], off offset:32
369+
; CHECK-NEXT: global_load_dwordx4 v[6:9], v[2:3], off offset:16
370+
; CHECK-NEXT: s_nop 0
371+
; CHECK-NEXT: global_load_dwordx4 v[2:5], v[2:3], off
372+
; CHECK-NEXT: s_waitcnt vmcnt(0)
373+
; CHECK-NEXT: v_mfma_f32_32x32x1_2b_f32 v[2:33], v0, v1, v[2:33]
374+
; CHECK-NEXT: s_nop 7
375+
; CHECK-NEXT: s_nop 7
376+
; CHECK-NEXT: s_nop 1
377+
; CHECK-NEXT: v_accvgpr_write_b32 a0, v3
378+
; CHECK-NEXT: v_accvgpr_write_b32 a1, v4
379+
; CHECK-NEXT: v_accvgpr_write_b32 a2, v5
380+
; CHECK-NEXT: v_accvgpr_write_b32 a3, v6
381+
; CHECK-NEXT: ;;#ASMSTART
382+
; CHECK-NEXT: ; use a[0:3]
383+
; CHECK-NEXT: ;;#ASMEND
384+
; CHECK-NEXT: s_setpc_b64 s[30:31]
385+
bb:
386+
%src2 = load <32 x float>, ptr addrspace(1) %ptr
387+
%mai = call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float %arg0, float %arg1, <32 x float> %src2, i32 0, i32 0, i32 0)
388+
%extract.sub4 = shufflevector <32 x float> %mai, <32 x float> poison, <4 x i32> <i32 1, i32 2, i32 3, i32 4>
389+
call void asm sideeffect "; use $0", "a"(<4 x float> %extract.sub4)
390+
ret void
391+
}
392+
393+
declare <4 x float> @llvm.amdgcn.mfma.f32.16x16x16f16(<4 x half>, <4 x half>, <4 x float>, i32 immarg, i32 immarg, i32 immarg) #2
394+
declare <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float, float, <32 x float>, i32 immarg, i32 immarg, i32 immarg) #2
395+
declare noundef range(i32 0, 1024) i32 @llvm.amdgcn.workitem.id.x() #3
205396

206397
attributes #0 = { nounwind "amdgpu-flat-work-group-size"="1,256" "amdgpu-waves-per-eu"="4,4" }
207398
attributes #1 = { convergent nocallback nofree nosync nounwind willreturn memory(none) }

0 commit comments

Comments
 (0)