1313// ===----------------------------------------------------------------------===//
1414
1515#include " llvm/CodeGen/LivePhysRegs.h"
16+ #include " llvm/CodeGen/MachineDominators.h"
1617#include " llvm/CodeGen/MachineFunctionPass.h"
1718#include " llvm/CodeGen/MachineInstrBuilder.h"
19+ #include " llvm/CodeGen/MachineLoopInfo.h"
1820#include " llvm/CodeGen/MachineRegisterInfo.h"
1921#include " llvm/IR/Module.h"
2022
@@ -40,14 +42,14 @@ class AArch64WinFixupBufferSecurityCheckPass : public MachineFunctionPass {
4042
4143 bool runOnMachineFunction (MachineFunction &MF) override ;
4244
43- std::pair<MachineBasicBlock *, MachineInstr *>
44- getSecurityCheckerBasicBlock (MachineFunction &MF);
45+ void getAnalysisUsage (AnalysisUsage &AU) const override ;
4546
46- MachineInstr *cloneLoadStackGuard (MachineBasicBlock *CurMBB,
47- MachineInstr *CheckCall );
47+ std::pair< MachineInstr *, MachineInstr *>
48+ findSecurityCheckAndLoadStackGuard (MachineFunction &MF );
4849
49- void getGuardCheckSequence (MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
50- MachineInstr *SeqMI[5 ]);
50+ MachineInstr *cloneLoadStackGuard (MachineFunction &MF, MachineInstr *MI);
51+
52+ bool getGuardCheckSequence (MachineInstr *CheckCall, MachineInstr *SeqMI[5 ]);
5153
5254 void finishBlock (MachineBasicBlock *MBB);
5355
@@ -64,93 +66,113 @@ FunctionPass *llvm::createAArch64WinFixupBufferSecurityCheckPass() {
6466 return new AArch64WinFixupBufferSecurityCheckPass ();
6567}
6668
67- std::pair<MachineBasicBlock *, MachineInstr *>
68- AArch64WinFixupBufferSecurityCheckPass::getSecurityCheckerBasicBlock (
69+ void AArch64WinFixupBufferSecurityCheckPass::getAnalysisUsage (
70+ AnalysisUsage &AU) const {
71+ AU.addUsedIfAvailable <MachineDominatorTreeWrapperPass>();
72+ AU.addPreserved <MachineDominatorTreeWrapperPass>();
73+ AU.addPreserved <MachineLoopInfoWrapperPass>();
74+ MachineFunctionPass::getAnalysisUsage (AU);
75+ }
76+
77+ std::pair<MachineInstr *, MachineInstr *>
78+ AArch64WinFixupBufferSecurityCheckPass::findSecurityCheckAndLoadStackGuard (
6979 MachineFunction &MF) {
80+
81+ MachineInstr *SecurityCheckCall = nullptr ;
82+ MachineInstr *LoadStackGuard = nullptr ;
83+
7084 for (auto &MBB : MF) {
7185 for (auto &MI : MBB) {
86+ if (!LoadStackGuard && MI.getOpcode () == TargetOpcode::LOAD_STACK_GUARD) {
87+ LoadStackGuard = &MI;
88+ }
89+
7290 if (MI.isCall () && MI.getNumExplicitOperands () == 1 ) {
7391 auto MO = MI.getOperand (0 );
7492 if (MO.isGlobal ()) {
7593 auto Callee = dyn_cast<Function>(MO.getGlobal ());
7694 if (Callee && Callee->getName () == " __security_check_cookie" ) {
77- return std::make_pair (&MBB, &MI) ;
95+ SecurityCheckCall = &MI;
7896 }
7997 }
8098 }
99+
100+ // If both are found, return them
101+ if (LoadStackGuard && SecurityCheckCall) {
102+ return std::make_pair (LoadStackGuard, SecurityCheckCall);
103+ }
81104 }
82105 }
106+
83107 return std::make_pair (nullptr , nullptr );
84108}
85109
86- MachineInstr *AArch64WinFixupBufferSecurityCheckPass::cloneLoadStackGuard (
87- MachineBasicBlock *CurMBB, MachineInstr *CheckCall) {
88- // Ensure that we have a valid MachineBasicBlock and CheckCall
89- if (!CurMBB || !CheckCall)
90- return nullptr ;
110+ MachineInstr *
111+ AArch64WinFixupBufferSecurityCheckPass::cloneLoadStackGuard (MachineFunction &MF,
112+ MachineInstr *MI) {
91113
92- MachineFunction &MF = *CurMBB->getParent ();
114+ MachineInstr *ClonedInstr = MF.CloneMachineInstr (MI);
115+
116+ // Get the register class of the original destination register
117+ Register OrigReg = MI->getOperand (0 ).getReg ();
93118 MachineRegisterInfo &MRI = MF.getRegInfo ();
119+ const TargetRegisterClass *RegClass = MRI.getRegClass (OrigReg);
94120
95- // Initialize reverse iterator starting just before CheckCall
96- MachineBasicBlock::reverse_iterator DIt (CheckCall);
97- MachineBasicBlock::reverse_iterator DEnd = CurMBB->rend ();
98-
99- // Reverse iterate from CheckCall to find LOAD_STACK_GUARD
100- for (; DIt != DEnd; ++DIt) {
101- MachineInstr &MI = *DIt;
102- if (MI.getOpcode () == TargetOpcode::LOAD_STACK_GUARD) {
103- // Clone the LOAD_STACK_GUARD instruction
104- MachineInstr *ClonedInstr = MF.CloneMachineInstr (&MI);
105-
106- // Get the register class of the original destination register
107- Register OrigReg = MI.getOperand (0 ).getReg ();
108- const TargetRegisterClass *RegClass = MRI.getRegClass (OrigReg);
109-
110- // Create a new virtual register in the same register class
111- Register NewReg = MRI.createVirtualRegister (RegClass);
112-
113- // Update operand 0 (destination) of the cloned instruction
114- MachineOperand &DestOperand = ClonedInstr->getOperand (0 );
115- if (DestOperand.isReg () && DestOperand.isDef ()) {
116- DestOperand.setReg (NewReg); // Set the new virtual register
117- }
121+ // Create a new virtual register in the same register class
122+ Register NewReg = MRI.createVirtualRegister (RegClass);
118123
119- // Return the modified cloned instruction
120- return ClonedInstr;
121- }
124+ // Update operand 0 (destination) of the cloned instruction
125+ MachineOperand &DestOperand = ClonedInstr->getOperand (0 );
126+ if (DestOperand.isReg () && DestOperand.isDef ()) {
127+ DestOperand.setReg (NewReg); // Set the new virtual register
122128 }
123129
124- // If no LOAD_STACK_GUARD instruction was found, return nullptr
125- return nullptr ;
130+ return ClonedInstr;
126131}
127132
128- void AArch64WinFixupBufferSecurityCheckPass::getGuardCheckSequence (
129- MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
130- MachineInstr *SeqMI[5 ]) {
133+ bool AArch64WinFixupBufferSecurityCheckPass::getGuardCheckSequence (
134+ MachineInstr *CheckCall, MachineInstr *SeqMI[5 ]) {
135+
136+ MachineBasicBlock *MBB = CheckCall->getParent ();
131137
132138 MachineBasicBlock::iterator UIt (CheckCall);
133139 MachineBasicBlock::reverse_iterator DIt (CheckCall);
134140
135141 // Move forward to find the stack adjustment after the call
136- // to __security_check_cookie
137142 ++UIt;
143+ if (UIt == MBB->end () || UIt->getOpcode () != AArch64::ADJCALLSTACKUP) {
144+ return false ;
145+ }
138146 SeqMI[4 ] = &*UIt;
139147
140148 // Assign the BL instruction (call to __security_check_cookie)
141149 SeqMI[3 ] = CheckCall;
142150
143- // COPY function slot cookie
151+ // Move backward to find the COPY instruction for the function slot cookie
152+ // argument passing
144153 ++DIt;
154+ if (DIt == MBB->rend () || DIt->getOpcode () != AArch64::COPY) {
155+ return false ;
156+ }
145157 SeqMI[2 ] = &*DIt;
146158
147159 // Move backward to find the instruction that loads the security cookie from
148160 // the stack
149161 ++DIt;
162+ if (DIt == MBB->rend () || DIt->getOpcode () != AArch64::LDRXui) {
163+ return false ;
164+ }
150165 SeqMI[1 ] = &*DIt;
151166
152- ++DIt; // Find ADJCALLSTACKDOWN
167+ // Move backward to find the stack adjustment before the call
168+ ++DIt;
169+ if (DIt == MBB->rend () || DIt->getOpcode () != AArch64::ADJCALLSTACKDOWN) {
170+ return false ;
171+ }
153172 SeqMI[0 ] = &*DIt;
173+
174+ // If all instructions are matched and stored, the sequence is valid
175+ return true ;
154176}
155177
156178void AArch64WinFixupBufferSecurityCheckPass::finishBlock (
@@ -185,21 +207,23 @@ bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
185207 if (!GV)
186208 return Changed;
187209
188- const TargetInstrInfo *TII = MF.getSubtarget ().getInstrInfo ();
189-
190- // Check if security check cookie call was installed or not
191- auto [CurMBB, CheckCall] = getSecurityCheckerBasicBlock (MF);
192- if (!CheckCall)
210+ // Find LOAD_STACK_GUARD and __security_check_cookie instructions
211+ auto [StackGuard, CheckCall] = findSecurityCheckAndLoadStackGuard (MF);
212+ if (!CheckCall || !StackGuard)
193213 return Changed;
194214
195- // Get sequence of instruction in CurMBB responsible for calling
215+ // Get sequence of instructions in current basic block responsible for calling
196216 // __security_check_cookie
197217 MachineInstr *SeqMI[5 ];
198- getGuardCheckSequence (CurMBB, CheckCall, SeqMI);
218+ if (!getGuardCheckSequence (CheckCall, SeqMI))
219+ return Changed;
220+
221+ const TargetInstrInfo *TII = MF.getSubtarget ().getInstrInfo ();
222+ MachineBasicBlock *CurMBB = CheckCall->getParent ();
199223
200224 // Find LOAD_STACK_GUARD in CurrMBB and build a new LOAD_STACK_GUARD
201225 // instruction with new destination register
202- MachineInstr *ClonedInstr = cloneLoadStackGuard (CurMBB, CheckCall );
226+ MachineInstr *ClonedInstr = cloneLoadStackGuard (MF, StackGuard );
203227 if (!ClonedInstr)
204228 return Changed;
205229
@@ -216,13 +240,14 @@ bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
216240 CurMBB->splice (InsertPt, CurMBB, std::next (InsertPt));
217241
218242 // Create a new virtual register for the CMP instruction result
219- Register DiscardReg =
220- MF. getRegInfo () .createVirtualRegister (&AArch64::GPR64RegClass);
243+ MachineRegisterInfo &MRI = MF. getRegInfo ();
244+ Register DiscardReg = MRI .createVirtualRegister (&AArch64::GPR64RegClass);
221245
222246 // Emit the CMP instruction to compare stack cookie with global cookie
223247 BuildMI (*CurMBB, InsertPt, DebugLoc (), TII->get (AArch64::SUBSXrr))
224- .addReg (DiscardReg, RegState::Define | RegState::Dead) // Result discarded
225- .addReg (CookieLoadReg) // First operand: stack cookie
248+ .addReg (DiscardReg,
249+ RegState::Define | RegState::Dead) // Result discarded
250+ .addReg (CookieLoadReg) // First operand: stack cookie
226251 .addReg (GlobalCookieReg); // Second operand: global cookie
227252
228253 // Create FailMBB basic block to call __security_check_cookie
@@ -258,6 +283,15 @@ bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
258283 CurMBB->addSuccessor (NewRetMBB);
259284 CurMBB->addSuccessor (FailMBB);
260285
286+ MachineDominatorTreeWrapperPass *WrapperPass =
287+ getAnalysisIfAvailable<MachineDominatorTreeWrapperPass>();
288+ MachineDominatorTree *MDT =
289+ WrapperPass ? &WrapperPass->getDomTree () : nullptr ;
290+ if (MDT) {
291+ MDT->addNewBlock (FailMBB, CurMBB);
292+ MDT->addNewBlock (NewRetMBB, CurMBB);
293+ }
294+
261295 finishFunction (FailMBB, NewRetMBB);
262296
263297 return !Changed;
0 commit comments