Skip to content

Commit 8d58556

Browse files
[LLVM][ARM] Add native ct.select support for ARM32 and Thumb
This patch implements architecture-specific lowering for ct.select on ARM (both ARM32 and Thumb modes) using conditional move instructions and bitwise operations for constant-time selection. Implementation details: - Uses pseudo-instructions that are expanded Post-RA to bitwise operations - Post-RA expansion in ARMBaseInstrInfo for BUNDLE pseudo-instructions - Handles scalar integer types, floating-point, and half-precision types - Handles vector types with NEON when available - Support for both ARM and Thumb instruction sets (Thumb1 and Thumb2) - Special handling for Thumb1 which lacks conditional execution - Comprehensive test coverage including half-precision and vectors The implementation includes: - ISelLowering: Custom lowering to CTSELECT pseudo-instructions - ISelDAGToDAG: Selection of appropriate pseudo-instructions - BaseInstrInfo: Post-RA expansion of BUNDLE to bitwise instruction sequences - InstrInfo.td: Pseudo-instruction definitions for different types - TargetMachine: Registration of Post-RA expansion pass - Proper handling of condition codes and register allocation constraints
1 parent 6ac8221 commit 8d58556

File tree

10 files changed

+4499
-29
lines changed

10 files changed

+4499
-29
lines changed

llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp

Lines changed: 335 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,18 +1526,351 @@ void ARMBaseInstrInfo::expandMEMCPY(MachineBasicBlock::iterator MI) const {
15261526
BB->erase(MI);
15271527
}
15281528

