Skip to content

Commit a7dd1c1

Browse files
[LLVM][SVE] Implement isel for bfloat fptoi & itofp operations.
NOTE: This PR only considers scalable vectors because SVE VLS does not support bfloat (see useSVEForFixedLengthVectorVT()).
1 parent 607485f commit a7dd1c1

File tree

3 files changed

+869
-34
lines changed

3 files changed

+869
-34
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4582,6 +4582,10 @@ SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op,
45824582
bool Trunc = Op.getConstantOperandVal(IsStrict ? 2 : 1) == 1;
45834583

45844584
if (VT.isScalableVector()) {
4585+
// Let common code split the operation.
4586+
if (SrcVT == MVT::nxv8f32)
4587+
return Op;
4588+
45854589
if (VT.getScalarType() != MVT::bf16)
45864590
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU);
45874591

@@ -4724,6 +4728,22 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
47244728
assert(!(IsStrict && VT.isScalableVector()) &&
47254729
"Unimplemented SVE support for STRICT_FP_to_INT!");
47264730

4731+
// f16 conversions are promoted to f32 when full fp16 is not supported.
4732+
if ((InVT.getVectorElementType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
4733+
InVT.getVectorElementType() == MVT::bf16) {
4734+
EVT NewVT = VT.changeElementType(MVT::f32);
4735+
SDLoc dl(Op);
4736+
if (IsStrict) {
4737+
SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NewVT, MVT::Other},
4738+
{Op.getOperand(0), Op.getOperand(1)});
4739+
return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
4740+
{Ext.getValue(1), Ext.getValue(0)});
4741+
}
4742+
return DAG.getNode(
4743+
Op.getOpcode(), dl, Op.getValueType(),
4744+
DAG.getNode(ISD::FP_EXTEND, dl, NewVT, Op.getOperand(0)));
4745+
}
4746+
47274747
if (VT.isScalableVector()) {
47284748
if (VT.getVectorElementType() == MVT::i1) {
47294749
SDLoc DL(Op);
@@ -4733,6 +4753,10 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
47334753
return DAG.getSetCC(DL, VT, Cvt, Zero, ISD::SETNE);
47344754
}
47354755

4756+
// Let common code split the operation.
4757+
if (InVT == MVT::nxv8f32)
4758+
return Op;
4759+
47364760
unsigned Opcode = Op.getOpcode() == ISD::FP_TO_UINT
47374761
? AArch64ISD::FCVTZU_MERGE_PASSTHRU
47384762
: AArch64ISD::FCVTZS_MERGE_PASSTHRU;
@@ -4743,24 +4767,6 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
47434767
useSVEForFixedLengthVectorVT(InVT, !Subtarget->isNeonAvailable()))
47444768
return LowerFixedLengthFPToIntToSVE(Op, DAG);
47454769

