@@ -345,10 +345,9 @@ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
345345}
346346// / Default unknown type converter: Use a fully dynamic layout map.
347347BaseMemRefType
348- defaultUnknownTypeConverter (Value value , Attribute memorySpace,
348+ defaultUnknownTypeConverter (TensorType tensorType , Attribute memorySpace,
349349 const BufferizationOptions &options) {
350- return getMemRefTypeWithFullyDynamicLayout (
351- llvm::cast<TensorType>(value.getType ()), memorySpace);
350+ return getMemRefTypeWithFullyDynamicLayout (tensorType, memorySpace);
352351}
353352
354353} // namespace
@@ -724,7 +723,8 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
724723 if (!memSpace.has_value ())
725724 return op->emitError (" could not infer memory space" );
726725
727- return getMemRefType (value, options, /* layout=*/ {}, *memSpace);
726+ return getMemRefType (cast<TensorType>(value.getType ()), options,
727+ /* layout=*/ {}, *memSpace);
728728}
729729
730730bool bufferization::hasTensorSemantics (Operation *op) {
@@ -797,12 +797,10 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
797797// Bufferization-specific IRMapping support with debugging.
798798// ===----------------------------------------------------------------------===//
799799
800- BaseMemRefType bufferization::getMemRefType (Value value ,
800+ BaseMemRefType bufferization::getMemRefType (TensorType tensorType ,
801801 const BufferizationOptions &options,
802802 MemRefLayoutAttrInterface layout,
803803 Attribute memorySpace) {
804- auto tensorType = llvm::cast<TensorType>(value.getType ());
805-
806804 // Case 1: Unranked memref type.
807805 if (auto unrankedTensorType =
808806 llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
@@ -819,7 +817,7 @@ BaseMemRefType bufferization::getMemRefType(Value value,
819817 memorySpace);
820818 }
821819
822- return options.unknownTypeConverterFn (value , memorySpace, options);
820+ return options.unknownTypeConverterFn (tensorType , memorySpace, options);
823821}
824822
825823BaseMemRefType
@@ -955,10 +953,11 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
955953 const BufferizationState &bufferizationState,
956954 SmallVector<Value> &invocationStack) {
957955 assert (llvm::isa<TensorType>(value.getType ()) && " expected tensor type" );
956+ auto tensorType = cast<TensorType>(value.getType ());
958957
959958 // No further analysis is possible for a block argument.
960959 if (llvm::isa<BlockArgument>(value))
961- return bufferization::getMemRefType (value , options);
960+ return bufferization::getMemRefType (tensorType , options);
962961
963962 // Value is an OpResult.
964963 Operation *op = getOwnerOfValue (value);
@@ -981,7 +980,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
981980 if (!memSpace.has_value ())
982981 return op->emitError (" could not infer memory space" );
983982
984- return getMemRefType (value , options, /* layout=*/ {}, *memSpace);
983+ return getMemRefType (tensorType , options, /* layout=*/ {}, *memSpace);
985984}
986985
987986bool bufferization::detail::defaultIsRepetitiveRegion (
0 commit comments