Skip to content

Commit 983b2c6

Browse files
committed
Dynamically add FuncAnalysisState extension
1 parent 8499232 commit 983b2c6

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,24 @@ static FuncOp getCalledFunction(CallOpInterface callOp,
8686
symbolTable.lookupNearestSymbolFrom(callOp, sym));
8787
}
8888

89-
/// Get FuncAnalysisState.
89+
/// Get or create FuncAnalysisState.
9090
static const FuncAnalysisState &
91-
getFuncAnalysisState(const AnalysisState &state) {
91+
getOrCreateFuncAnalysisState(const AnalysisState &state) {
9292
assert(isa<OneShotAnalysisState>(state) && "expected OneShotAnalysisState");
93-
auto *result = static_cast<const OneShotAnalysisState &>(state)
94-
.getExtension<FuncAnalysisState>();
95-
assert(result && "FuncAnalysisState does not exist");
96-
return *result;
93+
94+
// Unfortunately, at the moment the BufferizableOpInterface methods do provide
95+
// a const reference to the AnalysisState class, and the only way to
96+
// dynamically add an extension is to const_cast it to a non-const reference.
97+
// Should the const qualifier be dropped from the interface?
98+
auto &oneShotAnalysisState =
99+
static_cast<OneShotAnalysisState &>(const_cast<AnalysisState &>(state));
100+
101+
auto *result = oneShotAnalysisState.getExtension<FuncAnalysisState>();
102+
103+
if (result)
104+
return *result;
105+
106+
return oneShotAnalysisState.addExtension<FuncAnalysisState>();
97107
}
98108

99109
/// Return the state (phase) of analysis of the FuncOp.
@@ -136,7 +146,7 @@ struct CallOpInterface
136146
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
137147
const AnalysisState &state) const {
138148
func::CallOp callOp = cast<func::CallOp>(op);
139-
const FuncAnalysisState &funcState = getFuncAnalysisState(state);
149+
const FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
140150
FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
141151
assert(funcOp && "expected CallOp to a FuncOp");
142152

@@ -151,7 +161,7 @@ struct CallOpInterface
151161
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
152162
const AnalysisState &state) const {
153163
func::CallOp callOp = cast<func::CallOp>(op);
154-
const FuncAnalysisState &funcState = getFuncAnalysisState(state);
164+
const FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
155165
FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
156166
assert(funcOp && "expected CallOp to a FuncOp");
157167

@@ -166,7 +176,7 @@ struct CallOpInterface
166176
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
167177
const AnalysisState &state) const {
168178
func::CallOp callOp = cast<func::CallOp>(op);
169-
const FuncAnalysisState &funcState = getFuncAnalysisState(state);
179+
const FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
170180
FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
171181
assert(funcOp && "expected CallOp to a FuncOp");
172182
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)

0 commit comments

Comments
 (0)