Skip to content

Commit 128caa1

Browse files
[mlir][bufferization] Refine tensor-buffer compatibility checks (#167705)
Generally, to_tensor and to_buffer already perform sufficient verification. However, there are some unnecessarily strict constraints: * builtin tensor requires its buffer counterpart to always be memref * to_buffer on ranked tensor requires to always return memref These checks are assertions (i.e. preconditions), however, they actually prevent an apparently useful bufferization where builtin tensors could become custom buffers. Lift these assertions, maintaining the verification procedure unchanged, to allow builtin -> custom bufferizations at operation boundary level.
1 parent 8603552 commit 128caa1

File tree

5 files changed

+110
-20
lines changed

5 files changed

+110
-20
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -680,16 +680,6 @@ bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
680680
return false;
681681
}
682682

683-
// bufferization.to_buffer is not allowed to change the rank.
684-
static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
685-
#ifndef NDEBUG
686-
auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType());
687-
assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() ==
688-
rankedTensorType.getRank()) &&
689-
"to_buffer would be invalid: mismatching ranks");
690-
#endif
691-
}
692-
693683
FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
694684
const BufferizationOptions &options,
695685
const BufferizationState &state) {
@@ -708,7 +698,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
708698
FailureOr<BufferLikeType> bufferType = getBufferType(value, options, state);
709699
if (failed(bufferType))
710700
return failure();
711-
ensureToBufferOpIsValid(value, *bufferType);
701+
712702
return bufferization::ToBufferOp::create(rewriter, value.getLoc(),
713703
*bufferType, value)
714704
.getResult();

mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,6 @@ struct BuiltinTensorExternalModel
5454
mlir::LogicalResult verifyCompatibleBufferType(
5555
mlir::Type tensor, BufferLikeType bufferType,
5656
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
57-
assert(isa<TensorType>(tensor) && "expected tensor type");
58-
assert(isa<BaseMemRefType>(bufferType) && "expected memref type");
59-
6057
auto tensorType = cast<ShapedType>(tensor);
6158
auto memrefType = cast<ShapedType>(bufferType);
6259

mlir/test/Dialect/Bufferization/invalid.mlir

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,63 @@ func.func @invalid_manual_deallocation() {
127127
// expected-error @below{{op attribute 'bufferization.manual_deallocation' can be used only on ops that have an allocation and/or free side effect}}
128128
arith.constant {bufferization.manual_deallocation} 0 : index
129129
}
130+
131+
// -----
132+
133+
func.func @invalid_rank_to_buffer(%t: tensor<1x2x3x4xf32>) {
134+
// expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
135+
// expected-error @below{{shapes do not match}}
136+
%b = bufferization.to_buffer %t
137+
: tensor<1x2x3x4xf32> to memref<1x2x3xf32>
138+
return
139+
}
140+
141+
// -----
142+
143+
func.func @invalid_rank_to_tensor(%b: memref<1x2x3xf32>) {
144+
// expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
145+
// expected-error @below{{shapes do not match}}
146+
%t = bufferization.to_tensor %b
147+
: memref<1x2x3xf32> to tensor<1x2x3x4xf32>
148+
return
149+
}
150+
151+
// -----
152+
153+
func.func @invalid_shape_to_buffer(%t: tensor<1x2x3x4xf32>) {
154+
// expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
155+
// expected-error @below{{shapes do not match}}
156+
%b = bufferization.to_buffer %t
157+
: tensor<1x2x3x4xf32> to memref<1x2x4x3xf32>
158+
return
159+
}
160+
161+
// -----
162+
163+
func.func @invalid_shape_to_tensor(%b: memref<1x2x4x3xf32>) {
164+
// expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
165+
// expected-error @below{{shapes do not match}}
166+
%t = bufferization.to_tensor %b
167+
: memref<1x2x4x3xf32> to tensor<1x2x3x4xf32>
168+
return
169+
}
170+
171+
// -----
172+
173+
func.func @invalid_type_to_buffer(%t: tensor<1x2x3x4xf32>) {
174+
// expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
175+
// expected-error @below{{element types do not match}}
176+
%b = bufferization.to_buffer %t
177+
: tensor<1x2x3x4xf32> to memref<1x2x3x4xf16>
178+
return
179+
}
180+
181+
// -----
182+
183+
func.func @invalid_type_to_tensor(%b: memref<1x2x3x4xf16>) {
184+
// expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
185+
// expected-error @below{{element types do not match}}
186+
%t2 = bufferization.to_tensor %b
187+
: memref<1x2x3x4xf16> to tensor<1x2x3x4xf32>
188+
return
189+
}

mlir/test/Dialect/Bufferization/ops.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,40 @@ func.func @test_dealloc_op(%arg0: memref<2xf32>, %arg1: memref<4xi32>,
8383
bufferization.dealloc
8484
return %0#0, %0#1 : i1, i1
8585
}
86+
87+
// CHECK: func.func @test_builtin_custom_builtin_type_conversion
88+
// CHECK-SAME: (%[[t:.*]]: tensor<42xf32>) -> tensor<42xf32>
89+
func.func @test_builtin_custom_builtin_type_conversion(%t: tensor<42xf32>)
90+
-> tensor<42xf32> {
91+
// CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]]
92+
// CHECK-SAME: to !test.test_memref<[42], f32>
93+
%buffer = bufferization.to_buffer %t
94+
: tensor<42xf32> to !test.test_memref<[42], f32>
95+
96+
// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]]
97+
// CHECK-SAME: to tensor<42xf32>
98+
%tensor = bufferization.to_tensor %buffer
99+
: !test.test_memref<[42], f32> to tensor<42xf32>
100+
101+
// CHECK: return %[[tensor]]
102+
return %tensor : tensor<42xf32>
103+
}
104+
105+
// CHECK: func.func @test_custom_builtin_custom_type_conversion
106+
// CHECK-SAME: (%[[t:.*]]: !test.test_tensor<[42], f32>)
107+
// CHECK-SAME: -> !test.test_tensor<[42], f32>
108+
func.func @test_custom_builtin_custom_type_conversion(%t: !test.test_tensor<[42], f32>)
109+
-> !test.test_tensor<[42], f32> {
110+
// CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]]
111+
// CHECK-SAME: to memref<42xf32>
112+
%buffer = bufferization.to_buffer %t
113+
: !test.test_tensor<[42], f32> to memref<42xf32>
114+
115+
// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]]
116+
// CHECK-SAME: to !test.test_tensor<[42], f32>
117+
%tensor = bufferization.to_tensor %buffer
118+
: memref<42xf32> to !test.test_tensor<[42], f32>
119+
120+
// CHECK: return %[[tensor]]
121+
return %tensor : !test.test_tensor<[42], f32>
122+
}

mlir/test/lib/Dialect/Test/TestTypes.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -569,11 +569,17 @@ TestTensorType::getBufferType(
569569
::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType(
570570
::mlir::bufferization::BufferLikeType bufferType,
571571
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {
572-
auto testMemref = dyn_cast<TestMemrefType>(bufferType);
573-
if (!testMemref)
574-
return emitError() << "expected TestMemrefType";
572+
if (auto testMemref = dyn_cast<TestMemrefType>(bufferType)) {
573+
const bool valid = getShape() == testMemref.getShape() &&
574+
getElementType() == testMemref.getElementType();
575+
return mlir::success(valid);
576+
}
577+
578+
if (auto builtinMemref = dyn_cast<MemRefType>(bufferType)) {
579+
const bool valid = getShape() == builtinMemref.getShape() &&
580+
getElementType() == builtinMemref.getElementType();
581+
return mlir::success(valid);
582+
}
575583

576-
const bool valid = getShape() == testMemref.getShape() &&
577-
getElementType() == testMemref.getElementType();
578-
return mlir::success(valid);
584+
return emitError() << "expected MemRefType or TestMemrefType";
579585
}

0 commit comments

Comments
 (0)