diff --git a/llvm/include/llvm/SandboxIR/Pass.h b/llvm/include/llvm/SandboxIR/Pass.h index 34776e3529e5c..211f10f5d57c5 100644 --- a/llvm/include/llvm/SandboxIR/Pass.h +++ b/llvm/include/llvm/SandboxIR/Pass.h @@ -15,6 +15,7 @@ namespace llvm::sandboxir { class Function; +class Region; /// The base class of a Sandbox IR Pass. class Pass { @@ -24,6 +25,7 @@ class Pass { const std::string Name; public: + /// \p Name can't contain any spaces or start with '-'. Pass(StringRef Name) : Name(Name) { assert(!Name.contains(' ') && "A pass name should not contain whitespaces!"); @@ -47,11 +49,21 @@ class Pass { /// A pass that runs on a sandbox::Function. class FunctionPass : public Pass { public: + /// \p Name can't contain any spaces or start with '-'. FunctionPass(StringRef Name) : Pass(Name) {} /// \Returns true if it modifies \p F. virtual bool runOnFunction(Function &F) = 0; }; +/// A pass that runs on a sandbox::Region. +class RegionPass : public Pass { +public: + /// \p Name can't contain any spaces or start with '-'. + RegionPass(StringRef Name) : Pass(Name) {} + /// \Returns true if it modifies \p R. + virtual bool runOnRegion(Region &R) = 0; +}; + } // namespace llvm::sandboxir #endif // LLVM_SANDBOXIR_PASS_H diff --git a/llvm/include/llvm/SandboxIR/PassManager.h b/llvm/include/llvm/SandboxIR/PassManager.h index 98b56ba08c4eb..54192c6bf1333 100644 --- a/llvm/include/llvm/SandboxIR/PassManager.h +++ b/llvm/include/llvm/SandboxIR/PassManager.h @@ -73,6 +73,12 @@ class FunctionPassManager final bool runOnFunction(Function &F) final; }; +class RegionPassManager final : public PassManager { +public: + RegionPassManager(StringRef Name) : PassManager(Name) {} + bool runOnRegion(Region &R) final; +}; + /// Owns the passes and provides an API to get a pass by its name. class PassRegistry { SmallVector, 8> Passes; diff --git a/llvm/lib/SandboxIR/PassManager.cpp b/llvm/lib/SandboxIR/PassManager.cpp index 4168420a01ce2..95bc5e56bb3ec 100644 --- a/llvm/lib/SandboxIR/PassManager.cpp +++ b/llvm/lib/SandboxIR/PassManager.cpp @@ -20,6 +20,16 @@ bool FunctionPassManager::runOnFunction(Function &F) { return Change; } +bool RegionPassManager::runOnRegion(Region &R) { + bool Change = false; + for (RegionPass *Pass : Passes) { + Change |= Pass->runOnRegion(R); + // TODO: run the verifier. + } + // TODO: Check ChangeAll against hashes before/after. + return Change; +} + FunctionPassManager & PassRegistry::parseAndCreatePassPipeline(StringRef Pipeline) { static constexpr const char EndToken = '\0'; diff --git a/llvm/unittests/SandboxIR/PassTest.cpp b/llvm/unittests/SandboxIR/PassTest.cpp index 10fe59b654a2e..b380ae9fd475a 100644 --- a/llvm/unittests/SandboxIR/PassTest.cpp +++ b/llvm/unittests/SandboxIR/PassTest.cpp @@ -13,6 +13,7 @@ #include "llvm/SandboxIR/Context.h" #include "llvm/SandboxIR/Function.h" #include "llvm/SandboxIR/PassManager.h" +#include "llvm/SandboxIR/Region.h" #include "llvm/Support/SourceMgr.h" #include "gtest/gtest.h" @@ -86,6 +87,68 @@ define void @foo() { #endif } +TEST_F(PassTest, RegionPass) { + auto *F = parseFunction(R"IR( +define i8 @foo(i8 %v0, i8 %v1) { + %t0 = add i8 %v0, 1 + %t1 = add i8 %t0, %v1, !sandboxvec !0 + %t2 = add i8 %t1, %v1, !sandboxvec !0 + ret i8 %t1 +} + +!0 = distinct !{!"sandboxregion"} +)IR", + "foo"); + + class TestPass final : public RegionPass { + unsigned &InstCount; + + public: + TestPass(unsigned &InstCount) + : RegionPass("test-pass"), InstCount(InstCount) {} + bool runOnRegion(Region &R) final { + for ([[maybe_unused]] auto &Inst : R) { + ++InstCount; + } + return false; + } + }; + unsigned InstCount = 0; + TestPass TPass(InstCount); + // Check getName(), + EXPECT_EQ(TPass.getName(), "test-pass"); + // Check runOnRegion(); + llvm::SmallVector> Regions = + Region::createRegionsFromMD(*F); + ASSERT_EQ(Regions.size(), 1u); + TPass.runOnRegion(*Regions[0]); + EXPECT_EQ(InstCount, 2u); +#ifndef NDEBUG + { + // Check print(). + std::string Buff; + llvm::raw_string_ostream SS(Buff); + TPass.print(SS); + EXPECT_EQ(Buff, "test-pass"); + } + { + // Check operator<<(). + std::string Buff; + llvm::raw_string_ostream SS(Buff); + SS << TPass; + EXPECT_EQ(Buff, "test-pass"); + } + // Check pass name assertions. + class TestNamePass final : public RegionPass { + public: + TestNamePass(llvm::StringRef Name) : RegionPass(Name) {} + bool runOnRegion(Region &F) { return false; } + }; + EXPECT_DEATH(TestNamePass("white space"), ".*whitespace.*"); + EXPECT_DEATH(TestNamePass("-dash"), ".*start with.*"); +#endif +} + TEST_F(PassTest, FunctionPassManager) { auto *F = parseFunction(R"IR( define void @foo() { @@ -136,6 +199,67 @@ define void @foo() { #endif // NDEBUG } +TEST_F(PassTest, RegionPassManager) { + auto *F = parseFunction(R"IR( +define i8 @foo(i8 %v0, i8 %v1) { + %t0 = add i8 %v0, 1 + %t1 = add i8 %t0, %v1, !sandboxvec !0 + %t2 = add i8 %t1, %v1, !sandboxvec !0 + ret i8 %t1 +} + +!0 = distinct !{!"sandboxregion"} +)IR", + "foo"); + + class TestPass1 final : public RegionPass { + unsigned &InstCount; + + public: + TestPass1(unsigned &InstCount) + : RegionPass("test-pass1"), InstCount(InstCount) {} + bool runOnRegion(Region &R) final { + for ([[maybe_unused]] auto &Inst : R) + ++InstCount; + return false; + } + }; + class TestPass2 final : public RegionPass { + unsigned &InstCount; + + public: + TestPass2(unsigned &InstCount) + : RegionPass("test-pass2"), InstCount(InstCount) {} + bool runOnRegion(Region &R) final { + for ([[maybe_unused]] auto &Inst : R) + ++InstCount; + return false; + } + }; + unsigned InstCount1 = 0; + unsigned InstCount2 = 0; + TestPass1 TPass1(InstCount1); + TestPass2 TPass2(InstCount2); + + RegionPassManager RPM("test-rpm"); + RPM.addPass(&TPass1); + RPM.addPass(&TPass2); + // Check runOnRegion(). + llvm::SmallVector> Regions = + Region::createRegionsFromMD(*F); + ASSERT_EQ(Regions.size(), 1u); + RPM.runOnRegion(*Regions[0]); + EXPECT_EQ(InstCount1, 2u); + EXPECT_EQ(InstCount2, 2u); +#ifndef NDEBUG + // Check dump(). + std::string Buff; + llvm::raw_string_ostream SS(Buff); + RPM.print(SS); + EXPECT_EQ(Buff, "test-rpm(test-pass1,test-pass2)"); +#endif // NDEBUG +} + TEST_F(PassTest, PassRegistry) { class TestPass1 final : public FunctionPass { public: