Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
337 changes: 335 additions & 2 deletions llvm/lib/Target/ARM/ARMBaseInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1526,18 +1526,351 @@ void ARMBaseInstrInfo::expandMEMCPY(MachineBasicBlock::iterator MI) const {
BB->erase(MI);
}

// Expands the ctselect pseudo for vector operands, post-RA.
bool ARMBaseInstrInfo::expandCtSelectVector(MachineInstr &MI) const {
MachineBasicBlock *MBB = MI.getParent();
DebugLoc DL = MI.getDebugLoc();

Register DestReg = MI.getOperand(0).getReg();
Register MaskReg = MI.getOperand(1).getReg();

// These operations will differ by operand register size.
unsigned AndOp = ARM::VANDd;
unsigned BicOp = ARM::VBICd;
unsigned OrrOp = ARM::VORRd;
unsigned BroadcastOp = ARM::VDUP32d;

const TargetRegisterInfo *TRI = &getRegisterInfo();
const TargetRegisterClass *RC = TRI->getMinimalPhysRegClass(DestReg);

if (ARM::QPRRegClass.hasSubClassEq(RC)) {
AndOp = ARM::VANDq;
BicOp = ARM::VBICq;
OrrOp = ARM::VORRq;
BroadcastOp = ARM::VDUP32q;
}

unsigned RsbOp = Subtarget.isThumb2() ? ARM::t2RSBri : ARM::RSBri;

// Any vector pseudo has: ((outs $dst, $tmp_mask, $bcast_mask), (ins $src1,
// $src2, $cond))
Register VectorMaskReg = MI.getOperand(2).getReg();
Register Src1Reg = MI.getOperand(3).getReg();
Register Src2Reg = MI.getOperand(4).getReg();
Register CondReg = MI.getOperand(5).getReg();

// The following sequence of steps yields: (src1 & mask) | (src2 & ~mask)

// 1. mask = 0 - cond
// When cond = 0: mask = 0x00000000.
// When cond = 1: mask = 0xFFFFFFFF.

MachineInstr *FirstNewMI = BuildMI(*MBB, MI, DL, get(RsbOp), MaskReg)
.addReg(CondReg)
.addImm(0)
.add(predOps(ARMCC::AL))
.add(condCodeOp())
.setMIFlag(MachineInstr::MIFlag::NoMerge);

// 2. A = src1 & mask
// For vectors, broadcast the scalar mask so it matches operand size.
BuildMI(*MBB, MI, DL, get(BroadcastOp), VectorMaskReg)
.addReg(MaskReg)
.add(predOps(ARMCC::AL))
.setMIFlag(MachineInstr::MIFlag::NoMerge);

BuildMI(*MBB, MI, DL, get(AndOp), DestReg)
.addReg(Src1Reg)
.addReg(VectorMaskReg)
.add(predOps(ARMCC::AL))
.setMIFlag(MachineInstr::MIFlag::NoMerge);

// 3. B = src2 & ~mask
BuildMI(*MBB, MI, DL, get(BicOp), VectorMaskReg)
.addReg(Src2Reg)
.addReg(VectorMaskReg)
.add(predOps(ARMCC::AL))
.setMIFlag(MachineInstr::MIFlag::NoMerge);

// 4. result = A | B
auto LastNewMI = BuildMI(*MBB, MI, DL, get(OrrOp), DestReg)
.addReg(DestReg)
.addReg(VectorMaskReg)
.add(predOps(ARMCC::AL))
.setMIFlag(MachineInstr::MIFlag::NoMerge);

auto BundleStart = FirstNewMI->getIterator();
auto BundleEnd = LastNewMI->getIterator();

// Add instruction bundling
finalizeBundle(*MBB, BundleStart, std::next(BundleEnd));

MI.eraseFromParent();
return true;
}

