Skip to content

Commit 135ddf1

Browse files
authored
[AArch64][SVE] Add basic support for @llvm.masked.compressstore (#168350)
This patch adds SVE support for the `masked.compressstore` intrinsic via the existing `VECTOR_COMPRESS` lowering and compressing the store mask via `VECREDUCE_ADD`. Currently, only `nxv4[i32|f32]` and `nxv2[i64|f64]` are directly supported, with other types promoted to these, where possible. This is done in preparation for LV support of this intrinsic, which is currently being worked on in #140723.
1 parent f54c6b4 commit 135ddf1

File tree

7 files changed

+390
-23
lines changed

7 files changed

+390
-23
lines changed

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10607,23 +10607,26 @@ TargetLowering::IncrementMemoryAddress(SDValue Addr, SDValue Mask,
1060710607
assert(DataVT.getVectorElementCount() == MaskVT.getVectorElementCount() &&
1060810608
"Incompatible types of Data and Mask");
1060910609
if (IsCompressedMemory) {
10610-
if (DataVT.isScalableVector())
10611-
report_fatal_error(
10612-
"Cannot currently handle compressed memory with scalable vectors");
1061310610
// Incrementing the pointer according to number of '1's in the mask.
10614-
EVT MaskIntVT = EVT::getIntegerVT(*DAG.getContext(), MaskVT.getSizeInBits());
10615-
SDValue MaskInIntReg = DAG.getBitcast(MaskIntVT, Mask);
10616-
if (MaskIntVT.getSizeInBits() < 32) {
10617-
MaskInIntReg = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, MaskInIntReg);
10618-
MaskIntVT = MVT::i32;
10611+
if (DataVT.isScalableVector()) {
10612+
EVT MaskExtVT = MaskVT.changeElementType(MVT::i32);
10613+
SDValue MaskExt = DAG.getNode(ISD::ZERO_EXTEND, DL, MaskExtVT, Mask);
10614+
Increment = DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, MaskExt);
10615+
} else {
10616+
EVT MaskIntVT =
10617+
EVT::getIntegerVT(*DAG.getContext(), MaskVT.getSizeInBits());
10618+
SDValue MaskInIntReg = DAG.getBitcast(MaskIntVT, Mask);
10619+
if (MaskIntVT.getSizeInBits() < 32) {
10620+
MaskInIntReg =
10621+
DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, MaskInIntReg);
10622+
MaskIntVT = MVT::i32;
10623+
}
10624+
Increment = DAG.getNode(ISD::CTPOP, DL, MaskIntVT, MaskInIntReg);
1061910625
}
10620-
10621-
// Count '1's with POPCNT.
10622-
Increment = DAG.getNode(ISD::CTPOP, DL, MaskIntVT, MaskInIntReg);
10623-
Increment = DAG.getZExtOrTrunc(Increment, DL, AddrVT);
1062410626
// Scale is an element size in bytes.
1062510627
SDValue Scale = DAG.getConstant(DataVT.getScalarSizeInBits() / 8, DL,
1062610628
AddrVT);
10629+
Increment = DAG.getZExtOrTrunc(Increment, DL, AddrVT);
1062710630
Increment = DAG.getNode(ISD::MUL, DL, AddrVT, Increment, Scale);
1062810631
} else if (DataVT.isScalableVector()) {
1062910632
Increment = DAG.getVScale(DL, AddrVT,

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1987,10 +1987,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
19871987
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
19881988

19891989
// We can lower types that have <vscale x {2|4}> elements to compact.
1990-
for (auto VT :
1991-
{MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
1992-
MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32})
1990+
for (auto VT : {MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64,
1991+
MVT::nxv2f32, MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16,
1992+
MVT::nxv4i32, MVT::nxv4f32}) {
19931993
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
1994+
// Use a custom lowering for masked stores that could be a supported
1995+
// compressing store. Note: These types still use the normal (Legal)
1996+
// lowering for non-compressing masked stores.
1997+
setOperationAction(ISD::MSTORE, VT, Custom);
1998+
}
19941999

19952000
// If we have SVE, we can use SVE logic for legal (or smaller than legal)
19962001
// NEON vectors in the lowest bits of the SVE register.
@@ -7936,7 +7941,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
79367941
case ISD::STORE:
79377942
return LowerSTORE(Op, DAG);
79387943
case ISD::MSTORE:
7939-
return LowerFixedLengthVectorMStoreToSVE(Op, DAG);
7944+
return LowerMSTORE(Op, DAG);
79407945
case ISD::MGATHER:
79417946
return LowerMGATHER(Op, DAG);
79427947
case ISD::MSCATTER:
@@ -30439,6 +30444,43 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorStoreToSVE(
3043930444
Store->isTruncatingStore());
3044030445
}
3044130446

30447+
SDValue AArch64TargetLowering::LowerMSTORE(SDValue Op,
30448+
SelectionDAG &DAG) const {
30449+
SDLoc DL(Op);
30450+
auto *Store = cast<MaskedStoreSDNode>(Op);
30451+
EVT VT = Store->getValue().getValueType();
30452+
if (VT.isFixedLengthVector())
30453+
return LowerFixedLengthVectorMStoreToSVE(Op, DAG);
30454+
30455+
if (!Store->isCompressingStore())
30456+
return SDValue();
30457+
30458+
EVT MaskVT = Store->getMask().getValueType();
30459+
EVT MaskExtVT = getPromotedVTForPredicate(MaskVT);
30460+
EVT MaskReduceVT = MaskExtVT.getScalarType();
30461+
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
30462+
30463+
SDValue MaskExt =
30464+
DAG.getNode(ISD::ZERO_EXTEND, DL, MaskExtVT, Store->getMask());
30465+
SDValue CntActive =
30466+
DAG.getNode(ISD::VECREDUCE_ADD, DL, MaskReduceVT, MaskExt);
30467+
if (MaskReduceVT != MVT::i64)
30468+
CntActive = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, CntActive);
30469+
30470+
SDValue CompressedValue =
30471+
DAG.getNode(ISD::VECTOR_COMPRESS, DL, VT, Store->getValue(),
30472+
Store->getMask(), DAG.getPOISON(VT));
30473+
SDValue CompressedMask =
30474+
DAG.getNode(ISD::GET_ACTIVE_LANE_MASK, DL, MaskVT, Zero, CntActive);
30475+
30476+
return DAG.getMaskedStore(Store->getChain(), DL, CompressedValue,
30477+
Store->getBasePtr(), Store->getOffset(),
30478+
CompressedMask, Store->getMemoryVT(),
30479+
Store->getMemOperand(), Store->getAddressingMode(),
30480+
Store->isTruncatingStore(),
30481+
/*isCompressing=*/false);
30482+
}
30483+
3044230484
SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE(
3044330485
SDValue Op, SelectionDAG &DAG) const {
3044430486
auto *Store = cast<MaskedStoreSDNode>(Op);
@@ -30453,7 +30495,8 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE(
3045330495
return DAG.getMaskedStore(
3045430496
Store->getChain(), DL, NewValue, Store->getBasePtr(), Store->getOffset(),
3045530497
Mask, Store->getMemoryVT(), Store->getMemOperand(),
30456-
Store->getAddressingMode(), Store->isTruncatingStore());
30498+
Store->getAddressingMode(), Store->isTruncatingStore(),
30499+
Store->isCompressingStore());
3045730500
}
3045830501

