@@ -86,18 +86,13 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
8686 return state.addExtension <FuncAnalysisState>();
8787}
8888
89- // / Return the unique ReturnOp that terminates `funcOp`.
90- // / Return nullptr if there is no such unique ReturnOp.
91- static func::ReturnOp getAssumedUniqueReturnOp (func::FuncOp funcOp) {
92- func::ReturnOp returnOp;
93- for (Block &b : funcOp.getBody ()) {
94- if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator ())) {
95- if (returnOp)
96- return nullptr ;
97- returnOp = candidateOp;
98- }
99- }
100- return returnOp;
89+ // / Return all top-level func.return ops in the given function.
90+ static SmallVector<func::ReturnOp> getReturnOps (FuncOp funcOp) {
91+ SmallVector<func::ReturnOp> result;
92+ for (Block &b : funcOp.getBody ())
93+ if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator ()))
94+ result.push_back (returnOp);
95+ return result;
10196}
10297
10398namespace {
@@ -146,24 +141,80 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
146141 return success ();
147142 }
148143
149- // Support only single return-terminated block in the function.
150- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
151- assert (returnOp && " expected func with single return op" );
152-
153- for (OpOperand &returnVal : returnOp->getOpOperands ())
154- if (isa<RankedTensorType>(returnVal.get ().getType ()))
155- for (BlockArgument bbArg : funcOp.getArguments ())
156- if (isa<RankedTensorType>(bbArg.getType ())) {
157- int64_t returnIdx = returnVal.getOperandNumber ();
158- int64_t bbArgIdx = bbArg.getArgNumber ();
159- if (state.areEquivalentBufferizedValues (returnVal.get (), bbArg)) {
160- funcState.equivalentFuncArgs [funcOp][returnIdx] = bbArgIdx;
161- if (state.getOptions ().testAnalysisOnly )
162- annotateEquivalentReturnBbArg (returnVal, bbArg);
144+ // Find all func.return ops.
145+ SmallVector<func::ReturnOp> returnOps = getReturnOps (funcOp);
146+ assert (!returnOps.empty () && " expected at least one ReturnOp" );
147+
148+ // Build alias sets. Merge all aliases from all func.return ops.
149+ for (BlockArgument bbArg : funcOp.getArguments ()) {
150+ if (isa<RankedTensorType>(bbArg.getType ())) {
151+ int64_t bbArgIdx = bbArg.getArgNumber ();
152+ // Store aliases in a set, so that we don't add the same alias twice.
153+ SetVector<int64_t > aliases;
154+ for (func::ReturnOp returnOp : returnOps) {
155+ for (OpOperand &returnVal : returnOp->getOpOperands ()) {
156+ if (isa<RankedTensorType>(returnVal.get ().getType ())) {
157+ int64_t returnIdx = returnVal.getOperandNumber ();
158+ if (state.areAliasingBufferizedValues (returnVal.get (), bbArg))
159+ aliases.insert (returnIdx);
163160 }
164- if (state.areAliasingBufferizedValues (returnVal.get (), bbArg))
165- funcState.aliasingReturnVals [funcOp][bbArgIdx].push_back (returnIdx);
166161 }
162+ }
163+ for (int64_t alias : aliases)
164+ funcState.aliasingReturnVals [funcOp][bbArgIdx].push_back (alias);
165+ }
166+ }
167+
168+ // Build equivalence sets.
169+ // Helper function that finds an equivalent block argument index for the
170+ // given OpOperand. Return std::nullopt if no equivalent block argument could
171+ // be found.
172+ auto findEquivalentBlockArgIdx =
173+ [&](OpOperand &opOperand) -> std::optional<int64_t > {
174+ Value v = opOperand.get ();
175+ if (!isa<TensorType>(v.getType ()))
176+ return std::nullopt ;
177+ for (BlockArgument bbArg : funcOp.getArguments ()) {
178+ if (isa<RankedTensorType>(bbArg.getType ())) {
179+ if (state.areEquivalentBufferizedValues (v, bbArg)) {
180+ if (state.getOptions ().testAnalysisOnly )
181+ annotateEquivalentReturnBbArg (opOperand, bbArg);
182+ return bbArg.getArgNumber ();
183+ }
184+ }
185+ }
186+ return std::nullopt ;
187+ };
188+
189+ int64_t numResults = returnOps.front ()->getNumOperands ();
190+ for (int64_t i = 0 ; i < numResults; ++i) {
191+ // Find the equivalent block argument index for the i-th operand of the
192+ // first func.return op.
193+ std::optional<int64_t > maybeEquiv =
194+ findEquivalentBlockArgIdx (returnOps.front ()->getOpOperand (i));
195+ if (!maybeEquiv.has_value ())
196+ continue ;
197+ int64_t bbArgIdx = *maybeEquiv;
198+ bool allEquiv = true ;
199+
200+ // Check if all other func.return ops have the same equivalent block
201+ // argument for the i-th operand. In contrast to aliasing information,
202+ // which is just "merged", equivalence information must match across all
203+ // func.return ops.
204+ for (func::ReturnOp returnOp : ArrayRef (returnOps).drop_front ()) {
205+ std::optional<int64_t > maybeEquiv =
206+ findEquivalentBlockArgIdx (returnOp->getOpOperand (i));
207+ if (maybeEquiv != bbArgIdx) {
208+ allEquiv = false ;
209+ break ;
210+ }
211+ }
212+
213+ // All func.return ops have the same equivalent block argument for the i-th
214+ // operand.
215+ if (allEquiv)
216+ funcState.equivalentFuncArgs [funcOp][i] = bbArgIdx;
217+ }
167218
168219 return success ();
169220}
@@ -299,14 +350,6 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
299350 // For each FuncOp, the number of func::CallOp it contains.
300351 DenseMap<func::FuncOp, unsigned > numberCallOpsContainedInFuncOp;
301352 WalkResult res = moduleOp.walk ([&](func::FuncOp funcOp) -> WalkResult {
302- if (!funcOp.getBody ().empty ()) {
303- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
304- if (!returnOp)
305- return funcOp->emitError ()
306- << " cannot bufferize a FuncOp with tensors and "
307- " without a unique ReturnOp" ;
308- }
309-
310353 // Collect function calls and populate the caller map.
311354 numberCallOpsContainedInFuncOp[funcOp] = 0 ;
312355 return funcOp.walk ([&](func::CallOp callOp) -> WalkResult {
@@ -342,6 +385,42 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
342385 return success ();
343386}
344387
388+ // / Helper function that extracts the source from a memref.cast. If the given
389+ // / value is not a memref.cast result, simply returns the given value.
390+ static Value unpackCast (Value v) {
391+ auto castOp = v.getDefiningOp <memref::CastOp>();
392+ if (!castOp)
393+ return v;
394+ return castOp.getSource ();
395+ }
396+
397+ // / Helper function that returns the return types (skipping casts) of the given
398+ // / func.return ops. This function returns as many types as the return ops have
399+ // / operands. If the i-th operand is not the same for all func.return ops, then
400+ // / the i-th returned type is an "empty" type.
401+ static SmallVector<Type> getReturnTypes (SmallVector<func::ReturnOp> returnOps) {
402+ assert (!returnOps.empty () && " expected at least one ReturnOp" );
403+ int numOperands = returnOps.front ()->getNumOperands ();
404+
405+ // Helper function that unpacks memref.cast ops and returns the type.
406+ auto getSourceType = [&](Value v) { return unpackCast (v).getType (); };
407+
408+ SmallVector<Type> result;
409+ for (int i = 0 ; i < numOperands; ++i) {
410+ // Get the type of the i-th operand of the first func.return ops.
411+ Type t = getSourceType (returnOps.front ()->getOperand (i));
412+
413+ // Check if all other func.return ops have a matching operand type.
414+ for (int j = 1 ; j < static_cast <int >(returnOps.size ()); ++j)
415+ if (getSourceType (returnOps[j]->getOperand (i)) != t)
416+ t = Type ();
417+
418+ result.push_back (t);
419+ }
420+
421+ return result;
422+ }
423+
345424// / Fold return values that are memref casts and update function return types.
346425// /
347426// / During FuncOp bufferization, the exact type of the returned memrefs (if any)
@@ -350,21 +429,33 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
350429// / entire function body, a more concise memref type can potentially be used for
351430// / the return type of the function.
352431static void foldMemRefCasts (func::FuncOp funcOp) {
432+ // There is nothing to do for bodiless ops.
353433 if (funcOp.getBody ().empty ())
354434 return ;
355435
356- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
357- SmallVector<Type> resultTypes;
436+ // Compute the common result types of all return ops.
437+ SmallVector<func::ReturnOp> returnOps = getReturnOps (funcOp);
438+ SmallVector<Type> resultTypes = getReturnTypes (returnOps);
358439
359- for (OpOperand &operand : returnOp->getOpOperands ()) {
360- if (auto castOp = operand.get ().getDefiningOp <memref::CastOp>()) {
361- operand.set (castOp.getSource ());
362- resultTypes.push_back (castOp.getSource ().getType ());
363- } else {
364- resultTypes.push_back (operand.get ().getType ());
440+ // Remove direct casts.
441+ for (func::ReturnOp returnOp : returnOps) {
442+ for (OpOperand &operand : returnOp->getOpOperands ()) {
443+ // Bail if no common result type was found.
444+ if (resultTypes[operand.getOperandNumber ()]) {
445+ operand.set (unpackCast (operand.get ()));
446+ }
365447 }
366448 }
367449
450+ // Fill in the missing result types that were not the same among all
451+ // func.return ops.
452+ for (int i = 0 ; i < static_cast <int >(resultTypes.size ()); ++i) {
453+ if (resultTypes[i])
454+ continue ;
455+ resultTypes[i] = funcOp.getFunctionType ().getResult (i);
456+ }
457+
458+ // Update the function type.
368459 auto newFuncType = FunctionType::get (
369460 funcOp.getContext (), funcOp.getFunctionType ().getInputs (), resultTypes);
370461 funcOp.setType (newFuncType);
0 commit comments