// Expands the ctselect pseudo for thumb1, post-RA.
bool ARMBaseInstrInfo::expandCtSelectThumb(MachineInstr &MI) const {
MachineBasicBlock *MBB = MI.getParent();
DebugLoc DL = MI.getDebugLoc();

// pseudos in thumb1 mode have: (outs $dst, $tmp_mask), (ins $src1, $src2,
// $cond)) register class here is always tGPR.
Register DestReg = MI.getOperand(0).getReg();
Register MaskReg = MI.getOperand(1).getReg();
Register Src1Reg = MI.getOperand(2).getReg();
Register Src2Reg = MI.getOperand(3).getReg();
Register CondReg = MI.getOperand(4).getReg();

// Access register info
MachineFunction *MF = MBB->getParent();
const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
MachineRegisterInfo &MRI = MF->getRegInfo();

unsigned RegSize = TRI->getRegSizeInBits(MaskReg, MRI);
unsigned ShiftAmount = RegSize - 1;

// Option 1: Shift-based mask (preferred - no flag modification)
MachineInstr *FirstNewMI = BuildMI(*MBB, MI, DL, get(ARM::tMOVr), MaskReg)
.addReg(CondReg)
.add(predOps(ARMCC::AL))
.setMIFlag(MachineInstr::MIFlag::NoMerge);

// Instead of using RSB, we can use LSL and ASR to get the mask. This is to
// avoid the flag modification caused by RSB. tLSLri: (outs tGPR:$Rd,
// s_cc_out:$s), (ins tGPR:$Rm, imm0_31:$imm5, pred:$p)
BuildMI(*MBB, MI, DL, get(ARM::tLSLri), MaskReg)
.addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
.addReg(MaskReg) // $Rm
.addImm(ShiftAmount) // imm0_31:$imm5
.add(predOps(ARMCC::AL)) // pred:$p
.setMIFlag(MachineInstr::MIFlag::NoMerge);

// tASRri: (outs tGPR:$Rd, s_cc_out:$s), (ins tGPR:$Rm, imm_sr:$imm5, pred:$p)
BuildMI(*MBB, MI, DL, get(ARM::tASRri), MaskReg)
.addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
.addReg(MaskReg) // $Rm
.addImm(ShiftAmount) // imm_sr:$imm5
.add(predOps(ARMCC::AL)) // pred:$p
.setMIFlag(MachineInstr::MIFlag::NoMerge);

// 2. xor_diff = src1 ^ src2
BuildMI(*MBB, MI, DL, get(ARM::tMOVr), DestReg)
.addReg(Src1Reg)
.add(predOps(ARMCC::AL))
.setMIFlag(MachineInstr::MIFlag::NoMerge);

// tEOR has tied operands: (outs tGPR:$Rdn, s_cc_out:$s), (ins tGPR:$Rn,
// pred:$p) with constraint "$Rn = $Rdn"
BuildMI(*MBB, MI, DL, get(ARM::tEOR), DestReg)
.addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
.addReg(DestReg) // tied input $Rn
.addReg(Src2Reg) // $Rm
.add(predOps(ARMCC::AL)) // pred:$p
.setMIFlag(MachineInstr::MIFlag::NoMerge);

// 3. masked_xor = xor_diff & mask
// tAND has tied operands: (outs tGPR:$Rdn, s_cc_out:$s), (ins tGPR:$Rn,
// pred:$p) with constraint "$Rn = $Rdn"
BuildMI(*MBB, MI, DL, get(ARM::tAND), DestReg)
.addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
.addReg(DestReg) // tied input $Rn
.addReg(MaskReg, RegState::Kill) // $Rm
.add(predOps(ARMCC::AL)) // pred:$p
.setMIFlag(MachineInstr::MIFlag::NoMerge);

// 4. result = src2 ^ masked_xor
// tEOR has tied operands: (outs tGPR:$Rdn, s_cc_out:$s), (ins tGPR:$Rn,
// pred:$p) with constraint "$Rn = $Rdn"
auto LastMI =
BuildMI(*MBB, MI, DL, get(ARM::tEOR), DestReg)
.addReg(ARM::CPSR, RegState::Define | RegState::Dead) // s_cc_out:$s
.addReg(DestReg) // tied input $Rn
.addReg(Src2Reg) // $Rm
.add(predOps(ARMCC::AL)) // pred:$p
.setMIFlag(MachineInstr::MIFlag::NoMerge);

// Add instruction bundling
auto BundleStart = FirstNewMI->getIterator();
finalizeBundle(*MBB, BundleStart, std::next(LastMI->getIterator()));

MI.eraseFromParent();
return true;
}

// Expands the ctselect pseudo, post-RA.
bool ARMBaseInstrInfo::expandCtSelect(MachineInstr &MI) const {
MachineBasicBlock *MBB = MI.getParent();
DebugLoc DL = MI.getDebugLoc();

Register DestReg = MI.getOperand(0).getReg();
Register MaskReg = MI.getOperand(1).getReg();
Register DestRegSavedRef = DestReg;
Register Src1Reg, Src2Reg, CondReg;

// These operations will differ by operand register size.
unsigned RsbOp = ARM::RSBri;
unsigned AndOp = ARM::ANDrr;
unsigned BicOp = ARM::BICrr;
unsigned OrrOp = ARM::ORRrr;

if (Subtarget.isThumb2()) {
RsbOp = ARM::t2RSBri;
AndOp = ARM::t2ANDrr;
BicOp = ARM::t2BICrr;
OrrOp = ARM::t2ORRrr;
}

unsigned Opcode = MI.getOpcode();
bool IsFloat = Opcode == ARM::CTSELECTf32 || Opcode == ARM::CTSELECTf16 ||
Opcode == ARM::CTSELECTbf16;
MachineInstr *FirstNewMI = nullptr;
if (IsFloat) {
// Each float pseudo has: (outs $dst, $tmp_mask, $scratch1, $scratch2), (ins
// $src1, $src2, $cond)) We use two scratch registers in tablegen for
// bitwise ops on float types,.
Register GPRScratch1 = MI.getOperand(2).getReg();
Register GPRScratch2 = MI.getOperand(3).getReg();

// choice a from __builtin_ct_select(cond, a, b)
Src1Reg = MI.getOperand(4).getReg();
// choice b from __builtin_ct_select(cond, a, b)
Src2Reg = MI.getOperand(5).getReg();
// cond from __builtin_ct_select(cond, a, b)
CondReg = MI.getOperand(6).getReg();

// Move fp src1 to GPR scratch1 so we can do our bitwise ops
FirstNewMI = BuildMI(*MBB, MI, DL, get(ARM::VMOVRS), GPRScratch1)
.addReg(Src1Reg)
.add(predOps(ARMCC::AL))
.setMIFlag(MachineInstr::MIFlag::NoMerge);

// Move src2 to scratch2
BuildMI(*MBB, MI, DL, get(ARM::VMOVRS), GPRScratch2)
.addReg(Src2Reg)
.add(predOps(ARMCC::AL))
.setMIFlag(MachineInstr::MIFlag::NoMerge);

Src1Reg = GPRScratch1;
Src2Reg = GPRScratch2;
// Reuse GPRScratch1 for dest after we are done working with src1.
DestReg = GPRScratch1;
} else {
// Any non-float, non-vector pseudo has: (outs $dst, $tmp_mask), (ins $src1,
// $src2, $cond))
Src1Reg = MI.getOperand(2).getReg();
Src2Reg = MI.getOperand(3).getReg();
CondReg = MI.getOperand(4).getReg();
}

// The following sequence of steps yields: (src1 & mask) | (src2 & ~mask)

// 1. mask = 0 - cond
// When cond = 0: mask = 0x00000000.
// When cond = 1: mask = 0xFFFFFFFF.
auto TmpNewMI = BuildMI(*MBB, MI, DL, get(RsbOp), MaskReg)
.addReg(CondReg)
.addImm(0)
.add(predOps(ARMCC::AL))
.add(condCodeOp())
.setMIFlag(MachineInstr::MIFlag::NoMerge);

// We use the first instruction in the bundle as the first instruction.
if (!FirstNewMI)
FirstNewMI = TmpNewMI;

// 2. A = src1 & mask
BuildMI(*MBB, MI, DL, get(AndOp), DestReg)
.addReg(Src1Reg)
.addReg(MaskReg)
.add(predOps(ARMCC::AL))
.add(condCodeOp())
.setMIFlag(MachineInstr::MIFlag::NoMerge);

// 3. B = src2 & ~mask
BuildMI(*MBB, MI, DL, get(BicOp), MaskReg)
.addReg(Src2Reg)
.addReg(MaskReg)
.add(predOps(ARMCC::AL))
.add(condCodeOp())
.setMIFlag(MachineInstr::MIFlag::NoMerge);

// 4. result = A | B
auto LastNewMI = BuildMI(*MBB, MI, DL, get(OrrOp), DestReg)
.addReg(DestReg)
.addReg(MaskReg)
.add(predOps(ARMCC::AL))
.add(condCodeOp())
.setMIFlag(MachineInstr::MIFlag::NoMerge);

if (IsFloat) {
// Return our result from GPR to the correct register type.
LastNewMI = BuildMI(*MBB, MI, DL, get(ARM::VMOVSR), DestRegSavedRef)
.addReg(DestReg)
.add(predOps(ARMCC::AL))
.setMIFlag(MachineInstr::MIFlag::NoMerge);
}

auto BundleStart = FirstNewMI->getIterator();
auto BundleEnd = LastNewMI->getIterator();

// Add instruction bundling
finalizeBundle(*MBB, BundleStart, std::next(BundleEnd));

MI.eraseFromParent();
return true;
}

