@@ -86,14 +86,24 @@ static FuncOp getCalledFunction(CallOpInterface callOp,
8686 symbolTable.lookupNearestSymbolFrom (callOp, sym));
8787}
8888
89- // / Get FuncAnalysisState.
89+ // / Get or create FuncAnalysisState.
9090static const FuncAnalysisState &
91- getFuncAnalysisState (const AnalysisState &state) {
91+ getOrCreateFuncAnalysisState (const AnalysisState &state) {
9292 assert (isa<OneShotAnalysisState>(state) && " expected OneShotAnalysisState" );
93- auto *result = static_cast <const OneShotAnalysisState &>(state)
94- .getExtension <FuncAnalysisState>();
95- assert (result && " FuncAnalysisState does not exist" );
96- return *result;
93+
94+ // Unfortunately, at the moment the BufferizableOpInterface methods do provide
95+ // a const reference to the AnalysisState class, and the only way to
96+ // dynamically add an extension is to const_cast it to a non-const reference.
97+ // Should the const qualifier be dropped from the interface?
98+ auto &oneShotAnalysisState =
99+ static_cast <OneShotAnalysisState &>(const_cast <AnalysisState &>(state));
100+
101+ auto *result = oneShotAnalysisState.getExtension <FuncAnalysisState>();
102+
103+ if (result)
104+ return *result;
105+
106+ return oneShotAnalysisState.addExtension <FuncAnalysisState>();
97107}
98108
99109// / Return the state (phase) of analysis of the FuncOp.
@@ -136,7 +146,7 @@ struct CallOpInterface
136146 bool bufferizesToMemoryRead (Operation *op, OpOperand &opOperand,
137147 const AnalysisState &state) const {
138148 func::CallOp callOp = cast<func::CallOp>(op);
139- const FuncAnalysisState &funcState = getFuncAnalysisState (state);
149+ const FuncAnalysisState &funcState = getOrCreateFuncAnalysisState (state);
140150 FuncOp funcOp = getCalledFunction (callOp, funcState.symbolTable );
141151 assert (funcOp && " expected CallOp to a FuncOp" );
142152
@@ -151,7 +161,7 @@ struct CallOpInterface
151161 bool bufferizesToMemoryWrite (Operation *op, OpOperand &opOperand,
152162 const AnalysisState &state) const {
153163 func::CallOp callOp = cast<func::CallOp>(op);
154- const FuncAnalysisState &funcState = getFuncAnalysisState (state);
164+ const FuncAnalysisState &funcState = getOrCreateFuncAnalysisState (state);
155165 FuncOp funcOp = getCalledFunction (callOp, funcState.symbolTable );
156166 assert (funcOp && " expected CallOp to a FuncOp" );
157167
@@ -166,7 +176,7 @@ struct CallOpInterface
166176 AliasingValueList getAliasingValues (Operation *op, OpOperand &opOperand,
167177 const AnalysisState &state) const {
168178 func::CallOp callOp = cast<func::CallOp>(op);
169- const FuncAnalysisState &funcState = getFuncAnalysisState (state);
179+ const FuncAnalysisState &funcState = getOrCreateFuncAnalysisState (state);
170180 FuncOp funcOp = getCalledFunction (callOp, funcState.symbolTable );
171181 assert (funcOp && " expected CallOp to a FuncOp" );
172182 if (getFuncOpAnalysisState (state, funcOp) != FuncOpAnalysisState::Analyzed)
0 commit comments