Skip to content

Commit c743086

Browse files
[mlir][bufferization] Use Type instead of Value in unknown conversion (llvm#144658)
Generally, bufferization should be able to create a memref from a tensor without needing to know more than just a mlir::Type. Thus, change BufferizationOptions::UnknownTypeConverterFn to accept just a type (mlir::TensorType for now) instead of mlir::Value. Additionally, apply the same rationale to getMemRefType() helper function. Both changes are prerequisites to enable custom types support in one-shot bufferization.
1 parent 23d53cd commit c743086

File tree

4 files changed

+17
-18
lines changed

4 files changed

+17
-18
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,9 @@ struct BufferizationOptions {
265265
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
266266
func::FuncOp, const BufferizationOptions &)>;
267267
/// Tensor -> MemRef type converter.
268-
/// Parameters: Value, memory space, bufferization options
268+
/// Parameters: tensor type, memory space, bufferization options
269269
using UnknownTypeConverterFn = std::function<BaseMemRefType(
270-
Value, Attribute memorySpace, const BufferizationOptions &)>;
270+
TensorType, Attribute memorySpace, const BufferizationOptions &)>;
271271
// Produce a MemorySpace attribute from a tensor type
272272
using DefaultMemorySpaceFn =
273273
std::function<std::optional<Attribute>(TensorType t)>;
@@ -638,7 +638,7 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
638638
return newOp;
639639
}
640640

641-
/// Return a MemRefType to which the type of the given value can be bufferized.
641+
/// Return a MemRefType to which the TensorType can be bufferized.
642642
///
643643
/// If possible, op bufferization implementations should not use this function
644644
/// and instead infer precise memref types for tensor results by themselves.
@@ -650,7 +650,8 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
650650
/// Note: Canonicalization patterns could clean up layout maps and infer more
651651
/// precise layout maps after bufferization. However, many possible
652652
/// canonicalizations are currently not implemented.
653-
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
653+
BaseMemRefType getMemRefType(TensorType tensorType,
654+
const BufferizationOptions &options,
654655
MemRefLayoutAttrInterface layout = {},
655656
Attribute memorySpace = nullptr);
656657

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -320,10 +320,9 @@ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
320320
}
321321
/// Default unknown type converter: Use a fully dynamic layout map.
322322
BaseMemRefType
323-
defaultUnknownTypeConverter(Value value, Attribute memorySpace,
323+
defaultUnknownTypeConverter(TensorType tensorType, Attribute memorySpace,
324324
const BufferizationOptions &options) {
325-
return getMemRefTypeWithFullyDynamicLayout(
326-
llvm::cast<TensorType>(value.getType()), memorySpace);
325+
return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
327326
}
328327

329328
} // namespace
@@ -769,12 +768,10 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
769768
// Bufferization-specific IRMapping support with debugging.
770769
//===----------------------------------------------------------------------===//
771770

772-
BaseMemRefType bufferization::getMemRefType(Value value,
771+
BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
773772
const BufferizationOptions &options,
774773
MemRefLayoutAttrInterface layout,
775774
Attribute memorySpace) {
776-
auto tensorType = llvm::cast<TensorType>(value.getType());
777-
778775
// Case 1: Unranked memref type.
779776
if (auto unrankedTensorType =
780777
llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
@@ -791,7 +788,7 @@ BaseMemRefType bufferization::getMemRefType(Value value,
791788
memorySpace);
792789
}
793790

794-
return options.unknownTypeConverterFn(value, memorySpace, options);
791+
return options.unknownTypeConverterFn(tensorType, memorySpace, options);
795792
}
796793

797794
BaseMemRefType
@@ -926,10 +923,11 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
926923
Value value, const BufferizationOptions &options,
927924
SmallVector<Value> &invocationStack) {
928925
assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
926+
auto tensorType = cast<TensorType>(value.getType());
929927

930928
// No further analysis is possible for a block argument.
931929
if (llvm::isa<BlockArgument>(value))
932-
return bufferization::getMemRefType(value, options);
930+
return bufferization::getMemRefType(tensorType, options);
933931

934932
// Value is an OpResult.
935933
Operation *op = getOwnerOfValue(value);
@@ -951,7 +949,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
951949
if (!memSpace.has_value())
952950
return op->emitError("could not infer memory space");
953951

954-
return getMemRefType(value, options, /*layout=*/{}, *memSpace);
952+
return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
955953
}
956954

957955
bool bufferization::detail::defaultIsRepetitiveRegion(

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,9 @@ struct OneShotBufferizePass
124124
"'unknown-type-conversion'");
125125
return signalPassFailure();
126126
}
127-
opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
127+
opt.unknownTypeConverterFn = [=](TensorType tensorType,
128+
Attribute memorySpace,
128129
const BufferizationOptions &options) {
129-
auto tensorType = cast<TensorType>(value.getType());
130130
if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
131131
return bufferization::getMemRefTypeWithStaticIdentityLayout(
132132
tensorType, memorySpace);

mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,10 @@ mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
218218
OneShotBufferizationOptions options;
219219
options.bufferizeFunctionBoundaries = true;
220220
options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
221-
options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
221+
options.unknownTypeConverterFn = [](TensorType tensorType,
222+
Attribute memorySpace,
222223
const BufferizationOptions &options) {
223-
return getMemRefTypeWithStaticIdentityLayout(
224-
cast<TensorType>(value.getType()), memorySpace);
224+
return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
225225
};
226226
if (analysisOnly) {
227227
options.testAnalysisOnly = true;

0 commit comments

Comments
 (0)