@@ -86,18 +86,13 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
8686 return state.addExtension <FuncAnalysisState>();
8787}
8888
89- // / Return the unique ReturnOp that terminates `funcOp`.
90- // / Return nullptr if there is no such unique ReturnOp.
91- static func::ReturnOp getAssumedUniqueReturnOp (func::FuncOp funcOp) {
92- func::ReturnOp returnOp;
93- for (Block &b : funcOp.getBody ()) {
94- if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator ())) {
95- if (returnOp)
96- return nullptr ;
97- returnOp = candidateOp;
98- }
99- }
100- return returnOp;
89+ // / Return all top-level func.return ops in the given function.
90+ static SmallVector<func::ReturnOp> getReturnOps (FuncOp funcOp) {
91+ SmallVector<func::ReturnOp> result;
92+ for (Block &b : funcOp.getBody ())
93+ if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator ()))
94+ result.push_back (returnOp);
95+ return result;
10196}
10297
10398namespace {
@@ -146,24 +141,80 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
146141 return success ();
147142 }
148143
149- // Support only single return-terminated block in the function.
150- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
151- assert (returnOp && " expected func with single return op" );
152-
153- for (OpOperand &returnVal : returnOp->getOpOperands ())
154- if (isa<RankedTensorType>(returnVal.get ().getType ()))
155- for (BlockArgument bbArg : funcOp.getArguments ())
156- if (isa<RankedTensorType>(bbArg.getType ())) {
157- int64_t returnIdx = returnVal.getOperandNumber ();
158- int64_t bbArgIdx = bbArg.getArgNumber ();
159- if (state.areEquivalentBufferizedValues (returnVal.get (), bbArg)) {
160- funcState.equivalentFuncArgs [funcOp][returnIdx] = bbArgIdx;
161- if (state.getOptions ().testAnalysisOnly )
162- annotateEquivalentReturnBbArg (returnVal, bbArg);
144+ // Find all func.return ops.
145+ SmallVector<func::ReturnOp> returnOps = getReturnOps (funcOp);
146+ assert (!returnOps.empty () && " expected at least one ReturnOp" );
147+
148+ // Build alias sets. Merge all aliases from all func.return ops.
149+ for (BlockArgument bbArg : funcOp.getArguments ()) {
150+ if (isa<RankedTensorType>(bbArg.getType ())) {
151+ int64_t bbArgIdx = bbArg.getArgNumber ();
152+ // Store aliases in a set, so that we don't add the same alias twice.
153+ SetVector<int64_t > aliases;
154+ for (func::ReturnOp returnOp : returnOps) {
155+ for (OpOperand &returnVal : returnOp->getOpOperands ()) {
156+ if (isa<RankedTensorType>(returnVal.get ().getType ())) {
157+ int64_t returnIdx = returnVal.getOperandNumber ();
158+ if (state.areAliasingBufferizedValues (returnVal.get (), bbArg))
159+ aliases.insert (returnIdx);
163160 }
164- if (state.areAliasingBufferizedValues (returnVal.get (), bbArg))
165- funcState.aliasingReturnVals [funcOp][bbArgIdx].push_back (returnIdx);
166161 }
162+ }
163+ for (int64_t alias : aliases)
164+ funcState.aliasingReturnVals [funcOp][bbArgIdx].push_back (alias);
165+ }
166+ }
167+
168+ // Build equivalence sets.
169+ // Helper function that finds an equivalent block argument index for the
170+ // given OpOperand. Return std::nullopt if no equivalent block argument could
171+ // be found.
172+ auto findEquivalentBlockArgIdx =
173+ [&](OpOperand &opOperand) -> std::optional<int64_t > {
174+ Value v = opOperand.get ();
175+ if (!isa<TensorType>(v.getType ()))
176+ return std::nullopt ;
177+ for (BlockArgument bbArg : funcOp.getArguments ()) {
178+ if (isa<RankedTensorType>(bbArg.getType ())) {
179+ if (state.areEquivalentBufferizedValues (v, bbArg)) {
180+ if (state.getOptions ().testAnalysisOnly )
181+ annotateEquivalentReturnBbArg (opOperand, bbArg);
182+ return bbArg.getArgNumber ();
183+ }
184+ }
185+ }
186+ return std::nullopt ;
187+ };
188+
189+ int64_t numResults = returnOps.front ()->getNumOperands ();
190+ for (int64_t i = 0 ; i < numResults; ++i) {
191+ // Find the equivalent block argument index for the i-th operand of the
192+ // first func.return op.
193+ std::optional<int64_t > maybeEquiv =
194+ findEquivalentBlockArgIdx (returnOps.front ()->getOpOperand (i));
195+ if (!maybeEquiv.has_value ())
196+ continue ;
197+ int64_t bbArgIdx = *maybeEquiv;
198+ bool allEquiv = true ;
199+
200+ // Check if all other func.return ops have the same equivalent block
201+ // argument for the i-th operand. In contrast to aliasing information,
202+ // which is just "merged", equivalence information must match across all
203+ // func.return ops.
204+ for (func::ReturnOp returnOp : ArrayRef (returnOps).drop_front ()) {
205+ std::optional<int64_t > maybeEquiv =
206+ findEquivalentBlockArgIdx (returnOp->getOpOperand (i));
207+ if (maybeEquiv != bbArgIdx) {
208+ allEquiv = false ;
209+ break ;
210+ }
211+ }
212+
213+ // All func.return ops have the same equivalent block argument for the i-th
214+ // operand.
215+ if (allEquiv)
216+ funcState.equivalentFuncArgs [funcOp][i] = bbArgIdx;
217+ }
167218
168219 return success ();
169220}
@@ -302,14 +353,6 @@ static LogicalResult getFuncOpsOrderedByCalls(
302353 // For each FuncOp, the number of func::CallOp it contains.
303354 DenseMap<func::FuncOp, unsigned > numberCallOpsContainedInFuncOp;
304355 WalkResult res = moduleOp.walk ([&](func::FuncOp funcOp) -> WalkResult {
305- if (!funcOp.getBody ().empty ()) {
306- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
307- if (!returnOp)
308- return funcOp->emitError ()
309- << " cannot bufferize a FuncOp with tensors and "
310- " without a unique ReturnOp" ;
311- }
312-
313356 // Collect function calls and populate the caller map.
314357 numberCallOpsContainedInFuncOp[funcOp] = 0 ;
315358 return funcOp.walk ([&](func::CallOp callOp) -> WalkResult {
@@ -351,6 +394,42 @@ static LogicalResult getFuncOpsOrderedByCalls(
351394 return success ();
352395}
353396
397+ // / Helper function that extracts the source from a memref.cast. If the given
398+ // / value is not a memref.cast result, simply returns the given value.
399+ static Value unpackCast (Value v) {
400+ auto castOp = v.getDefiningOp <memref::CastOp>();
401+ if (!castOp)
402+ return v;
403+ return castOp.getSource ();
404+ }
405+
406+ // / Helper function that returns the return types (skipping casts) of the given
407+ // / func.return ops. This function returns as many types as the return ops have
408+ // / operands. If the i-th operand is not the same for all func.return ops, then
409+ // / the i-th returned type is an "empty" type.
410+ static SmallVector<Type> getReturnTypes (SmallVector<func::ReturnOp> returnOps) {
411+ assert (!returnOps.empty () && " expected at least one ReturnOp" );
412+ int numOperands = returnOps.front ()->getNumOperands ();
413+
414+ // Helper function that unpacks memref.cast ops and returns the type.
415+ auto getSourceType = [&](Value v) { return unpackCast (v).getType (); };
416+
417+ SmallVector<Type> result;
418+ for (int i = 0 ; i < numOperands; ++i) {
419+ // Get the type of the i-th operand of the first func.return ops.
420+ Type t = getSourceType (returnOps.front ()->getOperand (i));
421+
422+ // Check if all other func.return ops have a matching operand type.
423+ for (int j = 1 ; j < static_cast <int >(returnOps.size ()); ++j)
424+ if (getSourceType (returnOps[j]->getOperand (i)) != t)
425+ t = Type ();
426+
427+ result.push_back (t);
428+ }
429+
430+ return result;
431+ }
432+
354433// / Fold return values that are memref casts and update function return types.
355434// /
356435// / During FuncOp bufferization, the exact type of the returned memrefs (if any)
@@ -359,21 +438,33 @@ static LogicalResult getFuncOpsOrderedByCalls(
359438// / entire function body, a more concise memref type can potentially be used for
360439// / the return type of the function.
361440static void foldMemRefCasts (func::FuncOp funcOp) {
441+ // There is nothing to do for bodiless ops.
362442 if (funcOp.getBody ().empty ())
363443 return ;
364444
365- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
366- SmallVector<Type> resultTypes;
445+ // Compute the common result types of all return ops.
446+ SmallVector<func::ReturnOp> returnOps = getReturnOps (funcOp);
447+ SmallVector<Type> resultTypes = getReturnTypes (returnOps);
367448
368- for (OpOperand &operand : returnOp->getOpOperands ()) {
369- if (auto castOp = operand.get ().getDefiningOp <memref::CastOp>()) {
370- operand.set (castOp.getSource ());
371- resultTypes.push_back (castOp.getSource ().getType ());
372- } else {
373- resultTypes.push_back (operand.get ().getType ());
449+ // Remove direct casts.
450+ for (func::ReturnOp returnOp : returnOps) {
451+ for (OpOperand &operand : returnOp->getOpOperands ()) {
452+ // Bail if no common result type was found.
453+ if (resultTypes[operand.getOperandNumber ()]) {
454+ operand.set (unpackCast (operand.get ()));
455+ }
374456 }
375457 }
376458
459+ // Fill in the missing result types that were not the same among all
460+ // func.return ops.
461+ for (int i = 0 ; i < static_cast <int >(resultTypes.size ()); ++i) {
462+ if (resultTypes[i])
463+ continue ;
464+ resultTypes[i] = funcOp.getFunctionType ().getResult (i);
465+ }
466+
467+ // Update the function type.
377468 auto newFuncType = FunctionType::get (
378469 funcOp.getContext (), funcOp.getFunctionType ().getInputs (), resultTypes);
379470 funcOp.setType (newFuncType);
0 commit comments