Skip to content

Commit 7ef1183

Browse files
[NFC] Remove type-inferring builder of ToTensorOp
The builder is ambiguous given customizable tensor-like -> buffer-like conversion and is thus removed. The places where reverse bufferization has to happen rely on the pre-existing functionality.
1 parent be85978 commit 7ef1183

File tree

12 files changed

+27
-58
lines changed

12 files changed

+27
-58
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -748,10 +748,6 @@ FailureOr<BaseMemRefType> castToMemRef(FailureOr<BufferLikeType> bufferType);
748748
/// bufferization::ConversionInterface to verify the types in tensor and buffer
749749
/// worlds match.
750750
bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer);
751-
752-
/// This function is a free-standing helper that relies on
753-
/// bufferization::ConversionInterface to perform the conversion.
754-
Type getTensorFromBuffer(Type buffer);
755751
} // namespace detail
756752

757753
} // namespace bufferization

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,6 @@ struct ConversionDialectInterface
3535
virtual LogicalResult typesMatch(
3636
TensorLikeType tensor, BufferLikeType buffer,
3737
function_ref<InFlightDiagnostic(const Twine &)> emitError) const = 0;
38-
39-
/// Hook to customize buffer-like -> tensor-like conversion, which is the
40-
/// opposite of bufferization.
41-
virtual TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const = 0;
4238
};
4339

4440
/// Interface collection for conversion between tensor-like and buffer-like
@@ -60,10 +56,6 @@ struct ConversionInterface
6056
LogicalResult
6157
typesMatch(TensorLikeType tensor, BufferLikeType buffer,
6258
function_ref<InFlightDiagnostic(const Twine &)> emitError) const;
63-
64-
/// Dispatches to ConversionDialectInterface::getTensorFromBuffer() of the
65-
/// dialect associated with the value type.
66-
TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const;
6759
};
6860

6961
} // namespace bufferization

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -490,13 +490,6 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
490490
`:` type($buffer) `to` type($result)
491491
}];
492492

493-
let builders = [
494-
OpBuilder<(ins "Value":$buffer, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
495-
auto rtt = bufferization::detail::getTensorFromBuffer(buffer.getType());
496-
build($_builder, $_state, rtt, buffer, restrict, writeable);
497-
}]>
498-
];
499-
500493
let hasCanonicalizer = 1;
501494
let hasFolder = 1;
502495
}

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,9 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
172172
if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
173173
tensor = shapedValue;
174174
} else if (llvm::isa<MemRefType>(shapedValue.getType())) {
175-
tensor = b.create<ToTensorOp>(loc, shapedValue);
175+
tensor = b.create<ToTensorOp>(
176+
loc, memref::getTensorTypeFromMemRefType(shapedValue.getType()),
177+
shapedValue);
176178
} else if (llvm::isa<UnrankedTensorType>(shapedValue.getType()) ||
177179
llvm::isa<UnrankedMemRefType>(shapedValue.getType())) {
178180
return getOwnerOfValue(shapedValue)
@@ -1064,9 +1066,3 @@ bool bufferization::detail::typesMatchAfterBufferization(Operation &op,
10641066
cast<BufferLikeType>(buffer.getType()),
10651067
[&](const Twine &message) { return op.emitError(message); }));
10661068
}
1067-
1068-
Type bufferization::detail::getTensorFromBuffer(Type buffer) {
1069-
assert(isa<BufferLikeType>(buffer) && "expected BufferLikeType");
1070-
bufferization::ConversionInterface iface(buffer.getContext());
1071-
return iface.getTensorFromBuffer(cast<BufferLikeType>(buffer));
1072-
}

mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,5 @@ LogicalResult ConversionInterface::typesMatch(
5454
return success();
5555
}
5656

57-
TensorLikeType
58-
ConversionInterface::getTensorFromBuffer(BufferLikeType buffer) const {
59-
Dialect *dialect = &buffer.getDialect();
60-
if (const ConversionDialectInterface *iface = getInterfaceFor(dialect))
61-
return iface->getTensorFromBuffer(buffer);
62-
63-
return cast<TensorLikeType>(memref::getTensorTypeFromMemRefType(buffer));
64-
}
65-
6657
} // namespace bufferization
6758
} // namespace mlir

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -643,8 +643,9 @@ Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
643643
assert(getRestrict() &&
644644
"expected that ops with memrefs dest have 'restrict'");
645645
setRestrict(false);
646-
return builder.create<ToTensorOp>(loc, getDest(), /*restrict=*/true,
647-
getWritable());
646+
return builder.create<ToTensorOp>(
647+
loc, memref::getTensorTypeFromMemRefType(getDest().getType()), getDest(),
648+
/*restrict=*/true, getWritable());
648649
}
649650

