Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1983,10 +1983,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);

// We can lower types that have <vscale x {2|4}> elements to compact.
for (auto VT :
{MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32})
for (auto VT : {MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64,
MVT::nxv2f32, MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16,
MVT::nxv4i32, MVT::nxv4f32}) {
setOperationAction(ISD::MSTORE, VT, Custom);
// Use a custom lowering for masked stores that could be a supported
// compressing store. Note: These types still use the normal (Legal)
// lowering for non-compressing masked stores.
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
}

// If we have SVE, we can use SVE logic for legal (or smaller than legal)
// NEON vectors in the lowest bits of the SVE register.
Expand Down Expand Up @@ -7932,7 +7937,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::STORE:
return LowerSTORE(Op, DAG);
case ISD::MSTORE:
return LowerFixedLengthVectorMStoreToSVE(Op, DAG);
return LowerMSTORE(Op, DAG);
case ISD::MGATHER:
return LowerMGATHER(Op, DAG);
case ISD::MSCATTER:
Expand Down Expand Up @@ -30400,6 +30405,43 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorStoreToSVE(
Store->isTruncatingStore());
}

SDValue AArch64TargetLowering::LowerMSTORE(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
auto *Store = cast<MaskedStoreSDNode>(Op);
EVT VT = Store->getValue().getValueType();
if (VT.isFixedLengthVector())
return LowerFixedLengthVectorMStoreToSVE(Op, DAG);

if (!Store->isCompressingStore())
return SDValue();

EVT MaskVT = Store->getMask().getValueType();
EVT MaskExtVT = getPromotedVTForPredicate(MaskVT);
EVT MaskReduceVT = MaskExtVT.getScalarType();
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);

SDValue MaskExt =
DAG.getNode(ISD::ZERO_EXTEND, DL, MaskExtVT, Store->getMask());
SDValue CntActive =
DAG.getNode(ISD::VECREDUCE_ADD, DL, MaskReduceVT, MaskExt);
if (MaskReduceVT != MVT::i64)
CntActive = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, CntActive);

SDValue CompressedValue =
DAG.getNode(ISD::VECTOR_COMPRESS, DL, VT, Store->getValue(),
Store->getMask(), DAG.getPOISON(VT));
SDValue CompressedMask =
DAG.getNode(ISD::GET_ACTIVE_LANE_MASK, DL, MaskVT, Zero, CntActive);

return DAG.getMaskedStore(Store->getChain(), DL, CompressedValue,
Store->getBasePtr(), Store->getOffset(),
CompressedMask, Store->getMemoryVT(),
Store->getMemOperand(), Store->getAddressingMode(),
Store->isTruncatingStore(),
/*isCompressing=*/false);
}

SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE(
SDValue Op, SelectionDAG &DAG) const {
auto *Store = cast<MaskedStoreSDNode>(Op);
Expand All @@ -30414,7 +30456,8 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE(
return DAG.getMaskedStore(
Store->getChain(), DL, NewValue, Store->getBasePtr(), Store->getOffset(),
Mask, Store->getMemoryVT(), Store->getMemOperand(),
Store->getAddressingMode(), Store->isTruncatingStore());
Store->getAddressingMode(), Store->isTruncatingStore(),
Store->isCompressingStore());
}

SDValue AArch64TargetLowering::LowerFixedLengthVectorIntDivideToSVE(
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerWindowsDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerInlineDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMSTORE(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerAVG(SDValue Op, SelectionDAG &DAG, unsigned NewOp) const;

Expand Down
18 changes: 12 additions & 6 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -645,37 +645,43 @@ def nontrunc_masked_store :
(masked_st node:$val, node:$ptr, undef, node:$pred), [{
return !cast<MaskedStoreSDNode>(N)->isTruncatingStore() &&
cast<MaskedStoreSDNode>(N)->isUnindexed() &&
!cast<MaskedStoreSDNode>(N)->isNonTemporal();
!cast<MaskedStoreSDNode>(N)->isNonTemporal() &&
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
}]>;
// truncating masked store fragments.
def trunc_masked_store :
PatFrag<(ops node:$val, node:$ptr, node:$pred),
(masked_st node:$val, node:$ptr, undef, node:$pred), [{
return cast<MaskedStoreSDNode>(N)->isTruncatingStore() &&
cast<MaskedStoreSDNode>(N)->isUnindexed();
cast<MaskedStoreSDNode>(N)->isUnindexed() &&
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
}]>;
def trunc_masked_store_i8 :
PatFrag<(ops node:$val, node:$ptr, node:$pred),
(trunc_masked_store node:$val, node:$ptr, node:$pred), [{
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8;
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8 &&
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
}]>;
def trunc_masked_store_i16 :
PatFrag<(ops node:$val, node:$ptr, node:$pred),
(trunc_masked_store node:$val, node:$ptr, node:$pred), [{
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16;
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16 &&
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
}]>;
def trunc_masked_store_i32 :
PatFrag<(ops node:$val, node:$ptr, node:$pred),
(trunc_masked_store node:$val, node:$ptr, node:$pred), [{
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32;
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32 &&
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
}]>;

def non_temporal_store :
PatFrag<(ops node:$val, node:$ptr, node:$pred),
(masked_st node:$val, node:$ptr, undef, node:$pred), [{
return !cast<MaskedStoreSDNode>(N)->isTruncatingStore() &&
cast<MaskedStoreSDNode>(N)->isUnindexed() &&
cast<MaskedStoreSDNode>(N)->isNonTemporal();
cast<MaskedStoreSDNode>(N)->isNonTemporal() &&
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
}]>;

multiclass masked_gather_scatter<PatFrags GatherScatterOp> {
Expand Down
42 changes: 42 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,48 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
return isLegalMaskedLoadStore(DataType, Alignment);
}

bool isElementTypeLegalForCompressStore(Type *Ty) const {
if (Ty->isFloatTy() || Ty->isDoubleTy())
return true;

if (Ty->isIntegerTy(8) || Ty->isIntegerTy(16) || Ty->isIntegerTy(32) ||
Ty->isIntegerTy(64))
return true;

return false;
}

bool isLegalMaskedCompressStore(Type *DataType,
Align Alignment) const override {
auto VecTy = cast<VectorType>(DataType);
Type *ElTy = VecTy->getScalarType();
unsigned ElSizeInBits = ElTy->getScalarSizeInBits();
TypeSize VecSizeInBits = VecTy->getPrimitiveSizeInBits();

if (isa<FixedVectorType>(VecTy)) {
// Each 128-bit segment must contain 2 or 4 elements (packed).
if (ElSizeInBits != 32 && ElSizeInBits != 64)
return false;
if (VecSizeInBits % 128 != 0 ||
VecSizeInBits > std::max(128U, ST->getMinSVEVectorSizeInBits()))
return false;
} else {
// Each segment must contain 2 or 4 elements, but the segments can be
// < 128-bits for unpacked vector types.
if (VecSizeInBits.getKnownMinValue() > 128)
return false;
unsigned ElementsPerSegment =
VecSizeInBits.getKnownMinValue() / ElSizeInBits;
if (ElementsPerSegment != 2 && ElementsPerSegment != 4)
return false;
}

if (!isElementTypeLegalForCompressStore(DataType->getScalarType()))
return false;

return isLegalMaskedLoadStore(DataType, Alignment);
}

bool isLegalMaskedGatherScatter(Type *DataType) const {
if (!ST->isSVEAvailable())
return false;
Expand Down
Loading