@@ -144,6 +144,25 @@ static inline EVT getPackedSVEVectorVT(EVT VT) {
144
144
return MVT::nxv4f32;
145
145
case MVT::f64:
146
146
return MVT::nxv2f64;
147
+ case MVT::bf16:
148
+ return MVT::nxv8bf16;
149
+ }
150
+ }
151
+
152
+ // NOTE: Currently there's only a need to return integer vector types. If this
153
+ // changes then just add an extra "type" parameter.
154
+ static inline EVT getPackedSVEVectorVT(ElementCount EC) {
155
+ switch (EC.getKnownMinValue()) {
156
+ default:
157
+ llvm_unreachable("unexpected element count for vector");
158
+ case 16:
159
+ return MVT::nxv16i8;
160
+ case 8:
161
+ return MVT::nxv8i16;
162
+ case 4:
163
+ return MVT::nxv4i32;
164
+ case 2:
165
+ return MVT::nxv2i64;
147
166
}
148
167
}
149
168
@@ -3988,14 +4007,10 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
3988
4007
!static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
3989
4008
return SDValue();
3990
4009
3991
- // Handle FP data
4010
+ // Handle FP data by using an integer gather and casting the result.
3992
4011
if (VT.isFloatingPoint()) {
3993
- ElementCount EC = VT.getVectorElementCount();
3994
- auto ScalarIntVT =
3995
- MVT::getIntegerVT(AArch64::SVEBitsPerBlock / EC.getKnownMinValue());
3996
- PassThru = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL,
3997
- MVT::getVectorVT(ScalarIntVT, EC), PassThru);
3998
-
4012
+ EVT PassThruVT = getPackedSVEVectorVT(VT.getVectorElementCount());
4013
+ PassThru = getSVESafeBitCast(PassThruVT, PassThru, DAG);
3999
4014
InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
4000
4015
}
4001
4016
@@ -4015,7 +4030,7 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
4015
4030
SDValue Gather = DAG.getNode(Opcode, DL, VTs, Ops);
4016
4031
4017
4032
if (VT.isFloatingPoint()) {
4018
- SDValue Cast = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Gather);
4033
+ SDValue Cast = getSVESafeBitCast( VT, Gather, DAG );
4019
4034
return DAG.getMergeValues({Cast, Gather}, DL);
4020
4035
}
4021
4036
@@ -4052,15 +4067,10 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
4052
4067
!static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
4053
4068
return SDValue();
4054
4069
4055
- // Handle FP data
4070
+ // Handle FP data by casting the data so an integer scatter can be used.
4056
4071
if (VT.isFloatingPoint()) {
4057
- VT = VT.changeVectorElementTypeToInteger();
4058
- ElementCount EC = VT.getVectorElementCount();
4059
- auto ScalarIntVT =
4060
- MVT::getIntegerVT(AArch64::SVEBitsPerBlock / EC.getKnownMinValue());
4061
- StoreVal = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL,
4062
- MVT::getVectorVT(ScalarIntVT, EC), StoreVal);
4063
-
4072
+ EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount());
4073
+ StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG);
4064
4074
InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
4065
4075
}
4066
4076
@@ -17157,3 +17167,40 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorSetccToSVE(
17157
17167
auto Promote = DAG.getBoolExtOrTrunc(Cmp, DL, PromoteVT, InVT);
17158
17168
return convertFromScalableVector(DAG, Op.getValueType(), Promote);
17159
17169
}
17170
+
17171
+ SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
17172
+ SelectionDAG &DAG) const {
17173
+ SDLoc DL(Op);
17174
+ EVT InVT = Op.getValueType();
17175
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
17176
+
17177
+ assert(VT.isScalableVector() && TLI.isTypeLegal(VT) &&
17178
+ InVT.isScalableVector() && TLI.isTypeLegal(InVT) &&
17179
+ "Only expect to cast between legal scalable vector types!");
17180
+ assert((VT.getVectorElementType() == MVT::i1) ==
17181
+ (InVT.getVectorElementType() == MVT::i1) &&
17182
+ "Cannot cast between data and predicate scalable vector types!");
17183
+
17184
+ if (InVT == VT)
17185
+ return Op;
17186
+
17187
+ if (VT.getVectorElementType() == MVT::i1)
17188
+ return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
17189
+
17190
+ EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType());
17191
+ EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType());
17192
+ assert((VT == PackedVT || InVT == PackedInVT) &&
17193
+ "Cannot cast between unpacked scalable vector types!");
17194
+
17195
+ // Pack input if required.
17196
+ if (InVT != PackedInVT)
17197
+ Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, PackedInVT, Op);
17198
+
17199
+ Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op);
17200
+
17201
+ // Unpack result if required.
17202
+ if (VT != PackedVT)
17203
+ Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
17204
+
17205
+ return Op;
17206
+ }
0 commit comments