-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[MLIR] Avoid resolving callable outside the analysis scope in DeadCodeAnalysis #155088
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir Author: Mehdi Amini (joker-eph) ChangesWe are using the symbol table machinery to lookup for a callable, but when the analysis scope if a function, such lookup will resolve outside of the scope. This can lead to race-condition issues since other passes may operate in parallel on the sibling functions. Fix #154948 Full diff: https://github.com/llvm/llvm-project/pull/155088.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
index 2250db823b551..c7c405e1423cb 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
@@ -229,6 +229,13 @@ class DeadCodeAnalysis : public DataFlowAnalysis {
/// considered an external callable.
Operation *analysisScope;
+ /// Whether the analysis scope has a symbol table. This is used to avoid
+ /// resolving callables outside the analysis scope.
+ /// It is updated when recursing into a region in case where the top-level
+ /// operation does not have a symbol table, but one is encountered in a nested
+ /// region.
+ bool hasSymbolTable = false;
+
/// A symbol table used for O(1) symbol lookups during simplification.
SymbolTableCollection symbolTable;
};
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 9424eff3e6b6f..131c49c44171b 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -22,6 +22,7 @@
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugLog.h"
@@ -159,6 +160,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
LDBG() << "[init] Entering initializeSymbolCallables for top-level op: "
<< OpWithFlags(top, OpPrintingFlags().skipRegions());
analysisScope = top;
+ hasSymbolTable = top->hasTrait<OpTrait::SymbolTable>();
auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
LDBG() << "[init] Processing symbol table op: "
<< OpWithFlags(symTable, OpPrintingFlags().skipRegions());
@@ -260,14 +262,25 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
return failure();
}
// Recurse on nested operations.
- for (Region ®ion : op->getRegions()) {
- LDBG() << "[init] Recursing into region of op: "
- << OpWithFlags(op, OpPrintingFlags().skipRegions());
- for (Operation &nestedOp : region.getOps()) {
- LDBG() << "[init] Recursing into nested op: "
- << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions());
- if (failed(initializeRecursively(&nestedOp)))
- return failure();
+ if (op->getNumRegions()) {
+ // If we haven't seen a symbol table yet, check if the current operation
+ // has one. If so, update the flag to allow for resolving callables in
+ // nested regions.
+ bool savedHasSymbolTable = hasSymbolTable;
+ auto restoreHasSymbolTable =
+ llvm::make_scope_exit([&]() { hasSymbolTable = savedHasSymbolTable; });
+ if (!hasSymbolTable && op->hasTrait<OpTrait::SymbolTable>())
+ hasSymbolTable = true;
+
+ for (Region ®ion : op->getRegions()) {
+ LDBG() << "[init] Recursing into region of op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ for (Operation &nestedOp : region.getOps()) {
+ LDBG() << "[init] Recursing into nested op: "
+ << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions());
+ if (failed(initializeRecursively(&nestedOp)))
+ return failure();
+ }
}
}
LDBG() << "[init] Finished initializeRecursively for op: "
@@ -388,7 +401,13 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
LDBG() << "visitCallOperation: "
<< OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions());
- Operation *callableOp = call.resolveCallableInTable(&symbolTable);
+
+ Operation *callableOp = nullptr;
+ if (hasSymbolTable)
+ callableOp = call.resolveCallableInTable(&symbolTable);
+ else
+ LDBG()
+ << "No symbol table present in analysis scope, can't resolve callable";
// A call to a externally-defined callable has unknown predecessors.
const auto isExternalCallable = [this](Operation *op) {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @joker-eph ! I tested this against the original error and this solves it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I approved a little bit too quickly. I re-ran again and as I mentioned on the original issue, the error moves to a different place.
I am attaching the traces from TSan.
Read of size 8 at 0x722800004078 by thread T8: #0 mlir::Attribute::getDialect() const /home/eochoalo/code/iree/third_party/llvm-project/mlir/include/mlir/IR/Attributes.h:59:12 (libIREECompiler.so+0x7dc2799) (BuildId: 366a285014486d45)
#1 mlir::Attribute::getContext() const /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/IR/Attributes.cpp:37:53 (libIREECompiler.so+0x7dc2799)
#2 mlir::Operation::getContext() /home/eochoalo/code/iree/third_party/llvm-project/mlir/include/mlir/IR/Operation.h:216:48 (libIREECompiler.so+0xfe88741) (BuildId: 366a285014486d45)
#3 mlir::RegisteredOperationName::Model<mlir::iree_compiler::IREE::Util::FuncOp>::getInherentAttr(mlir::Operation*, llvm::StringRef) /home/eochoalo/code/iree/third_party/llvm-project/mlir/include/mlir/IR/OperationSupport.h:570:56 (libIREECompiler.so+0xfe88741)
#4 mlir::OperationName::getInherentAttr(mlir::Operation*, llvm::StringRef) const /home/eochoalo/code/iree/third_party/llvm-project/mlir/include/mlir/IR/OperationSupport.h:393:23 (libIREECompiler.so+0x7eaf7c4) (BuildId: 366a285014486d45)
#5 mlir::Operation::getInherentAttr(llvm::StringRef) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/IR/Operation.cpp:341:20 (libIREECompiler.so+0x7eaf7c4)
#6 mlir::Operation::getAttr(mlir::StringAttr) /home/eochoalo/code/iree/third_party/llvm-project/mlir/include/mlir/IR/Operation.h:536:51 (libIREECompiler.so+0x7ed8b86) (BuildId: 366a285014486d45)
#7 mlir::StringAttr mlir::Operation::getAttrOfType<mlir::StringAttr>(mlir::StringAttr) /home/eochoalo/code/iree/third_party/llvm-project/mlir/include/mlir/IR/Operation.h:551:46 (libIREECompiler.so+0x7ed8b86)
#8 getNameIfSymbol(mlir::Operation*, mlir::StringAttr) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/IR/SymbolTable.cpp:31:14 (libIREECompiler.so+0x7ed8b86)
#9 mlir::SymbolTable::lookupSymbolIn(mlir::Operation*, mlir::StringAttr) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/IR/SymbolTable.cpp:395:9 (libIREECompiler.so+0x7ed8b86)
#10 mlir::SymbolTable::lookupSymbolIn(mlir::Operation*, mlir::SymbolRefAttr, llvm::SmallVectorImpl<mlir::Operation*>&)::$_0::operator()(mlir::Operation*, mlir::StringAttr) const /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/IR/SymbolTable.cpp:446:12 (libIREECompiler.so+0x7ee03b0) (BuildId: 366a285014486d45)
#11 mlir::Operation* llvm::function_ref<mlir::Operation* (mlir::Operation*, mlir::StringAttr)>::callback_fn<mlir::SymbolTable::lookupSymbolIn(mlir::Operation*, mlir::SymbolRefAttr, llvm::SmallVectorImpl<mlir::Operation*>&)::$_0>(long, mlir::Operation*, mlir::StringAttr) /home/eochoalo/code/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:46:12 (libIREECompiler.so+0x7ee03b0)
#12 llvm::function_ref<mlir::Operation* (mlir::Operation*, mlir::StringAttr)>::operator()(mlir::Operation*, mlir::StringAttr) const /home/eochoalo/code/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:69:12 (libIREECompiler.so+0x7ed8e9a) (BuildId: 366a285014486d45)
#13 lookupSymbolInImpl(mlir::Operation*, mlir::SymbolRefAttr, llvm::SmallVectorImpl<mlir::Operation*>&, llvm::function_ref<mlir::Operation* (mlir::Operation*, mlir::StringAttr)>) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/IR/SymbolTable.cpp:416:19 (libIREECompiler.so+0x7ed8e9a)
#14 mlir::SymbolTable::lookupSymbolIn(mlir::Operation*, mlir::SymbolRefAttr, llvm::SmallVectorImpl<mlir::Operation*>&) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/IR/SymbolTable.cpp:448:10 (libIREECompiler.so+0x7ed9438) (BuildId: 366a285014486d45)
#15 mlir::SymbolTable::lookupSymbolIn(mlir::Operation*, mlir::SymbolRefAttr) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/IR/SymbolTable.cpp:402:14 (libIREECompiler.so+0x7ed9438)
#16 mlir::SymbolTable::lookupNearestSymbolFrom(mlir::Operation*, mlir::SymbolRefAttr) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/IR/SymbolTable.cpp:462:26 (libIREECompiler.so+0x7ed9438)
#17 mlir::call_interface_impl::resolveCallable(mlir::CallOpInterface, mlir::SymbolTableCollection*) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Interfaces/CallInterfaces.cpp:197:10 (libIREECompiler.so+0x10791d36) (BuildId: 366a285014486d45)
#18 mlir::detail::CallOpInterfaceTrait<mlir::iree_compiler::IREE::Util::CallOp>::resolveCallable() /home/eochoalo/code/iree-build/llvm-project/tools/mlir/include/mlir/Interfaces/CallInterfaces.h.inc:252:14 (libIREECompiler.so+0xfe81bad) (BuildId: 366a285014486d45)
#19 mlir::detail::CallOpInterfaceInterfaceTraits::Model<mlir::iree_compiler::IREE::Util::CallOp>::resolveCallable(mlir::detail::CallOpInterfaceInterfaceTraits::Concept const*, mlir::Operation*) /home/eochoalo/code/iree-build/llvm-project/tools/mlir/include/mlir/Interfaces/CallInterfaces.h.inc:445:56 (libIREECompiler.so+0xfe81bad)
#20 mlir::CallOpInterface::resolveCallable() /home/eochoalo/code/iree-build/llvm-project/tools/mlir/include/mlir/Interfaces/CallInterfaces.cpp.inc:90:14 (libIREECompiler.so+0x107921f0) (BuildId: 366a285014486d45)
#21 mlir::dataflow::AbstractSparseForwardDataFlowAnalysis::visitCallOperation(mlir::CallOpInterface, llvm::ArrayRef<mlir::dataflow::AbstractSparseLattice const*>, llvm::ArrayRef<mlir::dataflow::AbstractSparseLattice*>) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp:232:53 (libIREECompiler.so+0x10756f68) (BuildId: 366a285014486d45)
#22 mlir::dataflow::AbstractSparseForwardDataFlowAnalysis::visitOperation(mlir::Operation*) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp:148:12 (libIREECompiler.so+0x107561b7) (BuildId: 366a285014486d45)
#23 mlir::dataflow::AbstractSparseForwardDataFlowAnalysis::visit(mlir::ProgramPoint*) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp:106:12 (libIREECompiler.so+0x10756dd2) (BuildId: 366a285014486d45)
#24 mlir::DataFlowSolver::initializeAndRun(mlir::Operation*) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Analysis/DataFlowFramework.cpp:132:26 (libIREECompiler.so+0x10715b59) (BuildId: 366a285014486d45)
#25 (anonymous namespace)::ArithUnsignedWhenEquivalentPass::runOnOperation() /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp:130:23 (libIREECompiler.so+0xe7141b3) (BuildId: 366a285014486d45)
...
Previous write of size 8 at 0x722800004078 by thread T5:
#0 mlir::Operation::setLoc(mlir::Location) /home/eochoalo/code/iree/third_party/llvm-project/mlir/include/mlir/IR/Operation.h:226:40 (libIREECompiler.so+0x1005670b) (BuildId: 366a285014486d45)
#1 (anonymous namespace)::ModifyOperationRewrite::rollback() /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:683:9 (libIREECompiler.so+0x1005670b)
#2 mlir::ConversionPatternRewriter::cancelOpModification(mlir::Operation*) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:2239:10 (libIREECompiler.so+0x10035a89) (BuildId: 366a285014486d45)
#3 (anonymous namespace)::OperationLegalizer::legalizeWithFold(mlir::Operation*, mlir::ConversionPatternRewriter&) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:2558:14 (libIREECompiler.so+0x100442bb) (BuildId: 366a285014486d45)
#4 (anonymous namespace)::OperationLegalizer::legalize(mlir::Operation*, mlir::ConversionPatternRewriter&) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:2484:19 (libIREECompiler.so+0x10036ff0) (BuildId: 366a285014486d45)
#5 mlir::OperationConverter::convert(mlir::ConversionPatternRewriter&, mlir::Operation*) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:3108:26 (libIREECompiler.so+0x1003624f) (BuildId: 366a285014486d45)
#6 mlir::OperationConverter::convertOperations(llvm::ArrayRef<mlir::Operation*>) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:3207:16 (libIREECompiler.so+0x10037855) (BuildId: 366a285014486d45)
#7 applyConversion(llvm::ArrayRef<mlir::Operation*>, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, mlir::ConversionConfig, (anonymous namespace)::OpConversionMode)::$_0::operator()() const /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:3918:30 (libIREECompiler.so+0x1004aace) (BuildId: 366a285014486d45)
#8 void llvm::function_ref<void ()>::callback_fn<applyConversion(llvm::ArrayRef<mlir::Operation*>, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, mlir::ConversionConfig, (anonymous namespace)::OpConversionMode)::$_0>(long) /home/eochoalo/code/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:46:12 (libIREECompiler.so+0x1004aace)
#9 llvm::function_ref<void ()>::operator()() const /home/eochoalo/code/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:69:12 (libIREECompiler.so+0x1003e62a) (BuildId: 366a285014486d45)
#10 void mlir::MLIRContext::executeAction<ApplyConversionAction>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>) /home/eochoalo/code/iree/third_party/llvm-project/mlir/include/mlir/IR/MLIRContext.h:280:7 (libIREECompiler.so+0x1003e62a)
#11 applyConversion(llvm::ArrayRef<mlir::Operation*>, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, mlir::ConversionConfig, (anonymous namespace)::OpConversionMode) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:3915:8 (libIREECompiler.so+0x1003e62a)
#12 mlir::applyPartialConversion(llvm::ArrayRef<mlir::Operation*>, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, mlir::ConversionConfig) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:3931:10 (libIREECompiler.so+0x1003e709) (BuildId: 366a285014486d45)
#13 mlir::applyPartialConversion(mlir::Operation*, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, mlir::ConversionConfig) /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:3938:10 (libIREECompiler.so+0x1003e709)
#14 (anonymous namespace)::LowerAffine::runOnOperation() /home/eochoalo/code/iree/third_party/llvm-project/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp:564:16 (libIREECompiler.so+0xc812528) (BuildId: 366a285014486d45)
Let me try to produce a test case.
|
Thanks, this occurs with this resolve call: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp#L231-L232 |
We are using the symbol table machinery to lookup for a callable, but when the analysis scope if a function, such lookup will resolve outside of the scope. This can lead to race-condition issues since other passes may operate in parallel on the sibling functions. In DeadCode Analysis, the callable would be discarded right after the lookup (we check the analysis scope), so avoiding the lookup is NFC. For the DataFlow solver, we're looking at the top-level operation, and if it isn't a SymbolTable we disable the interprocedural optimization in the solver config directly. This strategy isn't NFC but seems reasonnable and does not encounter any change in behavior in practice in tree. Fix llvm#154948
98eb52b to
19cd550
Compare
|
I updated the patch, can you try it out? I'm interested in the test case though! |
|
@joker-eph I've tested the patch and it works. Thanks Mehdi. Here is the reduced test case: // compile mlir-opt with TSan
// mlir-opt --pass-pipeline="builtin.module(func.func(lower-affine,arith-unsigned-when-equivalent))"
module {
func.func public @expect_true_of_false() {
call @_expect_true_of_false() : () -> ()
return
}
func.func private @_expect_true_of_false() {
%c0_i32 = arith.constant 0 : i32
return
}
} |
We are using the symbol table machinery to lookup for a callable, but when the analysis scope if a function, such lookup will resolve outside of the scope. This can lead to race-condition issues since other passes may operate in parallel on the sibling functions.
The callable would be discarded right after the lookup (we check the analysis scope), so avoiding the lookup is NFC.
For the DataFlow solver, we're looking at the top-level operation, and if
it isn't a SymbolTable we disable the interprocedural optimization in the
solver config directly.
This strategy isn't NFC but seems reasonnable and does not encounter any
change in behavior in practice in tree.
Fix #154948