bool ARMBaseInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
if (MI.getOpcode() == TargetOpcode::LOAD_STACK_GUARD) {
auto opcode = MI.getOpcode();

if (opcode == TargetOpcode::LOAD_STACK_GUARD) {
expandLoadStackGuard(MI);
MI.getParent()->erase(MI);
return true;
}

if (MI.getOpcode() == ARM::MEMCPY) {
if (opcode == ARM::MEMCPY) {
expandMEMCPY(MI);
return true;
}

if (opcode == ARM::CTSELECTf64) {
if (Subtarget.isThumb1Only()) {
LLVM_DEBUG(dbgs() << "Opcode (thumb1 subtarget) " << opcode
<< "replaced by: " << MI);
return expandCtSelectThumb(MI);
} else {
LLVM_DEBUG(dbgs() << "Opcode (vector) " << opcode
<< "replaced by: " << MI);
return expandCtSelectVector(MI);
}
}

if (opcode == ARM::CTSELECTv8i8 || opcode == ARM::CTSELECTv4i16 ||
opcode == ARM::CTSELECTv2i32 || opcode == ARM::CTSELECTv1i64 ||
opcode == ARM::CTSELECTv2f32 || opcode == ARM::CTSELECTv4f16 ||
opcode == ARM::CTSELECTv4bf16 || opcode == ARM::CTSELECTv16i8 ||
opcode == ARM::CTSELECTv8i16 || opcode == ARM::CTSELECTv4i32 ||
opcode == ARM::CTSELECTv2i64 || opcode == ARM::CTSELECTv4f32 ||
opcode == ARM::CTSELECTv2f64 || opcode == ARM::CTSELECTv8f16 ||
opcode == ARM::CTSELECTv8bf16) {
LLVM_DEBUG(dbgs() << "Opcode (vector) " << opcode << "replaced by: " << MI);
return expandCtSelectVector(MI);
}

if (opcode == ARM::CTSELECTint || opcode == ARM::CTSELECTf16 ||
opcode == ARM::CTSELECTbf16 || opcode == ARM::CTSELECTf32) {
if (Subtarget.isThumb1Only()) {
LLVM_DEBUG(dbgs() << "Opcode (thumb1 subtarget) " << opcode
<< "replaced by: " << MI);
return expandCtSelectThumb(MI);
} else {
LLVM_DEBUG(dbgs() << "Opcode " << opcode << "replaced by: " << MI);
return expandCtSelect(MI);
}
}

// This hook gets to expand COPY instructions before they become
// copyPhysReg() calls. Look for VMOVS instructions that can legally be
// widened to VMOVD. We prefer the VMOVD when possible because it may be
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/ARM/ARMBaseInstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@ class ARMBaseInstrInfo : public ARMGenInstrInfo {
const TargetRegisterInfo *TRI, Register VReg,
MachineInstr::MIFlag Flags = MachineInstr::NoFlags) const override;

bool expandCtSelectVector(MachineInstr &MI) const;

bool expandCtSelectThumb(MachineInstr &MI) const;

bool expandCtSelect(MachineInstr &MI) const;

bool expandPostRAPseudo(MachineInstr &MI) const override;

bool shouldSink(const MachineInstr &MI) const override;
Expand Down
Loading
Loading