Skip to content

Commit 3e746bd

Browse files
Allowing RDV to call getArgOperandsMutable() (#160415)
## Problem `RemoveDeadValues` can legally drop dead function arguments on private `func.func` callees. But call-sites to such functions aren't fixed if the call operation keeps its call arguments in a **segmented operand group** (i.ie, uses `AttrSizedOperandSegments`), unless the call op implements `getArgOperandsMutable` and the RDV pass actually uses it. ## Fix When RDV decides to drop callee function args, it should, for each call-site that implements `CallOpInterface`, **shrink the call's argument segment** via `getArgOperandsMutable()` using the same dead-arg indices. This keeps both the flat operand list and the `operand_segment_sizes` attribute in sync (that's what `MutableOperandRange` does when bound to the segment). ## Note This change is a no-op for: * call ops without segment operands (they still get their flat operands erased via the generic path) * call ops whose calle args weren't dropped (public, external, non-`func-func`, unresolved symbol, etc) * `llvm.call`/`llvm.invoke` (RDV doesn't drop `llvm.func` args --------- Co-authored-by: Mehdi Amini <[email protected]>
1 parent acb826e commit 3e746bd

File tree

4 files changed

+165
-14
lines changed

4 files changed

+165
-14
lines changed

mlir/lib/Transforms/RemoveDeadValues.cpp

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ struct FunctionToCleanUp {
8888
struct OperationToCleanup {
8989
Operation *op;
9090
BitVector nonLive;
91+
Operation *callee =
92+
nullptr; // Optional: For CallOpInterface ops, stores the callee function
9193
};
9294

9395
struct BlockArgsToCleanup {
@@ -306,19 +308,19 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
306308
nonLiveSet.insert(arg);
307309
}
308310

309-
// Do (2).
311+
// Do (2). (Skip creating generic operand cleanup entries for call ops.
312+
// Call arguments will be removed in the call-site specific segment-aware
313+
// cleanup, avoiding generic eraseOperands bitvector mechanics.)
310314
SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
311315
for (SymbolTable::SymbolUse use : uses) {
312316
Operation *callOp = use.getUser();
313317
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
314-
// The number of operands in the call op may not match the number of
315-
// arguments in the func op.
316-
BitVector nonLiveCallOperands(callOp->getNumOperands(), false);
317-
SmallVector<OpOperand *> callOpOperands =
318-
operandsToOpOperands(cast<CallOpInterface>(callOp).getArgOperands());
319-
for (int index : nonLiveArgs.set_bits())
320-
nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
321-
cl.operands.push_back({callOp, nonLiveCallOperands});
318+
// Push an empty operand cleanup entry so that call-site specific logic in
319+
// cleanUpDeadVals runs (it keys off CallOpInterface). The BitVector is
320+
// intentionally all false to avoid generic erasure.
321+
// Store the funcOp as the callee to avoid expensive symbol lookup later.
322+
cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false),
323+
funcOp.getOperation()});
322324
}
323325

324326
// Do (3).
@@ -746,6 +748,10 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
746748

747749
// 3. Functions
748750
LDBG() << "Cleaning up " << list.functions.size() << " functions";
751+
// Record which function arguments were erased so we can shrink call-site
752+
// argument segments for CallOpInterface operations (e.g. ops using
753+
// AttrSizedOperandSegments) in the next phase.
754+
DenseMap<Operation *, BitVector> erasedFuncArgs;
749755
for (auto &f : list.functions) {
750756
LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName();
751757
LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments";
@@ -754,17 +760,52 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
754760
// Some functions may not allow erasing arguments or results. These calls
755761
// return failure in such cases without modifying the function, so it's okay
756762
// to proceed.
757-
(void)f.funcOp.eraseArguments(f.nonLiveArgs);
763+
if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
764+
// Record only if we actually erased something.
765+
if (f.nonLiveArgs.any())
766+
erasedFuncArgs.try_emplace(f.funcOp.getOperation(), f.nonLiveArgs);
767+
}
758768
(void)f.funcOp.eraseResults(f.nonLiveRets);
759769
}
760770

