Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -598,13 +598,14 @@ class BufferizationState {
FailureOr<Value>
allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
const BufferizationOptions &options,
bool copy = true);
BufferizationState &state, bool copy = true);

/// Lookup the buffer for the given value. If the value was not bufferized
/// yet, wrap it in a ToBufferOp. Otherwise, it is the result of a ToTensorOp,
/// from which the memref operand is returned.
FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options);
const BufferizationOptions &options,
BufferizationState &state);

/// Return the buffer type for a given Value (tensor) after bufferization
/// without bufferizing any IR.
Expand All @@ -615,7 +616,8 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
///
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
FailureOr<BaseMemRefType> getBufferType(Value value,
const BufferizationOptions &options);
const BufferizationOptions &options,
BufferizationState &state);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be a const reference? Here and in all the other places that this PR touches?


/// Return the buffer type for a given Value (tensor) after bufferization
/// without bufferizing any IR. This function (and not the other overload
Expand All @@ -629,6 +631,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
FailureOr<BaseMemRefType> getBufferType(Value value,
const BufferizationOptions &options,
BufferizationState &state,
SmallVector<Value> &invocationStack);

/// Return "true" if the given op has tensor semantics and should be bufferized.
Expand Down Expand Up @@ -709,6 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
/// places.
FailureOr<BaseMemRefType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
BufferizationState &state,
SmallVector<Value> &invocationStack);

/// This is the default implementation of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,14 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"::llvm::LogicalResult",
/*methodName=*/"resolveConflicts",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
"const ::mlir::bufferization::AnalysisState &":$state),
"const ::mlir::bufferization::AnalysisState &":$analysisState,
"::mlir::bufferization::BufferizationState &":$bufferizationState),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Const possible here?

/*methodBody=*/"",
/*defaultImplementation=*/[{
auto bufferizableOp =
::llvm::cast<BufferizableOpInterface>($_op.getOperation());
return bufferizableOp.resolveTensorOpOperandConflicts(
rewriter, state);
rewriter, analysisState, bufferizationState);
}]
>,
InterfaceMethod<
Expand Down Expand Up @@ -523,6 +524,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*methodName=*/"getBufferType",
/*args=*/(ins "::mlir::Value":$value,
"const ::mlir::bufferization::BufferizationOptions &":$options,
"::mlir::bufferization::BufferizationState &":$state,
"::llvm::SmallVector<::mlir::Value> &":$invocationStack),
/*methodBody=*/"",
/*defaultImplementation=*/[{
Expand All @@ -531,7 +533,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
assert(invocationStack.back() == value &&
"inconsistant invocation stack");
return ::mlir::bufferization::detail::defaultGetBufferType(
value, options, invocationStack);
value, options, state, invocationStack);
}]
>,
InterfaceMethod<
Expand Down Expand Up @@ -616,7 +618,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/// form of `bufferization.alloc_tensor` ops.
::llvm::LogicalResult resolveTensorOpOperandConflicts(
::mlir::RewriterBase &rewriter,
const ::mlir::bufferization::AnalysisState &state);
const ::mlir::bufferization::AnalysisState &analysisState,
::mlir::bufferization::BufferizationState &bufferizationState);

/// Return `true` if the given OpOperand creates an alias but does neither
/// read nor write. This implies that `bufferizesToMemoryRead` and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",

FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
BufferizationState &state,
SmallVector<Value> &invocationStack);

RankedTensorType getType() {
Expand Down Expand Up @@ -471,7 +472,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [

FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) {
BufferizationState &state, SmallVector<Value> &invocationStack) {
return ::llvm::cast<BaseMemRefType>(getMemref().getType());
}
}];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel

FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
BufferizationState &state,
SmallVector<Value> &invocationStack) const {
// Note: The user may want to override this function for OpResults in
// case the bufferized result type is different from the bufferized type of
// the aliasing OpOperand (if any).
if (isa<OpResult>(value))
return bufferization::detail::defaultGetBufferType(value, options,
return bufferization::detail::defaultGetBufferType(value, options, state,
invocationStack);

// Compute the buffer type of the block argument by computing the bufferized
Expand All @@ -65,7 +66,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
callerType = memrefType;
} else {
FailureOr<BaseMemRefType> maybeCallerType =
bufferization::getBufferType(opOperand->get(), options,
bufferization::getBufferType(opOperand->get(), options, state,
invocationStack);
if (failed(maybeCallerType))
return failure();
Expand All @@ -81,9 +82,9 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
if (bufferType == callerType)
continue;

// If the computed buffer type does not match the computed buffer type
// of the earlier forwarded operands, fall back to a buffer type with a
// fully dynamic layout map.
// If the computed buffer type does not match the computed buffer type
// of the earlier forwarded operands, fall back to a buffer type with a
// fully dynamic layout map.
#ifndef NDEBUG
if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
assert(bufferType.hasRank() && callerType.hasRank() &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
/// `BufferizableOpInterface`. The buffer types of tensor block arguments are
/// computed with `BufferizableOpIntercace::getBufferType`.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
const BufferizationOptions &options);
const BufferizationOptions &options,
BufferizationState &state);

} // namespace bufferization
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,15 @@ void hoistBuffersFromLoops(Operation *op);
/// additional buffer allocations.
LogicalResult insertTensorCopies(Operation *op,
const OneShotBufferizationOptions &options,
BufferizationState &bufferizationState,
BufferizationStatistics *statistics = nullptr);

/// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
/// After applying this transform, the IR can be bufferized without inserting
/// additional buffer allocations.
LogicalResult insertTensorCopies(Operation *op, const AnalysisState &state);
LogicalResult insertTensorCopies(Operation *op,
const AnalysisState &analysisState,
BufferizationState &bufferizationState);

/// Populate patterns to lower tensor.empty ops to bufferization.alloc_tensor
/// ops.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ struct IndexCastOpInterface
auto castOp = cast<arith::IndexCastOp>(op);
auto resultTensorType = cast<TensorType>(castOp.getType());

FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
FailureOr<Value> source =
getBuffer(rewriter, castOp.getIn(), options, state);
if (failed(source))
return failure();
auto sourceType = cast<BaseMemRefType>(source->getType());
Expand Down Expand Up @@ -151,9 +152,9 @@ struct SelectOpInterface
// the moment (one for each tensor). When copying the op result, only one
// copy would be needed.
FailureOr<Value> maybeTrueBuffer =
getBuffer(rewriter, selectOp.getTrueValue(), options);
getBuffer(rewriter, selectOp.getTrueValue(), options, state);
FailureOr<Value> maybeFalseBuffer =
getBuffer(rewriter, selectOp.getFalseValue(), options);
getBuffer(rewriter, selectOp.getFalseValue(), options, state);
if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
return failure();
Value trueBuffer = *maybeTrueBuffer;
Expand All @@ -164,7 +165,7 @@ struct SelectOpInterface
// both of them to the most dynamic MemRef type.
if (trueBuffer.getType() != falseBuffer.getType()) {
auto targetType =
bufferization::getBufferType(selectOp.getResult(), options);
bufferization::getBufferType(selectOp.getResult(), options, state);
if (failed(targetType))
return failure();
if (trueBuffer.getType() != *targetType)
Expand All @@ -182,13 +183,14 @@ struct SelectOpInterface

FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto selectOp = cast<arith::SelectOp>(op);
assert(value == selectOp.getResult() && "invalid value");
auto trueType = bufferization::getBufferType(selectOp.getTrueValue(),
options, invocationStack);
auto falseType = bufferization::getBufferType(selectOp.getFalseValue(),
options, invocationStack);
auto trueType = bufferization::getBufferType(
selectOp.getTrueValue(), options, state, invocationStack);
auto falseType = bufferization::getBufferType(
selectOp.getFalseValue(), options, state, invocationStack);
if (failed(trueType) || failed(falseType))
return failure();
if (*trueType == *falseType)
Expand Down
51 changes: 30 additions & 21 deletions mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ Operation *bufferization::getOwnerOfValue(Value value) {
/// allocated.
FailureOr<Value> bufferization::allocateTensorForShapedValue(
OpBuilder &b, Location loc, Value shapedValue,
const BufferizationOptions &options, bool copy) {
const BufferizationOptions &options, BufferizationState &state, bool copy) {
Value tensor;
if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
tensor = shapedValue;
Expand Down Expand Up @@ -210,7 +210,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
// Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
if (copy)
return allocTensorOp.getResult();
FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
FailureOr<BaseMemRefType> copyBufferType =
getBufferType(tensor, options, state);
if (failed(copyBufferType))
return failure();
std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
Expand All @@ -222,7 +223,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
}

LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
RewriterBase &rewriter, const AnalysisState &state) {
RewriterBase &rewriter, const AnalysisState &analysisState,
BufferizationState &bufferizationState) {
OpBuilder::InsertionGuard g(rewriter);
Operation *op = getOperation();
SmallVector<OpOperand *> outOfPlaceOpOperands;
Expand All @@ -235,16 +237,18 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
Type operandType = opOperand.get().getType();
if (!llvm::isa<TensorType>(operandType))
continue;
if (state.isInPlace(opOperand))
if (analysisState.isInPlace(opOperand))
continue;
if (llvm::isa<UnrankedTensorType>(operandType))
return op->emitError("copying of unranked tensors is not implemented");

AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
AliasingValueList aliasingValues =
analysisState.getAliasingValues(opOperand);
if (aliasingValues.getNumAliases() == 1 &&
isa<OpResult>(aliasingValues.getAliases()[0].value) &&
!state.bufferizesToMemoryWrite(opOperand) &&
state.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
!analysisState.bufferizesToMemoryWrite(opOperand) &&
analysisState
.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
.getNumAliases() == 1 &&
!isa<UnrankedTensorType>(
aliasingValues.getAliases()[0].value.getType())) {
Expand All @@ -256,12 +260,12 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
// cannot be copied at the moment).
Value value = aliasingValues.getAliases()[0].value;
outOfPlaceValues.push_back(value);
if (!state.canOmitTensorCopy(opOperand))
if (!analysisState.canOmitTensorCopy(opOperand))
copiedOpValues.insert(value);
} else {
// In all other cases, make a copy of the OpOperand.
outOfPlaceOpOperands.push_back(&opOperand);
if (!state.canOmitTensorCopy(opOperand))
if (!analysisState.canOmitTensorCopy(opOperand))
copiedOpOperands.insert(&opOperand);
}
}
Expand All @@ -270,8 +274,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
rewriter.setInsertionPoint(op);
for (OpOperand *opOperand : outOfPlaceOpOperands) {
FailureOr<Value> copy = allocateTensorForShapedValue(
rewriter, op->getLoc(), opOperand->get(), state.getOptions(),
copiedOpOperands.contains(opOperand));
rewriter, op->getLoc(), opOperand->get(), analysisState.getOptions(),
bufferizationState, copiedOpOperands.contains(opOperand));
if (failed(copy))
return failure();
rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
Expand All @@ -281,8 +285,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
rewriter.setInsertionPointAfter(op);
for (Value value : outOfPlaceValues) {
FailureOr<Value> copy = allocateTensorForShapedValue(
rewriter, op->getLoc(), value, state.getOptions(),
copiedOpValues.count(value));
rewriter, op->getLoc(), value, analysisState.getOptions(),
bufferizationState, copiedOpValues.count(value));
if (failed(copy))
return failure();
SmallVector<OpOperand *> uses = llvm::to_vector(
Expand Down Expand Up @@ -665,7 +669,8 @@ static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
}

FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options) {
const BufferizationOptions &options,
BufferizationState &state) {
#ifndef NDEBUG
auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
assert(tensorType && "unexpected non-tensor type");
Expand All @@ -678,7 +683,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
// Insert to_buffer op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
if (failed(memrefType))
return failure();
ensureToBufferOpIsValid(value, *memrefType);
Expand All @@ -689,14 +694,16 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,

/// Return the buffer type for a given Value (tensor) after bufferization.
FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options) {
bufferization::getBufferType(Value value, const BufferizationOptions &options,
BufferizationState &state) {
SmallVector<Value> invocationStack;
return getBufferType(value, options, invocationStack);
return getBufferType(value, options, state, invocationStack);
}

/// Return the buffer type for a given Value (tensor) after bufferization.
FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options,
BufferizationState &state,
SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) &&
"unexpected non-tensor type");
Expand All @@ -708,7 +715,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
Operation *op = getOwnerOfValue(value);
auto bufferizableOp = options.dynCastBufferizableOp(op);
if (bufferizableOp)
return bufferizableOp.getBufferType(value, options, invocationStack);
return bufferizableOp.getBufferType(value, options, state, invocationStack);

// Op is not bufferizable.
auto memSpace =
Expand Down Expand Up @@ -944,6 +951,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(

FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
Value value, const BufferizationOptions &options,
BufferizationState &bufferizationState,
SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");

Expand All @@ -954,14 +962,15 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
// Value is an OpResult.
Operation *op = getOwnerOfValue(value);
auto opResult = llvm::cast<OpResult>(value);
AnalysisState state(options);
AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
AnalysisState analysisState(options);
AliasingOpOperandList aliases = analysisState.getAliasingOpOperands(opResult);
if (aliases.getNumAliases() > 0 &&
aliases.getAliases()[0].relation == BufferRelation::Equivalent) {
// If the OpResult has an equivalent OpOperand, both OpResult and
// OpOperand bufferize to the exact same buffer type.
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
return getBufferType(equivalentOperand, options, invocationStack);
return getBufferType(equivalentOperand, options, bufferizationState,
invocationStack);
}

// If we do not know the memory space and there is no default memory space,
Expand Down
Loading