@@ -82,7 +82,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
8282
8383// / Return the FuncOp called by `callOp`.
8484static FuncOp getCalledFunction (CallOpInterface callOp) {
85- SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee ());
85+ SymbolRefAttr sym =
86+ llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee ());
8687 if (!sym)
8788 return nullptr ;
8889 return dyn_cast_or_null<FuncOp>(
@@ -392,36 +393,45 @@ struct FuncOpInterface
392393 auto funcOp = cast<FuncOp>(op);
393394 FunctionType funcType = funcOp.getFunctionType ();
394395
395- // Construct the bufferized function type .
396+ // Compute the argument types .
396397 SmallVector<Type> argTypes;
397398 for (const auto &it : llvm::enumerate (funcType.getInputs ())) {
398399 Type argType = it.value ();
399- if (dyn_cast <TensorType>(argType)) {
400+ if (isa <TensorType>(argType)) {
400401 argTypes.push_back (
401402 getBufferizedFunctionArgType (funcOp, it.index (), options));
402403 continue ;
403404 }
404405 argTypes.push_back (argType);
405406 }
406407
407- // Bodiless functions are assumed opaque and we cannot know the
408- // bufferization contract they want to enforce. As a consequence, only
409- // support functions that don't return any tensors atm.
410- if (funcOp.isExternal ()) {
411- SmallVector<Type> retTypes;
412- for (Type resultType : funcType.getResults ()) {
413- if (isa<TensorType>(resultType))
414- return funcOp->emitError () << " cannot bufferize bodiless function "
415- << " that returns a tensor" ;
408+ // Compute the result types.
409+ SmallVector<Type> retTypes;
410+ for (Type resultType : funcType.getResults ()) {
411+ if (auto tensorType = dyn_cast<TensorType>(resultType)) {
412+ BaseMemRefType resultType = options.functionArgTypeConverterFn (
413+ tensorType, *options.defaultMemorySpaceFn (tensorType), funcOp,
414+ options);
416415 retTypes.push_back (resultType);
416+ continue ;
417417 }
418- funcOp.setType (FunctionType::get (op->getContext (), argTypes, retTypes));
418+ retTypes.push_back (resultType);
419+ }
420+
421+ // Compute the new function type.
422+ auto newFuncType = FunctionType::get (op->getContext (), argTypes, retTypes);
423+
424+ // If the function has no body, set the new function type and we are done.
425+ if (funcOp.isExternal ()) {
426+ funcOp.setType (newFuncType);
419427 return success ();
420428 }
421429
422430 // TODO: Support functions with multiple returns.
423431 func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
424432 assert (returnOp && " expected func with single return op" );
433+ assert (returnOp->getNumOperands () == retTypes.size () &&
434+ " incorrect number of return values" );
425435 Location loc = returnOp.getLoc ();
426436
427437 // 1. Bufferize every block.
@@ -430,10 +440,10 @@ struct FuncOpInterface
430440 options)))
431441 return failure ();
432442
433- // 2. For each result, keep track of which inplace argument it reuses .
443+ // 2. Bufferize all operands of the return op .
434444 SmallVector<Value> returnValues;
435- for (OpOperand &returnOperand : returnOp-> getOpOperands ()) {
436- Value returnVal = returnOperand. get ();
445+ for (auto [returnVal, bufferizedType] :
446+ llvm::zip_equal (returnOp-> getOperands (), retTypes)) {
437447 auto tensorType = dyn_cast<TensorType>(returnVal.getType ());
438448 rewriter.setInsertionPoint (returnOp);
439449
@@ -443,23 +453,17 @@ struct FuncOpInterface
443453 continue ;
444454 }
445455
446- // Note: If `inferFunctionResultLayout = true`, cast are later folded
456+ // Note: If `inferFunctionResultLayout = true`, casts are later folded
447457 // away.
448- BaseMemRefType resultType = options.functionArgTypeConverterFn (
449- tensorType, *options.defaultMemorySpaceFn (tensorType), funcOp,
450- options);
451458 Value toMemrefOp = rewriter.create <bufferization::ToMemrefOp>(
452- loc, resultType , returnVal);
459+ loc, bufferizedType , returnVal);
453460 returnValues.push_back (toMemrefOp);
454461 }
455462
456- // 3. Rewrite the terminator without the in-place bufferizable values.
457463 returnOp.getOperandsMutable ().assign (returnValues);
458464
459- // 4. Rewrite the FuncOp type to buffer form.
460- funcOp.setType (FunctionType::get (op->getContext (), argTypes,
461- ValueRange (returnValues).getTypes ()));
462-
465+ // 3. Set the new function type.
466+ funcOp.setType (newFuncType);
463467 return success ();
464468 }
465469
0 commit comments