Skip to content

Commit 2de5728

Browse files
Address code review feedback
* Remove BufferizationState from TensorLikeType::getBufferType() * Rename castToMemRef to asMemRefType (+ add extra docs) * Improve ToTensorOp's docs * Apply minor suggestions
1 parent b05a291 commit 2de5728

File tree

11 files changed

+31
-32
lines changed

11 files changed

+31
-32
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,12 @@ AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);
742742
bool defaultHasTensorSemantics(Operation *op);
743743

744744
/// This is a helper function used when buffer type is guaranteed to be memref.
745-
FailureOr<BaseMemRefType> castToMemRef(FailureOr<BufferLikeType> bufferType);
745+
/// It performs two actions: failure state checking and an explicit llvm::cast<>
746+
/// from the buffer-like type interface to a BaseMemRefType. This allows easier
747+
/// management of differences in C++ types at the API boundaries. Valid buffer
748+
/// type is casted to the memref type. Otherwise, the failure state is
749+
/// propagated i.e. asMemRefType(mlir::failure()) returns mlir::failure().
750+
FailureOr<BaseMemRefType> asMemRefType(FailureOr<BufferLikeType> bufferType);
746751

747752
/// This function is a free-standing helper that relies on
748753
/// bufferization::TensorLikeTypeInterface to verify the types in tensor and

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
404404
let summary = "create a buffer-like type from a tensor-like type";
405405
let description = [{
406406
An operation that creates a tensor from a buffer. The result value is a
407-
tensor-like type whose shape and element type match the buffer-like operand.
407+
tensor-like type that must match the corresponding buffer-like operand as
408+
per TensorLikeType::verifyCompatibleBufferType(). For builtins (TensorType
409+
and BaseMemRefType), this means that shapes and element types match between
410+
the tensor and the buffer.
408411

409412
The opposite of this op is `to_buffer`. Together, these two ops are
410413
useful for source/target materializations when doing type conversions

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def Bufferization_TensorLikeTypeInterface
3131
/*methodName=*/"getBufferType",
3232
/*args=*/(ins
3333
"const ::mlir::bufferization::BufferizationOptions &":$options,
34-
"const ::mlir::bufferization::BufferizationState &":$state,
3534
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError
3635
)
3736
>,

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ struct SelectOpInterface
164164
// buffers have different types, they differ only in their layout map. Cast
165165
// both of them to the most dynamic MemRef type.
166166
if (trueBuffer.getType() != falseBuffer.getType()) {
167-
auto targetType = bufferization::detail::castToMemRef(
167+
auto targetType = bufferization::detail::asMemRefType(
168168
bufferization::getBufferType(selectOp.getResult(), options, state));
169169
if (failed(targetType))
170170
return failure();
@@ -188,10 +188,10 @@ struct SelectOpInterface
188188
auto selectOp = cast<arith::SelectOp>(op);
189189
assert(value == selectOp.getResult() && "invalid value");
190190
auto trueType =
191-
bufferization::detail::castToMemRef(bufferization::getBufferType(
191+
bufferization::detail::asMemRefType(bufferization::getBufferType(
192192
selectOp.getTrueValue(), options, state, invocationStack));
193193
auto falseType =
194-
bufferization::detail::castToMemRef(bufferization::getBufferType(
194+
bufferization::detail::asMemRefType(bufferization::getBufferType(
195195
selectOp.getFalseValue(), options, state, invocationStack));
196196
if (failed(trueType) || failed(falseType))
197197
return failure();

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
214214
if (copy)
215215
return allocTensorOp.getResult();
216216
auto copyBufferType =
217-
detail::castToMemRef(getBufferType(tensor, options, state));
217+
detail::asMemRefType(getBufferType(tensor, options, state));
218218
if (failed(copyBufferType))
219219
return failure();
220220
std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
@@ -720,8 +720,9 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
720720
return bufferizableOp.getBufferType(value, options, state, invocationStack);
721721

722722
// Op is not bufferizable.
723-
return cast<TensorLikeType>(value.getType())
724-
.getBufferType(options, state, [&]() { return op->emitError(); });
723+
return cast<TensorLikeType>(value.getType()).getBufferType(options, [&]() {
724+
return op->emitError();
725+
});
725726
}
726727

727728
bool bufferization::hasTensorSemantics(Operation *op) {
@@ -965,7 +966,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
965966
// If the OpResult has an equivalent OpOperand, both OpResult and
966967
// OpOperand bufferize to the exact same buffer type.
967968
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
968-
return castToMemRef(getBufferType(equivalentOperand, options,
969+
return asMemRefType(getBufferType(equivalentOperand, options,
969970
bufferizationState, invocationStack));
970971
}
971972

@@ -1043,19 +1044,15 @@ bool bufferization::detail::defaultHasTensorSemantics(Operation *op) {
10431044
}
10441045

10451046
FailureOr<BaseMemRefType>
1046-
bufferization::detail::castToMemRef(FailureOr<BufferLikeType> bufferType) {
1047+
bufferization::detail::asMemRefType(FailureOr<BufferLikeType> bufferType) {
10471048
if (failed(bufferType))
10481049
return failure();
1049-
assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
10501050
return cast<BaseMemRefType>(*bufferType);
10511051
}
10521052

10531053
bool bufferization::detail::typesMatchAfterBufferization(Operation &op,
10541054
Value tensor,
10551055
Value buffer) {
1056-
assert(isa<TensorLikeType>(tensor.getType()) && "expected TensorLikeType");
1057-
assert(isa<BufferLikeType>(buffer.getType()) && "expected BufferLikeType");
1058-
10591056
return mlir::succeeded(
10601057
cast<TensorLikeType>(tensor.getType())
10611058
.verifyCompatibleBufferType(cast<BufferLikeType>(buffer.getType()),

mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,8 @@ struct BuiltinTensorExternalModel
6060
Tensor> {
6161
llvm::FailureOr<BufferLikeType> getBufferType(
6262
mlir::Type tensor, const BufferizationOptions &options,
63-
const BufferizationState &state,
6463
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
6564
auto tensorType = cast<TensorType>(tensor);
66-
// Fall back to tensor -> memref conversion.
6765
auto memSpace = options.defaultMemorySpaceFn(tensorType);
6866
if (!memSpace.has_value())
6967
return emitError() << "could not infer memory space";
@@ -75,7 +73,6 @@ struct BuiltinTensorExternalModel
7573
mlir::LogicalResult verifyCompatibleBufferType(
7674
mlir::Type tensor, BufferLikeType bufferType,
7775
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
78-
// Fall back to tensor, memref checking.
7976
assert(isa<TensorType>(tensor) && "expected tensor type");
8077
assert(isa<BaseMemRefType>(bufferType) && "expected memref type");
8178

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
234234
memorySpace = *getMemorySpace();
235235
} else if (getCopy()) {
236236
auto copyBufferType =
237-
bufferization::detail::castToMemRef(bufferization::getBufferType(
237+
bufferization::detail::asMemRefType(bufferization::getBufferType(
238238
getCopy(), options, state, invocationStack));
239239
if (failed(copyBufferType))
240240
return failure();

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ struct IfOpInterface
293293
thenBufferType = cast<BaseMemRefType>(thenValue.getType());
294294
} else {
295295
auto maybeBufferType =
296-
bufferization::detail::castToMemRef(bufferization::getBufferType(
296+
bufferization::detail::asMemRefType(bufferization::getBufferType(
297297
thenValue, options, state, invocationStack));
298298
if (failed(maybeBufferType))
299299
return failure();
@@ -304,7 +304,7 @@ struct IfOpInterface
304304
elseBufferType = cast<BaseMemRefType>(elseValue.getType());
305305
} else {
306306
auto maybeBufferType =
307-
bufferization::detail::castToMemRef(bufferization::getBufferType(
307+
bufferization::detail::asMemRefType(bufferization::getBufferType(
308308
elseValue, options, state, invocationStack));
309309
if (failed(maybeBufferType))
310310
return failure();
@@ -408,7 +408,7 @@ struct IndexSwitchOpInterface
408408
return bufferType;
409409
auto maybeBufferType = bufferization::getBufferType(
410410
yieldedValue, options, state, invocationStack);
411-
return bufferization::detail::castToMemRef(maybeBufferType);
411+
return bufferization::detail::asMemRefType(maybeBufferType);
412412
};
413413

414414
// Compute buffer type of the default case.
@@ -527,7 +527,7 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
527527
const BufferizationOptions &options, const BufferizationState &state,
528528
SmallVector<Value> &invocationStack) {
529529
// Determine the buffer type of the init_arg.
530-
auto initArgBufferType = bufferization::detail::castToMemRef(
530+
auto initArgBufferType = bufferization::detail::asMemRefType(
531531
bufferization::getBufferType(initArg, options, state, invocationStack));
532532
if (failed(initArgBufferType))
533533
return failure();
@@ -555,7 +555,7 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
555555
// Note: This typically triggers a recursive call for the buffer type of
556556
// the iter_arg.
557557
auto maybeBufferType =
558-
bufferization::detail::castToMemRef(bufferization::getBufferType(
558+
bufferization::detail::asMemRefType(bufferization::getBufferType(
559559
yieldedValue, options, state, invocationStack));
560560
if (failed(maybeBufferType))
561561
return failure();
@@ -1083,7 +1083,7 @@ struct WhileOpInterface
10831083
// scf.condition was already bufferized.
10841084
return cast<BaseMemRefType>(conditionYieldedVal.getType());
10851085
}
1086-
return bufferization::detail::castToMemRef(bufferization::getBufferType(
1086+
return bufferization::detail::asMemRefType(bufferization::getBufferType(
10871087
conditionYieldedVal, options, state, invocationStack));
10881088
}
10891089

@@ -1312,13 +1312,13 @@ struct ForallOpInterface
13121312
if (auto bbArg = dyn_cast<BlockArgument>(value))
13131313
// A tensor block argument has the same bufferized type as the
13141314
// corresponding output operand.
1315-
return bufferization::detail::castToMemRef(
1315+
return bufferization::detail::asMemRefType(
13161316
bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(),
13171317
options, state, invocationStack));
13181318

13191319
// The bufferized result type is the same as the bufferized type of the
13201320
// corresponding output operand.
1321-
return bufferization::detail::castToMemRef(bufferization::getBufferType(
1321+
return bufferization::detail::asMemRefType(bufferization::getBufferType(
13221322
forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
13231323
state, invocationStack));
13241324
}

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ struct CastOpInterface
5555
SmallVector<Value> &invocationStack) const {
5656
auto castOp = cast<tensor::CastOp>(op);
5757
auto maybeSrcBufferType =
58-
bufferization::detail::castToMemRef(bufferization::getBufferType(
58+
bufferization::detail::asMemRefType(bufferization::getBufferType(
5959
castOp.getSource(), options, state, invocationStack));
6060
if (failed(maybeSrcBufferType))
6161
return failure();
@@ -501,7 +501,7 @@ struct FromElementsOpInterface
501501
/*copy=*/false);
502502
if (failed(tensorAlloc))
503503
return failure();
504-
FailureOr<BaseMemRefType> memrefType = bufferization::detail::castToMemRef(
504+
FailureOr<BaseMemRefType> memrefType = bufferization::detail::asMemRefType(
505505
bufferization::getBufferType(*tensorAlloc, options, state));
506506
if (failed(memrefType))
507507
return failure();
@@ -760,7 +760,7 @@ struct PadOpInterface
760760
// Infer memory space from the source tensor.
761761
auto padOp = cast<tensor::PadOp>(op);
762762
auto maybeSrcBufferType =
763-
bufferization::detail::castToMemRef(bufferization::getBufferType(
763+
bufferization::detail::asMemRefType(bufferization::getBufferType(
764764
padOp.getSource(), options, state, invocationStack));
765765
if (failed(maybeSrcBufferType))
766766
return failure();

mlir/test/lib/Dialect/Test/TestTypeDefs.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,6 @@ def TestTensorType : Test_Type<"TestTensor",
432432
// TensorLikeTypeInterface:
433433
::mlir::FailureOr<::mlir::bufferization::BufferLikeType>
434434
getBufferType(const ::mlir::bufferization::BufferizationOptions& options,
435-
const ::mlir::bufferization::BufferizationState& state,
436435
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError);
437436

438437
::mlir::LogicalResult verifyCompatibleBufferType(

0 commit comments

Comments
 (0)