Skip to content

Commit 08d4923

Browse files
committed
[AArch64][SVE] Add basic support for @llvm.masked.compressstore
This patch adds SVE support for the `masked.compressstore` intrinsic via the existing `VECTOR_COMPRESS` lowering and compressing the store mask via `VECREDUCE_ADD`. Currently, only `nxv4[i32|f32]` and `nxv2[i64|f64]` are directly supported, with other types promoted to these, where possible. This is done in preparation for LV support of this intrinsic, which is currently being worked on in #140723.
1 parent 44f72fb commit 08d4923

File tree

5 files changed

+218
-11
lines changed

5 files changed

+218
-11
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1962,10 +1962,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
19621962
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
19631963

19641964
// We can lower types that have <vscale x {2|4}> elements to compact.
1965-
for (auto VT :
1966-
{MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
1967-
MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32})
1965+
for (auto VT : {MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64,
1966+
MVT::nxv2f32, MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16,
1967+
MVT::nxv4i32, MVT::nxv4f32}) {
1968+
setOperationAction(ISD::MSTORE, VT, Custom);
1969+
// Use a custom lowering for masked stores that could be a supported
1970+
// compressing store. Note: These types still use the normal (Legal)
1971+
// lowering for non-compressing masked stores.
19681972
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
1973+
}
19691974

19701975
// If we have SVE, we can use SVE logic for legal (or smaller than legal)
19711976
// NEON vectors in the lowest bits of the SVE register.
@@ -7740,7 +7745,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
77407745
case ISD::STORE:
77417746
return LowerSTORE(Op, DAG);
77427747
case ISD::MSTORE:
7743-
return LowerFixedLengthVectorMStoreToSVE(Op, DAG);
7748+
return LowerMSTORE(Op, DAG);
77447749
case ISD::MGATHER:
77457750
return LowerMGATHER(Op, DAG);
77467751
case ISD::MSCATTER:
@@ -30180,6 +30185,36 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorStoreToSVE(
3018030185
Store->isTruncatingStore());
3018130186
}
3018230187

30188+
SDValue AArch64TargetLowering::LowerMSTORE(SDValue Op,
30189+
SelectionDAG &DAG) const {
30190+
SDLoc DL(Op);
30191+
auto *Store = cast<MaskedStoreSDNode>(Op);
30192+
EVT VT = Store->getValue().getValueType();
30193+
if (VT.isFixedLengthVector())
30194+
return LowerFixedLengthVectorMStoreToSVE(Op, DAG);
30195+
30196+
if (!Store->isCompressingStore())
30197+
return SDValue();
30198+
30199+
EVT MaskVT = Store->getMask().getValueType();
30200+
30201+
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
30202+
SDValue CntActive =
30203+
DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i64, Store->getMask());
30204+
SDValue CompressedValue =
30205+
DAG.getNode(ISD::VECTOR_COMPRESS, DL, VT, Store->getValue(),
30206+
Store->getMask(), DAG.getPOISON(VT));
30207+
SDValue CompressedMask =
30208+
DAG.getNode(ISD::GET_ACTIVE_LANE_MASK, DL, MaskVT, Zero, CntActive);
30209+
30210+
return DAG.getMaskedStore(Store->getChain(), DL, CompressedValue,
30211+
Store->getBasePtr(), Store->getOffset(),
30212+
CompressedMask, Store->getMemoryVT(),
30213+
Store->getMemOperand(), Store->getAddressingMode(),
30214+
Store->isTruncatingStore(),
30215+
/*isCompressing=*/false);
30216+
}
30217+
3018330218
SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE(
3018430219
SDValue Op, SelectionDAG &DAG) const {
3018530220
auto *Store = cast<MaskedStoreSDNode>(Op);
@@ -30194,7 +30229,8 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE(
3019430229
return DAG.getMaskedStore(
3019530230
Store->getChain(), DL, NewValue, Store->getBasePtr(), Store->getOffset(),
3019630231
Mask, Store->getMemoryVT(), Store->getMemOperand(),
30197-
Store->getAddressingMode(), Store->isTruncatingStore());
30232+
Store->getAddressingMode(), Store->isTruncatingStore(),
30233+
Store->isCompressingStore());
3019830234
}
3019930235

