@@ -86,20 +86,6 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
8686 return state.addExtension <FuncAnalysisState>();
8787}
8888
89- // / Return the unique ReturnOp that terminates `funcOp`.
90- // / Return nullptr if there is no such unique ReturnOp.
91- static func::ReturnOp getAssumedUniqueReturnOp (func::FuncOp funcOp) {
92- func::ReturnOp returnOp;
93- for (Block &b : funcOp.getBody ()) {
94- if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator ())) {
95- if (returnOp)
96- return nullptr ;
97- returnOp = candidateOp;
98- }
99- }
100- return returnOp;
101- }
102-
10389namespace {
10490
10591// / Annotate IR with the results of the analysis. For testing purposes only.
@@ -146,24 +132,80 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
146132 return success ();
147133 }
148134
149- // Support only single return-terminated block in the function.
150- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
151- assert (returnOp && " expected func with single return op" );
152-
153- for (OpOperand &returnVal : returnOp->getOpOperands ())
154- if (isa<RankedTensorType>(returnVal.get ().getType ()))
155- for (BlockArgument bbArg : funcOp.getArguments ())
156- if (isa<RankedTensorType>(bbArg.getType ())) {
157- int64_t returnIdx = returnVal.getOperandNumber ();
158- int64_t bbArgIdx = bbArg.getArgNumber ();
159- if (state.areEquivalentBufferizedValues (returnVal.get (), bbArg)) {
160- funcState.equivalentFuncArgs [funcOp][returnIdx] = bbArgIdx;
161- if (state.getOptions ().testAnalysisOnly )
162- annotateEquivalentReturnBbArg (returnVal, bbArg);
135+ // Find all func.return ops.
136+ SmallVector<func::ReturnOp> returnOps = getReturnOps (funcOp);
137+ assert (!returnOps.empty () && " expected at least one ReturnOp" );
138+
139+ // Build alias sets. Merge all aliases from all func.return ops.
140+ for (BlockArgument bbArg : funcOp.getArguments ()) {
141+ if (isa<RankedTensorType>(bbArg.getType ())) {
142+ int64_t bbArgIdx = bbArg.getArgNumber ();
143+ // Store aliases in a set, so that we don't add the same alias twice.
144+ SetVector<int64_t > aliases;
145+ for (func::ReturnOp returnOp : returnOps) {
146+ for (OpOperand &returnVal : returnOp->getOpOperands ()) {
147+ if (isa<RankedTensorType>(returnVal.get ().getType ())) {
148+ int64_t returnIdx = returnVal.getOperandNumber ();
149+ if (state.areAliasingBufferizedValues (returnVal.get (), bbArg))
150+ aliases.insert (returnIdx);
163151 }
164- if (state.areAliasingBufferizedValues (returnVal.get (), bbArg))
165- funcState.aliasingReturnVals [funcOp][bbArgIdx].push_back (returnIdx);
166152 }
153+ }
154+ for (int64_t alias : aliases)
155+ funcState.aliasingReturnVals [funcOp][bbArgIdx].push_back (alias);
156+ }
157+ }
158+
159+ // Build equivalence sets.
160+ // Helper function that finds an equivalent block argument index for the
161+ // given OpOperand. Return std::nullopt if no equivalent block argument could
162+ // be found.
163+ auto findEquivalentBlockArgIdx =
164+ [&](OpOperand &opOperand) -> std::optional<int64_t > {
165+ Value v = opOperand.get ();
166+ if (!isa<TensorType>(v.getType ()))
167+ return std::nullopt ;
168+ for (BlockArgument bbArg : funcOp.getArguments ()) {
169+ if (isa<RankedTensorType>(bbArg.getType ())) {
170+ if (state.areEquivalentBufferizedValues (v, bbArg)) {
171+ if (state.getOptions ().testAnalysisOnly )
172+ annotateEquivalentReturnBbArg (opOperand, bbArg);
173+ return bbArg.getArgNumber ();
174+ }
175+ }
176+ }
177+ return std::nullopt ;
178+ };
179+
180+ int64_t numResults = returnOps.front ()->getNumOperands ();
181+ for (int64_t i = 0 ; i < numResults; ++i) {
182+ // Find the equivalent block argument index for the i-th operand of the
183+ // first func.return op.
184+ std::optional<int64_t > maybeEquiv =
185+ findEquivalentBlockArgIdx (returnOps.front ()->getOpOperand (i));
186+ if (!maybeEquiv.has_value ())
187+ continue ;
188+ int64_t bbArgIdx = *maybeEquiv;
189+ bool allEquiv = true ;
190+
191+ // Check if all other func.return ops have the same equivalent block
192+ // argument for the i-th operand. In contrast to aliasing information,
193+ // which is just "merged", equivalence information must match across all
194+ // func.return ops.
195+ for (func::ReturnOp returnOp : ArrayRef (returnOps).drop_front ()) {
196+ std::optional<int64_t > maybeEquiv =
197+ findEquivalentBlockArgIdx (returnOp->getOpOperand (i));
198+ if (maybeEquiv != bbArgIdx) {
199+ allEquiv = false ;
200+ break ;
201+ }
202+ }
203+
204+ // All func.return ops have the same equivalent block argument for the i-th
205+ // operand.
206+ if (allEquiv)
207+ funcState.equivalentFuncArgs [funcOp][i] = bbArgIdx;
208+ }
167209
168210 return success ();
169211}
@@ -302,14 +344,6 @@ static LogicalResult getFuncOpsOrderedByCalls(
302344 // For each FuncOp, the number of func::CallOp it contains.
303345 DenseMap<func::FuncOp, unsigned > numberCallOpsContainedInFuncOp;
304346 WalkResult res = moduleOp.walk ([&](func::FuncOp funcOp) -> WalkResult {
305- if (!funcOp.getBody ().empty ()) {
306- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
307- if (!returnOp)
308- return funcOp->emitError ()
309- << " cannot bufferize a FuncOp with tensors and "
310- " without a unique ReturnOp" ;
311- }
312-
313347 // Collect function calls and populate the caller map.
314348 numberCallOpsContainedInFuncOp[funcOp] = 0 ;
315349 return funcOp.walk ([&](func::CallOp callOp) -> WalkResult {
@@ -351,6 +385,42 @@ static LogicalResult getFuncOpsOrderedByCalls(
351385 return success ();
352386}
353387
388+ // / Helper function that extracts the source from a memref.cast. If the given
389+ // / value is not a memref.cast result, simply returns the given value.
390+ static Value unpackCast (Value v) {
391+ auto castOp = v.getDefiningOp <memref::CastOp>();
392+ if (!castOp)
393+ return v;
394+ return castOp.getSource ();
395+ }
396+
397+ // / Helper function that returns the return types (skipping casts) of the given
398+ // / func.return ops. This function returns as many types as the return ops have
399+ // / operands. If the i-th operand is not the same for all func.return ops, then
400+ // / the i-th returned type is an "empty" type.
401+ static SmallVector<Type> getReturnTypes (SmallVector<func::ReturnOp> returnOps) {
402+ assert (!returnOps.empty () && " expected at least one ReturnOp" );
403+ int numOperands = returnOps.front ()->getNumOperands ();
404+
405+ // Helper function that unpacks memref.cast ops and returns the type.
406+ auto getSourceType = [&](Value v) { return unpackCast (v).getType (); };
407+
408+ SmallVector<Type> result;
409+ for (int i = 0 ; i < numOperands; ++i) {
410+ // Get the type of the i-th operand of the first func.return ops.
411+ Type t = getSourceType (returnOps.front ()->getOperand (i));
412+
413+ // Check if all other func.return ops have a matching operand type.
414+ for (int j = 1 ; j < static_cast <int >(returnOps.size ()); ++j)
415+ if (getSourceType (returnOps[j]->getOperand (i)) != t)
416+ t = Type ();
417+
418+ result.push_back (t);
419+ }
420+
421+ return result;
422+ }
423+
354424// / Fold return values that are memref casts and update function return types.
355425// /
356426// / During FuncOp bufferization, the exact type of the returned memrefs (if any)
@@ -359,21 +429,33 @@ static LogicalResult getFuncOpsOrderedByCalls(
359429// / entire function body, a more concise memref type can potentially be used for
360430// / the return type of the function.
361431static void foldMemRefCasts (func::FuncOp funcOp) {
432+ // There is nothing to do for bodiless ops.
362433 if (funcOp.getBody ().empty ())
363434 return ;
364435
365- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
366- SmallVector<Type> resultTypes;
436+ // Compute the common result types of all return ops.
437+ SmallVector<func::ReturnOp> returnOps = getReturnOps (funcOp);
438+ SmallVector<Type> resultTypes = getReturnTypes (returnOps);
367439
368- for (OpOperand &operand : returnOp->getOpOperands ()) {
369- if (auto castOp = operand.get ().getDefiningOp <memref::CastOp>()) {
370- operand.set (castOp.getSource ());
371- resultTypes.push_back (castOp.getSource ().getType ());
372- } else {
373- resultTypes.push_back (operand.get ().getType ());
440+ // Remove direct casts.
441+ for (func::ReturnOp returnOp : returnOps) {
442+ for (OpOperand &operand : returnOp->getOpOperands ()) {
443+ // Bail if no common result type was found.
444+ if (resultTypes[operand.getOperandNumber ()]) {
445+ operand.set (unpackCast (operand.get ()));
446+ }
374447 }
375448 }
376449
450+ // Fill in the missing result types that were not the same among all
451+ // func.return ops.
452+ for (int i = 0 ; i < static_cast <int >(resultTypes.size ()); ++i) {
453+ if (resultTypes[i])
454+ continue ;
455+ resultTypes[i] = funcOp.getFunctionType ().getResult (i);
456+ }
457+
458+ // Update the function type.
377459 auto newFuncType = FunctionType::get (
378460 funcOp.getContext (), funcOp.getFunctionType ().getInputs (), resultTypes);
379461 funcOp.setType (newFuncType);
0 commit comments