@@ -49,29 +49,47 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
49
49
#endif // NDEBUG
50
50
}
51
51
52
+ // Note: this is a local adaptor to unify TensorType and TensorLikeType code
53
+ // paths that both work with BufferizationOptions.
54
+ static mlir::Attribute
55
+ getDefaultMemorySpace (const BufferizationOptions &options,
56
+ TensorLikeType type) {
57
+ if (auto tensorType = dyn_cast<TensorType>(type)) {
58
+ return *options.defaultMemorySpaceFn (tensorType);
59
+ }
60
+ return nullptr ;
61
+ }
62
+
52
63
// / Return the index-th bufferized function argument type. This assumes that the
53
64
// / specified argument is a tensor. If the tensor is ranked, a layout map may be
54
65
// / specified by the user (as per `options.functionArgTypeConverterFn`).
55
- static BaseMemRefType
66
+ static BufferLikeType
56
67
getBufferizedFunctionArgType (FuncOp funcOp, int64_t index,
57
68
const BufferizationOptions &options) {
58
- auto tensorType =
59
- dyn_cast<TensorType>(funcOp.getFunctionType ().getInput (index));
60
- assert (tensorType && " expected TensorType" );
61
-
62
- BaseMemRefType memrefType = options.functionArgTypeConverterFn (
63
- tensorType, *options.defaultMemorySpaceFn (tensorType), funcOp, options);
64
-
65
- auto layoutAttr = funcOp.getArgAttrOfType <MemRefLayoutAttrInterface>(
66
- index, BufferizationDialect::kBufferLayoutAttrName );
67
- if (!layoutAttr)
68
- return memrefType;
69
-
70
- auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
71
- assert (rankedMemrefType && " buffer layout not supported on unranked tensors" );
72
- return MemRefType::get (rankedMemrefType.getShape (),
73
- rankedMemrefType.getElementType (), layoutAttr,
74
- rankedMemrefType.getMemorySpace ());
69
+ auto type =
70
+ dyn_cast<TensorLikeType>(funcOp.getFunctionType ().getInput (index));
71
+ assert (type && " expected TensorLikeType" );
72
+
73
+ // Note: For builtin tensors there is additional logic related to layout.
74
+ if (auto tensorType = dyn_cast<TensorType>(type)) {
75
+ BufferLikeType memrefType = options.functionArgTypeConverterFn (
76
+ type, *options.defaultMemorySpaceFn (tensorType), funcOp, options);
77
+
78
+ auto layoutAttr = funcOp.getArgAttrOfType <MemRefLayoutAttrInterface>(
79
+ index, BufferizationDialect::kBufferLayoutAttrName );
80
+ if (!layoutAttr)
81
+ return memrefType;
82
+
83
+ auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
84
+ assert (rankedMemrefType &&
85
+ " buffer layout not supported on unranked tensors" );
86
+ return cast<BufferLikeType>(MemRefType::get (
87
+ rankedMemrefType.getShape (), rankedMemrefType.getElementType (),
88
+ layoutAttr, rankedMemrefType.getMemorySpace ()));
89
+ }
90
+
91
+ return options.functionArgTypeConverterFn (type, /* memSpace=*/ nullptr , funcOp,
92
+ options);
75
93
}
76
94
77
95
// / Return the FuncOp called by `callOp`.
@@ -227,13 +245,13 @@ struct CallOpInterface
227
245
FunctionType funcType = funcOp.getFunctionType ();
228
246
Type resultType =
229
247
funcType.getResult (cast<OpResult>(value).getResultNumber ());
230
- if (auto bufferizedType = dyn_cast<BaseMemRefType >(resultType))
231
- return cast<BufferLikeType>( bufferizedType) ;
248
+ if (auto bufferizedType = dyn_cast<BufferLikeType >(resultType))
249
+ return bufferizedType;
232
250
233
251
// Otherwise, call the type converter to compute the bufferized type.
234
- auto tensorType = cast<TensorType >(resultType);
252
+ auto tensorType = cast<TensorLikeType >(resultType);
235
253
return cast<BufferLikeType>(options.functionArgTypeConverterFn (
236
- tensorType, *options. defaultMemorySpaceFn ( tensorType), funcOp,
254
+ tensorType, getDefaultMemorySpace (options, tensorType), funcOp,
237
255
options));
238
256
}
239
257
@@ -248,7 +266,7 @@ struct CallOpInterface
248
266
SmallVector<Type> resultTypes;
249
267
for (Value result : callOp.getResults ()) {
250
268
Type returnType = result.getType ();
251
- if (!isa<TensorType >(returnType)) {
269
+ if (!isa<TensorLikeType >(returnType)) {
252
270
// Non-tensor values are returned.
253
271
resultTypes.push_back (returnType);
254
272
continue ;
@@ -272,7 +290,7 @@ struct CallOpInterface
272
290
273
291
for (OpOperand &opOperand : callOp->getOpOperands ()) {
274
292
// Non-tensor operands are just copied.
275
- if (!isa<TensorType >(opOperand.get ().getType ())) {
293
+ if (!isa<TensorLikeType >(opOperand.get ().getType ())) {
276
294
newOperands.push_back (opOperand.get ());
277
295
continue ;
278
296
}
@@ -285,8 +303,8 @@ struct CallOpInterface
285
303
Value buffer = *maybeBuffer;
286
304
287
305
// Caller / callee type mismatch is handled with castOrReallocMemRefValue.
288
- auto memRefType = funcType.getInput (opOperand.getOperandNumber ());
289
- if (!isa<BaseMemRefType>(memRefType )) {
306
+ auto bufferType = funcType.getInput (opOperand.getOperandNumber ());
307
+ if (!isa<BufferLikeType>(bufferType )) {
290
308
// The called function was not bufferized yet. This can happen when
291
309
// there cycles in the function call graph. Compute the bufferized
292
310
// result type.
@@ -296,7 +314,7 @@ struct CallOpInterface
296
314
state);
297
315
if (failed (maybeBufferType))
298
316
return failure ();
299
- memRefType = *maybeBufferType;
317
+ bufferType = *maybeBufferType;
300
318
}
301
319
302
320
// Since we don't yet have a clear layout story, to_buffer may
@@ -305,8 +323,8 @@ struct CallOpInterface
305
323
// that will either canonicalize away or fail compilation until we can do
306
324
// something better. Insert a reallocation + copy if it cannot be
307
325
// statically guaranteed that a direct cast would be valid.
308
- if (buffer.getType () != memRefType ) {
309
- auto memrefDstType = dyn_cast<MemRefType>(memRefType );
326
+ if (buffer.getType () != bufferType ) {
327
+ auto memrefDstType = dyn_cast<MemRefType>(bufferType );
310
328
assert (memrefDstType &&
311
329
" buffer layout not supported on unranked tensors" );
312
330
FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue (
@@ -370,7 +388,7 @@ struct FuncOpInterface
370
388
static bool supportsUnstructuredControlFlow () { return true ; }
371
389
372
390
bool hasTensorSemantics (Operation *op) const {
373
- auto isaTensor = llvm::IsaPred<TensorType >;
391
+ auto isaTensor = llvm::IsaPred<TensorLikeType >;
374
392
375
393
// A function has tensor semantics if it has tensor arguments/results.
376
394
auto funcOp = cast<FuncOp>(op);
@@ -406,8 +424,8 @@ struct FuncOpInterface
406
424
407
425
// Function arguments are special.
408
426
if (bbArg.getOwner () == &funcOp.getBody ().front ())
409
- return cast<BufferLikeType>(
410
- getBufferizedFunctionArgType (funcOp, bbArg. getArgNumber (), options) );
427
+ return getBufferizedFunctionArgType (funcOp, bbArg. getArgNumber (),
428
+ options);
411
429
412
430
return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
413
431
getBufferType (op, value, options, state, invocationStack);
@@ -430,7 +448,7 @@ struct FuncOpInterface
430
448
SmallVector<Type> argTypes;
431
449
for (const auto &it : llvm::enumerate (funcType.getInputs ())) {
432
450
Type argType = it.value ();
433
- if (isa<TensorType >(argType)) {
451
+ if (isa<TensorLikeType >(argType)) {
434
452
argTypes.push_back (
435
453
getBufferizedFunctionArgType (funcOp, it.index (), options));
436
454
continue ;
@@ -441,9 +459,9 @@ struct FuncOpInterface
441
459
// Compute the result types.
442
460
SmallVector<Type> retTypes;
443
461
for (Type resultType : funcType.getResults ()) {
444
- if (auto tensorType = dyn_cast<TensorType >(resultType)) {
445
- BaseMemRefType resultType = options.functionArgTypeConverterFn (
446
- tensorType, *options. defaultMemorySpaceFn ( tensorType), funcOp,
462
+ if (auto tensorType = dyn_cast<TensorLikeType >(resultType)) {
463
+ BufferLikeType resultType = options.functionArgTypeConverterFn (
464
+ tensorType, getDefaultMemorySpace (options, tensorType), funcOp,
447
465
options);
448
466
retTypes.push_back (resultType);
449
467
continue ;
@@ -473,7 +491,7 @@ struct FuncOpInterface
473
491
SmallVector<Value> returnValues;
474
492
for (auto [returnVal, bufferizedType] :
475
493
llvm::zip_equal (returnOp->getOperands (), retTypes)) {
476
- auto tensorType = dyn_cast<TensorType >(returnVal.getType ());
494
+ auto tensorType = dyn_cast<TensorLikeType >(returnVal.getType ());
477
495
rewriter.setInsertionPoint (returnOp);
478
496
479
497
// If not a tensor type just forward it.
0 commit comments