Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,12 +511,36 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::BR_CC, MVT::f64, Custom);
setOperationAction(ISD::SELECT, MVT::i32, Custom);
setOperationAction(ISD::SELECT, MVT::i64, Custom);
setOperationAction(ISD::CTSELECT, MVT::i8, Promote);
setOperationAction(ISD::CTSELECT, MVT::i16, Promote);
setOperationAction(ISD::CTSELECT, MVT::i32, Custom);
setOperationAction(ISD::CTSELECT, MVT::i64, Custom);
if (Subtarget->hasFPARMv8()) {
setOperationAction(ISD::SELECT, MVT::f16, Custom);
setOperationAction(ISD::SELECT, MVT::bf16, Custom);
}
if (Subtarget->hasFullFP16()) {
setOperationAction(ISD::CTSELECT, MVT::f16, Custom);
setOperationAction(ISD::CTSELECT, MVT::bf16, Custom);
} else {
setOperationAction(ISD::CTSELECT, MVT::f16, Promote);
setOperationAction(ISD::CTSELECT, MVT::bf16, Promote);
}
setOperationAction(ISD::SELECT, MVT::f32, Custom);
setOperationAction(ISD::SELECT, MVT::f64, Custom);
setOperationAction(ISD::CTSELECT, MVT::f32, Custom);
setOperationAction(ISD::CTSELECT, MVT::f64, Custom);
for (MVT VT : MVT::vector_valuetypes()) {
MVT elemType = VT.getVectorElementType();
if (elemType == MVT::i8 || elemType == MVT::i16) {
setOperationAction(ISD::CTSELECT, VT, Promote);
} else if ((elemType == MVT::f16 || elemType == MVT::bf16) &&
!Subtarget->hasFullFP16()) {
setOperationAction(ISD::CTSELECT, VT, Promote);
} else {
setOperationAction(ISD::CTSELECT, VT, Expand);
}
}
setOperationAction(ISD::SELECT_CC, MVT::i32, Custom);
setOperationAction(ISD::SELECT_CC, MVT::i64, Custom);
setOperationAction(ISD::SELECT_CC, MVT::f16, Custom);
Expand Down Expand Up @@ -3328,6 +3352,20 @@ void AArch64TargetLowering::fixupPtrauthDiscriminator(
IntDiscOp.setImm(IntDisc);
}

MachineBasicBlock *AArch64TargetLowering::EmitCTSELECT(MachineInstr &MI,
MachineBasicBlock *MBB,
unsigned Opcode) const {
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
DebugLoc DL = MI.getDebugLoc();
MachineInstrBuilder Builder = BuildMI(*MBB, MI, DL, TII->get(Opcode));
for (unsigned Idx = 0; Idx < MI.getNumOperands(); ++Idx) {
Builder.add(MI.getOperand(Idx));
}
Builder->setFlag(MachineInstr::NoMerge);
MBB->remove_instr(&MI);
return MBB;
}

MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
MachineInstr &MI, MachineBasicBlock *BB) const {

Expand Down Expand Up @@ -7590,6 +7628,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerSELECT(Op, DAG);
case ISD::SELECT_CC:
return LowerSELECT_CC(Op, DAG);
case ISD::CTSELECT:
return LowerCTSELECT(Op, DAG);
case ISD::JumpTable:
return LowerJumpTable(Op, DAG);
case ISD::BR_JT:
Expand Down Expand Up @@ -12149,6 +12189,22 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
return Res;
}

SDValue AArch64TargetLowering::LowerCTSELECT(SDValue Op,
SelectionDAG &DAG) const {
SDValue CCVal = Op->getOperand(0);
SDValue TVal = Op->getOperand(1);
SDValue FVal = Op->getOperand(2);
SDLoc DL(Op);

EVT VT = Op.getValueType();

SDValue Zero = DAG.getConstant(0, DL, CCVal.getValueType());
SDValue CC;
SDValue Cmp = getAArch64Cmp(CCVal, Zero, ISD::SETNE, CC, DAG, DL);

return DAG.getNode(AArch64ISD::CTSELECT, DL, VT, TVal, FVal, CC, Cmp);
}

SDValue AArch64TargetLowering::LowerJumpTable(SDValue Op,
SelectionDAG &DAG) const {
// Jump table entries as PC relative offsets. No additional tweaking
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@

namespace llvm {

namespace AArch64ISD {
// Forward declare the enum from the generated file
enum GenNodeType : unsigned;
} // namespace AArch64ISD

class AArch64TargetMachine;

namespace AArch64 {
Expand Down Expand Up @@ -202,6 +207,9 @@ class AArch64TargetLowering : public TargetLowering {
MachineOperand &AddrDiscOp,
const TargetRegisterClass *AddrDiscRC) const;

MachineBasicBlock *EmitCTSELECT(MachineInstr &MI, MachineBasicBlock *BB,
unsigned Opcode) const;

MachineBasicBlock *
EmitInstrWithCustomInserter(MachineInstr &MI,
MachineBasicBlock *MBB) const override;
Expand Down Expand Up @@ -684,6 +692,7 @@ class AArch64TargetLowering : public TargetLowering {
iterator_range<SDNode::user_iterator> Users,
SDNodeFlags Flags, const SDLoc &dl,
SelectionDAG &DAG) const;
SDValue LowerCTSELECT(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const;
Expand Down Expand Up @@ -919,6 +928,8 @@ class AArch64TargetLowering : public TargetLowering {
bool hasMultipleConditionRegisters(EVT VT) const override {
return VT.isScalableVector();
}

bool isSelectSupported(SelectSupportKind Kind) const override { return true; }
};

namespace AArch64 {
Expand Down
Loading
Loading