3045930502
SDValue AArch64TargetLowering::LowerFixedLengthVectorIntDivideToSVE(

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,7 @@ class AArch64TargetLowering : public TargetLowering {
761761
SDValue LowerWindowsDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
762762
SDValue LowerInlineDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
763763
SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
764+
SDValue LowerMSTORE(SDValue Op, SelectionDAG &DAG) const;
764765

765766
SDValue LowerAVG(SDValue Op, SelectionDAG &DAG, unsigned NewOp) const;
766767

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -645,37 +645,43 @@ def nontrunc_masked_store :
645645
(masked_st node:$val, node:$ptr, undef, node:$pred), [{
646646
return !cast<MaskedStoreSDNode>(N)->isTruncatingStore() &&
647647
cast<MaskedStoreSDNode>(N)->isUnindexed() &&
648-
!cast<MaskedStoreSDNode>(N)->isNonTemporal();
648+
!cast<MaskedStoreSDNode>(N)->isNonTemporal() &&
649+
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
649650
}]>;
650651
// truncating masked store fragments.
651652
def trunc_masked_store :
652653
PatFrag<(ops node:$val, node:$ptr, node:$pred),
653654
(masked_st node:$val, node:$ptr, undef, node:$pred), [{
654655
return cast<MaskedStoreSDNode>(N)->isTruncatingStore() &&
655-
cast<MaskedStoreSDNode>(N)->isUnindexed();
656+
cast<MaskedStoreSDNode>(N)->isUnindexed() &&
657+
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
656658
}]>;
657659
def trunc_masked_store_i8 :
658660
PatFrag<(ops node:$val, node:$ptr, node:$pred),
659661
(trunc_masked_store node:$val, node:$ptr, node:$pred), [{
660-
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8;
662+
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8 &&
663+
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
661664
}]>;
662665
def trunc_masked_store_i16 :
663666
PatFrag<(ops node:$val, node:$ptr, node:$pred),
664667
(trunc_masked_store node:$val, node:$ptr, node:$pred), [{
665-
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16;
668+
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16 &&
669+
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
666670
}]>;
667671
def trunc_masked_store_i32 :
668672
PatFrag<(ops node:$val, node:$ptr, node:$pred),
669673
(trunc_masked_store node:$val, node:$ptr, node:$pred), [{
670-
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32;
674+
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32 &&
675+
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
671676
}]>;
672677

673678
def non_temporal_store :
674679
PatFrag<(ops node:$val, node:$ptr, node:$pred),
675680
(masked_st node:$val, node:$ptr, undef, node:$pred), [{
676681
return !cast<MaskedStoreSDNode>(N)->isTruncatingStore() &&
677682
cast<MaskedStoreSDNode>(N)->isUnindexed() &&
678-
cast<MaskedStoreSDNode>(N)->isNonTemporal();
683+
cast<MaskedStoreSDNode>(N)->isNonTemporal() &&
684+
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
679685
}]>;
680686

681687
multiclass masked_gather_scatter<PatFrags GatherScatterOp> {

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,23 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
334334
return isLegalMaskedLoadStore(DataType, Alignment);
335335
}
336336

337+
bool isElementTypeLegalForCompressStore(Type *Ty) const {
338+
return Ty->isFloatTy() || Ty->isDoubleTy() || Ty->isIntegerTy(32) ||
339+
Ty->isIntegerTy(64);
340+
}
341+
342+
bool isLegalMaskedCompressStore(Type *DataType,
343+
Align Alignment) const override {
344+
if (!ST->isSVEAvailable())
345+
return false;
346+
347+
if (isa<FixedVectorType>(DataType) &&
348+
DataType->getPrimitiveSizeInBits() < 128)
349+
return false;
350+
351+
return isElementTypeLegalForCompressStore(DataType->getScalarType());
352+
}
353+
337354
bool isLegalMaskedGatherScatter(Type *DataType) const {
338355
if (!ST->isSVEAvailable())
339356
return false;
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
; RUN: llc -mtriple=aarch64 -mattr=+sve2p2 < %s
2+
3+
;; These masked.compressstore operations could be natively supported with +sve2p2
4+
;; (or by promoting to 32/64 bit elements + a truncstore), but currently are not
5+
;; supported.
6+
7+
; XFAIL: *
8+
9+
define void @test_compressstore_nxv8i16(ptr %p, <vscale x 8 x i16> %vec, <vscale x 8 x i1> %mask) {
10+
tail call void @llvm.masked.compressstore.nxv8i16(<vscale x 8 x i16> %vec, ptr align 2 %p, <vscale x 8 x i1> %mask)
11+
ret void
12+
}
13+
14+
define void @test_compressstore_nxv16i8(ptr %p, <vscale x 16 x i8> %vec, <vscale x 16 x i1> %mask) {
15+
tail call void @llvm.masked.compressstore.nxv16i8(<vscale x 16 x i8> %vec, ptr align 1 %p, <vscale x 16 x i1> %mask)
16+
ret void
17+
}

0 commit comments

Comments
 (0)