Skip to content

Commit b2e8b8f

Browse files
authored
[RISCV] Lower f16/bf16 splat_vector by bitcasting to i16 instead of promoting to f32. (#108298)
If f16/bf16 scalar types are not legal we also need to custom legalize to prevent a crash. We do similar lowering for build_vector.
1 parent 35a0fd5 commit b2e8b8f

34 files changed

+2099
-2454
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
10801080
VT, Custom);
10811081
if (Subtarget.hasStdExtZfhmin())
10821082
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
1083+
else
1084+
setOperationAction(ISD::SPLAT_VECTOR, MVT::f16, Custom);
10831085
// load/store
10841086
setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
10851087

@@ -1117,6 +1119,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
11171119
VT, Custom);
11181120
if (Subtarget.hasStdExtZfbfmin())
11191121
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
1122+
else
1123+
setOperationAction(ISD::SPLAT_VECTOR, MVT::bf16, Custom);
11201124
setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);
11211125

11221126
setOperationAction(ISD::FNEG, VT, Expand);
@@ -6988,30 +6992,28 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
69886992
return lowerVECTOR_SPLICE(Op, DAG);
69896993
case ISD::BUILD_VECTOR:
69906994
return lowerBUILD_VECTOR(Op, DAG, Subtarget);
6991-
case ISD::SPLAT_VECTOR:
6992-
if ((Op.getValueType().getScalarType() == MVT::f16 &&
6993-
(Subtarget.hasVInstructionsF16Minimal() &&
6994-
Subtarget.hasStdExtZfhminOrZhinxmin() &&
6995-
!Subtarget.hasVInstructionsF16())) ||
6996-
(Op.getValueType().getScalarType() == MVT::bf16 &&
6997-
(Subtarget.hasVInstructionsBF16Minimal() &&
6998-
Subtarget.hasStdExtZfbfmin()))) {
6999-
if (Op.getValueType() == MVT::nxv32f16 ||
7000-
Op.getValueType() == MVT::nxv32bf16)
7001-
return SplitVectorOp(Op, DAG);
6995+
case ISD::SPLAT_VECTOR: {
6996+
MVT VT = Op.getSimpleValueType();
6997+
MVT EltVT = VT.getVectorElementType();
6998+
if ((EltVT == MVT::f16 && !Subtarget.hasStdExtZvfh()) ||
6999+
EltVT == MVT::bf16) {
70027000
SDLoc DL(Op);
7003-
SDValue NewScalar =
7004-
DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op.getOperand(0));
7005-
SDValue NewSplat = DAG.getNode(
7006-
ISD::SPLAT_VECTOR, DL,
7007-
MVT::getVectorVT(MVT::f32, Op.getValueType().getVectorElementCount()),
7008-
NewScalar);
7009-
return DAG.getNode(ISD::FP_ROUND, DL, Op.getValueType(), NewSplat,
7010-
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
7001+
SDValue Elt;
7002+
if ((EltVT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) ||
7003+
(EltVT == MVT::f16 && Subtarget.hasStdExtZfhmin()))
7004+
Elt = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, Subtarget.getXLenVT(),
7005+
Op.getOperand(0));
7006+
else
7007+
Elt = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Op.getOperand(0));
7008+
MVT IVT = VT.changeVectorElementType(MVT::i16);
7009+
return DAG.getNode(ISD::BITCAST, DL, VT,
7010+
DAG.getNode(ISD::SPLAT_VECTOR, DL, IVT, Elt));
70117011
}
7012-
if (Op.getValueType().getVectorElementType() == MVT::i1)
7012+
7013+
if (EltVT == MVT::i1)
70137014
return lowerVectorMaskSplat(Op, DAG);
70147015
return SDValue();
7016+
}
70157017
case ISD::VECTOR_SHUFFLE:
70167018
return lowerVECTOR_SHUFFLE(Op, DAG, Subtarget);
70177019
case ISD::CONCAT_VECTORS: {

0 commit comments

Comments
 (0)