3020030236
SDValue AArch64TargetLowering::LowerFixedLengthVectorIntDivideToSVE(

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,7 @@ class AArch64TargetLowering : public TargetLowering {
755755
SDValue LowerWindowsDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
756756
SDValue LowerInlineDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
757757
SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
758+
SDValue LowerMSTORE(SDValue Op, SelectionDAG &DAG) const;
758759

759760
SDValue LowerAVG(SDValue Op, SelectionDAG &DAG, unsigned NewOp) const;
760761

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -645,37 +645,43 @@ def nontrunc_masked_store :
645645
(masked_st node:$val, node:$ptr, undef, node:$pred), [{
646646
return !cast<MaskedStoreSDNode>(N)->isTruncatingStore() &&
647647
cast<MaskedStoreSDNode>(N)->isUnindexed() &&
648-
!cast<MaskedStoreSDNode>(N)->isNonTemporal();
648+
!cast<MaskedStoreSDNode>(N)->isNonTemporal() &&
649+
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
649650
}]>;
650651
// truncating masked store fragments.
651652
def trunc_masked_store :
652653
PatFrag<(ops node:$val, node:$ptr, node:$pred),
653654
(masked_st node:$val, node:$ptr, undef, node:$pred), [{
654655
return cast<MaskedStoreSDNode>(N)->isTruncatingStore() &&
655-
cast<MaskedStoreSDNode>(N)->isUnindexed();
656+
cast<MaskedStoreSDNode>(N)->isUnindexed() &&
657+
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
656658
}]>;
657659
def trunc_masked_store_i8 :
658660
PatFrag<(ops node:$val, node:$ptr, node:$pred),
659661
(trunc_masked_store node:$val, node:$ptr, node:$pred), [{
660-
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8;
662+
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8 &&
663+
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
661664
}]>;
662665
def trunc_masked_store_i16 :
663666
PatFrag<(ops node:$val, node:$ptr, node:$pred),
664667
(trunc_masked_store node:$val, node:$ptr, node:$pred), [{
665-
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16;
668+
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16 &&
669+
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
666670
}]>;
667671
def trunc_masked_store_i32 :
668672
PatFrag<(ops node:$val, node:$ptr, node:$pred),
669673
(trunc_masked_store node:$val, node:$ptr, node:$pred), [{
670-
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32;
674+
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32 &&
675+
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
671676
}]>;
672677

673678
def non_temporal_store :
674679
PatFrag<(ops node:$val, node:$ptr, node:$pred),
675680
(masked_st node:$val, node:$ptr, undef, node:$pred), [{
676681
return !cast<MaskedStoreSDNode>(N)->isTruncatingStore() &&
677682
cast<MaskedStoreSDNode>(N)->isUnindexed() &&
678-
cast<MaskedStoreSDNode>(N)->isNonTemporal();
683+
cast<MaskedStoreSDNode>(N)->isNonTemporal() &&
684+
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
679685
}]>;
680686

681687
multiclass masked_gather_scatter<PatFrags GatherScatterOp> {

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,29 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
333333
return isLegalMaskedLoadStore(DataType, Alignment);
334334
}
335335

336+
bool isElementTypeLegalForCompressStore(Type *Ty) const {
337+
if (Ty->isFloatTy() || Ty->isDoubleTy())
338+
return true;
339+
340+
if (Ty->isIntegerTy(8) || Ty->isIntegerTy(16) || Ty->isIntegerTy(32) ||
341+
Ty->isIntegerTy(64))
342+
return true;
343+
344+
return false;
345+
}
346+
347+
bool isLegalMaskedCompressStore(Type *DataType,
348+
Align Alignment) const override {
349+
ElementCount EC = cast<VectorType>(DataType)->getElementCount();
350+
if (EC.getKnownMinValue() != 2 && EC.getKnownMinValue() != 4)
351+
return false;
352+
353+
if (!isElementTypeLegalForCompressStore(DataType->getScalarType()))
354+
return false;
355+
356+
return isLegalMaskedLoadStore(DataType, Alignment);
357+
}
358+
336359
bool isLegalMaskedGatherScatter(Type *DataType) const {
337360
if (!ST->isSVEAvailable())
338361
return false;
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
2+
; RUN: llc -mtriple=aarch64 -mattr=+sve < %s | FileCheck %s
3+
4+
;; Full SVE vectors (supported with +sve)
5+
6+
define void @test_compressstore_nxv4i32(ptr %p, <vscale x 4 x i32> %vec, <vscale x 4 x i1> %mask) {
7+
; CHECK-LABEL: test_compressstore_nxv4i32:
8+
; CHECK: // %bb.0:
9+
; CHECK-NEXT: ptrue p1.s
10+
; CHECK-NEXT: compact z0.s, p0, z0.s
11+
; CHECK-NEXT: cntp x8, p1, p0.s
12+
; CHECK-NEXT: whilelo p0.s, xzr, x8
13+
; CHECK-NEXT: st1w { z0.s }, p0, [x0]
14+
; CHECK-NEXT: ret
15+
tail call void @llvm.masked.compressstore.nxv4i32(<vscale x 4 x i32> %vec, ptr align 4 %p, <vscale x 4 x i1> %mask)
16+
ret void
17+
}
18+
19+
define void @test_compressstore_nxv2i64(ptr %p, <vscale x 2 x i64> %vec, <vscale x 2 x i1> %mask) {
20+
; CHECK-LABEL: test_compressstore_nxv2i64:
21+
; CHECK: // %bb.0:
22+
; CHECK-NEXT: ptrue p1.d
23+
; CHECK-NEXT: compact z0.d, p0, z0.d
24+
; CHECK-NEXT: cntp x8, p1, p0.d
25+
; CHECK-NEXT: whilelo p0.d, xzr, x8
26+
; CHECK-NEXT: st1d { z0.d }, p0, [x0]
27+
; CHECK-NEXT: ret
28+
tail call void @llvm.masked.compressstore.nxv2i64(<vscale x 2 x i64> %vec, ptr align 8 %p, <vscale x 2 x i1> %mask)
29+
ret void
30+
}
31+
32+
define void @test_compressstore_nxv4f32(ptr %p, <vscale x 4 x float> %vec, <vscale x 4 x i1> %mask) {
33+
; CHECK-LABEL: test_compressstore_nxv4f32:
34+
; CHECK: // %bb.0:
35+
; CHECK-NEXT: ptrue p1.s
36+
; CHECK-NEXT: compact z0.s, p0, z0.s
37+
; CHECK-NEXT: cntp x8, p1, p0.s
38+
; CHECK-NEXT: whilelo p0.s, xzr, x8
39+
; CHECK-NEXT: st1w { z0.s }, p0, [x0]
40+
; CHECK-NEXT: ret
41+
tail call void @llvm.masked.compressstore.nxv4f32(<vscale x 4 x float> %vec, ptr align 4 %p, <vscale x 4 x i1> %mask)
42+
ret void
43+
}
44+
45+
; TODO: Legal and nonstreaming check
46+
define void @test_compressstore_nxv2f64(ptr %p, <vscale x 2 x double> %vec, <vscale x 2 x i1> %mask) {
47+
; CHECK-LABEL: test_compressstore_nxv2f64:
48+
; CHECK: // %bb.0:
49+
; CHECK-NEXT: ptrue p1.d
50+
; CHECK-NEXT: compact z0.d, p0, z0.d
51+
; CHECK-NEXT: cntp x8, p1, p0.d
52+
; CHECK-NEXT: whilelo p0.d, xzr, x8
53+
; CHECK-NEXT: st1d { z0.d }, p0, [x0]
54+
; CHECK-NEXT: ret
55+
tail call void @llvm.masked.compressstore.nxv2f64(<vscale x 2 x double> %vec, ptr align 8 %p, <vscale x 2 x i1> %mask)
56+
ret void
57+
}
58+
59+
;; Promoted SVE vector types promoted to 32/64-bit (non-exhaustive)
60+
61+
define void @test_compressstore_nxv2i8(ptr %p, <vscale x 2 x i8> %vec, <vscale x 2 x i1> %mask) {
62+
; CHECK-LABEL: test_compressstore_nxv2i8:
63+
; CHECK: // %bb.0:
64+
; CHECK-NEXT: ptrue p1.d
65+
; CHECK-NEXT: compact z0.d, p0, z0.d
66+
; CHECK-NEXT: cntp x8, p1, p0.d
67+
; CHECK-NEXT: whilelo p0.d, xzr, x8
68+
; CHECK-NEXT: st1b { z0.d }, p0, [x0]
69+
; CHECK-NEXT: ret
70+
tail call void @llvm.masked.compressstore.nxv2i8(<vscale x 2 x i8> %vec, ptr align 1 %p, <vscale x 2 x i1> %mask)
71+
ret void
72+
}
73+
74+
define void @test_compressstore_nxv4i16(ptr %p, <vscale x 4 x i16> %vec, <vscale x 4 x i1> %mask) {
75+
; CHECK-LABEL: test_compressstore_nxv4i16:
76+
; CHECK: // %bb.0:
77+
; CHECK-NEXT: ptrue p1.s
78+
; CHECK-NEXT: compact z0.s, p0, z0.s
79+
; CHECK-NEXT: cntp x8, p1, p0.s
80+
; CHECK-NEXT: whilelo p0.s, xzr, x8
81+
; CHECK-NEXT: st1h { z0.s }, p0, [x0]
82+
; CHECK-NEXT: ret
83+
tail call void @llvm.masked.compressstore.nxv4i16(<vscale x 4 x i16> %vec, ptr align 2 %p, <vscale x 4 x i1> %mask)
84+
ret void
85+
}
86+
87+
;; NEON vector types (promoted to SVE)
88+
89+
define void @test_compressstore_v2f32(ptr %p, <2 x double> %vec, <2 x i1> %mask) {
90+
; CHECK-LABEL: test_compressstore_v2f32:
91+
; CHECK: // %bb.0:
92+
; CHECK-NEXT: ushll v1.2d, v1.2s, #0
93+
; CHECK-NEXT: ptrue p0.d, vl2
94+
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
95+
; CHECK-NEXT: ptrue p1.d
96+
; CHECK-NEXT: shl v1.2d, v1.2d, #63
97+
; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0
98+
; CHECK-NEXT: cntp x8, p1, p0.d
99+
; CHECK-NEXT: compact z0.d, p0, z0.d
100+
; CHECK-NEXT: whilelo p0.d, xzr, x8
101+
; CHECK-NEXT: st1d { z0.d }, p0, [x0]
102+
; CHECK-NEXT: ret
103+
tail call void @llvm.masked.compressstore.v2f64(<2 x double> %vec, ptr align 8 %p, <2 x i1> %mask)
104+
ret void
105+
}
106+
107+
define void @test_compressstore_v4i32(ptr %p, <4 x i32> %vec, <4 x i1> %mask) {
108+
; CHECK-LABEL: test_compressstore_v4i32:
109+
; CHECK: // %bb.0:
110+
; CHECK-NEXT: ushll v1.4s, v1.4h, #0
111+
; CHECK-NEXT: ptrue p0.s, vl4
112+
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
113+
; CHECK-NEXT: ptrue p1.s
114+
; CHECK-NEXT: shl v1.4s, v1.4s, #31
115+
; CHECK-NEXT: cmpne p0.s, p0/z, z1.s, #0
116+
; CHECK-NEXT: cntp x8, p1, p0.s
117+
; CHECK-NEXT: compact z0.s, p0, z0.s
118+
; CHECK-NEXT: whilelo p0.s, xzr, x8
119+
; CHECK-NEXT: st1w { z0.s }, p0, [x0]
120+
; CHECK-NEXT: ret
121+
tail call void @llvm.masked.compressstore.v4i32(<4 x i32> %vec, ptr align 4 %p, <4 x i1> %mask)
122+
ret void
123+
}
124+
125+
define void @test_compressstore_v2i64(ptr %p, <2 x i64> %vec, <2 x i1> %mask) {
126+
; CHECK-LABEL: test_compressstore_v2i64:
127+
; CHECK: // %bb.0:
128+
; CHECK-NEXT: ushll v1.2d, v1.2s, #0
129+
; CHECK-NEXT: ptrue p0.d, vl2
130+
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
131+
; CHECK-NEXT: ptrue p1.d
132+
; CHECK-NEXT: shl v1.2d, v1.2d, #63
133+
; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0
134+
; CHECK-NEXT: cntp x8, p1, p0.d
135+
; CHECK-NEXT: compact z0.d, p0, z0.d
136+
; CHECK-NEXT: whilelo p0.d, xzr, x8
137+
; CHECK-NEXT: st1d { z0.d }, p0, [x0]
138+
; CHECK-NEXT: ret
139+
tail call void @llvm.masked.compressstore.v2i64(<2 x i64> %vec, ptr align 8 %p, <2 x i1> %mask)
140+
ret void
141+
}

0 commit comments

Comments
 (0)