diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h index 7fe97d984b958..a88b0003f55bd 100644 --- a/llvm/include/llvm/SandboxIR/Context.h +++ b/llvm/include/llvm/SandboxIR/Context.h @@ -218,6 +218,8 @@ class Context { public: Context(LLVMContext &LLVMCtx); ~Context(); + /// Clears function-level state. + void clear(); Tracker &getTracker() { return IRTracker; } /// Convenience function for `getTracker().save()` diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h index 09369dbb496fc..7ea9386f08bee 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h @@ -24,10 +24,18 @@ class SandboxVectorizerPass : public PassInfoMixin { TargetTransformInfo *TTI = nullptr; AAResults *AA = nullptr; ScalarEvolution *SE = nullptr; - + // NOTE: We define the Context as a pass-scope object instead of local object + // in runOnFunction() because the passes defined in the pass-manager need + // access to it for registering/deregistering callbacks during construction + // and destruction. std::unique_ptr Ctx; // A pipeline of SandboxIR function passes run by the vectorizer. + // NOTE: We define this as a pass-scope object to avoid recreating the + // pass-pipeline every time in runOnFunction(). The downside is that the + // Context also needs to be defined as a pass-scope object because the passes + // within FPM may register/unregister callbacks, so they need access to + // Context. sandboxir::FunctionPassManager FPM; bool runImpl(Function &F); diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp index 440210f5a1bf7..830f2832853fe 100644 --- a/llvm/lib/SandboxIR/Context.cpp +++ b/llvm/lib/SandboxIR/Context.cpp @@ -611,6 +611,12 @@ Context::Context(LLVMContext &LLVMCtx) Context::~Context() {} +void Context::clear() { + // TODO: Ideally we should clear only function-scope objects, and keep global + // objects, like Constants to avoid recreating them. + LLVMValueToValueMap.clear(); +} + Module *Context::getModule(llvm::Module *LLVMM) const { auto It = LLVMModuleToModuleMap.find(LLVMM); if (It != LLVMModuleToModuleMap.end()) diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp index 542fcde71e83c..798a0ad915375 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp @@ -88,7 +88,10 @@ bool SandboxVectorizerPass::runImpl(Function &LLVMF) { sandboxir::Function &F = *Ctx->createFunction(&LLVMF); sandboxir::Analyses A(*AA, *SE, *TTI); bool Change = FPM.runOnFunction(F, A); - // TODO: This is a function pass, so we won't be needing the function-level - // Sandbox IR objects in the future. So we should clear them. + // Given that sandboxir::Context `Ctx` is defined at a pass-level scope, the + // maps from LLVM IR to Sandbox IR may go stale as later passes remove LLVM IR + // objects. To avoid issues caused by this clear the context's state. + // NOTE: The alternative would be to define Ctx and FPM within runOnFunction() + Ctx->clear(); return Change; } diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt index bbfbcc730a4cb..104a27975cfc0 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/CMakeLists.txt @@ -12,6 +12,7 @@ add_llvm_unittest(SandboxVectorizerTests InstrMapsTest.cpp IntervalTest.cpp LegalityTest.cpp + SandboxVectorizerTest.cpp SchedulerTest.cpp SeedCollectorTest.cpp VecUtilsTest.cpp diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerTest.cpp new file mode 100644 index 0000000000000..e9b304618a06c --- /dev/null +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerTest.cpp @@ -0,0 +1,63 @@ +//===- SandboxVectorizerTest.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/PassInstrumentation.h" +#include "llvm/SandboxIR/Function.h" +#include "llvm/SandboxIR/Instruction.h" +#include "llvm/Support/SourceMgr.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using namespace llvm; + +struct SandboxVectorizerTest : public testing::Test { + LLVMContext C; + std::unique_ptr M; + + void parseIR(LLVMContext &C, const char *IR) { + SMDiagnostic Err; + M = parseAssemblyString(IR, Err, C); + if (!M) + Err.print("SandboxVectorizerTest", errs()); + } +}; + +// Check that we can run the pass on the same function more than once without +// issues. This basically checks that Sandbox IR Context gets cleared after we +// run the function pass. +TEST_F(SandboxVectorizerTest, ContextCleared) { + parseIR(C, R"IR( +define void @foo() { + ret void +} +)IR"); + auto &LLVMF = *M->getFunction("foo"); + SandboxVectorizerPass SVecPass; + FunctionAnalysisManager AM; + AM.registerPass([] { return TargetIRAnalysis(); }); + AM.registerPass([] { return AAManager(); }); + AM.registerPass([] { return ScalarEvolutionAnalysis(); }); + AM.registerPass([] { return PassInstrumentationAnalysis(); }); + AM.registerPass([] { return TargetLibraryAnalysis(); }); + AM.registerPass([] { return AssumptionAnalysis(); }); + AM.registerPass([] { return DominatorTreeAnalysis(); }); + AM.registerPass([] { return LoopAnalysis(); }); + SVecPass.run(LLVMF, AM); + // This shouldn't crash. + SVecPass.run(LLVMF, AM); +}