1529+
// Expands the ctselect pseudo for vector operands, post-RA.
1530+
bool ARMBaseInstrInfo::expandCtSelectVector(MachineInstr &MI) const {
1531+
MachineBasicBlock *MBB = MI.getParent();
1532+
DebugLoc DL = MI.getDebugLoc();
1533+
1534+
Register DestReg = MI.getOperand(0).getReg();
1535+
Register MaskReg = MI.getOperand(1).getReg();
1536+
1537+
// These operations will differ by operand register size.
1538+
unsigned AndOp = ARM::VANDd;
1539+
unsigned BicOp = ARM::VBICd;
1540+
unsigned OrrOp = ARM::VORRd;
1541+
unsigned BroadcastOp = ARM::VDUP32d;
1542+
1543+
const TargetRegisterInfo *TRI = &getRegisterInfo();
1544+
const TargetRegisterClass *RC = TRI->getMinimalPhysRegClass(DestReg);
1545+
1546+
if (ARM::QPRRegClass.hasSubClassEq(RC)) {
1547+
AndOp = ARM::VANDq;
1548+
BicOp = ARM::VBICq;
1549+
OrrOp = ARM::VORRq;
1550+
BroadcastOp = ARM::VDUP32q;
1551+
}
1552+
1553+
unsigned RsbOp = Subtarget.isThumb2() ? ARM::t2RSBri : ARM::RSBri;
1554+
1555+
// Any vector pseudo has: ((outs $dst, $tmp_mask, $bcast_mask), (ins $src1,
1556+
// $src2, $cond))
1557+
Register VectorMaskReg = MI.getOperand(2).getReg();
1558+
Register Src1Reg = MI.getOperand(3).getReg();
1559+
Register Src2Reg = MI.getOperand(4).getReg();
1560+
Register CondReg = MI.getOperand(5).getReg();
1561+
1562+
// The following sequence of steps yields: (src1 & mask) | (src2 & ~mask)
1563+
1564+
// 1. mask = 0 - cond
1565+
// When cond = 0: mask = 0x00000000.
1566+
// When cond = 1: mask = 0xFFFFFFFF.
1567+
1568+
MachineInstr *FirstNewMI = BuildMI(*MBB, MI, DL, get(RsbOp), MaskReg)
1569+
.addReg(CondReg)
1570+
.addImm(0)
1571+
.add(predOps(ARMCC::AL))
1572+
.add(condCodeOp())
1573+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1574+
1575+
// 2. A = src1 & mask
1576+
// For vectors, broadcast the scalar mask so it matches operand size.
1577+
BuildMI(*MBB, MI, DL, get(BroadcastOp), VectorMaskReg)
1578+
.addReg(MaskReg)
1579+
.add(predOps(ARMCC::AL))
1580+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1581+
1582+
BuildMI(*MBB, MI, DL, get(AndOp), DestReg)
1583+
.addReg(Src1Reg)
1584+
.addReg(VectorMaskReg)
1585+
.add(predOps(ARMCC::AL))
1586+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1587+
1588+
// 3. B = src2 & ~mask
1589+
BuildMI(*MBB, MI, DL, get(BicOp), VectorMaskReg)
1590+
.addReg(Src2Reg)
1591+
.addReg(VectorMaskReg)
1592+
.add(predOps(ARMCC::AL))
1593+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1594+
1595+
// 4. result = A | B
1596+
auto LastNewMI = BuildMI(*MBB, MI, DL, get(OrrOp), DestReg)
1597+
.addReg(DestReg)
1598+
.addReg(VectorMaskReg)
1599+
.add(predOps(ARMCC::AL))
1600+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1601+
1602+
auto BundleStart = FirstNewMI->getIterator();
1603+
auto BundleEnd = LastNewMI->getIterator();
1604+
1605+
// Add instruction bundling
1606+
finalizeBundle(*MBB, BundleStart, std::next(BundleEnd));
1607+
1608+
MI.eraseFromParent();
1609+
return true;
1610+
}
1611+
1612+
// Expands the ctselect pseudo for thumb1, post-RA.
1613+
bool ARMBaseInstrInfo::expandCtSelectThumb(MachineInstr &MI) const {
1614+
MachineBasicBlock *MBB = MI.getParent();
1615+
DebugLoc DL = MI.getDebugLoc();
1616+
1617+
// pseudos in thumb1 mode have: (outs $dst, $tmp_mask), (ins $src1, $src2,
1618+
// $cond)) register class here is always tGPR.
1619+
Register DestReg = MI.getOperand(0).getReg();
1620+
Register MaskReg = MI.getOperand(1).getReg();
1621+
Register Src1Reg = MI.getOperand(2).getReg();
1622+
Register Src2Reg = MI.getOperand(3).getReg();
1623+
Register CondReg = MI.getOperand(4).getReg();
1624+
1625+
// Access register info
1626+
MachineFunction *MF = MBB->getParent();
1627+
const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
1628+
MachineRegisterInfo &MRI = MF->getRegInfo();
1629+
1630+
unsigned RegSize = TRI->getRegSizeInBits(MaskReg, MRI);
1631+
unsigned ShiftAmount = RegSize - 1;
1632+
1633+
// Option 1: Shift-based mask (preferred - no flag modification)
1634+
MachineInstr *FirstNewMI = BuildMI(*MBB, MI, DL, get(ARM::tMOVr), MaskReg)
1635+
.addReg(CondReg)
1636+
.add(predOps(ARMCC::AL))
1637+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1638+
1639+
// Instead of using RSB, we can use LSL and ASR to get the mask. This is to
1640+
// avoid the flag modification caused by RSB. tLSLri: (outs tGPR:$Rd,
1641+
// s_cc_out:$s), (ins tGPR:$Rm, imm0_31:$imm5, pred:$p)
1642+
BuildMI(*MBB, MI, DL, get(ARM::tLSLri), MaskReg)
1643+
.addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
1644+
.addReg(MaskReg) // $Rm
1645+
.addImm(ShiftAmount) // imm0_31:$imm5
1646+
.add(predOps(ARMCC::AL)) // pred:$p
1647+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1648+
1649+
// tASRri: (outs tGPR:$Rd, s_cc_out:$s), (ins tGPR:$Rm, imm_sr:$imm5, pred:$p)
1650+
BuildMI(*MBB, MI, DL, get(ARM::tASRri), MaskReg)
1651+
.addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
1652+
.addReg(MaskReg) // $Rm
1653+
.addImm(ShiftAmount) // imm_sr:$imm5
1654+
.add(predOps(ARMCC::AL)) // pred:$p
1655+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1656+
1657+
// 2. xor_diff = src1 ^ src2
1658+
BuildMI(*MBB, MI, DL, get(ARM::tMOVr), DestReg)
1659+
.addReg(Src1Reg)
1660+
.add(predOps(ARMCC::AL))
1661+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1662+
1663+
// tEOR has tied operands: (outs tGPR:$Rdn, s_cc_out:$s), (ins tGPR:$Rn,
1664+
// pred:$p) with constraint "$Rn = $Rdn"
1665+
BuildMI(*MBB, MI, DL, get(ARM::tEOR), DestReg)
1666+
.addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
1667+
.addReg(DestReg) // tied input $Rn
1668+
.addReg(Src2Reg) // $Rm
1669+
.add(predOps(ARMCC::AL)) // pred:$p
1670+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1671+
1672+
// 3. masked_xor = xor_diff & mask
1673+
// tAND has tied operands: (outs tGPR:$Rdn, s_cc_out:$s), (ins tGPR:$Rn,
1674+
// pred:$p) with constraint "$Rn = $Rdn"
1675+
BuildMI(*MBB, MI, DL, get(ARM::tAND), DestReg)
1676+
.addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
1677+
.addReg(DestReg) // tied input $Rn
1678+
.addReg(MaskReg, RegState::Kill) // $Rm
1679+
.add(predOps(ARMCC::AL)) // pred:$p
1680+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1681+
1682+
// 4. result = src2 ^ masked_xor
1683+
// tEOR has tied operands: (outs tGPR:$Rdn, s_cc_out:$s), (ins tGPR:$Rn,
1684+
// pred:$p) with constraint "$Rn = $Rdn"
1685+
auto LastMI =
1686+
BuildMI(*MBB, MI, DL, get(ARM::tEOR), DestReg)
1687+
.addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
1688+
.addReg(DestReg) // tied input $Rn
1689+
.addReg(Src2Reg) // $Rm
1690+
.add(predOps(ARMCC::AL)) // pred:$p
1691+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1692+
1693+
// Add instruction bundling
1694+
auto BundleStart = FirstNewMI->getIterator();
1695+
finalizeBundle(*MBB, BundleStart, std::next(LastMI->getIterator()));
1696+
1697+
MI.eraseFromParent();
1698+
return true;
1699+
}
1700+
1701+
// Expands the ctselect pseudo, post-RA.
1702+
bool ARMBaseInstrInfo::expandCtSelect(MachineInstr &MI) const {
1703+
MachineBasicBlock *MBB = MI.getParent();
1704+
DebugLoc DL = MI.getDebugLoc();
1705+
1706+
Register DestReg = MI.getOperand(0).getReg();
1707+
Register MaskReg = MI.getOperand(1).getReg();
1708+
Register DestRegSavedRef = DestReg;
1709+
Register Src1Reg, Src2Reg, CondReg;
1710+
1711+
// These operations will differ by operand register size.
1712+
unsigned RsbOp = ARM::RSBri;
1713+
unsigned AndOp = ARM::ANDrr;
1714+
unsigned BicOp = ARM::BICrr;
1715+
unsigned OrrOp = ARM::ORRrr;
1716+
1717+
if (Subtarget.isThumb2()) {
1718+
RsbOp = ARM::t2RSBri;
1719+
AndOp = ARM::t2ANDrr;
1720+
BicOp = ARM::t2BICrr;
1721+
OrrOp = ARM::t2ORRrr;
1722+
}
1723+
1724+
unsigned Opcode = MI.getOpcode();
1725+
bool IsFloat = Opcode == ARM::CTSELECTf32 || Opcode == ARM::CTSELECTf16 ||
1726+
Opcode == ARM::CTSELECTbf16;
1727+
MachineInstr *FirstNewMI = nullptr;
1728+
if (IsFloat) {
1729+
// Each float pseudo has: (outs $dst, $tmp_mask, $scratch1, $scratch2), (ins
1730+
// $src1, $src2, $cond)) We use two scratch registers in tablegen for
1731+
// bitwise ops on float types,.
1732+
Register GPRScratch1 = MI.getOperand(2).getReg();
1733+
Register GPRScratch2 = MI.getOperand(3).getReg();
1734+
1735+
// choice a from __builtin_ct_select(cond, a, b)
1736+
Src1Reg = MI.getOperand(4).getReg();
1737+
// choice b from __builtin_ct_select(cond, a, b)
1738+
Src2Reg = MI.getOperand(5).getReg();
1739+
// cond from __builtin_ct_select(cond, a, b)
1740+
CondReg = MI.getOperand(6).getReg();
1741+
1742+
// Move fp src1 to GPR scratch1 so we can do our bitwise ops
1743+
FirstNewMI = BuildMI(*MBB, MI, DL, get(ARM::VMOVRS), GPRScratch1)
1744+
.addReg(Src1Reg)
1745+
.add(predOps(ARMCC::AL))
1746+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1747+
1748+
// Move src2 to scratch2
1749+
BuildMI(*MBB, MI, DL, get(ARM::VMOVRS), GPRScratch2)
1750+
.addReg(Src2Reg)
1751+
.add(predOps(ARMCC::AL))
1752+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1753+
1754+
Src1Reg = GPRScratch1;
1755+
Src2Reg = GPRScratch2;
1756+
// Reuse GPRScratch1 for dest after we are done working with src1.
1757+
DestReg = GPRScratch1;
1758+
} else {
1759+
// Any non-float, non-vector pseudo has: (outs $dst, $tmp_mask), (ins $src1,
1760+
// $src2, $cond))
1761+
Src1Reg = MI.getOperand(2).getReg();
1762+
Src2Reg = MI.getOperand(3).getReg();
1763+
CondReg = MI.getOperand(4).getReg();
1764+
}
1765+
1766+
// The following sequence of steps yields: (src1 & mask) | (src2 & ~mask)
1767+
1768+
// 1. mask = 0 - cond
1769+
// When cond = 0: mask = 0x00000000.
1770+
// When cond = 1: mask = 0xFFFFFFFF.
1771+
auto TmpNewMI = BuildMI(*MBB, MI, DL, get(RsbOp), MaskReg)
1772+
.addReg(CondReg)
1773+
.addImm(0)
1774+
.add(predOps(ARMCC::AL))
1775+
.add(condCodeOp())
1776+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1777+
1778+
// We use the first instruction in the bundle as the first instruction.
1779+
if (!FirstNewMI)
1780+
FirstNewMI = TmpNewMI;
1781+
1782+
// 2. A = src1 & mask
1783+
BuildMI(*MBB, MI, DL, get(AndOp), DestReg)
1784+
.addReg(Src1Reg)
1785+
.addReg(MaskReg)
1786+
.add(predOps(ARMCC::AL))
1787+
.add(condCodeOp())
1788+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1789+
1790+
// 3. B = src2 & ~mask
1791+
BuildMI(*MBB, MI, DL, get(BicOp), MaskReg)
1792+
.addReg(Src2Reg)
1793+
.addReg(MaskReg)
1794+
.add(predOps(ARMCC::AL))
1795+
.add(condCodeOp())
1796+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1797+
1798+
// 4. result = A | B
1799+
auto LastNewMI = BuildMI(*MBB, MI, DL, get(OrrOp), DestReg)
1800+
.addReg(DestReg)
1801+
.addReg(MaskReg)
1802+
.add(predOps(ARMCC::AL))
1803+
.add(condCodeOp())
1804+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1805+
1806+
if (IsFloat) {
1807+
// Return our result from GPR to the correct register type.
1808+
LastNewMI = BuildMI(*MBB, MI, DL, get(ARM::VMOVSR), DestRegSavedRef)
1809+
.addReg(DestReg)
1810+
.add(predOps(ARMCC::AL))
1811+
.setMIFlag(MachineInstr::MIFlag::NoMerge);
1812+
}
1813+
1814+
auto BundleStart = FirstNewMI->getIterator();
1815+
auto BundleEnd = LastNewMI->getIterator();
1816+
1817+
// Add instruction bundling
1818+
finalizeBundle(*MBB, BundleStart, std::next(BundleEnd));
1819+
1820+
MI.eraseFromParent();
1821+
return true;
1822+
}
1823+
15291824
bool ARMBaseInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
1530-
if (MI.getOpcode() == TargetOpcode::LOAD_STACK_GUARD) {
1825+
auto opcode = MI.getOpcode();
1826+
1827+
if (opcode == TargetOpcode::LOAD_STACK_GUARD) {
15311828
expandLoadStackGuard(MI);
15321829
MI.getParent()->erase(MI);
15331830
return true;
15341831
}
15351832

1536-
if (MI.getOpcode() == ARM::MEMCPY) {
1833+
if (opcode == ARM::MEMCPY) {
15371834
expandMEMCPY(MI);
15381835
return true;
15391836
}
15401837

1838+
if (opcode == ARM::CTSELECTf64) {
1839+
if (Subtarget.isThumb1Only()) {
1840+
LLVM_DEBUG(dbgs() << "Opcode (thumb1 subtarget) " << opcode
1841+
<< "replaced by: " << MI);
1842+
return expandCtSelectThumb(MI);
1843+
} else {
1844+
LLVM_DEBUG(dbgs() << "Opcode (vector) " << opcode
1845+
<< "replaced by: " << MI);
1846+
return expandCtSelectVector(MI);
1847+
}
1848+
}
1849+
1850+
if (opcode == ARM::CTSELECTv8i8 || opcode == ARM::CTSELECTv4i16 ||
1851+
opcode == ARM::CTSELECTv2i32 || opcode == ARM::CTSELECTv1i64 ||
1852+
opcode == ARM::CTSELECTv2f32 || opcode == ARM::CTSELECTv4f16 ||
1853+
opcode == ARM::CTSELECTv4bf16 || opcode == ARM::CTSELECTv16i8 ||
1854+
opcode == ARM::CTSELECTv8i16 || opcode == ARM::CTSELECTv4i32 ||
1855+
opcode == ARM::CTSELECTv2i64 || opcode == ARM::CTSELECTv4f32 ||
1856+
opcode == ARM::CTSELECTv2f64 || opcode == ARM::CTSELECTv8f16 ||
1857+
opcode == ARM::CTSELECTv8bf16) {
1858+
LLVM_DEBUG(dbgs() << "Opcode (vector) " << opcode << "replaced by: " << MI);
1859+
return expandCtSelectVector(MI);
1860+
}
1861+
1862+
if (opcode == ARM::CTSELECTint || opcode == ARM::CTSELECTf16 ||
1863+
opcode == ARM::CTSELECTbf16 || opcode == ARM::CTSELECTf32) {
1864+
if (Subtarget.isThumb1Only()) {
1865+
LLVM_DEBUG(dbgs() << "Opcode (thumb1 subtarget) " << opcode
1866+
<< "replaced by: " << MI);
1867+
return expandCtSelectThumb(MI);
1868+
} else {
1869+
LLVM_DEBUG(dbgs() << "Opcode " << opcode << "replaced by: " << MI);
1870+
return expandCtSelect(MI);
1871+
}
1872+
}
1873+
15411874
// This hook gets to expand COPY instructions before they become
15421875
// copyPhysReg() calls. Look for VMOVS instructions that can legally be
15431876
// widened to VMOVD. We prefer the VMOVD when possible because it may be

llvm/lib/Target/ARM/ARMBaseInstrInfo.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,12 @@ class ARMBaseInstrInfo : public ARMGenInstrInfo {
221221
const TargetRegisterInfo *TRI, Register VReg,
222222
MachineInstr::MIFlag Flags = MachineInstr::NoFlags) const override;
223223

224+
bool expandCtSelectVector(MachineInstr &MI) const;
225+
226+
bool expandCtSelectThumb(MachineInstr &MI) const;
227+
228+
bool expandCtSelect(MachineInstr &MI) const;
229+
224230
bool expandPostRAPseudo(MachineInstr &MI) const override;
225231

226232
bool shouldSink(const MachineInstr &MI) const override;

0 commit comments

Comments
 (0)