-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[ConstantTime] Native ct.select support for ARM32 and Thumb #166707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/wizardengineer/ct-select-clang
Are you sure you want to change the base?
[ConstantTime] Native ct.select support for ARM32 and Thumb #166707
Conversation
|
Warning This pull request is not mergeable via GitHub because a downstack PR is open. Once all requirements are satisfied, merge this PR as a stack on Graphite.
This stack of pull requests is managed by Graphite. Learn more about stacking. |
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
This patch implements architecture-specific lowering for ct.select on ARM (both ARM32 and Thumb modes) using conditional move instructions and bitwise operations for constant-time selection. Implementation details: - Uses pseudo-instructions that are expanded Post-RA to bitwise operations - Post-RA expansion in ARMBaseInstrInfo for BUNDLE pseudo-instructions - Handles scalar integer types, floating-point, and half-precision types - Handles vector types with NEON when available - Support for both ARM and Thumb instruction sets (Thumb1 and Thumb2) - Special handling for Thumb1 which lacks conditional execution - Comprehensive test coverage including half-precision and vectors The implementation includes: - ISelLowering: Custom lowering to CTSELECT pseudo-instructions - ISelDAGToDAG: Selection of appropriate pseudo-instructions - BaseInstrInfo: Post-RA expansion of BUNDLE to bitwise instruction sequences - InstrInfo.td: Pseudo-instruction definitions for different types - TargetMachine: Registration of Post-RA expansion pass - Proper handling of condition codes and register allocation constraints
19a683e to
8d58556
Compare
cbb5490 to
6ac8221
Compare
|
@llvm/pr-subscribers-backend-arm Author: Julius Alexandre (wizardengineer) ChangesThis patch implements architecture-specific lowering for ct.select on ARM Implementation details:
The implementation includes:
Patch is 166.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166707.diff 10 Files Affected:
diff --git a/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp b/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
index 22769dbf38719..6d8a3b72244fe 100644
--- a/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
@@ -1526,18 +1526,351 @@ void ARMBaseInstrInfo::expandMEMCPY(MachineBasicBlock::iterator MI) const {
BB->erase(MI);
}
+// Expands the ctselect pseudo for vector operands, post-RA.
+bool ARMBaseInstrInfo::expandCtSelectVector(MachineInstr &MI) const {
+ MachineBasicBlock *MBB = MI.getParent();
+ DebugLoc DL = MI.getDebugLoc();
+
+ Register DestReg = MI.getOperand(0).getReg();
+ Register MaskReg = MI.getOperand(1).getReg();
+
+ // These operations will differ by operand register size.
+ unsigned AndOp = ARM::VANDd;
+ unsigned BicOp = ARM::VBICd;
+ unsigned OrrOp = ARM::VORRd;
+ unsigned BroadcastOp = ARM::VDUP32d;
+
+ const TargetRegisterInfo *TRI = &getRegisterInfo();
+ const TargetRegisterClass *RC = TRI->getMinimalPhysRegClass(DestReg);
+
+ if (ARM::QPRRegClass.hasSubClassEq(RC)) {
+ AndOp = ARM::VANDq;
+ BicOp = ARM::VBICq;
+ OrrOp = ARM::VORRq;
+ BroadcastOp = ARM::VDUP32q;
+ }
+
+ unsigned RsbOp = Subtarget.isThumb2() ? ARM::t2RSBri : ARM::RSBri;
+
+ // Any vector pseudo has: ((outs $dst, $tmp_mask, $bcast_mask), (ins $src1,
+ // $src2, $cond))
+ Register VectorMaskReg = MI.getOperand(2).getReg();
+ Register Src1Reg = MI.getOperand(3).getReg();
+ Register Src2Reg = MI.getOperand(4).getReg();
+ Register CondReg = MI.getOperand(5).getReg();
+
+ // The following sequence of steps yields: (src1 & mask) | (src2 & ~mask)
+
+ // 1. mask = 0 - cond
+ // When cond = 0: mask = 0x00000000.
+ // When cond = 1: mask = 0xFFFFFFFF.
+
+ MachineInstr *FirstNewMI = BuildMI(*MBB, MI, DL, get(RsbOp), MaskReg)
+ .addReg(CondReg)
+ .addImm(0)
+ .add(predOps(ARMCC::AL))
+ .add(condCodeOp())
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // 2. A = src1 & mask
+ // For vectors, broadcast the scalar mask so it matches operand size.
+ BuildMI(*MBB, MI, DL, get(BroadcastOp), VectorMaskReg)
+ .addReg(MaskReg)
+ .add(predOps(ARMCC::AL))
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ BuildMI(*MBB, MI, DL, get(AndOp), DestReg)
+ .addReg(Src1Reg)
+ .addReg(VectorMaskReg)
+ .add(predOps(ARMCC::AL))
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // 3. B = src2 & ~mask
+ BuildMI(*MBB, MI, DL, get(BicOp), VectorMaskReg)
+ .addReg(Src2Reg)
+ .addReg(VectorMaskReg)
+ .add(predOps(ARMCC::AL))
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // 4. result = A | B
+ auto LastNewMI = BuildMI(*MBB, MI, DL, get(OrrOp), DestReg)
+ .addReg(DestReg)
+ .addReg(VectorMaskReg)
+ .add(predOps(ARMCC::AL))
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ auto BundleStart = FirstNewMI->getIterator();
+ auto BundleEnd = LastNewMI->getIterator();
+
+ // Add instruction bundling
+ finalizeBundle(*MBB, BundleStart, std::next(BundleEnd));
+
+ MI.eraseFromParent();
+ return true;
+}
+
+// Expands the ctselect pseudo for thumb1, post-RA.
+bool ARMBaseInstrInfo::expandCtSelectThumb(MachineInstr &MI) const {
+ MachineBasicBlock *MBB = MI.getParent();
+ DebugLoc DL = MI.getDebugLoc();
+
+ // pseudos in thumb1 mode have: (outs $dst, $tmp_mask), (ins $src1, $src2,
+ // $cond)) register class here is always tGPR.
+ Register DestReg = MI.getOperand(0).getReg();
+ Register MaskReg = MI.getOperand(1).getReg();
+ Register Src1Reg = MI.getOperand(2).getReg();
+ Register Src2Reg = MI.getOperand(3).getReg();
+ Register CondReg = MI.getOperand(4).getReg();
+
+ // Access register info
+ MachineFunction *MF = MBB->getParent();
+ const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
+ MachineRegisterInfo &MRI = MF->getRegInfo();
+
+ unsigned RegSize = TRI->getRegSizeInBits(MaskReg, MRI);
+ unsigned ShiftAmount = RegSize - 1;
+
+ // Option 1: Shift-based mask (preferred - no flag modification)
+ MachineInstr *FirstNewMI = BuildMI(*MBB, MI, DL, get(ARM::tMOVr), MaskReg)
+ .addReg(CondReg)
+ .add(predOps(ARMCC::AL))
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // Instead of using RSB, we can use LSL and ASR to get the mask. This is to
+ // avoid the flag modification caused by RSB. tLSLri: (outs tGPR:$Rd,
+ // s_cc_out:$s), (ins tGPR:$Rm, imm0_31:$imm5, pred:$p)
+ BuildMI(*MBB, MI, DL, get(ARM::tLSLri), MaskReg)
+ .addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
+ .addReg(MaskReg) // $Rm
+ .addImm(ShiftAmount) // imm0_31:$imm5
+ .add(predOps(ARMCC::AL)) // pred:$p
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // tASRri: (outs tGPR:$Rd, s_cc_out:$s), (ins tGPR:$Rm, imm_sr:$imm5, pred:$p)
+ BuildMI(*MBB, MI, DL, get(ARM::tASRri), MaskReg)
+ .addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
+ .addReg(MaskReg) // $Rm
+ .addImm(ShiftAmount) // imm_sr:$imm5
+ .add(predOps(ARMCC::AL)) // pred:$p
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // 2. xor_diff = src1 ^ src2
+ BuildMI(*MBB, MI, DL, get(ARM::tMOVr), DestReg)
+ .addReg(Src1Reg)
+ .add(predOps(ARMCC::AL))
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // tEOR has tied operands: (outs tGPR:$Rdn, s_cc_out:$s), (ins tGPR:$Rn,
+ // pred:$p) with constraint "$Rn = $Rdn"
+ BuildMI(*MBB, MI, DL, get(ARM::tEOR), DestReg)
+ .addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
+ .addReg(DestReg) // tied input $Rn
+ .addReg(Src2Reg) // $Rm
+ .add(predOps(ARMCC::AL)) // pred:$p
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // 3. masked_xor = xor_diff & mask
+ // tAND has tied operands: (outs tGPR:$Rdn, s_cc_out:$s), (ins tGPR:$Rn,
+ // pred:$p) with constraint "$Rn = $Rdn"
+ BuildMI(*MBB, MI, DL, get(ARM::tAND), DestReg)
+ .addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
+ .addReg(DestReg) // tied input $Rn
+ .addReg(MaskReg, RegState::Kill) // $Rm
+ .add(predOps(ARMCC::AL)) // pred:$p
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // 4. result = src2 ^ masked_xor
+ // tEOR has tied operands: (outs tGPR:$Rdn, s_cc_out:$s), (ins tGPR:$Rn,
+ // pred:$p) with constraint "$Rn = $Rdn"
+ auto LastMI =
+ BuildMI(*MBB, MI, DL, get(ARM::tEOR), DestReg)
+ .addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
+ .addReg(DestReg) // tied input $Rn
+ .addReg(Src2Reg) // $Rm
+ .add(predOps(ARMCC::AL)) // pred:$p
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // Add instruction bundling
+ auto BundleStart = FirstNewMI->getIterator();
+ finalizeBundle(*MBB, BundleStart, std::next(LastMI->getIterator()));
+
+ MI.eraseFromParent();
+ return true;
+}
+
+// Expands the ctselect pseudo, post-RA.
+bool ARMBaseInstrInfo::expandCtSelect(MachineInstr &MI) const {
+ MachineBasicBlock *MBB = MI.getParent();
+ DebugLoc DL = MI.getDebugLoc();
+
+ Register DestReg = MI.getOperand(0).getReg();
+ Register MaskReg = MI.getOperand(1).getReg();
+ Register DestRegSavedRef = DestReg;
+ Register Src1Reg, Src2Reg, CondReg;
+
+ // These operations will differ by operand register size.
+ unsigned RsbOp = ARM::RSBri;
+ unsigned AndOp = ARM::ANDrr;
+ unsigned BicOp = ARM::BICrr;
+ unsigned OrrOp = ARM::ORRrr;
+
+ if (Subtarget.isThumb2()) {
+ RsbOp = ARM::t2RSBri;
+ AndOp = ARM::t2ANDrr;
+ BicOp = ARM::t2BICrr;
+ OrrOp = ARM::t2ORRrr;
+ }
+
+ unsigned Opcode = MI.getOpcode();
+ bool IsFloat = Opcode == ARM::CTSELECTf32 || Opcode == ARM::CTSELECTf16 ||
+ Opcode == ARM::CTSELECTbf16;
+ MachineInstr *FirstNewMI = nullptr;
+ if (IsFloat) {
+ // Each float pseudo has: (outs $dst, $tmp_mask, $scratch1, $scratch2), (ins
+ // $src1, $src2, $cond)) We use two scratch registers in tablegen for
+ // bitwise ops on float types,.
+ Register GPRScratch1 = MI.getOperand(2).getReg();
+ Register GPRScratch2 = MI.getOperand(3).getReg();
+
+ // choice a from __builtin_ct_select(cond, a, b)
+ Src1Reg = MI.getOperand(4).getReg();
+ // choice b from __builtin_ct_select(cond, a, b)
+ Src2Reg = MI.getOperand(5).getReg();
+ // cond from __builtin_ct_select(cond, a, b)
+ CondReg = MI.getOperand(6).getReg();
+
+ // Move fp src1 to GPR scratch1 so we can do our bitwise ops
+ FirstNewMI = BuildMI(*MBB, MI, DL, get(ARM::VMOVRS), GPRScratch1)
+ .addReg(Src1Reg)
+ .add(predOps(ARMCC::AL))
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // Move src2 to scratch2
+ BuildMI(*MBB, MI, DL, get(ARM::VMOVRS), GPRScratch2)
+ .addReg(Src2Reg)
+ .add(predOps(ARMCC::AL))
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ Src1Reg = GPRScratch1;
+ Src2Reg = GPRScratch2;
+ // Reuse GPRScratch1 for dest after we are done working with src1.
+ DestReg = GPRScratch1;
+ } else {
+ // Any non-float, non-vector pseudo has: (outs $dst, $tmp_mask), (ins $src1,
+ // $src2, $cond))
+ Src1Reg = MI.getOperand(2).getReg();
+ Src2Reg = MI.getOperand(3).getReg();
+ CondReg = MI.getOperand(4).getReg();
+ }
+
+ // The following sequence of steps yields: (src1 & mask) | (src2 & ~mask)
+
+ // 1. mask = 0 - cond
+ // When cond = 0: mask = 0x00000000.
+ // When cond = 1: mask = 0xFFFFFFFF.
+ auto TmpNewMI = BuildMI(*MBB, MI, DL, get(RsbOp), MaskReg)
+ .addReg(CondReg)
+ .addImm(0)
+ .add(predOps(ARMCC::AL))
+ .add(condCodeOp())
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // We use the first instruction in the bundle as the first instruction.
+ if (!FirstNewMI)
+ FirstNewMI = TmpNewMI;
+
+ // 2. A = src1 & mask
+ BuildMI(*MBB, MI, DL, get(AndOp), DestReg)
+ .addReg(Src1Reg)
+ .addReg(MaskReg)
+ .add(predOps(ARMCC::AL))
+ .add(condCodeOp())
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // 3. B = src2 & ~mask
+ BuildMI(*MBB, MI, DL, get(BicOp), MaskReg)
+ .addReg(Src2Reg)
+ .addReg(MaskReg)
+ .add(predOps(ARMCC::AL))
+ .add(condCodeOp())
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ // 4. result = A | B
+ auto LastNewMI = BuildMI(*MBB, MI, DL, get(OrrOp), DestReg)
+ .addReg(DestReg)
+ .addReg(MaskReg)
+ .add(predOps(ARMCC::AL))
+ .add(condCodeOp())
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+
+ if (IsFloat) {
+ // Return our result from GPR to the correct register type.
+ LastNewMI = BuildMI(*MBB, MI, DL, get(ARM::VMOVSR), DestRegSavedRef)
+ .addReg(DestReg)
+ .add(predOps(ARMCC::AL))
+ .setMIFlag(MachineInstr::MIFlag::NoMerge);
+ }
+
+ auto BundleStart = FirstNewMI->getIterator();
+ auto BundleEnd = LastNewMI->getIterator();
+
+ // Add instruction bundling
+ finalizeBundle(*MBB, BundleStart, std::next(BundleEnd));
+
+ MI.eraseFromParent();
+ return true;
+}
+
bool ARMBaseInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
- if (MI.getOpcode() == TargetOpcode::LOAD_STACK_GUARD) {
+ auto opcode = MI.getOpcode();
+
+ if (opcode == TargetOpcode::LOAD_STACK_GUARD) {
expandLoadStackGuard(MI);
MI.getParent()->erase(MI);
return true;
}
- if (MI.getOpcode() == ARM::MEMCPY) {
+ if (opcode == ARM::MEMCPY) {
expandMEMCPY(MI);
return true;
}
+ if (opcode == ARM::CTSELECTf64) {
+ if (Subtarget.isThumb1Only()) {
+ LLVM_DEBUG(dbgs() << "Opcode (thumb1 subtarget) " << opcode
+ << "replaced by: " << MI);
+ return expandCtSelectThumb(MI);
+ } else {
+ LLVM_DEBUG(dbgs() << "Opcode (vector) " << opcode
+ << "replaced by: " << MI);
+ return expandCtSelectVector(MI);
+ }
+ }
+
+ if (opcode == ARM::CTSELECTv8i8 || opcode == ARM::CTSELECTv4i16 ||
+ opcode == ARM::CTSELECTv2i32 || opcode == ARM::CTSELECTv1i64 ||
+ opcode == ARM::CTSELECTv2f32 || opcode == ARM::CTSELECTv4f16 ||
+ opcode == ARM::CTSELECTv4bf16 || opcode == ARM::CTSELECTv16i8 ||
+ opcode == ARM::CTSELECTv8i16 || opcode == ARM::CTSELECTv4i32 ||
+ opcode == ARM::CTSELECTv2i64 || opcode == ARM::CTSELECTv4f32 ||
+ opcode == ARM::CTSELECTv2f64 || opcode == ARM::CTSELECTv8f16 ||
+ opcode == ARM::CTSELECTv8bf16) {
+ LLVM_DEBUG(dbgs() << "Opcode (vector) " << opcode << "replaced by: " << MI);
+ return expandCtSelectVector(MI);
+ }
+
+ if (opcode == ARM::CTSELECTint || opcode == ARM::CTSELECTf16 ||
+ opcode == ARM::CTSELECTbf16 || opcode == ARM::CTSELECTf32) {
+ if (Subtarget.isThumb1Only()) {
+ LLVM_DEBUG(dbgs() << "Opcode (thumb1 subtarget) " << opcode
+ << "replaced by: " << MI);
+ return expandCtSelectThumb(MI);
+ } else {
+ LLVM_DEBUG(dbgs() << "Opcode " << opcode << "replaced by: " << MI);
+ return expandCtSelect(MI);
+ }
+ }
+
// This hook gets to expand COPY instructions before they become
// copyPhysReg() calls. Look for VMOVS instructions that can legally be
// widened to VMOVD. We prefer the VMOVD when possible because it may be
diff --git a/llvm/lib/Target/ARM/ARMBaseInstrInfo.h b/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
index 2869e7f708046..f0e090f09f5dc 100644
--- a/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
+++ b/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
@@ -221,6 +221,12 @@ class ARMBaseInstrInfo : public ARMGenInstrInfo {
const TargetRegisterInfo *TRI, Register VReg,
MachineInstr::MIFlag Flags = MachineInstr::NoFlags) const override;
+ bool expandCtSelectVector(MachineInstr &MI) const;
+
+ bool expandCtSelectThumb(MachineInstr &MI) const;
+
+ bool expandCtSelect(MachineInstr &MI) const;
+
bool expandPostRAPseudo(MachineInstr &MI) const override;
bool shouldSink(const MachineInstr &MI) const override;
diff --git a/llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp b/llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp
index 847b7af5a9b11..3fdc5734baaa5 100644
--- a/llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp
+++ b/llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp
@@ -4200,6 +4200,92 @@ void ARMDAGToDAGISel::Select(SDNode *N) {
// Other cases are autogenerated.
break;
}
+ case ARMISD::CTSELECT: {
+ EVT VT = N->getValueType(0);
+ unsigned PseudoOpcode;
+ bool IsFloat = false;
+ bool IsVector = false;
+
+ if (VT == MVT::f16) {
+ PseudoOpcode = ARM::CTSELECTf16;
+ IsFloat = true;
+ } else if (VT == MVT::bf16) {
+ PseudoOpcode = ARM::CTSELECTbf16;
+ IsFloat = true;
+ } else if (VT == MVT::f32) {
+ PseudoOpcode = ARM::CTSELECTf32;
+ IsFloat = true;
+ } else if (VT == MVT::f64) {
+ PseudoOpcode = ARM::CTSELECTf64;
+ IsVector = true;
+ } else if (VT == MVT::v8i8) {
+ PseudoOpcode = ARM::CTSELECTv8i8;
+ IsVector = true;
+ } else if (VT == MVT::v4i16) {
+ PseudoOpcode = ARM::CTSELECTv4i16;
+ IsVector = true;
+ } else if (VT == MVT::v2i32) {
+ PseudoOpcode = ARM::CTSELECTv2i32;
+ IsVector = true;
+ } else if (VT == MVT::v1i64) {
+ PseudoOpcode = ARM::CTSELECTv1i64;
+ IsVector = true;
+ } else if (VT == MVT::v2f32) {
+ PseudoOpcode = ARM::CTSELECTv2f32;
+ IsVector = true;
+ } else if (VT == MVT::v4f16) {
+ PseudoOpcode = ARM::CTSELECTv4f16;
+ IsVector = true;
+ } else if (VT == MVT::v4bf16) {
+ PseudoOpcode = ARM::CTSELECTv4bf16;
+ IsVector = true;
+ } else if (VT == MVT::v16i8) {
+ PseudoOpcode = ARM::CTSELECTv16i8;
+ IsVector = true;
+ } else if (VT == MVT::v8i16) {
+ PseudoOpcode = ARM::CTSELECTv8i16;
+ IsVector = true;
+ } else if (VT == MVT::v4i32) {
+ PseudoOpcode = ARM::CTSELECTv4i32;
+ IsVector = true;
+ } else if (VT == MVT::v2i64) {
+ PseudoOpcode = ARM::CTSELECTv2i64;
+ IsVector = true;
+ } else if (VT == MVT::v4f32) {
+ PseudoOpcode = ARM::CTSELECTv4f32;
+ IsVector = true;
+ } else if (VT == MVT::v2f64) {
+ PseudoOpcode = ARM::CTSELECTv2f64;
+ IsVector = true;
+ } else if (VT == MVT::v8f16) {
+ PseudoOpcode = ARM::CTSELECTv8f16;
+ IsVector = true;
+ } else if (VT == MVT::v8bf16) {
+ PseudoOpcode = ARM::CTSELECTv8bf16;
+ IsVector = true;
+ } else {
+ // i1, i8, i16, i32, i64
+ PseudoOpcode = ARM::CTSELECTint;
+ }
+
+ SmallVector<EVT, 4> VTs;
+ VTs.push_back(VT); // $dst
+ VTs.push_back(MVT::i32); // $tmp_mask (always GPR)
+
+ if (IsVector) {
+ VTs.push_back(VT); // $bcast_mask (same type as dst for vectors)
+ } else if (IsFloat) {
+ VTs.push_back(MVT::i32); // $scratch1 (GPR)
+ VTs.push_back(MVT::i32); // $scratch2 (GPR)
+ }
+
+ // src1, src2, cond
+ SDValue Ops[] = {N->getOperand(0), N->getOperand(1), N->getOperand(2)};
+
+ SDNode *ResNode = CurDAG->getMachineNode(PseudoOpcode, SDLoc(N), VTs, Ops);
+ ReplaceNode(N, ResNode);
+ return;
+ }
case ARMISD::VZIP: {
EVT VT = N->getValueType(0);
// vzip.32 Dd, Dm is a pseudo-instruction expanded to vtrn.32 Dd, Dm.
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 6b0653457cbaf..63005f1c9f989 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -203,6 +203,7 @@ void ARMTargetLowering::addTypeForNEON(MVT VT, MVT PromotedLdStVT) {
setOperationAction(ISD::SELECT, VT, Expand);
setOperationAction(ISD::SELECT_CC, VT, Expand);
setOperationAction(ISD::VSELECT, VT, Expand);
+ setOperationAction(ISD::CTSELECT, VT, Custom);
setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Expand);
if (VT.isInteger()) {
setOperationAction(ISD::SHL, VT, Custom);
@@ -304,6 +305,7 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) {
setOperationAction(ISD::CTPOP, VT, Expand);
setOperationAction(ISD::SELECT, VT, Expand);
setOperationAction(ISD::SELECT_CC, VT, Expand);
+ setOperationAction(ISD::CTSELECT, VT, Custom);
// Vector reductions
setOperationAction(ISD::VECREDUCE_ADD, VT, Legal);
@@ -355,6 +357,7 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) {
setOperationAction(ISD::MSTORE, VT, Legal);
setOperationAction(ISD::SELECT, VT, Expand);
setOperationAction(ISD::SELECT_CC, VT, Expand);
+ setOperationAction(ISD::CTSELECT, VT, Custom);
// Pre and Post inc are supported on loads and stores
for (unsigned im = (unsigned)ISD::PRE_INC;
@@ -408,6 +411,28 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) {
setOperationAction(ISD::VECREDUCE_FMIN, MVT::v2f16, Custom);
setOperationAction(ISD::VECREDUCE_FMAX, MVT::v2f16, Custom);
+ if (Subtarget->hasFullFP16()) {
+ setOperationAction(ISD::CTSELECT, MVT::v4f16, Custom);
+ setOperationAction(ISD::CTSELECT, MVT::v8f16, Custom);
+ }
+
+ if (Subtarget->hasBF16()) {
+ setOperationAction(ISD::CTSELECT, MVT::v4bf16, Custom);
+ setOperationAction(ISD::CTSELECT, MVT::v8bf16, Custom);
+ }
+
+ // small exotic vectors get scalarised for ctselect
+ setOperationAction(ISD::CTSELECT, MVT::v1i8, Expand);
+ setOperationAction(ISD::CTSELECT, MVT::v1i16, Expand);
+ setOperationAction(ISD::CTSELECT, MVT::v1i32, Expand);
+ setOperationAction(ISD::CTSELECT, MVT::v1f32, Expand);
+ setOperationAction(ISD::CTSELECT, MVT::v2i8, Expand);
+
+ setOperationAction(ISD::CTSELECT, MVT::v2i16, Promote);
+ setOperationPromotedToType(ISD::CTSELECT, MVT::v2i16, MVT::v4i16);
+ setOperationAction(IS...
[truncated]
|

This patch implements architecture-specific lowering for ct.select on ARM
(both ARM32 and Thumb modes) using conditional move instructions and
bitwise operations for constant-time selection.
Implementation details:
The implementation includes: