Skip to content

Commit 2b8db40

Browse files
[SVE] Restrict the usage of REINTERPRET_CAST.
In order to limit the number of combinations of REINTERPRET_CAST, whilst at the same time prevent overlap with BITCAST, this patch establishes the following rules: 1. The operand and result element types must be the same. 2. The operand and/or result type must be an unpacked type. Differential Revision: https://reviews.llvm.org/D94593
1 parent 141e45b commit 2b8db40

File tree

3 files changed

+87
-33
lines changed

3 files changed

+87
-33
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,25 @@ static inline EVT getPackedSVEVectorVT(EVT VT) {
144144
return MVT::nxv4f32;
145145
case MVT::f64:
146146
return MVT::nxv2f64;
147+
case MVT::bf16:
148+
return MVT::nxv8bf16;
149+
}
150+
}
151+
152+
// NOTE: Currently there's only a need to return integer vector types. If this
153+
// changes then just add an extra "type" parameter.
154+
static inline EVT getPackedSVEVectorVT(ElementCount EC) {
155+
switch (EC.getKnownMinValue()) {
156+
default:
157+
llvm_unreachable("unexpected element count for vector");
158+
case 16:
159+
return MVT::nxv16i8;
160+
case 8:
161+
return MVT::nxv8i16;
162+
case 4:
163+
return MVT::nxv4i32;
164+
case 2:
165+
return MVT::nxv2i64;
147166
}
148167
}
149168

@@ -3988,14 +4007,10 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
39884007
!static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
39894008
return SDValue();
39904009

3991-
// Handle FP data
4010+
// Handle FP data by using an integer gather and casting the result.
39924011
if (VT.isFloatingPoint()) {
3993-
ElementCount EC = VT.getVectorElementCount();
3994-
auto ScalarIntVT =
3995-
MVT::getIntegerVT(AArch64::SVEBitsPerBlock / EC.getKnownMinValue());
3996-
PassThru = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL,
3997-
MVT::getVectorVT(ScalarIntVT, EC), PassThru);
3998-
4012+
EVT PassThruVT = getPackedSVEVectorVT(VT.getVectorElementCount());
4013+
PassThru = getSVESafeBitCast(PassThruVT, PassThru, DAG);
39994014
InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
40004015
}
40014016

@@ -4015,7 +4030,7 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
40154030
SDValue Gather = DAG.getNode(Opcode, DL, VTs, Ops);
40164031

40174032
if (VT.isFloatingPoint()) {
4018-
SDValue Cast = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Gather);
4033+
SDValue Cast = getSVESafeBitCast(VT, Gather, DAG);
40194034
return DAG.getMergeValues({Cast, Gather}, DL);
40204035
}
40214036

@@ -4052,15 +4067,10 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
40524067
!static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
40534068
return SDValue();
40544069

4055-
// Handle FP data
4070+
// Handle FP data by casting the data so an integer scatter can be used.
40564071
if (VT.isFloatingPoint()) {
4057-
VT = VT.changeVectorElementTypeToInteger();
4058-
ElementCount EC = VT.getVectorElementCount();
4059-
auto ScalarIntVT =
4060-
MVT::getIntegerVT(AArch64::SVEBitsPerBlock / EC.getKnownMinValue());
4061-
StoreVal = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL,
4062-
MVT::getVectorVT(ScalarIntVT, EC), StoreVal);
4063-
4072+
EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount());
4073+
StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG);
40644074
InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
40654075
}
40664076

