Skip to content

Commit 607e3df

Browse files
- Allow mix of x2 & x4 multivector loads and intrinsics
1 parent 81c3d47 commit 607e3df

File tree

3 files changed

+109
-14
lines changed

3 files changed

+109
-14
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8763,17 +8763,9 @@ static bool checkZExtBool(SDValue Arg, const SelectionDAG &DAG) {
87638763
bool shouldUseFormStridedPseudo(MachineInstr &MI) {
87648764
MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
87658765

8766-
const TargetRegisterClass *RegClass = nullptr;
8767-
switch (MI.getOpcode()) {
8768-
case AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO:
8769-
RegClass = &AArch64::ZPR2StridedOrContiguousRegClass;
8770-
break;
8771-
case AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO:
8772-
RegClass = &AArch64::ZPR4StridedOrContiguousRegClass;
8773-
break;
8774-
default:
8775-
llvm_unreachable("Unexpected opcode.");
8776-
}
8766+
assert((MI.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
8767+
MI.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO) &&
8768+
"Unexpected opcode.");
87778769

87788770
MCRegister SubReg = MCRegister::NoRegister;
87798771
for (unsigned I = 1; I < MI.getNumOperands(); ++I) {
@@ -8790,8 +8782,11 @@ bool shouldUseFormStridedPseudo(MachineInstr &MI) {
87908782
SubReg = OpSubReg;
87918783

87928784
MachineOperand *CopySrcOp = MRI.getOneDef(CopySrc.getReg());
8785+
const TargetRegisterClass *CopySrcClass =
8786+
MRI.getRegClass(CopySrcOp->getReg());
87938787
if (!CopySrcOp || !CopySrcOp->isReg() || OpSubReg != SubReg ||
8794-
MRI.getRegClass(CopySrcOp->getReg()) != RegClass)
8788+
(CopySrcClass != &AArch64::ZPR2StridedOrContiguousRegClass &&
8789+
CopySrcClass != &AArch64::ZPR4StridedOrContiguousRegClass))
87958790
return false;
87968791
}
87978792

llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,8 +1125,9 @@ bool AArch64RegisterInfo::getRegAllocationHints(
11251125

11261126
unsigned LdOps = Use.getNumOperands() - 1;
11271127
const TargetRegisterClass *StridedRC =
1128-
LdOps == 2 ? &AArch64::ZPR2StridedRegClass
1129-
: &AArch64::ZPR4StridedRegClass;
1128+
RegID == AArch64::ZPR2StridedOrContiguousRegClassID
1129+
? &AArch64::ZPR2StridedRegClass
1130+
: &AArch64::ZPR4StridedRegClass;
11301131

11311132
SmallVector<MCPhysReg, 4> StridedOrder;
11321133
for (MCPhysReg Reg : Order)

llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,53 @@ entry:
354354
ret void
355355
}
356356

357+
define void @udot_single_za32_u16_vg1x2_x4load_x2tuple(ptr %ptr, i64 %stride, <vscale x 8 x i16> %zn) #0 {
358+
; CHECK-LABEL: udot_single_za32_u16_vg1x2_x4load_x2tuple:
359+
; CHECK: // %bb.0: // %entry
360+
; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
361+
; CHECK-NEXT: addvl sp, sp, #-5
362+
; CHECK-NEXT: str p8, [sp, #7, mul vl] // 2-byte Folded Spill
363+
; CHECK-NEXT: ptrue pn8.b
364+
; CHECK-NEXT: add x9, x0, x1
365+
; CHECK-NEXT: str z14, [sp, #1, mul vl] // 16-byte Folded Spill
366+
; CHECK-NEXT: mov w8, wzr
367+
; CHECK-NEXT: str z13, [sp, #2, mul vl] // 16-byte Folded Spill
368+
; CHECK-NEXT: str z10, [sp, #3, mul vl] // 16-byte Folded Spill
369+
; CHECK-NEXT: str z9, [sp, #4, mul vl] // 16-byte Folded Spill
370+
; CHECK-NEXT: ld1h { z1.h, z5.h, z9.h, z13.h }, pn8/z, [x0]
371+
; CHECK-NEXT: ld1h { z2.h, z6.h, z10.h, z14.h }, pn8/z, [x9]
372+
; CHECK-NEXT: udot za.s[w8, 0, vgx2], { z1.h, z2.h }, z0.h
373+
; CHECK-NEXT: udot za.s[w8, 0, vgx2], { z5.h, z6.h }, z0.h
374+
; CHECK-NEXT: udot za.s[w8, 0, vgx2], { z9.h, z10.h }, z0.h
375+
; CHECK-NEXT: udot za.s[w8, 0, vgx2], { z13.h, z14.h }, z0.h
376+
; CHECK-NEXT: ldr z14, [sp, #1, mul vl] // 16-byte Folded Reload
377+
; CHECK-NEXT: ldr z13, [sp, #2, mul vl] // 16-byte Folded Reload
378+
; CHECK-NEXT: ldr z10, [sp, #3, mul vl] // 16-byte Folded Reload
379+
; CHECK-NEXT: ldr z9, [sp, #4, mul vl] // 16-byte Folded Reload
380+
; CHECK-NEXT: ldr p8, [sp, #7, mul vl] // 2-byte Folded Reload
381+
; CHECK-NEXT: addvl sp, sp, #5
382+
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
383+
; CHECK-NEXT: ret
384+
entry:
385+
%0 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()
386+
%1 = tail call { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } @llvm.aarch64.sve.ld1.pn.x4.nxv8i16(target("aarch64.svcount") %0, ptr %ptr)
387+
%2 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %1, 0
388+
%3 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %1, 1
389+
%4 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %1, 2
390+
%5 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %1, 3
391+
%arrayidx2 = getelementptr inbounds i8, ptr %ptr, i64 %stride
392+
%6 = tail call { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } @llvm.aarch64.sve.ld1.pn.x4.nxv8i16(target("aarch64.svcount") %0, ptr %arrayidx2)
393+
%7 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %6, 0
394+
%8 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %6, 1
395+
%9 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %6, 2
396+
%10 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %6, 3
397+
call void @llvm.aarch64.sme.udot.single.za32.vg1x2.nxv8i16(i32 0, <vscale x 8 x i16> %2, <vscale x 8 x i16> %7, <vscale x 8 x i16> %zn)
398+
call void @llvm.aarch64.sme.udot.single.za32.vg1x2.nxv8i16(i32 0, <vscale x 8 x i16> %3, <vscale x 8 x i16> %8, <vscale x 8 x i16> %zn)
399+
call void @llvm.aarch64.sme.udot.single.za32.vg1x2.nxv8i16(i32 0, <vscale x 8 x i16> %4, <vscale x 8 x i16> %9, <vscale x 8 x i16> %zn)
400+
call void @llvm.aarch64.sme.udot.single.za32.vg1x2.nxv8i16(i32 0, <vscale x 8 x i16> %5, <vscale x 8 x i16> %10, <vscale x 8 x i16> %zn)
401+
ret void
402+
}
403+
357404
define void @udot_single_za32_u16_vg1x4(i32 %slice, <vscale x 16 x i8> %unused, <vscale x 8 x i16> %zn0, <vscale x 8 x i16> %zn1, <vscale x 8 x i16> %zn2, <vscale x 8 x i16> %zn3, <vscale x 8 x i16> %zn4) #0 {
358405
; CHECK-LABEL: udot_single_za32_u16_vg1x4:
359406
; CHECK: // %bb.0:
@@ -1196,6 +1243,58 @@ entry:
11961243
ret void
11971244
}
11981245

1246+
define void @udot_single_za32_u16_vg1x4_x2load_x4tuple(ptr %ptr, i64 %stride, <vscale x 16 x i8> %zn) #0 {
1247+
; CHECK-LABEL: udot_single_za32_u16_vg1x4_x2load_x4tuple:
1248+
; CHECK: // %bb.0: // %entry
1249+
; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
1250+
; CHECK-NEXT: addvl sp, sp, #-5
1251+
; CHECK-NEXT: lsl x9, x1, #1
1252+
; CHECK-NEXT: str p8, [sp, #7, mul vl] // 2-byte Folded Spill
1253+
; CHECK-NEXT: ptrue pn8.b
1254+
; CHECK-NEXT: str z12, [sp, #1, mul vl] // 16-byte Folded Spill
1255+
; CHECK-NEXT: st1b { z10.b, z11.b }, pn8, [sp, #2, mul vl] // 32-byte Folded Spill
1256+
; CHECK-NEXT: ptrue pn8.b
1257+
; CHECK-NEXT: str z9, [sp, #4, mul vl] // 16-byte Folded Spill
1258+
; CHECK-NEXT: add x10, x9, x1
1259+
; CHECK-NEXT: mov w8, wzr
1260+
; CHECK-NEXT: ld1b { z1.b, z9.b }, pn8/z, [x0]
1261+
; CHECK-NEXT: ld1b { z2.b, z10.b }, pn8/z, [x0, x1]
1262+
; CHECK-NEXT: ld1b { z3.b, z11.b }, pn8/z, [x0, x9]
1263+
; CHECK-NEXT: ld1b { z4.b, z12.b }, pn8/z, [x0, x10]
1264+
; CHECK-NEXT: ptrue pn8.b
1265+
; CHECK-NEXT: udot za.s[w8, 0, vgx4], { z1.b - z4.b }, z0.b
1266+
; CHECK-NEXT: udot za.s[w8, 0, vgx4], { z9.b - z12.b }, z0.b
1267+
; CHECK-NEXT: ldr z12, [sp, #1, mul vl] // 16-byte Folded Reload
1268+
; CHECK-NEXT: ld1b { z10.b, z11.b }, pn8/z, [sp, #2, mul vl] // 32-byte Folded Reload
1269+
; CHECK-NEXT: ldr z9, [sp, #4, mul vl] // 16-byte Folded Reload
1270+
; CHECK-NEXT: ldr p8, [sp, #7, mul vl] // 2-byte Folded Reload
1271+
; CHECK-NEXT: addvl sp, sp, #5
1272+
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
1273+
; CHECK-NEXT: ret
1274+
entry:
1275+
%0 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()
1276+
%1 = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.aarch64.sve.ld1.pn.x2.nxv16i8(target("aarch64.svcount") %0, ptr %ptr)
1277+
%2 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %1, 0
1278+
%3 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %1, 1
1279+
%arrayidx2 = getelementptr inbounds i8, ptr %ptr, i64 %stride
1280+
%4 = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.aarch64.sve.ld1.pn.x2.nxv16i8(target("aarch64.svcount") %0, ptr %arrayidx2)
1281+
%5 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %4, 0
1282+
%6 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %4, 1
1283+
%mul3 = shl i64 %stride, 1
1284+
%arrayidx4 = getelementptr inbounds i8, ptr %ptr, i64 %mul3
1285+
%7 = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.aarch64.sve.ld1.pn.x2.nxv16i8(target("aarch64.svcount") %0, ptr %arrayidx4)
1286+
%8 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %7, 0
1287+
%9 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %7, 1
1288+
%mul5 = mul i64 %stride, 3
1289+
%arrayidx6 = getelementptr inbounds i8, ptr %ptr, i64 %mul5
1290+
%10 = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.aarch64.sve.ld1.pn.x2.nxv16i8(target("aarch64.svcount") %0, ptr %arrayidx6)
1291+
%11 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %10, 0
1292+
%12 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %10, 1
1293+
call void @llvm.aarch64.sme.udot.single.za32.vg1x4.nxv16i8(i32 0, <vscale x 16 x i8> %2, <vscale x 16 x i8> %5, <vscale x 16 x i8> %8, <vscale x 16 x i8> %11, <vscale x 16 x i8> %zn)
1294+
call void @llvm.aarch64.sme.udot.single.za32.vg1x4.nxv16i8(i32 0, <vscale x 16 x i8> %3, <vscale x 16 x i8> %6, <vscale x 16 x i8> %9, <vscale x 16 x i8> %12, <vscale x 16 x i8> %zn)
1295+
ret void
1296+
}
1297+
11991298
define void @udot_lane_za64_u16_vg1x2(i32 %slice, <vscale x 16 x i8> %unused, <vscale x 8 x i16> %zn0, <vscale x 8 x i16> %zn1, <vscale x 8 x i16> %zn2) #1 {
12001299
; CHECK-LABEL: udot_lane_za64_u16_vg1x2:
12011300
; CHECK: // %bb.0:

0 commit comments

Comments
 (0)