Skip to content

Commit d9e17d3

Browse files
authored
[CIR][FlattenCFG] Fix use after free when flattening terminator (#843)
Per the operation walking documentation [1]: > A callback on a block or operation is allowed to erase that block or > operation if either: > * the walk is in post-order, or > * the walk is in pre-order and the walk is skipped after the erasure. We were doing neither when erasing terminator operations and replacing them with a branch, leading to a use after free and ASAN errors. This fixes the following tests with ASAN: ``` Clang :: CIR/CodeGen/switch-gnurange.cpp Clang :: CIR/Lowering/atomic-runtime.cpp Clang :: CIR/Lowering/loop.cir Clang :: CIR/Lowering/loops-with-break.cir Clang :: CIR/Lowering/loops-with-continue.cir Clang :: CIR/Lowering/switch.cir Clang :: CIR/Transforms/Target/x86_64/x86_64-call-conv-lowering-pass.cpp Clang :: CIR/Transforms/loop.cir Clang :: CIR/Transforms/switch.cir ``` These two tests still fail with ASAN after this, which I'm looking into: ``` Clang :: CIR/CodeGen/pointer-arith-ext.c Clang :: CIR/Transforms/Target/x86_64/x86_64-call-conv-lowering-pass.cpp ``` `CIR/CodeGen/global-new.cpp` is failing even on a non-ASAN Release build for me on the parent commit, so it's unrelated. [1] https://github.com/llvm/llvm-project/blob/0c55ad11ab3857056bb3917fdf087c4aa811b790/mlir/include/mlir/IR/Operation.h#L767-L770
1 parent ba8c248 commit d9e17d3

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

clang/include/clang/CIR/Interfaces/CIRLoopOpInterface.td

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,13 @@ def LoopOpInterface : OpInterface<"LoopOpInterface", [
7171
}],
7272
/*retTy=*/"mlir::WalkResult",
7373
/*methodName=*/"walkBodySkippingNestedLoops",
74-
/*args=*/(ins "::llvm::function_ref<void (Operation *)>":$callback),
74+
/*args=*/(ins "::llvm::function_ref<mlir::WalkResult (Operation *)>":$callback),
7575
/*methodBody=*/"",
7676
/*defaultImplementation=*/[{
7777
return $_op.getBody().template walk<WalkOrder::PreOrder>([&](Operation *op) {
7878
if (isa<LoopOpInterface>(op))
7979
return mlir::WalkResult::skip();
80-
callback(op);
81-
return mlir::WalkResult::advance();
80+
return callback(op);
8281
});
8382
}]
8483
>

clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ void lowerTerminator(mlir::Operation *op, mlir::Block *dest,
3636
/// Walks a region while skipping operations of type `Ops`. This ensures the
3737
/// callback is not applied to said operations and its children.
3838
template <typename... Ops>
39-
void walkRegionSkipping(mlir::Region &region,
40-
mlir::function_ref<void(mlir::Operation *)> callback) {
39+
void walkRegionSkipping(
40+
mlir::Region &region,
41+
mlir::function_ref<mlir::WalkResult(mlir::Operation *)> callback) {
4142
region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
4243
if (isa<Ops...>(op))
4344
return mlir::WalkResult::skip();
44-
callback(op);
45-
return mlir::WalkResult::advance();
45+
return callback(op);
4646
});
4747
}
4848

@@ -541,15 +541,21 @@ class CIRLoopOpInterfaceFlattening
541541
// Lower continue statements.
542542
mlir::Block *dest = (step ? step : cond);
543543
op.walkBodySkippingNestedLoops([&](mlir::Operation *op) {
544-
if (isa<mlir::cir::ContinueOp>(op))
545-
lowerTerminator(op, dest, rewriter);
544+
if (!isa<mlir::cir::ContinueOp>(op))
545+
return mlir::WalkResult::advance();
546+
547+
lowerTerminator(op, dest, rewriter);
548+
return mlir::WalkResult::skip();
546549
});
547550

548551
// Lower break statements.
549552
walkRegionSkipping<mlir::cir::LoopOpInterface, mlir::cir::SwitchOp>(
550553
op.getBody(), [&](mlir::Operation *op) {
551-
if (isa<mlir::cir::BreakOp>(op))
552-
lowerTerminator(op, exit, rewriter);
554+
if (!isa<mlir::cir::BreakOp>(op))
555+
return mlir::WalkResult::advance();
556+
557+
lowerTerminator(op, exit, rewriter);
558+
return mlir::WalkResult::skip();
553559
});
554560

555561
// Lower optional body region yield.
@@ -705,8 +711,11 @@ class CIRSwitchOpFlattening
705711
// Handle break statements.
706712
walkRegionSkipping<mlir::cir::LoopOpInterface, mlir::cir::SwitchOp>(
707713
region, [&](mlir::Operation *op) {
708-
if (isa<mlir::cir::BreakOp>(op))
709-
lowerTerminator(op, exitBlock, rewriter);
714+
if (!isa<mlir::cir::BreakOp>(op))
715+
return mlir::WalkResult::advance();
716+
717+
lowerTerminator(op, exitBlock, rewriter);
718+
return mlir::WalkResult::skip();
710719
});
711720

712721
// Extract region contents before erasing the switch op.

0 commit comments

Comments
 (0)