Skip to content

Commit 19a9c64

Browse files
Adding changes to RDV +small repro case for dialect with callOp and the AttrSizedOperandSegments trait
1 parent 4a4bdde commit 19a9c64

File tree

4 files changed

+160
-14
lines changed

4 files changed

+160
-14
lines changed

mlir/lib/Transforms/RemoveDeadValues.cpp

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -306,19 +306,17 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
306306
nonLiveSet.insert(arg);
307307
}
308308

309-
// Do (2).
309+
// Do (2). (Skip creating generic operand cleanup entries for call ops.
310+
// Call arguments will be removed in the call-site specific segment-aware
311+
// cleanup, avoiding generic eraseOperands bitvector mechanics.)
310312
SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
311313
for (SymbolTable::SymbolUse use : uses) {
312314
Operation *callOp = use.getUser();
313315
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
314-
// The number of operands in the call op may not match the number of
315-
// arguments in the func op.
316-
BitVector nonLiveCallOperands(callOp->getNumOperands(), false);
317-
SmallVector<OpOperand *> callOpOperands =
318-
operandsToOpOperands(cast<CallOpInterface>(callOp).getArgOperands());
319-
for (int index : nonLiveArgs.set_bits())
320-
nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
321-
cl.operands.push_back({callOp, nonLiveCallOperands});
316+
// Push an empty operand cleanup entry so that call-site specific logic in
317+
// cleanUpDeadVals runs (it keys off CallOpInterface). The BitVector is
318+
// intentionally all false to avoid generic erasure.
319+
cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false)});
322320
}
323321

324322
// Do (3).
@@ -746,6 +744,10 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
746744

747745
// 3. Functions
748746
LDBG() << "Cleaning up " << list.functions.size() << " functions";
747+
// Record which function arguments were erased so we can shrink call-site
748+
// argument segments for CallOpInterface operations (e.g. ops using
749+
// AttrSizedOperandSegments) in the next phase.
750+
DenseMap<Operation *, BitVector> erasedFuncArgs;
749751
for (auto &f : list.functions) {
750752
LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName();
751753
LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments";
@@ -754,17 +756,51 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
754756
// Some functions may not allow erasing arguments or results. These calls
755757
// return failure in such cases without modifying the function, so it's okay
756758
// to proceed.
757-
(void)f.funcOp.eraseArguments(f.nonLiveArgs);
759+
if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
760+
// Record only if we actually erased something.
761+
if (f.nonLiveArgs.any())
762+
erasedFuncArgs.try_emplace(f.funcOp.getOperation(), f.nonLiveArgs);
763+
}
758764
(void)f.funcOp.eraseResults(f.nonLiveRets);
759765
}
760766

