Skip to content

Commit 7de2b81

Browse files
[LLVM][AArch64] Add native ct.select support for ARM64
This patch implements architecture-specific lowering for ct.select on AArch64 using CSEL (conditional select) instructions for constant-time selection. Implementation details: - Uses CSEL family of instructions for scalar integer types - Uses FCSEL for floating-point types (F16, BF16, F32, F64) - Post-RA MC lowering to convert pseudo-instructions to real CSEL/FCSEL - Handles vector types appropriately - Comprehensive test coverage for AArch64 The implementation includes: - ISelLowering: Custom lowering to CTSELECT pseudo-instructions - InstrInfo: Pseudo-instruction definitions and patterns - MCInstLower: Post-RA lowering of pseudo-instructions to actual CSEL/FCSEL - Proper handling of condition codes for constant-time guarantees
1 parent 6ac8221 commit 7de2b81

File tree

6 files changed

+368
-110
lines changed

6 files changed

+368
-110
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,12 +511,36 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
511511
setOperationAction(ISD::BR_CC, MVT::f64, Custom);
512512
setOperationAction(ISD::SELECT, MVT::i32, Custom);
513513
setOperationAction(ISD::SELECT, MVT::i64, Custom);
514+
setOperationAction(ISD::CTSELECT, MVT::i8, Promote);
515+
setOperationAction(ISD::CTSELECT, MVT::i16, Promote);
516+
setOperationAction(ISD::CTSELECT, MVT::i32, Custom);
517+
setOperationAction(ISD::CTSELECT, MVT::i64, Custom);
514518
if (Subtarget->hasFPARMv8()) {
515519
setOperationAction(ISD::SELECT, MVT::f16, Custom);
516520
setOperationAction(ISD::SELECT, MVT::bf16, Custom);
517521
}
522+
if (Subtarget->hasFullFP16()) {
523+
setOperationAction(ISD::CTSELECT, MVT::f16, Custom);
524+
setOperationAction(ISD::CTSELECT, MVT::bf16, Custom);
525+
} else {
526+
setOperationAction(ISD::CTSELECT, MVT::f16, Promote);
527+
setOperationAction(ISD::CTSELECT, MVT::bf16, Promote);
528+
}
518529
setOperationAction(ISD::SELECT, MVT::f32, Custom);
519530
setOperationAction(ISD::SELECT, MVT::f64, Custom);
531+
setOperationAction(ISD::CTSELECT, MVT::f32, Custom);
532+
setOperationAction(ISD::CTSELECT, MVT::f64, Custom);
533+
for (MVT VT : MVT::vector_valuetypes()) {
534+
MVT elemType = VT.getVectorElementType();
535+
if (elemType == MVT::i8 || elemType == MVT::i16) {
536+
setOperationAction(ISD::CTSELECT, VT, Promote);
537+
} else if ((elemType == MVT::f16 || elemType == MVT::bf16) &&
538+
!Subtarget->hasFullFP16()) {
539+
setOperationAction(ISD::CTSELECT, VT, Promote);
540+
} else {
541+
setOperationAction(ISD::CTSELECT, VT, Expand);
542+
}
543+
}
520544
setOperationAction(ISD::SELECT_CC, MVT::i32, Custom);
521545
setOperationAction(ISD::SELECT_CC, MVT::i64, Custom);
522546
setOperationAction(ISD::SELECT_CC, MVT::f16, Custom);
@@ -3328,6 +3352,20 @@ void AArch64TargetLowering::fixupPtrauthDiscriminator(
33283352
IntDiscOp.setImm(IntDisc);
33293353
}
33303354

3355+
MachineBasicBlock *AArch64TargetLowering::EmitCTSELECT(MachineInstr &MI,
3356+
MachineBasicBlock *MBB,
3357+
unsigned Opcode) const {
3358+
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3359+
DebugLoc DL = MI.getDebugLoc();
3360+
MachineInstrBuilder Builder = BuildMI(*MBB, MI, DL, TII->get(Opcode));
3361+
for (unsigned Idx = 0; Idx < MI.getNumOperands(); ++Idx) {
3362+
Builder.add(MI.getOperand(Idx));
3363+
}
3364+
Builder->setFlag(MachineInstr::NoMerge);
3365+
MBB->remove_instr(&MI);
3366+
return MBB;
3367+
}
3368+
33313369
MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
33323370
MachineInstr &MI, MachineBasicBlock *BB) const {
33333371

@@ -7590,6 +7628,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
75907628
return LowerSELECT(Op, DAG);
75917629
case ISD::SELECT_CC:
75927630
return LowerSELECT_CC(Op, DAG);
7631+
case ISD::CTSELECT:
7632+
return LowerCTSELECT(Op, DAG);
75937633
case ISD::JumpTable:
75947634
return LowerJumpTable(Op, DAG);
75957635
case ISD::BR_JT:
@@ -12149,6 +12189,22 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
1214912189
return Res;
1215012190
}
1215112191

12192+
SDValue AArch64TargetLowering::LowerCTSELECT(SDValue Op,
12193+
SelectionDAG &DAG) const {
12194+
SDValue CCVal = Op->getOperand(0);
12195+
SDValue TVal = Op->getOperand(1);
12196+
SDValue FVal = Op->getOperand(2);
12197+
SDLoc DL(Op);
12198+
12199+
EVT VT = Op.getValueType();
12200+
12201+
SDValue Zero = DAG.getConstant(0, DL, CCVal.getValueType());
12202+
SDValue CC;
12203+
SDValue Cmp = getAArch64Cmp(CCVal, Zero, ISD::SETNE, CC, DAG, DL);
12204+
12205+
return DAG.getNode(AArch64ISD::CTSELECT, DL, VT, TVal, FVal, CC, Cmp);
12206+
}
12207+
1215212208
SDValue AArch64TargetLowering::LowerJumpTable(SDValue Op,
1215312209
SelectionDAG &DAG) const {
1215412210
// Jump table entries as PC relative offsets. No additional tweaking

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323

2424
namespace llvm {
2525

26+
namespace AArch64ISD {
27+
// Forward declare the enum from the generated file
28+
enum GenNodeType : unsigned;
29+
} // namespace AArch64ISD
30+
2631
class AArch64TargetMachine;
2732

2833
namespace AArch64 {
@@ -202,6 +207,9 @@ class AArch64TargetLowering : public TargetLowering {
202207
MachineOperand &AddrDiscOp,
203208
const TargetRegisterClass *AddrDiscRC) const;
204209

210+
MachineBasicBlock *EmitCTSELECT(MachineInstr &MI, MachineBasicBlock *BB,
211+
unsigned Opcode) const;
212+
205213
MachineBasicBlock *
206214
EmitInstrWithCustomInserter(MachineInstr &MI,
207215
MachineBasicBlock *MBB) const override;
@@ -684,6 +692,7 @@ class AArch64TargetLowering : public TargetLowering {
684692
iterator_range<SDNode::user_iterator> Users,
685693
SDNodeFlags Flags, const SDLoc &dl,
686694
SelectionDAG &DAG) const;
695+
SDValue LowerCTSELECT(SDValue Op, SelectionDAG &DAG) const;
687696
SDValue LowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
688697
SDValue LowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
689698
SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const;
@@ -919,6 +928,8 @@ class AArch64TargetLowering : public TargetLowering {
919928
bool hasMultipleConditionRegisters(EVT VT) const override {
920929
return VT.isScalableVector();
921930
}
931+
932+
bool isSelectSupported(SelectSupportKind Kind) const override { return true; }
922933
};
923934

924935
namespace AArch64 {

0 commit comments

Comments
 (0)