Skip to content

Commit 19a683e

Browse files
[LLVM][ARM] Add native ct.select support for ARM32 and Thumb
This patch implements architecture-specific lowering for ct.select on ARM (both ARM32 and Thumb modes) using conditional move instructions and bitwise operations for constant-time selection. Implementation details: - Uses pseudo-instructions that are expanded Post-RA to bitwise operations - Post-RA expansion in ARMBaseInstrInfo for BUNDLE pseudo-instructions - Handles scalar integer types, floating-point, and half-precision types - Handles vector types with NEON when available - Support for both ARM and Thumb instruction sets (Thumb1 and Thumb2) - Special handling for Thumb1 which lacks conditional execution - Comprehensive test coverage including half-precision and vectors The implementation includes: - ISelLowering: Custom lowering to CTSELECT pseudo-instructions - ISelDAGToDAG: Selection of appropriate pseudo-instructions - BaseInstrInfo: Post-RA expansion of BUNDLE to bitwise instruction sequences - InstrInfo.td: Pseudo-instruction definitions for different types - TargetMachine: Registration of Post-RA expansion pass - Proper handling of condition codes and register allocation constraints
1 parent cbb5490 commit 19a683e

File tree

10 files changed

+4538
-29
lines changed

10 files changed

+4538
-29
lines changed

llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp

Lines changed: 334 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,18 +1526,350 @@ void ARMBaseInstrInfo::expandMEMCPY(MachineBasicBlock::iterator MI) const {
15261526
BB->erase(MI);
15271527
}
15281528

