Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/include/mlir/IR/MLIRContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ class MLIRContext {
disableMultithreading(!enable);
}

/// Set the flag specifying if thread-local storage should be used
/// by storage allocators in this context.
void disableThreadLocalStorage(bool disable = true);
void enableThreadLocalStorage(bool enable = true) {
disableThreadLocalStorage(!enable);
}

/// Set a new thread pool to be used in this context. This method requires
/// that multithreading is disabled for this context prior to the call. This
/// allows to share a thread pool across multiple contexts, as well as
Expand Down
48 changes: 45 additions & 3 deletions mlir/lib/IR/AttributeDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
#include "mlir/Support/ThreadLocalCache.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/PointerIntPair.h"
#include "llvm/Support/ErrorHandling.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need error handling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is left over from a previous commit, apologies. I would like to clarify though, since clangd tells me there are other headers unused here, shall I remove those too as an en-passant tidy? Notably,

  • "mlir/Support/StorageUniquer.h"
  • "llvm/ADT/PointerIntPair.h"
  • "llvm/Support/TrailingObjects.h"

#include "llvm/Support/TrailingObjects.h"
#include <mutex>

namespace mlir {
namespace detail {
Expand Down Expand Up @@ -401,7 +403,8 @@ class DistinctAttributeUniquer {
/// is freed after the destruction of the distinct attribute allocator.
class DistinctAttributeAllocator {
public:
DistinctAttributeAllocator() = default;
DistinctAttributeAllocator(bool threadingIsEnabled)
: threadingIsEnabled(threadingIsEnabled), useThreadLocalAllocator(true) {};

DistinctAttributeAllocator(DistinctAttributeAllocator &&) = delete;
DistinctAttributeAllocator(const DistinctAttributeAllocator &) = delete;
Expand All @@ -411,12 +414,51 @@ class DistinctAttributeAllocator {
/// Allocates a distinct attribute storage using a thread local bump pointer
/// allocator to enable synchronization free parallel allocations.
DistinctAttrStorage *allocate(Attribute referencedAttr) {
return new (allocatorCache.get().Allocate<DistinctAttrStorage>())
DistinctAttrStorage(referencedAttr);
if (!useThreadLocalAllocator && threadingIsEnabled) {
std::scoped_lock<std::mutex> lock(allocatorMutex);
return allocateImpl(referencedAttr);
}
return allocateImpl(referencedAttr);
}

/// Sets flags to use thread local bump pointer allocators or a single
/// non-thread safe bump pointer allocator depending on if multi-threading is
/// enabled. Use this to disable the use of thread local storage and bypass
/// thread safety synchronization overhead.
void disableMultiThreading(bool disable = true) {
threadingIsEnabled = !disable;
}

/// Sets flags to disable using thread local bump pointer allocators and use a
/// single thread-safe allocator. Use this to persist allocated storage beyond
/// the lifetime of a child thread calling this function while ensuring
/// thread-safe allocation.
void disableThreadLocalStorage(bool disable = true) {
useThreadLocalAllocator = !disable;
}

private:
DistinctAttrStorage *allocateImpl(Attribute referencedAttr) {
return new (getAllocatorInUse().Allocate<DistinctAttrStorage>())
DistinctAttrStorage(referencedAttr);
}

/// If threading is disabled on the owning MLIR context, a normal non
/// thread-local, non-thread safe bump pointer allocator is used instead to
/// prevent use-after-free errors whenever attribute storage created on a
/// crash recover thread is accessed after the thread joins.
llvm::BumpPtrAllocator &getAllocatorInUse() {
if (useThreadLocalAllocator)
return allocatorCache.get();
return allocator;
}

ThreadLocalCache<llvm::BumpPtrAllocator> allocatorCache;
llvm::BumpPtrAllocator allocator;
std::mutex allocatorMutex;

bool threadingIsEnabled : 1;
bool useThreadLocalAllocator : 1;
};
} // namespace detail
} // namespace mlir
Expand Down
8 changes: 7 additions & 1 deletion mlir/lib/IR/MLIRContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ class MLIRContextImpl {

public:
MLIRContextImpl(bool threadingIsEnabled)
: threadingIsEnabled(threadingIsEnabled) {
: threadingIsEnabled(threadingIsEnabled),
distinctAttributeAllocator(threadingIsEnabled) {
if (threadingIsEnabled) {
ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
threadPool = ownedThreadPool.get();
Expand Down Expand Up @@ -596,6 +597,7 @@ void MLIRContext::disableMultithreading(bool disable) {
// Update the threading mode for each of the uniquers.
impl->affineUniquer.disableMultithreading(disable);
impl->attributeUniquer.disableMultithreading(disable);
impl->distinctAttributeAllocator.disableMultiThreading(disable);
impl->typeUniquer.disableMultithreading(disable);

// Destroy thread pool (stop all threads) if it is no longer needed, or create
Expand Down Expand Up @@ -717,6 +719,10 @@ bool MLIRContext::isOperationRegistered(StringRef name) {
return RegisteredOperationName::lookup(name, this).has_value();
}

void MLIRContext::disableThreadLocalStorage(bool disable) {
getImpl().distinctAttributeAllocator.disableThreadLocalStorage(disable);
}

void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
auto &impl = context->getImpl();
assert(impl.multiThreadedExecutionContext == 0 &&
Expand Down
9 changes: 9 additions & 0 deletions mlir/lib/Pass/PassCrashRecovery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,15 @@ struct FileReproducerStream : public mlir::ReproducerStream {

LogicalResult PassManager::runWithCrashRecovery(Operation *op,
AnalysisManager am) {
// Notify the context to disable the use of thread-local storage while the
// pass manager is running in a crash recovery context thread. Re-enable the
// thread local storage upon function exit. This is required to persist any
// attribute storage allocated during passes beyond the lifetime of the
// recovery context thread.
MLIRContext *ctx = getContext();
ctx->disableThreadLocalStorage();
auto guard =
llvm::make_scope_exit([ctx]() { ctx->enableThreadLocalStorage(); });
crashReproGenerator->initialize(getPasses(), op, verifyPasses);

// Safely invoke the passes within a recovery context.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Test that the enable-debug-info-scope-on-llvm-func pass can create its
// distinct attributes when running in the crash reproducer thread.

// RUN: mlir-opt --mlir-disable-threading --mlir-pass-pipeline-crash-reproducer=. \
// RUN: --pass-pipeline="builtin.module(ensure-debug-info-scope-on-llvm-func)" \
// RUN: --mlir-print-debuginfo %s | FileCheck %s

// RUN: mlir-opt --mlir-pass-pipeline-crash-reproducer=. \
// RUN: --pass-pipeline="builtin.module(ensure-debug-info-scope-on-llvm-func)" \
// RUN: --mlir-print-debuginfo %s | FileCheck %s

module {
llvm.func @func_no_debug() {
llvm.return loc(unknown)
} loc(unknown)
} loc(unknown)

// CHECK-LABEL: llvm.func @func_no_debug()
// CHECK: llvm.return loc(#loc
// CHECK: loc(#loc[[LOC:[0-9]+]])
// CHECK: #di_compile_unit = #llvm.di_compile_unit<id = distinct[{{.*}}]<>,
// CHECK: #di_subprogram = #llvm.di_subprogram<id = distinct[{{.*}}]<>
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// This test verifies that when running with crash reproduction enabled, distinct
// attribute storage is not allocated in thread-local storage. Since crash
// reproduction runs the pass manager in a separate thread, using thread-local
// storage for distinct attributes causes use-after-free errors once the thread
// that runs the pass manager joins.

// RUN: mlir-opt --mlir-disable-threading --mlir-pass-pipeline-crash-reproducer=. %s -test-distinct-attrs | FileCheck %s
// RUN: mlir-opt --mlir-pass-pipeline-crash-reproducer=. %s -test-distinct-attrs | FileCheck %s

// CHECK: #[[DIST0:.*]] = distinct[0]<42 : i32>
// CHECK: #[[DIST1:.*]] = distinct[1]<42 : i32>
#distinct = distinct[0]<42 : i32>

// CHECK: @foo_1
func.func @foo_1() {
// CHECK: "test.op"() {distinct.input = #[[DIST0]], distinct.output = #[[DIST1]]}
"test.op"() {distinct.input = #distinct} : () -> ()
}