@@ -49,29 +49,47 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
4949#endif // NDEBUG
5050}
5151
52+ // Note: this is a local adaptor to unify TensorType and TensorLikeType code
53+ // paths that both work with BufferizationOptions.
54+ static mlir::Attribute
55+ getDefaultMemorySpace (const BufferizationOptions &options,
56+ TensorLikeType type) {
57+ if (auto tensorType = dyn_cast<TensorType>(type)) {
58+ return *options.defaultMemorySpaceFn (tensorType);
59+ }
60+ return nullptr ;
61+ }
62+
5263// / Return the index-th bufferized function argument type. This assumes that the
5364// / specified argument is a tensor. If the tensor is ranked, a layout map may be
5465// / specified by the user (as per `options.functionArgTypeConverterFn`).
55- static BaseMemRefType
66+ static BufferLikeType
5667getBufferizedFunctionArgType (FuncOp funcOp, int64_t index,
5768 const BufferizationOptions &options) {
58- auto tensorType =
59- dyn_cast<TensorType>(funcOp.getFunctionType ().getInput (index));
60- assert (tensorType && " expected TensorType" );
61-
62- BaseMemRefType memrefType = options.functionArgTypeConverterFn (
63- tensorType, *options.defaultMemorySpaceFn (tensorType), funcOp, options);
64-
65- auto layoutAttr = funcOp.getArgAttrOfType <MemRefLayoutAttrInterface>(
66- index, BufferizationDialect::kBufferLayoutAttrName );
67- if (!layoutAttr)
68- return memrefType;
69-
70- auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
71- assert (rankedMemrefType && " buffer layout not supported on unranked tensors" );
72- return MemRefType::get (rankedMemrefType.getShape (),
73- rankedMemrefType.getElementType (), layoutAttr,
74- rankedMemrefType.getMemorySpace ());
69+ auto type =
70+ dyn_cast<TensorLikeType>(funcOp.getFunctionType ().getInput (index));
71+ assert (type && " expected TensorLikeType" );
72+
73+ // Note: For builtin tensors there is additional logic related to layout.
74+ if (auto tensorType = dyn_cast<TensorType>(type)) {
75+ BufferLikeType memrefType = options.functionArgTypeConverterFn (
76+ type, *options.defaultMemorySpaceFn (tensorType), funcOp, options);
77+
78+ auto layoutAttr = funcOp.getArgAttrOfType <MemRefLayoutAttrInterface>(
79+ index, BufferizationDialect::kBufferLayoutAttrName );
80+ if (!layoutAttr)
81+ return memrefType;
82+
83+ auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
84+ assert (rankedMemrefType &&
85+ " buffer layout not supported on unranked tensors" );
86+ return cast<BufferLikeType>(MemRefType::get (
87+ rankedMemrefType.getShape (), rankedMemrefType.getElementType (),
88+ layoutAttr, rankedMemrefType.getMemorySpace ()));
89+ }
90+
91+ return options.functionArgTypeConverterFn (type, /* memSpace=*/ nullptr , funcOp,
92+ options);
7593}
7694
7795// / Return the FuncOp called by `callOp`.
@@ -227,13 +245,13 @@ struct CallOpInterface
227245 FunctionType funcType = funcOp.getFunctionType ();
228246 Type resultType =
229247 funcType.getResult (cast<OpResult>(value).getResultNumber ());
230- if (auto bufferizedType = dyn_cast<BaseMemRefType >(resultType))
231- return cast<BufferLikeType>( bufferizedType) ;
248+ if (auto bufferizedType = dyn_cast<BufferLikeType >(resultType))
249+ return bufferizedType;
232250
233251 // Otherwise, call the type converter to compute the bufferized type.
234- auto tensorType = cast<TensorType >(resultType);
252+ auto tensorType = cast<TensorLikeType >(resultType);
235253 return cast<BufferLikeType>(options.functionArgTypeConverterFn (
236- tensorType, *options. defaultMemorySpaceFn ( tensorType), funcOp,
254+ tensorType, getDefaultMemorySpace (options, tensorType), funcOp,
237255 options));
238256 }
239257
@@ -248,7 +266,7 @@ struct CallOpInterface
248266 SmallVector<Type> resultTypes;
249267 for (Value result : callOp.getResults ()) {
250268 Type returnType = result.getType ();
251- if (!isa<TensorType >(returnType)) {
269+ if (!isa<TensorLikeType >(returnType)) {
252270 // Non-tensor values are returned.
253271 resultTypes.push_back (returnType);
254272 continue ;
@@ -272,7 +290,7 @@ struct CallOpInterface
272290
273291 for (OpOperand &opOperand : callOp->getOpOperands ()) {
274292 // Non-tensor operands are just copied.
275- if (!isa<TensorType >(opOperand.get ().getType ())) {
293+ if (!isa<TensorLikeType >(opOperand.get ().getType ())) {
276294 newOperands.push_back (opOperand.get ());
277295 continue ;
278296 }
@@ -285,8 +303,8 @@ struct CallOpInterface
285303 Value buffer = *maybeBuffer;
286304
287305 // Caller / callee type mismatch is handled with castOrReallocMemRefValue.
288- auto memRefType = funcType.getInput (opOperand.getOperandNumber ());
289- if (!isa<BaseMemRefType>(memRefType )) {
306+ auto bufferType = funcType.getInput (opOperand.getOperandNumber ());
307+ if (!isa<BufferLikeType>(bufferType )) {
290308 // The called function was not bufferized yet. This can happen when
291309 // there cycles in the function call graph. Compute the bufferized
292310 // result type.
@@ -295,7 +313,7 @@ struct CallOpInterface
295313 funcOp.getArgument (opOperand.getOperandNumber ()), options);
296314 if (failed (maybeBufferType))
297315 return failure ();
298- memRefType = *maybeBufferType;
316+ bufferType = *maybeBufferType;
299317 }
300318
301319 // Since we don't yet have a clear layout story, to_buffer may
@@ -304,8 +322,8 @@ struct CallOpInterface
304322 // that will either canonicalize away or fail compilation until we can do
305323 // something better. Insert a reallocation + copy if it cannot be
306324 // statically guaranteed that a direct cast would be valid.
307- if (buffer.getType () != memRefType ) {
308- auto memrefDstType = dyn_cast<MemRefType>(memRefType );
325+ if (buffer.getType () != bufferType ) {
326+ auto memrefDstType = dyn_cast<MemRefType>(bufferType );
309327 assert (memrefDstType &&
310328 " buffer layout not supported on unranked tensors" );
311329 FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue (
@@ -368,7 +386,7 @@ struct FuncOpInterface
368386 static bool supportsUnstructuredControlFlow () { return true ; }
369387
370388 bool hasTensorSemantics (Operation *op) const {
371- auto isaTensor = llvm::IsaPred<TensorType >;
389+ auto isaTensor = llvm::IsaPred<TensorLikeType >;
372390
373391 // A function has tensor semantics if it has tensor arguments/results.
374392 auto funcOp = cast<FuncOp>(op);
@@ -404,8 +422,8 @@ struct FuncOpInterface
404422
405423 // Function arguments are special.
406424 if (bbArg.getOwner () == &funcOp.getBody ().front ())
407- return cast<BufferLikeType>(
408- getBufferizedFunctionArgType (funcOp, bbArg. getArgNumber (), options) );
425+ return getBufferizedFunctionArgType (funcOp, bbArg. getArgNumber (),
426+ options);
409427
410428 return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
411429 getBufferType (op, value, options, state, invocationStack);
@@ -428,7 +446,7 @@ struct FuncOpInterface
428446 SmallVector<Type> argTypes;
429447 for (const auto &it : llvm::enumerate (funcType.getInputs ())) {
430448 Type argType = it.value ();
431- if (isa<TensorType >(argType)) {
449+ if (isa<TensorLikeType >(argType)) {
432450 argTypes.push_back (
433451 getBufferizedFunctionArgType (funcOp, it.index (), options));
434452 continue ;
@@ -439,9 +457,9 @@ struct FuncOpInterface
439457 // Compute the result types.
440458 SmallVector<Type> retTypes;
441459 for (Type resultType : funcType.getResults ()) {
442- if (auto tensorType = dyn_cast<TensorType >(resultType)) {
443- BaseMemRefType resultType = options.functionArgTypeConverterFn (
444- tensorType, *options. defaultMemorySpaceFn ( tensorType), funcOp,
460+ if (auto tensorType = dyn_cast<TensorLikeType >(resultType)) {
461+ BufferLikeType resultType = options.functionArgTypeConverterFn (
462+ tensorType, getDefaultMemorySpace (options, tensorType), funcOp,
445463 options);
446464 retTypes.push_back (resultType);
447465 continue ;
@@ -471,7 +489,7 @@ struct FuncOpInterface
471489 SmallVector<Value> returnValues;
472490 for (auto [returnVal, bufferizedType] :
473491 llvm::zip_equal (returnOp->getOperands (), retTypes)) {
474- auto tensorType = dyn_cast<TensorType >(returnVal.getType ());
492+ auto tensorType = dyn_cast<TensorLikeType >(returnVal.getType ());
475493 rewriter.setInsertionPoint (returnOp);
476494
477495 // If not a tensor type just forward it.
0 commit comments