@@ -165,7 +165,8 @@ Operation *bufferization::getOwnerOfValue(Value value) {
165165// / allocated.
166166FailureOr<Value> bufferization::allocateTensorForShapedValue (
167167 OpBuilder &b, Location loc, Value shapedValue,
168- const BufferizationOptions &options, bool copy) {
168+ const BufferizationOptions &options, const BufferizationState &state,
169+ bool copy) {
169170 Value tensor;
170171 if (llvm::isa<RankedTensorType>(shapedValue.getType ())) {
171172 tensor = shapedValue;
@@ -210,7 +211,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
210211 // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
211212 if (copy)
212213 return allocTensorOp.getResult ();
213- FailureOr<BaseMemRefType> copyBufferType = getBufferType (tensor, options);
214+ FailureOr<BaseMemRefType> copyBufferType =
215+ getBufferType (tensor, options, state);
214216 if (failed (copyBufferType))
215217 return failure ();
216218 std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace ();
@@ -222,7 +224,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
222224}
223225
224226LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts (
225- RewriterBase &rewriter, const AnalysisState &state) {
227+ RewriterBase &rewriter, const AnalysisState &analysisState,
228+ const BufferizationState &bufferizationState) {
226229 OpBuilder::InsertionGuard g (rewriter);
227230 Operation *op = getOperation ();
228231 SmallVector<OpOperand *> outOfPlaceOpOperands;
@@ -235,16 +238,18 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
235238 Type operandType = opOperand.get ().getType ();
236239 if (!llvm::isa<TensorType>(operandType))
237240 continue ;
238- if (state .isInPlace (opOperand))
241+ if (analysisState .isInPlace (opOperand))
239242 continue ;
240243 if (llvm::isa<UnrankedTensorType>(operandType))
241244 return op->emitError (" copying of unranked tensors is not implemented" );
242245
243- AliasingValueList aliasingValues = state.getAliasingValues (opOperand);
246+ AliasingValueList aliasingValues =
247+ analysisState.getAliasingValues (opOperand);
244248 if (aliasingValues.getNumAliases () == 1 &&
245249 isa<OpResult>(aliasingValues.getAliases ()[0 ].value ) &&
246- !state.bufferizesToMemoryWrite (opOperand) &&
247- state.getAliasingOpOperands (aliasingValues.getAliases ()[0 ].value )
250+ !analysisState.bufferizesToMemoryWrite (opOperand) &&
251+ analysisState
252+ .getAliasingOpOperands (aliasingValues.getAliases ()[0 ].value )
248253 .getNumAliases () == 1 &&
249254 !isa<UnrankedTensorType>(
250255 aliasingValues.getAliases ()[0 ].value .getType ())) {
@@ -256,12 +261,12 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
256261 // cannot be copied at the moment).
257262 Value value = aliasingValues.getAliases ()[0 ].value ;
258263 outOfPlaceValues.push_back (value);
259- if (!state .canOmitTensorCopy (opOperand))
264+ if (!analysisState .canOmitTensorCopy (opOperand))
260265 copiedOpValues.insert (value);
261266 } else {
262267 // In all other cases, make a copy of the OpOperand.
263268 outOfPlaceOpOperands.push_back (&opOperand);
264- if (!state .canOmitTensorCopy (opOperand))
269+ if (!analysisState .canOmitTensorCopy (opOperand))
265270 copiedOpOperands.insert (&opOperand);
266271 }
267272 }
@@ -270,8 +275,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
270275 rewriter.setInsertionPoint (op);
271276 for (OpOperand *opOperand : outOfPlaceOpOperands) {
272277 FailureOr<Value> copy = allocateTensorForShapedValue (
273- rewriter, op->getLoc (), opOperand->get (), state .getOptions (),
274- copiedOpOperands.contains (opOperand));
278+ rewriter, op->getLoc (), opOperand->get (), analysisState .getOptions (),
279+ bufferizationState, copiedOpOperands.contains (opOperand));
275280 if (failed (copy))
276281 return failure ();
277282 rewriter.modifyOpInPlace (op, [&]() { opOperand->set (*copy); });
@@ -281,8 +286,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
281286 rewriter.setInsertionPointAfter (op);
282287 for (Value value : outOfPlaceValues) {
283288 FailureOr<Value> copy = allocateTensorForShapedValue (
284- rewriter, op->getLoc (), value, state .getOptions (),
285- copiedOpValues.count (value));
289+ rewriter, op->getLoc (), value, analysisState .getOptions (),
290+ bufferizationState, copiedOpValues.count (value));
286291 if (failed (copy))
287292 return failure ();
288293 SmallVector<OpOperand *> uses = llvm::to_vector (
@@ -665,7 +670,8 @@ static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
665670}
666671
667672FailureOr<Value> bufferization::getBuffer (RewriterBase &rewriter, Value value,
668- const BufferizationOptions &options) {
673+ const BufferizationOptions &options,
674+ const BufferizationState &state) {
669675#ifndef NDEBUG
670676 auto tensorType = llvm::dyn_cast<TensorType>(value.getType ());
671677 assert (tensorType && " unexpected non-tensor type" );
@@ -678,7 +684,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
678684 // Insert to_buffer op.
679685 OpBuilder::InsertionGuard g (rewriter);
680686 setInsertionPointAfter (rewriter, value);
681- FailureOr<BaseMemRefType> memrefType = getBufferType (value, options);
687+ FailureOr<BaseMemRefType> memrefType = getBufferType (value, options, state );
682688 if (failed (memrefType))
683689 return failure ();
684690 ensureToBufferOpIsValid (value, *memrefType);
@@ -689,14 +695,16 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
689695
690696// / Return the buffer type for a given Value (tensor) after bufferization.
691697FailureOr<BaseMemRefType>
692- bufferization::getBufferType (Value value, const BufferizationOptions &options) {
698+ bufferization::getBufferType (Value value, const BufferizationOptions &options,
699+ const BufferizationState &state) {
693700 SmallVector<Value> invocationStack;
694- return getBufferType (value, options, invocationStack);
701+ return getBufferType (value, options, state, invocationStack);
695702}
696703
697704// / Return the buffer type for a given Value (tensor) after bufferization.
698705FailureOr<BaseMemRefType>
699706bufferization::getBufferType (Value value, const BufferizationOptions &options,
707+ const BufferizationState &state,
700708 SmallVector<Value> &invocationStack) {
701709 assert (llvm::isa<TensorType>(value.getType ()) &&
702710 " unexpected non-tensor type" );
@@ -708,7 +716,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
708716 Operation *op = getOwnerOfValue (value);
709717 auto bufferizableOp = options.dynCastBufferizableOp (op);
710718 if (bufferizableOp)
711- return bufferizableOp.getBufferType (value, options, invocationStack);
719+ return bufferizableOp.getBufferType (value, options, state, invocationStack);
712720
713721 // Op is not bufferizable.
714722 auto memSpace =
@@ -944,6 +952,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
944952
945953FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType (
946954 Value value, const BufferizationOptions &options,
955+ const BufferizationState &bufferizationState,
947956 SmallVector<Value> &invocationStack) {
948957 assert (llvm::isa<TensorType>(value.getType ()) && " expected tensor type" );
949958
@@ -954,14 +963,15 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
954963 // Value is an OpResult.
955964 Operation *op = getOwnerOfValue (value);
956965 auto opResult = llvm::cast<OpResult>(value);
957- AnalysisState state (options);
958- AliasingOpOperandList aliases = state .getAliasingOpOperands (opResult);
966+ AnalysisState analysisState (options);
967+ AliasingOpOperandList aliases = analysisState .getAliasingOpOperands (opResult);
959968 if (aliases.getNumAliases () > 0 &&
960969 aliases.getAliases ()[0 ].relation == BufferRelation::Equivalent) {
961970 // If the OpResult has an equivalent OpOperand, both OpResult and
962971 // OpOperand bufferize to the exact same buffer type.
963972 Value equivalentOperand = aliases.getAliases ().front ().opOperand ->get ();
964- return getBufferType (equivalentOperand, options, invocationStack);
973+ return getBufferType (equivalentOperand, options, bufferizationState,
974+ invocationStack);
965975 }
966976
967977 // If we do not know the memory space and there is no default memory space,
0 commit comments