1529+
// Expands the ctselect pseudo for vector operands, post-RA.
1530+
bool ARMBaseInstrInfo::expandCtSelectVector(MachineInstr &MI) const {
1531+
MachineBasicBlock *MBB = MI.getParent();
1532+
DebugLoc DL = MI.getDebugLoc();
1533+
1534+
Register DestReg = MI.getOperand(0).getReg();
1535+
Register MaskReg = MI.getOperand(1).getReg();
1536+
1537+
// These operations will differ by operand register size.
1538+
unsigned AndOp = ARM::VANDd;
1539+
unsigned BicOp = ARM::VBICd;
1540+
unsigned OrrOp = ARM::VORRd;
1541+
unsigned BroadcastOp = ARM::VDUP32d;
1542+
1543+
const TargetRegisterInfo *TRI = &getRegisterInfo();
1544+
const TargetRegisterClass *RC = TRI->getMinimalPhysRegClass(DestReg);
1545+
1546+
if (ARM::QPRRegClass.hasSubClassEq(RC)) {
1547+
AndOp = ARM::VANDq;
1548+
BicOp = ARM::VBICq;
1549+
OrrOp = ARM::VORRq;
1550+
BroadcastOp = ARM::VDUP32q;
1551+
}
1552+
1553+
unsigned RsbOp = Subtarget.isThumb2() ? ARM::t2RSBri : ARM::RSBri;
1554+
1555+
// Any vector pseudo has: ((outs $dst, $tmp_mask, $bcast_mask), (ins $src1, $src2, $cond))
1556+
Register VectorMaskReg = MI.getOperand(2).getReg();
1557+
Register Src1Reg = MI.getOperand(3).getReg();
1558+
Register Src2Reg = MI.getOperand(4).getReg();
1559+
Register CondReg = MI.getOperand(5).getReg();
1560+
1561+
// The following sequence of steps yields: (src1 & mask) | (src2 & ~mask)
1562+
1563+
// 1. mask = 0 - cond
1564+
// When cond = 0: mask = 0x00000000.
1565+
// When cond = 1: mask = 0xFFFFFFFF.
1566+
1567+
MachineInstr *FirstNewMI =
1568+
BuildMI(*MBB, MI, DL, get(RsbOp), MaskReg)
1569+
.addReg(CondReg)
1570+
.addImm(0)
1571+
.add(predOps(ARMCC::AL))
1572+
.add(condCodeOp())
1573+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1574+
1575+
// 2. A = src1 & mask
1576+
// For vectors, broadcast the scalar mask so it matches operand size.
1577+
BuildMI(*MBB, MI, DL, get(BroadcastOp), VectorMaskReg)
1578+
.addReg(MaskReg)
1579+
.add(predOps(ARMCC::AL))
1580+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1581+
1582+
BuildMI(*MBB, MI, DL, get(AndOp), DestReg)
1583+
.addReg(Src1Reg)
1584+
.addReg(VectorMaskReg)
1585+
.add(predOps(ARMCC::AL))
1586+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1587+
1588+
// 3. B = src2 & ~mask
1589+
BuildMI(*MBB, MI, DL, get(BicOp), VectorMaskReg)
1590+
.addReg(Src2Reg)
1591+
.addReg(VectorMaskReg)
1592+
.add(predOps(ARMCC::AL))
1593+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1594+
1595+
// 4. result = A | B
1596+
auto LastNewMI = BuildMI(*MBB, MI, DL, get(OrrOp), DestReg)
1597+
.addReg(DestReg)
1598+
.addReg(VectorMaskReg)
1599+
.add(predOps(ARMCC::AL))
1600+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1601+
1602+
auto BundleStart = FirstNewMI->getIterator();
1603+
auto BundleEnd = LastNewMI->getIterator();
1604+
1605+
// Add instruction bundling
1606+
finalizeBundle(*MBB, BundleStart, std::next(BundleEnd));
1607+
1608+
MI.eraseFromParent();
1609+
return true;
1610+
}
1611+
1612+
// Expands the ctselect pseudo for thumb1, post-RA.
1613+
bool ARMBaseInstrInfo::expandCtSelectThumb(MachineInstr &MI) const {
1614+
MachineBasicBlock *MBB = MI.getParent();
1615+
DebugLoc DL = MI.getDebugLoc();
1616+
1617+
// pseudos in thumb1 mode have: (outs $dst, $tmp_mask), (ins $src1, $src2, $cond))
1618+
// register class here is always tGPR.
1619+
Register DestReg = MI.getOperand(0).getReg();
1620+
Register MaskReg = MI.getOperand(1).getReg();
1621+
Register Src1Reg = MI.getOperand(2).getReg();
1622+
Register Src2Reg = MI.getOperand(3).getReg();
1623+
Register CondReg = MI.getOperand(4).getReg();
1624+
1625+
// Access register info
1626+
MachineFunction *MF = MBB->getParent();
1627+
const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
1628+
MachineRegisterInfo &MRI = MF->getRegInfo();
1629+
1630+
unsigned RegSize = TRI->getRegSizeInBits(MaskReg, MRI);
1631+
unsigned ShiftAmount = RegSize - 1;
1632+
1633+
// Option 1: Shift-based mask (preferred - no flag modification)
1634+
MachineInstr *FirstNewMI =
1635+
BuildMI(*MBB, MI, DL, get(ARM::tMOVr), MaskReg)
1636+
.addReg(CondReg)
1637+
.add(predOps(ARMCC::AL))
1638+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1639+
1640+
// Instead of using RSB, we can use LSL and ASR to get the mask. This is to avoid the flag modification caused by RSB.
1641+
// tLSLri: (outs tGPR:$Rd, s_cc_out:$s), (ins tGPR:$Rm, imm0_31:$imm5, pred:$p)
1642+
BuildMI(*MBB, MI, DL, get(ARM::tLSLri), MaskReg)
1643+
.addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
1644+
.addReg(MaskReg) // $Rm
1645+
.addImm(ShiftAmount) // imm0_31:$imm5
1646+
.add(predOps(ARMCC::AL)) // pred:$p
1647+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1648+
1649+
// tASRri: (outs tGPR:$Rd, s_cc_out:$s), (ins tGPR:$Rm, imm_sr:$imm5, pred:$p)
1650+
BuildMI(*MBB, MI, DL, get(ARM::tASRri), MaskReg)
1651+
.addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
1652+
.addReg(MaskReg) // $Rm
1653+
.addImm(ShiftAmount) // imm_sr:$imm5
1654+
.add(predOps(ARMCC::AL)) // pred:$p
1655+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1656+
1657+
// 2. xor_diff = src1 ^ src2
1658+
BuildMI(*MBB, MI, DL, get(ARM::tMOVr), DestReg)
1659+
.addReg(Src1Reg)
1660+
.add(predOps(ARMCC::AL))
1661+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1662+
1663+
// tEOR has tied operands: (outs tGPR:$Rdn, s_cc_out:$s), (ins tGPR:$Rn, pred:$p) with constraint "$Rn = $Rdn"
1664+
BuildMI(*MBB, MI, DL, get(ARM::tEOR), DestReg)
1665+
.addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
1666+
.addReg(DestReg) // tied input $Rn
1667+
.addReg(Src2Reg) // $Rm
1668+
.add(predOps(ARMCC::AL)) // pred:$p
1669+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1670+
1671+
// 3. masked_xor = xor_diff & mask
1672+
// tAND has tied operands: (outs tGPR:$Rdn, s_cc_out:$s), (ins tGPR:$Rn, pred:$p) with constraint "$Rn = $Rdn"
1673+
BuildMI(*MBB, MI, DL, get(ARM::tAND), DestReg)
1674+
.addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
1675+
.addReg(DestReg) // tied input $Rn
1676+
.addReg(MaskReg, RegState::Kill) // $Rm
1677+
.add(predOps(ARMCC::AL)) // pred:$p
1678+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1679+
1680+
// 4. result = src2 ^ masked_xor
1681+
// tEOR has tied operands: (outs tGPR:$Rdn, s_cc_out:$s), (ins tGPR:$Rn, pred:$p) with constraint "$Rn = $Rdn"
1682+
auto LastMI = BuildMI(*MBB, MI, DL, get(ARM::tEOR), DestReg)
1683+
.addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
1684+
.addReg(DestReg) // tied input $Rn
1685+
.addReg(Src2Reg) // $Rm
1686+
.add(predOps(ARMCC::AL)) // pred:$p
1687+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1688+
1689+
// Add instruction bundling
1690+
auto BundleStart = FirstNewMI->getIterator();
1691+
finalizeBundle(*MBB, BundleStart, std::next(LastMI->getIterator()));
1692+
1693+
MI.eraseFromParent();
1694+
return true;
1695+
}
1696+
1697+
// Expands the ctselect pseudo, post-RA.
1698+
bool ARMBaseInstrInfo::expandCtSelect(MachineInstr &MI) const {
1699+
MachineBasicBlock *MBB = MI.getParent();
1700+
DebugLoc DL = MI.getDebugLoc();
1701+
1702+
Register DestReg = MI.getOperand(0).getReg();
1703+
Register MaskReg = MI.getOperand(1).getReg();
1704+
Register DestRegSavedRef = DestReg;
1705+
Register Src1Reg, Src2Reg, CondReg;
1706+
1707+
// These operations will differ by operand register size.
1708+
unsigned RsbOp = ARM::RSBri;
1709+
unsigned AndOp = ARM::ANDrr;
1710+
unsigned BicOp = ARM::BICrr;
1711+
unsigned OrrOp = ARM::ORRrr;
1712+
1713+
if (Subtarget.isThumb2()) {
1714+
RsbOp = ARM::t2RSBri;
1715+
AndOp = ARM::t2ANDrr;
1716+
BicOp = ARM::t2BICrr;
1717+
OrrOp = ARM::t2ORRrr;
1718+
}
1719+
1720+
unsigned Opcode = MI.getOpcode();
1721+
bool IsFloat = Opcode == ARM::CTSELECTf32 || Opcode == ARM::CTSELECTf16 || Opcode == ARM::CTSELECTbf16;
1722+
MachineInstr *FirstNewMI = nullptr;
1723+
if (IsFloat) {
1724+
// Each float pseudo has: (outs $dst, $tmp_mask, $scratch1, $scratch2), (ins $src1, $src2, $cond))
1725+
// We use two scratch registers in tablegen for bitwise ops on float types,.
1726+
Register GPRScratch1 = MI.getOperand(2).getReg();
1727+
Register GPRScratch2 = MI.getOperand(3).getReg();
1728+
1729+
// choice a from __builtin_ct_select(cond, a, b)
1730+
Src1Reg = MI.getOperand(4).getReg();
1731+
// choice b from __builtin_ct_select(cond, a, b)
1732+
Src2Reg = MI.getOperand(5).getReg();
1733+
// cond from __builtin_ct_select(cond, a, b)
1734+
CondReg = MI.getOperand(6).getReg();
1735+
1736+
// Move fp src1 to GPR scratch1 so we can do our bitwise ops
1737+
FirstNewMI = BuildMI(*MBB, MI, DL, get(ARM::VMOVRS), GPRScratch1)
1738+
.addReg(Src1Reg)
1739+
.add(predOps(ARMCC::AL))
1740+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1741+
1742+
// Move src2 to scratch2
1743+
BuildMI(*MBB, MI, DL, get(ARM::VMOVRS), GPRScratch2)
1744+
.addReg(Src2Reg)
1745+
.add(predOps(ARMCC::AL))
1746+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1747+
1748+
Src1Reg = GPRScratch1;
1749+
Src2Reg = GPRScratch2;
1750+
// Reuse GPRScratch1 for dest after we are done working with src1.
1751+
DestReg = GPRScratch1;
1752+
} else {
1753+
// Any non-float, non-vector pseudo has: (outs $dst, $tmp_mask), (ins $src1, $src2, $cond))
1754+
Src1Reg = MI.getOperand(2).getReg();
1755+
Src2Reg = MI.getOperand(3).getReg();
1756+
CondReg = MI.getOperand(4).getReg();
1757+
}
1758+
1759+
// The following sequence of steps yields: (src1 & mask) | (src2 & ~mask)
1760+
1761+
// 1. mask = 0 - cond
1762+
// When cond = 0: mask = 0x00000000.
1763+
// When cond = 1: mask = 0xFFFFFFFF.
1764+
auto TmpNewMI = BuildMI(*MBB, MI, DL, get(RsbOp), MaskReg)
1765+
.addReg(CondReg)
1766+
.addImm(0)
1767+
.add(predOps(ARMCC::AL))
1768+
.add(condCodeOp())
1769+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1770+
1771+
// We use the first instruction in the bundle as the first instruction.
1772+
if (!FirstNewMI)
1773+
FirstNewMI = TmpNewMI;
1774+
1775+
// 2. A = src1 & mask
1776+
BuildMI(*MBB, MI, DL, get(AndOp), DestReg)
1777+
.addReg(Src1Reg)
1778+
.addReg(MaskReg)
1779+
.add(predOps(ARMCC::AL))
1780+
.add(condCodeOp())
1781+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1782+
1783+
// 3. B = src2 & ~mask
1784+
BuildMI(*MBB, MI, DL, get(BicOp), MaskReg)
1785+
.addReg(Src2Reg)
1786+
.addReg(MaskReg)
1787+
.add(predOps(ARMCC::AL))
1788+
.add(condCodeOp())
1789+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1790+
1791+
// 4. result = A | B
1792+
auto LastNewMI = BuildMI(*MBB, MI, DL, get(OrrOp), DestReg)
1793+
.addReg(DestReg)
1794+
.addReg(MaskReg)
1795+
.add(predOps(ARMCC::AL))
1796+
.add(condCodeOp())
1797+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1798+
1799+
if (IsFloat) {
1800+
// Return our result from GPR to the correct register type.
1801+
LastNewMI =BuildMI(*MBB, MI, DL, get(ARM::VMOVSR), DestRegSavedRef)
1802+
.addReg(DestReg)
1803+
.add(predOps(ARMCC::AL))
1804+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1805+
}
1806+
1807+
auto BundleStart = FirstNewMI->getIterator();
1808+
auto BundleEnd = LastNewMI->getIterator();
1809+
1810+
// Add instruction bundling
1811+
finalizeBundle(*MBB, BundleStart, std::next(BundleEnd));
1812+
1813+
MI.eraseFromParent();
1814+
return true;
1815+
}
1816+
15291817
bool ARMBaseInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
1530-
if (MI.getOpcode() == TargetOpcode::LOAD_STACK_GUARD) {
1818+
auto opcode = MI.getOpcode();
1819+
1820+
if (opcode == TargetOpcode::LOAD_STACK_GUARD) {
15311821
expandLoadStackGuard(MI);
15321822
MI.getParent()->erase(MI);
15331823
return true;
15341824
}
15351825

1536-
if (MI.getOpcode() == ARM::MEMCPY) {
1826+
if (opcode == ARM::MEMCPY) {
15371827
expandMEMCPY(MI);
15381828
return true;
15391829
}
15401830

1831+
if (opcode == ARM::CTSELECTf64) {
1832+
if (Subtarget.isThumb1Only()) {
1833+
LLVM_DEBUG(dbgs() << "Opcode (thumb1 subtarget) " << opcode << "replaced by: " << MI);
1834+
return expandCtSelectThumb(MI);
1835+
} else {
1836+
LLVM_DEBUG(dbgs() << "Opcode (vector) " << opcode << "replaced by: " << MI);
1837+
return expandCtSelectVector(MI);
1838+
}
1839+
}
1840+
1841+
if (opcode == ARM::CTSELECTv8i8 ||
1842+
opcode == ARM::CTSELECTv4i16 ||
1843+
opcode == ARM::CTSELECTv2i32 ||
1844+
opcode == ARM::CTSELECTv1i64 ||
1845+
opcode == ARM::CTSELECTv2f32 ||
1846+
opcode == ARM::CTSELECTv4f16 ||
1847+
opcode == ARM::CTSELECTv4bf16 ||
1848+
opcode == ARM::CTSELECTv16i8 ||
1849+
opcode == ARM::CTSELECTv8i16 ||
1850+
opcode == ARM::CTSELECTv4i32 ||
1851+
opcode == ARM::CTSELECTv2i64 ||
1852+
opcode == ARM::CTSELECTv4f32 ||
1853+
opcode == ARM::CTSELECTv2f64 ||
1854+
opcode == ARM::CTSELECTv8f16 ||
1855+
opcode == ARM::CTSELECTv8bf16) {
1856+
LLVM_DEBUG(dbgs() << "Opcode (vector) " << opcode << "replaced by: " << MI);
1857+
return expandCtSelectVector(MI);
1858+
}
1859+
1860+
if (opcode == ARM::CTSELECTint ||
1861+
opcode == ARM::CTSELECTf16 ||
1862+
opcode == ARM::CTSELECTbf16 ||
1863+
opcode == ARM::CTSELECTf32) {
1864+
if (Subtarget.isThumb1Only()) {
1865+
LLVM_DEBUG(dbgs() << "Opcode (thumb1 subtarget) " << opcode << "replaced by: " << MI);
1866+
return expandCtSelectThumb(MI);
1867+
} else {
1868+
LLVM_DEBUG(dbgs() << "Opcode " << opcode << "replaced by: " << MI);
1869+
return expandCtSelect(MI);
1870+
}
1871+
}
1872+
15411873
// This hook gets to expand COPY instructions before they become
15421874
// copyPhysReg() calls. Look for VMOVS instructions that can legally be
15431875
// widened to VMOVD. We prefer the VMOVD when possible because it may be

llvm/lib/Target/ARM/ARMBaseInstrInfo.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,12 @@ class ARMBaseInstrInfo : public ARMGenInstrInfo {
221221
const TargetRegisterInfo *TRI, Register VReg,
222222
MachineInstr::MIFlag Flags = MachineInstr::NoFlags) const override;
223223

224+
bool expandCtSelectVector(MachineInstr &MI) const;
225+
226+
bool expandCtSelectThumb(MachineInstr &MI) const;
227+
228+
bool expandCtSelect(MachineInstr &MI) const;
229+
224230
bool expandPostRAPseudo(MachineInstr &MI) const override;
225231

226232
bool shouldSink(const MachineInstr &MI) const override;

0 commit comments

Comments
 (0)