761771
// 4. Operands
762772
LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
763773
for (OperationToCleanup &o : list.operands) {
764-
if (o.op->getNumOperands() > 0) {
765-
LDBG() << "Erasing " << o.nonLive.count()
766-
<< " non-live operands from operation: "
767-
<< OpWithFlags(o.op, OpPrintingFlags().skipRegions());
774+
// Handle call-specific cleanup only when we have a cached callee reference.
775+
// This avoids expensive symbol lookup and is defensive against future
776+
// changes.
777+
bool handledAsCall = false;
778+
if (o.callee && isa<CallOpInterface>(o.op)) {
779+
auto call = cast<CallOpInterface>(o.op);
780+
auto it = erasedFuncArgs.find(o.callee);
781+
if (it != erasedFuncArgs.end()) {
782+
const BitVector &deadArgIdxs = it->second;
783+
MutableOperandRange args = call.getArgOperandsMutable();
784+
// First, erase the call arguments corresponding to erased callee
785+
// args. We iterate backwards to preserve indices.
786+
for (unsigned argIdx : llvm::reverse(deadArgIdxs.set_bits()))
787+
args.erase(argIdx);
788+
// If this operand cleanup entry also has a generic nonLive bitvector,
789+
// clear bits for call arguments we already erased above to avoid
790+
// double-erasing (which could impact other segments of ops with
791+
// AttrSizedOperandSegments).
792+
if (o.nonLive.any()) {
793+
// Map the argument logical index to the operand number(s) recorded.
794+
int operandOffset = call.getArgOperands().getBeginOperandIndex();
795+
for (int argIdx : deadArgIdxs.set_bits()) {
796+
int operandNumber = operandOffset + argIdx;
797+
if (operandNumber < static_cast<int>(o.nonLive.size()))
798+
o.nonLive.reset(operandNumber);
799+
}
800+
}
801+
handledAsCall = true;
802+
}
803+
}
804+
// Perform generic operand erasure for:
805+
// - Non-call operations
806+
// - Call operations without cached callee (where handledAsCall is false)
807+
// But skip call operations that were already handled via segment-aware path
808+
if (!handledAsCall && o.nonLive.any()) {
768809
o.op->eraseOperands(o.nonLive);
769810
}
770811
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt --split-input-file --remove-dead-values --mlir-print-op-generic %s | FileCheck %s --check-prefix=GEN
2+
3+
// -----
4+
// Private callee: both args become dead after internal DCE; RDV drops callee
5+
// args and shrinks the *args* segment on the call-site to zero; sizes kept in
6+
// sync.
7+
8+
module {
9+
func.func private @callee(%x: i32, %y: i32) {
10+
%u = arith.addi %x, %x : i32 // %y is dead
11+
return
12+
}
13+
14+
func.func @caller(%a: i32, %b: i32) {
15+
// args segment initially has 2 operands.
16+
"test.call_with_segments"(%a, %b) { callee = @callee,
17+
operandSegmentSizes = array<i32: 0, 2, 0> } : (i32, i32) -> ()
18+
return
19+
}
20+
}
21+
22+
// GEN: "test.call_with_segments"() <{callee = @callee, operandSegmentSizes = array<i32: 0, 0, 0>}> : () -> ()
23+
// ^ args shrank from 2 -> 0

mlir/test/lib/Dialect/Test/TestDialect.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,3 +431,47 @@ void TestDialect::getCanonicalizationPatterns(
431431
RewritePatternSet &results) const {
432432
results.add(&dialectCanonicalizationPattern);
433433
}
434+
435+
//===----------------------------------------------------------------------===//
436+
// TestCallWithSegmentsOp
437+
//===----------------------------------------------------------------------===//
438+
// The op `test.call_with_segments` models a call-like operation whose operands
439+
// are divided into 3 variadic segments: `prefix`, `args`, and `suffix`.
440+
// Only the middle segment represents the actual call arguments. The op uses
441+
// the AttrSizedOperandSegments trait, so we can derive segment boundaries from
442+
// the generated `operandSegmentSizes` attribute. We provide custom helpers to
443+
// expose the logical call arguments as both a read-only range and a mutable
444+
// range bound to the proper segment so that insertion/erasure updates the
445+
// attribute automatically.
446+
447+
// Segment layout indices in the DenseI32ArrayAttr: [prefix, args, suffix].
448+
static constexpr unsigned kTestCallWithSegmentsArgsSegIndex = 1;
449+
450+
Operation::operand_range CallWithSegmentsOp::getArgOperands() {
451+
// Leverage generated getters for segment sizes: slice between prefix and
452+
// suffix using current operand list.
453+
return getOperation()->getOperands().slice(getPrefix().size(),
454+
getArgs().size());
455+
}
456+
457+
MutableOperandRange CallWithSegmentsOp::getArgOperandsMutable() {
458+
Operation *op = getOperation();
459+
460+
// Obtain the canonical segment size attribute name for this op.
461+
auto segName =
462+
CallWithSegmentsOp::getOperandSegmentSizesAttrName(op->getName());
463+
auto sizesAttr = op->getAttrOfType<DenseI32ArrayAttr>(segName);
464+
assert(sizesAttr && "missing operandSegmentSizes attribute on op");
465+
466+
// Compute the start and length of the args segment from the prefix size and
467+
// args size stored in the attribute.
468+
auto sizes = sizesAttr.asArrayRef();
469+
unsigned start = static_cast<unsigned>(sizes[0]); // prefix size
470+
unsigned len = static_cast<unsigned>(sizes[1]); // args size
471+
472+
NamedAttribute segNamed(segName, sizesAttr);
473+
MutableOperandRange::OperandSegment binding{kTestCallWithSegmentsArgsSegIndex,
474+
segNamed};
475+
476+
return MutableOperandRange(op, start, len, {binding});
477+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3746,4 +3746,47 @@ def TestOpWithSuccessorRef : TEST_Op<"dummy_op_with_successor_ref"> {
37463746
}];
37473747
}
37483748

3749+
def CallWithSegmentsOp : TEST_Op<"call_with_segments",
3750+
[AttrSizedOperandSegments,
3751+
DeclareOpInterfaceMethods<CallOpInterface>]> {
3752+
let summary = "test call op with segmented args";
3753+
let arguments = (ins
3754+
FlatSymbolRefAttr:$callee,
3755+
Variadic<AnyType>:$prefix, // non-arg segment (e.g., 'in')
3756+
Variadic<AnyType>:$args, // <-- the call *arguments* segment
3757+
Variadic<AnyType>:$suffix // non-arg segment (e.g., 'out')
3758+
);
3759+
let results = (outs);
3760+
let assemblyFormat = [{
3761+
$callee `(` $prefix `:` type($prefix) `)`
3762+
`(` $args `:` type($args) `)`
3763+
`(` $suffix `:` type($suffix) `)` attr-dict
3764+
}];
3765+
3766+
// Provide stub implementations for the ArgAndResultAttrsOpInterface.
3767+
let extraClassDeclaration = [{
3768+
::mlir::ArrayAttr getArgAttrsAttr() { return {}; }
3769+
::mlir::ArrayAttr getResAttrsAttr() { return {}; }
3770+
void setArgAttrsAttr(::mlir::ArrayAttr) {}
3771+
void setResAttrsAttr(::mlir::ArrayAttr) {}
3772+
::mlir::Attribute removeArgAttrsAttr() { return {}; }
3773+
::mlir::Attribute removeResAttrsAttr() { return {}; }
3774+
}];
3775+
3776+
let extraClassDefinition = [{
3777+
::mlir::CallInterfaceCallable $cppClass::getCallableForCallee() {
3778+
if (auto sym = (*this)->getAttrOfType<::mlir::SymbolRefAttr>("callee"))
3779+
return ::mlir::CallInterfaceCallable(sym);
3780+
return ::mlir::CallInterfaceCallable();
3781+
}
3782+
void $cppClass::setCalleeFromCallable(::mlir::CallInterfaceCallable callee) {
3783+
if (auto sym = callee.dyn_cast<::mlir::SymbolRefAttr>())
3784+
(*this)->setAttr("callee", sym);
3785+
else
3786+
(*this)->removeAttr("callee");
3787+
}
3788+
}];
3789+
}
3790+
3791+
37493792
#endif // TEST_OPS

0 commit comments

Comments
 (0)