@@ -347,9 +347,10 @@ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
347347}
348348// / Default unknown type converter: Use a fully dynamic layout map.
349349BaseMemRefType
350- defaultUnknownTypeConverter (TensorType tensorType , Attribute memorySpace,
350+ defaultUnknownTypeConverter (Value value , Attribute memorySpace,
351351 const BufferizationOptions &options) {
352- return getMemRefTypeWithFullyDynamicLayout (tensorType, memorySpace);
352+ return getMemRefTypeWithFullyDynamicLayout (
353+ llvm::cast<TensorType>(value.getType ()), memorySpace);
353354}
354355
355356} // namespace
@@ -725,8 +726,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
725726 if (!memSpace.has_value ())
726727 return op->emitError (" could not infer memory space" );
727728
728- return getMemRefType (cast<TensorType>(value.getType ()), options,
729- /* layout=*/ {}, *memSpace);
729+ return getMemRefType (value, options, /* layout=*/ {}, *memSpace);
730730}
731731
732732bool bufferization::hasTensorSemantics (Operation *op) {
@@ -799,10 +799,12 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
799799// Bufferization-specific IRMapping support with debugging.
800800// ===----------------------------------------------------------------------===//
801801
802- BaseMemRefType bufferization::getMemRefType (TensorType tensorType ,
802+ BaseMemRefType bufferization::getMemRefType (Value value ,
803803 const BufferizationOptions &options,
804804 MemRefLayoutAttrInterface layout,
805805 Attribute memorySpace) {
806+ auto tensorType = llvm::cast<TensorType>(value.getType ());
807+
806808 // Case 1: Unranked memref type.
807809 if (auto unrankedTensorType =
808810 llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
@@ -819,7 +821,7 @@ BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
819821 memorySpace);
820822 }
821823
822- return options.unknownTypeConverterFn (tensorType , memorySpace, options);
824+ return options.unknownTypeConverterFn (value , memorySpace, options);
823825}
824826
825827BaseMemRefType
@@ -955,11 +957,10 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
955957 const BufferizationState &bufferizationState,
956958 SmallVector<Value> &invocationStack) {
957959 assert (llvm::isa<TensorType>(value.getType ()) && " expected tensor type" );
958- auto tensorType = cast<TensorType>(value.getType ());
959960
960961 // No further analysis is possible for a block argument.
961962 if (llvm::isa<BlockArgument>(value))
962- return bufferization::getMemRefType (tensorType , options);
963+ return bufferization::getMemRefType (value , options);
963964
964965 // Value is an OpResult.
965966 Operation *op = getOwnerOfValue (value);
@@ -982,7 +983,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
982983 if (!memSpace.has_value ())
983984 return op->emitError (" could not infer memory space" );
984985
985- return getMemRefType (tensorType , options, /* layout=*/ {}, *memSpace);
986+ return getMemRefType (value , options, /* layout=*/ {}, *memSpace);
986987}
987988
988989bool bufferization::detail::defaultIsRepetitiveRegion (
0 commit comments