@@ -17157,3 +17167,40 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorSetccToSVE(
1715717167
auto Promote = DAG.getBoolExtOrTrunc(Cmp, DL, PromoteVT, InVT);
1715817168
return convertFromScalableVector(DAG, Op.getValueType(), Promote);
1715917169
}
17170+
17171+
SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
17172+
SelectionDAG &DAG) const {
17173+
SDLoc DL(Op);
17174+
EVT InVT = Op.getValueType();
17175+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
17176+
17177+
assert(VT.isScalableVector() && TLI.isTypeLegal(VT) &&
17178+
InVT.isScalableVector() && TLI.isTypeLegal(InVT) &&
17179+
"Only expect to cast between legal scalable vector types!");
17180+
assert((VT.getVectorElementType() == MVT::i1) ==
17181+
(InVT.getVectorElementType() == MVT::i1) &&
17182+
"Cannot cast between data and predicate scalable vector types!");
17183+
17184+
if (InVT == VT)
17185+
return Op;
17186+
17187+
if (VT.getVectorElementType() == MVT::i1)
17188+
return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
17189+
17190+
EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType());
17191+
EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType());
17192+
assert((VT == PackedVT || InVT == PackedInVT) &&
17193+
"Cannot cast between unpacked scalable vector types!");
17194+
17195+
// Pack input if required.
17196+
if (InVT != PackedInVT)
17197+
Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, PackedInVT, Op);
17198+
17199+
Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op);
17200+
17201+
// Unpack result if required.
17202+
if (VT != PackedVT)
17203+
Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
17204+
17205+
return Op;
17206+
}

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ enum NodeType : unsigned {
314314
DUP_MERGE_PASSTHRU,
315315
INDEX_VECTOR,
316316

317+
// Cast between vectors of the same element type but differ in length.
317318
REINTERPRET_CAST,
318319

319320
LD1_MERGE_ZERO,
@@ -1022,6 +1023,17 @@ class AArch64TargetLowering : public TargetLowering {
10221023
// NEON vector. This changes when OverrideNEON is true, allowing SVE to be
10231024
// used for 64bit and 128bit vectors as well.
10241025
bool useSVEForFixedLengthVectorVT(EVT VT, bool OverrideNEON = false) const;
1026+
1027+
// With the exception of data-predicate transitions, no instructions are
1028+
// required to cast between legal scalable vector types. However:
1029+
// 1. Packed and unpacked types have different bit lengths, meaning BITCAST
1030+
// is not universally useable.
1031+
// 2. Most unpacked integer types are not legal and thus integer extends
1032+
// cannot be used to convert between unpacked and packed types.
1033+
// These can make "bitcasting" a multiphase process. REINTERPRET_CAST is used
1034+
// to transition between unpacked and packed types of the same element type,
1035+
// with BITCAST used otherwise.
1036+
SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;
10251037
};
10261038

10271039
namespace AArch64 {

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,6 +1721,7 @@ let Predicates = [HasSVE] in {
17211721
def : Pat<(nxv2f64 (bitconvert (nxv8bf16 ZPR:$src))), (nxv2f64 ZPR:$src)>;
17221722
}
17231723

1724+
// These allow casting from/to unpacked predicate types.
17241725
def : Pat<(nxv16i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
17251726
def : Pat<(nxv16i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
17261727
def : Pat<(nxv16i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
@@ -1735,23 +1736,17 @@ let Predicates = [HasSVE] in {
17351736
def : Pat<(nxv2i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
17361737
def : Pat<(nxv2i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
17371738

1738-
def : Pat<(nxv2i64 (reinterpret_cast (nxv2f64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1739-
def : Pat<(nxv2i64 (reinterpret_cast (nxv2f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1740-
def : Pat<(nxv2i64 (reinterpret_cast (nxv2f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1741-
def : Pat<(nxv4i32 (reinterpret_cast (nxv4f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1742-
def : Pat<(nxv4i32 (reinterpret_cast (nxv4f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1743-
def : Pat<(nxv2i64 (reinterpret_cast (nxv2bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1744-
def : Pat<(nxv4i32 (reinterpret_cast (nxv4bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1745-
1746-
def : Pat<(nxv2f16 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1747-
def : Pat<(nxv2f32 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1748-
def : Pat<(nxv2f64 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1749-
def : Pat<(nxv4f16 (reinterpret_cast (nxv4i32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1750-
def : Pat<(nxv4f32 (reinterpret_cast (nxv4i32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1751-
def : Pat<(nxv8f16 (reinterpret_cast (nxv8i16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1752-
def : Pat<(nxv2bf16 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1753-
def : Pat<(nxv4bf16 (reinterpret_cast (nxv4i32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1754-
def : Pat<(nxv8bf16 (reinterpret_cast (nxv8i16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1739+
// These allow casting from/to unpacked floating-point types.
1740+
def : Pat<(nxv2f16 (reinterpret_cast (nxv8f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1741+
def : Pat<(nxv8f16 (reinterpret_cast (nxv2f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1742+
def : Pat<(nxv4f16 (reinterpret_cast (nxv8f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1743+
def : Pat<(nxv8f16 (reinterpret_cast (nxv4f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1744+
def : Pat<(nxv2f32 (reinterpret_cast (nxv4f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1745+
def : Pat<(nxv4f32 (reinterpret_cast (nxv2f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1746+
def : Pat<(nxv2bf16 (reinterpret_cast (nxv8bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1747+
def : Pat<(nxv8bf16 (reinterpret_cast (nxv2bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1748+
def : Pat<(nxv4bf16 (reinterpret_cast (nxv8bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
1749+
def : Pat<(nxv8bf16 (reinterpret_cast (nxv4bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
17551750

17561751
def : Pat<(nxv16i1 (and PPR:$Ps1, PPR:$Ps2)),
17571752
(AND_PPzPP (PTRUE_B 31), PPR:$Ps1, PPR:$Ps2)>;

0 commit comments

Comments
 (0)