650651
bool MaterializeInDestinationOp::isEquivalentSubset(

mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,8 @@ Value linalg::bufferizeToAllocation(
252252
// Create bufferization.to_tensor with "restrict" and "writable". The returned
253253
// tensor is a new buffer allocation, so it does not alias with any buffer.
254254
Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
255-
loc, alloc, /*restrict=*/true, /*writable=*/true);
255+
loc, padOp.getResult().getType(), alloc, /*restrict=*/true,
256+
/*writable=*/true);
256257
rewriter.replaceOp(padOp, toTensorOp);
257258
return alloc;
258259
}
@@ -340,7 +341,8 @@ Value linalg::bufferizeToAllocation(
340341
// Create bufferization.to_tensor with "restrict" and "writable". The returned
341342
// tensor is a new buffer allocation, so it does not alias with any buffer.
342343
Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
343-
loc, alloc, /*restrict=*/true, /*writable=*/true);
344+
loc, allocTensorOp.getResult().getType(), alloc, /*restrict=*/true,
345+
/*writable=*/true);
344346
rewriter.replaceOp(allocTensorOp, toTensorOp);
345347
return alloc;
346348
}
@@ -567,7 +569,8 @@ Value linalg::bufferizeToAllocation(
567569
createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);
568570
}
569571
rewriter.modifyOpInPlace(op, [&]() {
570-
auto toTensorOp = rewriter.create<ToTensorOp>(op->getLoc(), alloc);
572+
auto toTensorOp = rewriter.create<ToTensorOp>(
573+
op->getLoc(), operand->get().getType(), alloc);
571574
operand->set(toTensorOp);
572575
if (options.bufferizeDestinationOnly) {
573576
rewriter.modifyOpInPlace(toTensorOp, [&]() {

mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ struct AssumingOpInterface
6767
for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
6868
if (isa<TensorType>(it.value())) {
6969
newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
70-
assumingOp.getLoc(), newOp->getResult(it.index())));
70+
assumingOp.getLoc(), it.value(), newOp->getResult(it.index())));
7171
} else {
7272
newResults.push_back(newOp->getResult(it.index()));
7373
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
651651
tokens.clear();
652652

653653
// Done.
654-
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY);
654+
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, y.getType(), memY);
655655
return success();
656656
}
657657

@@ -752,7 +752,7 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
752752
tokens.clear();
753753

754754
// Done.
755-
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
755+
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, c.getType(), bufC);
756756
return success();
757757
}
758758

@@ -925,9 +925,12 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
925925
tokens.clear();
926926

927927
// Done.
928-
Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
929-
Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
930-
Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
928+
Value vt = rewriter.create<bufferization::ToTensorOp>(
929+
loc, memref::getTensorTypeFromMemRefType(valH.getType()), valH);
930+
Value rt = rewriter.create<bufferization::ToTensorOp>(
931+
loc, memref::getTensorTypeFromMemRefType(rowH.getType()), rowH);
932+
Value ct = rewriter.create<bufferization::ToTensorOp>(
933+
loc, memref::getTensorTypeFromMemRefType(colH.getType()), colH);
931934
rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct},
932935
vt);
933936
return success();
@@ -1043,7 +1046,7 @@ static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter,
10431046
tokens.clear();
10441047

10451048
// Done.
1046-
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
1049+
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, C.getType(), bufC);
10471050
return success();
10481051
}
10491052

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1471,7 +1471,8 @@ struct SparseDisassembleOpConverter
14711471
// Converts MemRefs back to Tensors.
14721472
SmallVector<Value> retValues = llvm::to_vector(
14731473
llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value {
1474-
return rewriter.create<bufferization::ToTensorOp>(loc, v);
1474+
return rewriter.create<bufferization::ToTensorOp>(
1475+
loc, memref::getTensorTypeFromMemRefType(v.getType()), v);
14751476
}));
14761477
// Appends the actual memory length used in each buffer returned.
14771478
retValues.append(retLen.begin(), retLen.end());

0 commit comments

Comments
 (0)