@@ -82,7 +82,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
8282
8383// / Return the FuncOp called by `callOp`.
8484static FuncOp getCalledFunction (CallOpInterface callOp) {
85- SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee ());
85+ SymbolRefAttr sym =
86+ llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee ());
8687 if (!sym)
8788 return nullptr ;
8889 return dyn_cast_or_null<FuncOp>(
@@ -392,43 +393,45 @@ struct FuncOpInterface
392393 auto funcOp = cast<FuncOp>(op);
393394 FunctionType funcType = funcOp.getFunctionType ();
394395
395- // Construct the bufferized function type .
396+ // Compute the argument types .
396397 SmallVector<Type> argTypes;
397398 for (const auto &it : llvm::enumerate (funcType.getInputs ())) {
398399 Type argType = it.value ();
399- if (dyn_cast <TensorType>(argType)) {
400+ if (isa <TensorType>(argType)) {
400401 argTypes.push_back (
401402 getBufferizedFunctionArgType (funcOp, it.index (), options));
402403 continue ;
403404 }
404405 argTypes.push_back (argType);
405406 }
406407
407- // Bodiless functions are assumed opaque and we cannot know the
408- // bufferization contract they want to enforce. As a consequence, only
409- // support functions that don't return any tensors atm.
410- if (funcOp.isExternal ()) {
411- SmallVector<Type> retTypes;
412- for (Type resultType : funcType.getResults ()) {
413- if (auto tensorType = dyn_cast<TensorType>(resultType)) {
414- if (!options.bufferizeBodilessFunctionResults ) {
415- return funcOp->emitError () << " cannot bufferize bodiless function "
416- << " that returns a tensor" ;
417- }
418- retTypes.push_back (options.functionArgTypeConverterFn (
419- tensorType, *options.defaultMemorySpaceFn (tensorType), funcOp,
420- options));
421- } else {
422- retTypes.push_back (resultType);
423- }
408+ // Compute the result types.
409+ SmallVector<Type> retTypes;
410+ for (Type resultType : funcType.getResults ()) {
411+ if (auto tensorType = dyn_cast<TensorType>(resultType)) {
412+ BaseMemRefType resultType = options.functionArgTypeConverterFn (
413+ tensorType, *options.defaultMemorySpaceFn (tensorType), funcOp,
414+ options);
415+ retTypes.push_back (resultType);
416+ continue ;
424417 }
425- funcOp.setType (FunctionType::get (op->getContext (), argTypes, retTypes));
418+ retTypes.push_back (resultType);
419+ }
420+
421+ // Compute the new function type.
422+ auto newFuncType = FunctionType::get (op->getContext (), argTypes, retTypes);
423+
424+ // If the function has no body, set the new function type and we are done.
425+ if (funcOp.isExternal ()) {
426+ funcOp.setType (newFuncType);
426427 return success ();
427428 }
428429
429430 // TODO: Support functions with multiple returns.
430431 func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
431432 assert (returnOp && " expected func with single return op" );
433+ assert (returnOp->getNumOperands () == retTypes.size () &&
434+ " incorrect number of return values" );
432435 Location loc = returnOp.getLoc ();
433436
434437 // 1. Bufferize every block.
@@ -437,10 +440,10 @@ struct FuncOpInterface
437440 options)))
438441 return failure ();
439442
440- // 2. For each result, keep track of which inplace argument it reuses .
443+ // 2. Bufferize all operands of the return op .
441444 SmallVector<Value> returnValues;
442- for (OpOperand &returnOperand : returnOp-> getOpOperands ()) {
443- Value returnVal = returnOperand. get ();
445+ for (auto [returnVal, bufferizedType] :
446+ llvm::zip_equal (returnOp-> getOperands (), retTypes)) {
444447 auto tensorType = dyn_cast<TensorType>(returnVal.getType ());
445448 rewriter.setInsertionPoint (returnOp);
446449
@@ -450,23 +453,17 @@ struct FuncOpInterface
450453 continue ;
451454 }
452455
453- // Note: If `inferFunctionResultLayout = true`, cast are later folded
456+ // Note: If `inferFunctionResultLayout = true`, casts are later folded
454457 // away.
455- BaseMemRefType resultType = options.functionArgTypeConverterFn (
456- tensorType, *options.defaultMemorySpaceFn (tensorType), funcOp,
457- options);
458458 Value toMemrefOp = rewriter.create <bufferization::ToMemrefOp>(
459- loc, resultType , returnVal);
459+ loc, bufferizedType , returnVal);
460460 returnValues.push_back (toMemrefOp);
461461 }
462462
463- // 3. Rewrite the terminator without the in-place bufferizable values.
464463 returnOp.getOperandsMutable ().assign (returnValues);
465464
466- // 4. Rewrite the FuncOp type to buffer form.
467- funcOp.setType (FunctionType::get (op->getContext (), argTypes,
468- ValueRange (returnValues).getTypes ()));
469-
465+ // 3. Set the new function type.
466+ funcOp.setType (newFuncType);
470467 return success ();
471468 }
472469
0 commit comments