@@ -345,9 +345,10 @@ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
345345}
346346// / Default unknown type converter: Use a fully dynamic layout map.
347347BaseMemRefType
348- defaultUnknownTypeConverter (TensorType tensorType , Attribute memorySpace,
348+ defaultUnknownTypeConverter (Value value , Attribute memorySpace,
349349 const BufferizationOptions &options) {
350- return getMemRefTypeWithFullyDynamicLayout (tensorType, memorySpace);
350+ return getMemRefTypeWithFullyDynamicLayout (
351+ llvm::cast<TensorType>(value.getType ()), memorySpace);
351352}
352353
353354} // namespace
@@ -723,8 +724,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
723724 if (!memSpace.has_value ())
724725 return op->emitError (" could not infer memory space" );
725726
726- return getMemRefType (cast<TensorType>(value.getType ()), options,
727- /* layout=*/ {}, *memSpace);
727+ return getMemRefType (value, options, /* layout=*/ {}, *memSpace);
728728}
729729
730730bool bufferization::hasTensorSemantics (Operation *op) {
@@ -797,10 +797,12 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
797797// Bufferization-specific IRMapping support with debugging.
798798// ===----------------------------------------------------------------------===//
799799
800- BaseMemRefType bufferization::getMemRefType (TensorType tensorType ,
800+ BaseMemRefType bufferization::getMemRefType (Value value ,
801801 const BufferizationOptions &options,
802802 MemRefLayoutAttrInterface layout,
803803 Attribute memorySpace) {
804+ auto tensorType = llvm::cast<TensorType>(value.getType ());
805+
804806 // Case 1: Unranked memref type.
805807 if (auto unrankedTensorType =
806808 llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
@@ -817,7 +819,7 @@ BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
817819 memorySpace);
818820 }
819821
820- return options.unknownTypeConverterFn (tensorType , memorySpace, options);
822+ return options.unknownTypeConverterFn (value , memorySpace, options);
821823}
822824
823825BaseMemRefType
@@ -953,11 +955,10 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
953955 const BufferizationState &bufferizationState,
954956 SmallVector<Value> &invocationStack) {
955957 assert (llvm::isa<TensorType>(value.getType ()) && " expected tensor type" );
956- auto tensorType = cast<TensorType>(value.getType ());
957958
958959 // No further analysis is possible for a block argument.
959960 if (llvm::isa<BlockArgument>(value))
960- return bufferization::getMemRefType (tensorType , options);
961+ return bufferization::getMemRefType (value , options);
961962
962963 // Value is an OpResult.
963964 Operation *op = getOwnerOfValue (value);
@@ -980,7 +981,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
980981 if (!memSpace.has_value ())
981982 return op->emitError (" could not infer memory space" );
982983
983- return getMemRefType (tensorType , options, /* layout=*/ {}, *memSpace);
984+ return getMemRefType (value , options, /* layout=*/ {}, *memSpace);
984985}
985986
986987bool bufferization::detail::defaultIsRepetitiveRegion (
0 commit comments