4444#include " llvm/CodeGen/ShadowStackGCLowering.h"
4545#include " llvm/CodeGen/SjLjEHPrepare.h"
4646#include " llvm/CodeGen/StackProtector.h"
47+ #include " llvm/CodeGen/TargetPassConfig.h"
4748#include " llvm/CodeGen/UnreachableBlockElim.h"
4849#include " llvm/CodeGen/WasmEHPrepare.h"
4950#include " llvm/CodeGen/WinEHPrepare.h"
@@ -176,73 +177,80 @@ template <typename DerivedT> class CodeGenPassBuilder {
176177 // Function object to maintain state while adding codegen IR passes.
177178 class AddIRPass {
178179 public:
179- AddIRPass (ModulePassManager &MPM) : MPM(MPM) {}
180+ AddIRPass (ModulePassManager &MPM, const DerivedT &PB ) : MPM(MPM), PB(PB ) {}
180181 ~AddIRPass () {
181182 if (!FPM.isEmpty ())
182183 MPM.addPass (createModuleToFunctionPassAdaptor (std::move (FPM)));
183184 }
184185
185- template <typename PassT> void operator ()(PassT &&Pass) {
186+ template <typename PassT>
187+ void operator ()(PassT &&Pass, StringRef Name = PassT::name()) {
186188 static_assert ((is_detected<is_function_pass_t , PassT>::value ||
187189 is_detected<is_module_pass_t , PassT>::value) &&
188190 " Only module pass and function pass are supported." );
189191
192+ if (!PB.runBeforeAdding (Name))
193+ return ;
194+
190195 // Add Function Pass
191196 if constexpr (is_detected<is_function_pass_t , PassT>::value) {
192197 FPM.addPass (std::forward<PassT>(Pass));
198+
199+ for (auto &C : PB.AfterCallbacks )
200+ C (Name);
193201 } else {
194202 // Add Module Pass
195203 if (!FPM.isEmpty ()) {
196204 MPM.addPass (createModuleToFunctionPassAdaptor (std::move (FPM)));
197205 FPM = FunctionPassManager ();
198206 }
207+
199208 MPM.addPass (std::forward<PassT>(Pass));
209+
210+ for (auto &C : PB.AfterCallbacks )
211+ C (Name);
200212 }
201213 }
202214
203215 private:
204216 ModulePassManager &MPM;
205217 FunctionPassManager FPM;
218+ const DerivedT &PB;
206219 };
207220
208221 // Function object to maintain state while adding codegen machine passes.
209222 class AddMachinePass {
210223 public:
211- AddMachinePass (MachineFunctionPassManager &PM) : PM(PM) {}
224+ AddMachinePass (MachineFunctionPassManager &PM, const DerivedT &PB)
225+ : PM(PM), PB(PB) {}
212226
213227 template <typename PassT> void operator ()(PassT &&Pass) {
214228 static_assert (
215229 is_detected<has_key_t , PassT>::value,
216230 " Machine function pass must define a static member variable `Key`." );
217- for (auto &C : BeforeCallbacks)
218- if (!C (&PassT::Key))
219- return ;
231+
232+ if (!PB.runBeforeAdding (PassT::name ()))
233+ return ;
234+
220235 PM.addPass (std::forward<PassT>(Pass));
221- for (auto &C : AfterCallbacks)
222- C (&PassT::Key);
236+
237+ for (auto &C : PB.AfterCallbacks )
238+ C (PassT::name ());
223239 }
224240
225241 template <typename PassT> void insertPass (MachinePassKey *ID, PassT Pass) {
226- AfterCallbacks.emplace_back (
242+ PB. AfterCallbacks .emplace_back (
227243 [this , ID, Pass = std::move (Pass)](MachinePassKey *PassID) {
228244 if (PassID == ID)
229245 this ->PM .addPass (std::move (Pass));
230246 });
231247 }
232248
233- void disablePass (MachinePassKey *ID) {
234- BeforeCallbacks.emplace_back (
235- [ID](MachinePassKey *PassID) { return PassID != ID; });
236- }
237-
238249 MachineFunctionPassManager releasePM () { return std::move (PM); }
239250
240251 private:
241252 MachineFunctionPassManager &PM;
242- SmallVector<llvm::unique_function<bool (MachinePassKey *)>, 4 >
243- BeforeCallbacks;
244- SmallVector<llvm::unique_function<void (MachinePassKey *)>, 4 >
245- AfterCallbacks;
253+ const DerivedT &PB;
246254 };
247255
248256 LLVMTargetMachine &TM;
@@ -473,20 +481,43 @@ template <typename DerivedT> class CodeGenPassBuilder {
473481 const DerivedT &derived () const {
474482 return static_cast <const DerivedT &>(*this );
475483 }
484+
485+ bool runBeforeAdding (StringRef Name) const {
486+ bool ShouldAdd = true ;
487+ for (auto &C : BeforeCallbacks)
488+ ShouldAdd &= C (Name);
489+ return ShouldAdd;
490+ }
491+
492+ void setStartStopPasses (const TargetPassConfig::StartStopInfo &Info) const ;
493+
494+ Error verifyStartStop (const TargetPassConfig::StartStopInfo &Info) const ;
495+
496+ mutable SmallVector<llvm::unique_function<bool (StringRef)>, 4 >
497+ BeforeCallbacks;
498+ mutable SmallVector<llvm::unique_function<void (StringRef)>, 4 > AfterCallbacks;
499+
500+ // / Helper variable for `-start-before/-start-after/-stop-before/-stop-after`
501+ mutable bool Started = true ;
502+ mutable bool Stopped = true ;
476503};
477504
478505template <typename Derived>
479506Error CodeGenPassBuilder<Derived>::buildPipeline(
480507 ModulePassManager &MPM, MachineFunctionPassManager &MFPM,
481508 raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut,
482509 CodeGenFileType FileType) const {
483- AddIRPass addIRPass (MPM);
510+ auto StartStopInfo = TargetPassConfig::getStartStopInfo (*PIC);
511+ if (!StartStopInfo)
512+ return StartStopInfo.takeError ();
513+ setStartStopPasses (*StartStopInfo);
514+ AddIRPass addIRPass (MPM, derived ());
484515 // `ProfileSummaryInfo` is always valid.
485516 addIRPass (RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
486517 addIRPass (RequireAnalysisPass<CollectorMetadataAnalysis, Module>());
487518 addISelPasses (addIRPass);
488519
489- AddMachinePass addPass (MFPM);
520+ AddMachinePass addPass (MFPM, derived () );
490521 if (auto Err = addCoreISelPasses (addPass))
491522 return std::move (Err);
492523
@@ -499,6 +530,68 @@ Error CodeGenPassBuilder<Derived>::buildPipeline(
499530 });
500531
501532 addPass (FreeMachineFunctionPass ());
533+ return verifyStartStop (*StartStopInfo);
534+ }
535+
536+ template <typename Derived>
537+ void CodeGenPassBuilder<Derived>::setStartStopPasses(
538+ const TargetPassConfig::StartStopInfo &Info) const {
539+ if (!Info.StartPass .empty ()) {
540+ Started = false ;
541+ BeforeCallbacks.emplace_back ([this , &Info, AfterFlag = Info.StartAfter ,
542+ Count = 0u ](StringRef ClassName) mutable {
543+ if (Count == Info.StartInstanceNum ) {
544+ if (AfterFlag) {
545+ AfterFlag = false ;
546+ Started = true ;
547+ }
548+ return Started;
549+ }
550+
551+ auto PassName = PIC->getPassNameForClassName (ClassName);
552+ if (Info.StartPass == PassName && ++Count == Info.StartInstanceNum )
553+ Started = !Info.StartAfter ;
554+
555+ return Started;
556+ });
557+ }
558+
559+ if (!Info.StopPass .empty ()) {
560+ Stopped = false ;
561+ BeforeCallbacks.emplace_back ([this , &Info, AfterFlag = Info.StopAfter ,
562+ Count = 0u ](StringRef ClassName) mutable {
563+ if (Count == Info.StopInstanceNum ) {
564+ if (AfterFlag) {
565+ AfterFlag = false ;
566+ Stopped = true ;
567+ }
568+ return !Stopped;
569+ }
570+
571+ auto PassName = PIC->getPassNameForClassName (ClassName);
572+ if (Info.StopPass == PassName && ++Count == Info.StopInstanceNum )
573+ Stopped = !Info.StopAfter ;
574+ return !Stopped;
575+ });
576+ }
577+ }
578+
579+ template <typename Derived>
580+ Error CodeGenPassBuilder<Derived>::verifyStartStop(
581+ const TargetPassConfig::StartStopInfo &Info) const {
582+ if (Started && Stopped)
583+ return Error::success ();
584+
585+ if (!Started)
586+ return make_error<StringError>(
587+ " Can't find start pass \" " +
588+ PIC->getPassNameForClassName (Info.StartPass ) + " \" ." ,
589+ std::make_error_code (std::errc::invalid_argument));
590+ if (!Stopped)
591+ return make_error<StringError>(
592+ " Can't find stop pass \" " +
593+ PIC->getPassNameForClassName (Info.StopPass ) + " \" ." ,
594+ std::make_error_code (std::errc::invalid_argument));
502595 return Error::success ();
503596}
504597
0 commit comments