4746-
unsigned NumElts = InVT.getVectorNumElements();
4747-
4748-
// f16 conversions are promoted to f32 when full fp16 is not supported.
4749-
if ((InVT.getVectorElementType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
4750-
InVT.getVectorElementType() == MVT::bf16) {
4751-
MVT NewVT = MVT::getVectorVT(MVT::f32, NumElts);
4752-
SDLoc dl(Op);
4753-
if (IsStrict) {
4754-
SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NewVT, MVT::Other},
4755-
{Op.getOperand(0), Op.getOperand(1)});
4756-
return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
4757-
{Ext.getValue(1), Ext.getValue(0)});
4758-
}
4759-
return DAG.getNode(
4760-
Op.getOpcode(), dl, Op.getValueType(),
4761-
DAG.getNode(ISD::FP_EXTEND, dl, NewVT, Op.getOperand(0)));
4762-
}
4763-
47644770
uint64_t VTSize = VT.getFixedSizeInBits();
47654771
uint64_t InVTSize = InVT.getFixedSizeInBits();
47664772
if (VTSize < InVTSize) {
@@ -4795,7 +4801,7 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
47954801

47964802
// Use a scalar operation for conversions between single-element vectors of
47974803
// the same size.
4798-
if (NumElts == 1) {
4804+
if (InVT.getVectorNumElements() == 1) {
47994805
SDLoc dl(Op);
48004806
SDValue Extract = DAG.getNode(
48014807
ISD::EXTRACT_VECTOR_ELT, dl, InVT.getScalarType(),
@@ -5041,23 +5047,14 @@ SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
50415047
assert(!(IsStrict && VT.isScalableVector()) &&
50425048
"Unimplemented SVE support for ISD:::STRICT_INT_TO_FP!");
50435049

5044-
if (VT.isScalableVector()) {
5045-
if (InVT.getVectorElementType() == MVT::i1) {
5046-
SDValue FalseVal = DAG.getConstantFP(0.0, dl, VT);
5047-
SDValue TrueVal = IsSigned ? DAG.getConstantFP(-1.0, dl, VT)
5048-
: DAG.getConstantFP(1.0, dl, VT);
5049-
return DAG.getNode(ISD::VSELECT, dl, VT, In, TrueVal, FalseVal);
5050-
}
5051-
5052-
unsigned Opcode = IsSigned ? AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU
5053-
: AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU;
5054-
return LowerToPredicatedOp(Op, DAG, Opcode);
5050+
// NOTE: i1->bf16 does not require promotion to f32.
5051+
if (VT.isScalableVector() && InVT.getVectorElementType() == MVT::i1) {
5052+
SDValue FalseVal = DAG.getConstantFP(0.0, dl, VT);
5053+
SDValue TrueVal = IsSigned ? DAG.getConstantFP(-1.0, dl, VT)
5054+
: DAG.getConstantFP(1.0, dl, VT);
5055+
return DAG.getNode(ISD::VSELECT, dl, VT, In, TrueVal, FalseVal);
50555056
}
50565057

5057-
if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()) ||
5058-
useSVEForFixedLengthVectorVT(InVT, !Subtarget->isNeonAvailable()))
5059-
return LowerFixedLengthIntToFPToSVE(Op, DAG);
5060-
50615058
// Promote bf16 conversions to f32.
50625059
if (VT.getVectorElementType() == MVT::bf16) {
50635060
EVT F32 = VT.changeElementType(MVT::f32);
@@ -5074,6 +5071,20 @@ SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
50745071
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
50755072
}
50765073

5074+
if (VT.isScalableVector()) {
5075+
// Let common code split the operation.
5076+
if (VT == MVT::nxv8f32)
5077+
return Op;
5078+
5079+
unsigned Opcode = IsSigned ? AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU
5080+
: AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU;
5081+
return LowerToPredicatedOp(Op, DAG, Opcode);
5082+
}
5083+
5084+
if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()) ||
5085+
useSVEForFixedLengthVectorVT(InVT, !Subtarget->isNeonAvailable()))
5086+
return LowerFixedLengthIntToFPToSVE(Op, DAG);
5087+
50775088
uint64_t VTSize = VT.getFixedSizeInBits();
50785089
uint64_t InVTSize = InVT.getFixedSizeInBits();
50795090
if (VTSize < InVTSize) {

llvm/lib/Target/AArch64/SVEInstrFormats.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5465,6 +5465,14 @@ multiclass sve_int_dup_fpimm_pred<string asm> {
54655465
(!cast<Instruction>(NAME # _S) $zd, $pg, fpimm32:$imm8)>;
54665466
def : Pat<(nxv2f64 (vselect nxv2i1:$pg, (splat_vector fpimm64:$imm8), nxv2f64:$zd)),
54675467
(!cast<Instruction>(NAME # _D) $zd, $pg, fpimm64:$imm8)>;
5468+
5469+
// Some half precision immediates alias with bfloat (e.g. f16(1.875) == bf16(1.0)).
5470+
def : Pat<(nxv8bf16 (vselect nxv8i1:$pg, (splat_vector fpimmbf16:$imm8), nxv8bf16:$zd)),
5471+
(!cast<Instruction>(NAME # _H) $zd, $pg, (fpimm16XForm bf16:$imm8))>;
5472+
def : Pat<(nxv4bf16 (vselect nxv4i1:$pg, (splat_vector fpimmbf16:$imm8), nxv4bf16:$zd)),
5473+
(!cast<Instruction>(NAME # _H) $zd, $pg, (fpimm16XForm bf16:$imm8))>;
5474+
def : Pat<(nxv2bf16 (vselect nxv2i1:$pg, (splat_vector fpimmbf16:$imm8), nxv2bf16:$zd)),
5475+
(!cast<Instruction>(NAME # _H) $zd, $pg, (fpimm16XForm bf16:$imm8))>;
54685476
}
54695477

54705478
class sve_int_dup_imm_pred<bits<2> sz8_64, bit m, string asm,

0 commit comments

Comments
 (0)