761767
// 4. Operands
762768
LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
763769
for (OperationToCleanup &o : list.operands) {
764-
if (o.op->getNumOperands() > 0) {
765-
LDBG() << "Erasing " << o.nonLive.count()
766-
<< " non-live operands from operation: "
767-
<< OpWithFlags(o.op, OpPrintingFlags().skipRegions());
770+
if (auto call = dyn_cast<CallOpInterface>(o.op)) {
771+
if (SymbolRefAttr sym = call.getCallableForCallee().dyn_cast<SymbolRefAttr>()) {
772+
Operation *callee = SymbolTable::lookupNearestSymbolFrom(o.op, sym);
773+
auto it = erasedFuncArgs.find(callee);
774+
if (it != erasedFuncArgs.end()) {
775+
const BitVector &deadArgIdxs = it->second;
776+
MutableOperandRange args = call.getArgOperandsMutable();
777+
// First, erase the call arguments corresponding to erased callee args.
778+
for (int i = static_cast<int>(args.size()) - 1; i >= 0; --i) {
779+
if (i < static_cast<int>(deadArgIdxs.size()) && deadArgIdxs.test(i))
780+
args.erase(i);
781+
}
782+
// If this operand cleanup entry also has a generic nonLive bitvector,
783+
// clear bits for call arguments we already erased above to avoid
784+
// double-erasing (which could impact other segments of ops with
785+
// AttrSizedOperandSegments).
786+
if (o.nonLive.any()) {
787+
// Map the argument logical index to the operand number(s) recorded.
788+
SmallVector<OpOperand *> callOperands =
789+
operandsToOpOperands(call.getArgOperands());
790+
for (int argIdx : deadArgIdxs.set_bits()) {
791+
if (argIdx < static_cast<int>(callOperands.size())) {
792+
unsigned operandNumber = callOperands[argIdx]->getOperandNumber();
793+
if (operandNumber < o.nonLive.size())
794+
o.nonLive.reset(operandNumber);
795+
}
796+
}
797+
}
798+
}
799+
}
800+
}
801+
// Only perform generic operand erasure for non-call ops; for call ops we
802+
// already handled argument removals via the segment-aware path above.
803+
if (!isa<CallOpInterface>(o.op) && o.nonLive.any()) {
768804
o.op->eraseOperands(o.nonLive);
769805
}
770806
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt --split-input-file --remove-dead-values --mlir-print-op-generic %s | FileCheck %s --check-prefix=GEN
2+
3+
// -----
4+
// Private callee: both args become dead after internal DCE; RDV drops callee
5+
// args and shrinks the *args* segment on the call-site to zero; sizes kept in
6+
// sync.
7+
8+
module {
9+
func.func private @callee(%x: i32, %y: i32) {
10+
%u = arith.addi %x, %x : i32 // %y is dead
11+
return
12+
}
13+
14+
func.func @caller(%a: i32, %b: i32) {
15+
// args segment initially has 2 operands.
16+
"test.call_with_segments"(%a, %b) { callee = @callee,
17+
operandSegmentSizes = array<i32: 0, 2, 0> } : (i32, i32) -> ()
18+
return
19+
}
20+
}
21+
22+
// GEN: "test.call_with_segments"() <{callee = @callee, operandSegmentSizes = array<i32: 0, 0, 0>}> : () -> ()
23+
// ^ args shrank from 2 -> 0

mlir/test/lib/Dialect/Test/TestDialect.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,3 +431,47 @@ void TestDialect::getCanonicalizationPatterns(
431431
RewritePatternSet &results) const {
432432
results.add(&dialectCanonicalizationPattern);
433433
}
434+
435+
//===----------------------------------------------------------------------===//
436+
// TestCallWithSegmentsOp
437+
//===----------------------------------------------------------------------===//
438+
// The op `test.call_with_segments` models a call-like operation whose operands
439+
// are divided into 3 variadic segments: `prefix`, `args`, and `suffix`.
440+
// Only the middle segment represents the actual call arguments. The op uses
441+
// the AttrSizedOperandSegments trait, so we can derive segment boundaries from
442+
// the generated `operandSegmentSizes` attribute. We provide custom helpers to
443+
// expose the logical call arguments as both a read-only range and a mutable
444+
// range bound to the proper segment so that insertion/erasure updates the
445+
// attribute automatically.
446+
447+
// Segment layout indices in the DenseI32ArrayAttr: [prefix, args, suffix].
448+
static constexpr unsigned kTestCallWithSegmentsArgsSegIndex = 1;
449+
450+
Operation::operand_range CallWithSegmentsOp::getArgOperands() {
451+
// Leverage generated getters for segment sizes: slice between prefix and
452+
// suffix using current operand list.
453+
return getOperation()->getOperands().slice(getPrefix().size(),
454+
getArgs().size());
455+
}
456+
457+
MutableOperandRange CallWithSegmentsOp::getArgOperandsMutable() {
458+
Operation *op = getOperation();
459+
460+
// Obtain the canonical segment size attribute name for this op.
461+
auto segName =
462+
CallWithSegmentsOp::getOperandSegmentSizesAttrName(op->getName());
463+
auto sizesAttr = op->getAttrOfType<DenseI32ArrayAttr>(segName);
464+
assert(sizesAttr && "missing operandSegmentSizes attribute on op");
465+
466+
// Compute the start and length of the args segment from the prefix size and
467+
// args size stored in the attribute.
468+
auto sizes = sizesAttr.asArrayRef();
469+
unsigned start = static_cast<unsigned>(sizes[0]); // prefix size
470+
unsigned len = static_cast<unsigned>(sizes[1]); // args size
471+
472+
NamedAttribute segNamed(segName, sizesAttr);
473+
MutableOperandRange::OperandSegment binding{kTestCallWithSegmentsArgsSegIndex,
474+
segNamed};
475+
476+
return MutableOperandRange(op, start, len, {binding});
477+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3745,4 +3745,47 @@ def TestOpWithSuccessorRef : TEST_Op<"dummy_op_with_successor_ref"> {
37453745
}];
37463746
}
37473747

3748+
def CallWithSegmentsOp : TEST_Op<"call_with_segments",
3749+
[AttrSizedOperandSegments,
3750+
DeclareOpInterfaceMethods<CallOpInterface>]> {
3751+
let summary = "test call op with segmented args";
3752+
let arguments = (ins
3753+
FlatSymbolRefAttr:$callee,
3754+
Variadic<AnyType>:$prefix, // non-arg segment (e.g., 'in')
3755+
Variadic<AnyType>:$args, // <-- the call *arguments* segment
3756+
Variadic<AnyType>:$suffix // non-arg segment (e.g., 'out')
3757+
);
3758+
let results = (outs);
3759+
let assemblyFormat = [{
3760+
$callee `(` $prefix `:` type($prefix) `)`
3761+
`(` $args `:` type($args) `)`
3762+
`(` $suffix `:` type($suffix) `)` attr-dict
3763+
}];
3764+
3765+
// Provide stub implementations for the ArgAndResultAttrsOpInterface.
3766+
let extraClassDeclaration = [{
3767+
::mlir::ArrayAttr getArgAttrsAttr() { return {}; }
3768+
::mlir::ArrayAttr getResAttrsAttr() { return {}; }
3769+
void setArgAttrsAttr(::mlir::ArrayAttr) {}
3770+
void setResAttrsAttr(::mlir::ArrayAttr) {}
3771+
::mlir::Attribute removeArgAttrsAttr() { return {}; }
3772+
::mlir::Attribute removeResAttrsAttr() { return {}; }
3773+
}];
3774+
3775+
let extraClassDefinition = [{
3776+
::mlir::CallInterfaceCallable $cppClass::getCallableForCallee() {
3777+
if (auto sym = (*this)->getAttrOfType<::mlir::SymbolRefAttr>("callee"))
3778+
return ::mlir::CallInterfaceCallable(sym);
3779+
return ::mlir::CallInterfaceCallable();
3780+
}
3781+
void $cppClass::setCalleeFromCallable(::mlir::CallInterfaceCallable callee) {
3782+
if (auto sym = callee.dyn_cast<::mlir::SymbolRefAttr>())
3783+
(*this)->setAttr("callee", sym);
3784+
else
3785+
(*this)->removeAttr("callee");
3786+
}
3787+
}];
3788+
}
3789+
3790+
37483791
#endif // TEST_OPS

0 commit comments

Comments
 (0)