@@ -49,29 +49,47 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
4949#endif // NDEBUG
5050}
5151
52+ // Note: this is a local adaptor to unify TensorType and TensorLikeType code
53+ // paths that both work with BufferizationOptions.
54+ static mlir::Attribute
55+ getDefaultMemorySpace (const BufferizationOptions &options,
56+ TensorLikeType type) {
57+ if (auto tensorType = dyn_cast<TensorType>(type)) {
58+ return *options.defaultMemorySpaceFn (tensorType);
59+ }
60+ return nullptr ;
61+ }
62+
5263// / Return the index-th bufferized function argument type. This assumes that the
5364// / specified argument is a tensor. If the tensor is ranked, a layout map may be
5465// / specified by the user (as per `options.functionArgTypeConverterFn`).
55- static BaseMemRefType
66+ static BufferLikeType
5667getBufferizedFunctionArgType (FuncOp funcOp, int64_t index,
5768 const BufferizationOptions &options) {
58- auto tensorType =
59- dyn_cast<TensorType>(funcOp.getFunctionType ().getInput (index));
60- assert (tensorType && " expected TensorType" );
61-
62- BaseMemRefType memrefType = options.functionArgTypeConverterFn (
63- tensorType, *options.defaultMemorySpaceFn (tensorType), funcOp, options);
64-
65- auto layoutAttr = funcOp.getArgAttrOfType <MemRefLayoutAttrInterface>(
66- index, BufferizationDialect::kBufferLayoutAttrName );
67- if (!layoutAttr)
68- return memrefType;
69-
70- auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
71- assert (rankedMemrefType && " buffer layout not supported on unranked tensors" );
72- return MemRefType::get (rankedMemrefType.getShape (),
73- rankedMemrefType.getElementType (), layoutAttr,
74- rankedMemrefType.getMemorySpace ());
69+ auto type =
70+ dyn_cast<TensorLikeType>(funcOp.getFunctionType ().getInput (index));
71+ assert (type && " expected TensorLikeType" );
72+
73+ // Note: For builtin tensors there is additional logic related to layout.
74+ if (auto tensorType = dyn_cast<TensorType>(type)) {
75+ BufferLikeType memrefType = options.functionArgTypeConverterFn (
76+ type, *options.defaultMemorySpaceFn (tensorType), funcOp, options);
77+
78+ auto layoutAttr = funcOp.getArgAttrOfType <MemRefLayoutAttrInterface>(
79+ index, BufferizationDialect::kBufferLayoutAttrName );
80+ if (!layoutAttr)
81+ return memrefType;
82+
83+ auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
84+ assert (rankedMemrefType &&
85+ " buffer layout not supported on unranked tensors" );
86+ return cast<BufferLikeType>(MemRefType::get (
87+ rankedMemrefType.getShape (), rankedMemrefType.getElementType (),
88+ layoutAttr, rankedMemrefType.getMemorySpace ()));
89+ }
90+
91+ return options.functionArgTypeConverterFn (type, /* memSpace=*/ nullptr , funcOp,
92+ options);
7593}
7694
7795// / Return the FuncOp called by `callOp`.
@@ -227,13 +245,13 @@ struct CallOpInterface
227245 FunctionType funcType = funcOp.getFunctionType ();
228246 Type resultType =
229247 funcType.getResult (cast<OpResult>(value).getResultNumber ());
230- if (auto bufferizedType = dyn_cast<BaseMemRefType >(resultType))
231- return cast<BufferLikeType>( bufferizedType) ;
248+ if (auto bufferizedType = dyn_cast<BufferLikeType >(resultType))
249+ return bufferizedType;
232250
233251 // Otherwise, call the type converter to compute the bufferized type.
234- auto tensorType = cast<TensorType >(resultType);
252+ auto tensorType = cast<TensorLikeType >(resultType);
235253 return cast<BufferLikeType>(options.functionArgTypeConverterFn (
236- tensorType, *options. defaultMemorySpaceFn ( tensorType), funcOp,
254+ tensorType, getDefaultMemorySpace (options, tensorType), funcOp,
237255 options));
238256 }
239257
@@ -248,7 +266,7 @@ struct CallOpInterface
248266 SmallVector<Type> resultTypes;
249267 for (Value result : callOp.getResults ()) {
250268 Type returnType = result.getType ();
251- if (!isa<TensorType >(returnType)) {
269+ if (!isa<TensorLikeType >(returnType)) {
252270 // Non-tensor values are returned.
253271 resultTypes.push_back (returnType);
254272 continue ;
@@ -272,7 +290,7 @@ struct CallOpInterface
272290
273291 for (OpOperand &opOperand : callOp->getOpOperands ()) {
274292 // Non-tensor operands are just copied.
275- if (!isa<TensorType >(opOperand.get ().getType ())) {
293+ if (!isa<TensorLikeType >(opOperand.get ().getType ())) {
276294 newOperands.push_back (opOperand.get ());
277295 continue ;
278296 }
@@ -285,8 +303,8 @@ struct CallOpInterface
285303 Value buffer = *maybeBuffer;
286304
287305 // Caller / callee type mismatch is handled with castOrReallocMemRefValue.
288- auto memRefType = funcType.getInput (opOperand.getOperandNumber ());
289- if (!isa<BaseMemRefType>(memRefType )) {
306+ auto bufferType = funcType.getInput (opOperand.getOperandNumber ());
307+ if (!isa<BufferLikeType>(bufferType )) {
290308 // The called function was not bufferized yet. This can happen when
291309 // there cycles in the function call graph. Compute the bufferized
292310 // result type.
@@ -296,7 +314,7 @@ struct CallOpInterface
296314 state);
297315 if (failed (maybeBufferType))
298316 return failure ();
299- memRefType = *maybeBufferType;
317+ bufferType = *maybeBufferType;
300318 }
301319
302320 // Since we don't yet have a clear layout story, to_buffer may
@@ -305,8 +323,8 @@ struct CallOpInterface
305323 // that will either canonicalize away or fail compilation until we can do
306324 // something better. Insert a reallocation + copy if it cannot be
307325 // statically guaranteed that a direct cast would be valid.
308- if (buffer.getType () != memRefType ) {
309- auto memrefDstType = dyn_cast<MemRefType>(memRefType );
326+ if (buffer.getType () != bufferType ) {
327+ auto memrefDstType = dyn_cast<MemRefType>(bufferType );
310328 assert (memrefDstType &&
311329 " buffer layout not supported on unranked tensors" );
312330 FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue (
@@ -370,7 +388,7 @@ struct FuncOpInterface
370388 static bool supportsUnstructuredControlFlow () { return true ; }
371389
372390 bool hasTensorSemantics (Operation *op) const {
373- auto isaTensor = llvm::IsaPred<TensorType >;
391+ auto isaTensor = llvm::IsaPred<TensorLikeType >;
374392
375393 // A function has tensor semantics if it has tensor arguments/results.
376394 auto funcOp = cast<FuncOp>(op);
@@ -406,8 +424,8 @@ struct FuncOpInterface
406424
407425 // Function arguments are special.
408426 if (bbArg.getOwner () == &funcOp.getBody ().front ())
409- return cast<BufferLikeType>(
410- getBufferizedFunctionArgType (funcOp, bbArg. getArgNumber (), options) );
427+ return getBufferizedFunctionArgType (funcOp, bbArg. getArgNumber (),
428+ options);
411429
412430 return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
413431 getBufferType (op, value, options, state, invocationStack);
@@ -430,7 +448,7 @@ struct FuncOpInterface
430448 SmallVector<Type> argTypes;
431449 for (const auto &it : llvm::enumerate (funcType.getInputs ())) {
432450 Type argType = it.value ();
433- if (isa<TensorType >(argType)) {
451+ if (isa<TensorLikeType >(argType)) {
434452 argTypes.push_back (
435453 getBufferizedFunctionArgType (funcOp, it.index (), options));
436454 continue ;
@@ -441,9 +459,9 @@ struct FuncOpInterface
441459 // Compute the result types.
442460 SmallVector<Type> retTypes;
443461 for (Type resultType : funcType.getResults ()) {
444- if (auto tensorType = dyn_cast<TensorType >(resultType)) {
445- BaseMemRefType resultType = options.functionArgTypeConverterFn (
446- tensorType, *options. defaultMemorySpaceFn ( tensorType), funcOp,
462+ if (auto tensorType = dyn_cast<TensorLikeType >(resultType)) {
463+ BufferLikeType resultType = options.functionArgTypeConverterFn (
464+ tensorType, getDefaultMemorySpace (options, tensorType), funcOp,
447465 options);
448466 retTypes.push_back (resultType);
449467 continue ;
@@ -473,7 +491,7 @@ struct FuncOpInterface
473491 SmallVector<Value> returnValues;
474492 for (auto [returnVal, bufferizedType] :
475493 llvm::zip_equal (returnOp->getOperands (), retTypes)) {
476- auto tensorType = dyn_cast<TensorType >(returnVal.getType ());
494+ auto tensorType = dyn_cast<TensorLikeType >(returnVal.getType ());
477495 rewriter.setInsertionPoint (returnOp);
478496
479497 // If not a tensor type just forward it.
0 commit comments