@@ -196,7 +196,12 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
196196
197197 // If there is no preceding definition, the tensor contents are
198198 // undefined.
199- if (findDefinitionsCached (opResult).empty ())
199+ if (opResult.getUses ().empty ())
200+ return WalkResult::skip ();
201+ // It does not really matter which use to take to search about
202+ // the value's definitions.
203+ OpOperand *opOperand = &(*opResult.getUses ().begin ());
204+ if (findDefinitionsCached (opOperand).empty ())
200205 for (OpOperand &use : opResult.getUses ())
201206 undefinedTensorUses.insert (&use);
202207 }
@@ -464,20 +469,22 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
464469// / indexing. I.e., the tensor types do not change along the use-def chain,
465470// / apart from static <-> dynamic dim casts.
466471static bool hasEquivalentValueInReverseUseDefChain (AnalysisState &state,
467- Value start, Value other) {
472+ OpOperand *start,
473+ OpOperand *other) {
468474 TraversalConfig config;
469475 config.followEquivalentOnly = true ;
470476 config.alwaysIncludeLeaves = false ;
471477 config.followSameTypeOrCastsOnly = true ;
472478 return !state
473479 .findValueInReverseUseDefChain (
474- start, [&](Value v) { return v == other; }, config)
480+ start, [&](Value v) { return v == other-> get () ; }, config)
475481 .empty ();
476482}
477483
478- // / Return "true" if `value` is originating from a subset that is equivalent to
479- // / the subset that `subsetOp` inserts into.
480- static bool matchesInsertDestination (const AnalysisState &state, Value value,
484+ // / Return "true" if `opOperand` is originating from a subset that is equivalent
485+ // / to the subset that `subsetOp` inserts into.
486+ static bool matchesInsertDestination (const AnalysisState &state,
487+ OpOperand *opOperand,
481488 SubsetInsertionOpInterface subsetOp) {
482489 auto matchingSubset = [&](Value val) {
483490 if (auto opResult = dyn_cast<OpResult>(val))
@@ -490,7 +497,7 @@ static bool matchesInsertDestination(const AnalysisState &state, Value value,
490497 // There may be multiple leaves at which the reverse SSA use-def chain lookup
491498 // terminates. All of them must be equivalent subsets.
492499 SetVector<Value> backwardSlice =
493- state.findValueInReverseUseDefChain (value , matchingSubset);
500+ state.findValueInReverseUseDefChain (opOperand , matchingSubset);
494501 return static_cast <bool >(llvm::all_of (backwardSlice, matchingSubset));
495502}
496503
@@ -516,7 +523,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
516523 // {inplace= [true] }
517524
518525 if (uRead == &subsetOp.getDestinationOperand () &&
519- matchesInsertDestination (state, uConflictingWrite-> get () , subsetOp))
526+ matchesInsertDestination (state, uConflictingWrite, subsetOp))
520527 // Case 1: The main insight is that InsertSliceOp reads only part of
521528 // the destination tensor. The overwritten area is not read. If
522529 // uConflictingWrite writes into exactly the memory location that is
@@ -533,7 +540,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
533540
534541 if (uRead == &subsetOp.getSourceOperand () &&
535542 uConflictingWrite == &subsetOp.getDestinationOperand () &&
536- matchesInsertDestination (state, uRead-> get () , subsetOp))
543+ matchesInsertDestination (state, uRead, subsetOp))
537544 // Case 2: The read of the source tensor and the write to the dest
538545 // tensor via an InsertSliceOp is not a conflict if the read is
539546 // reading exactly that part of an equivalent tensor that the
@@ -567,8 +574,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
567574 if (uConflictingWrite == &subsetOp.getDestinationOperand () &&
568575 state.areEquivalentBufferizedValues (
569576 uRead->get (), subsetOp.getSourceOperand ().get ()) &&
570- matchesInsertDestination (state, subsetOp.getSourceOperand ().get (),
571- subsetOp))
577+ matchesInsertDestination (state, &subsetOp.getSourceOperand (), subsetOp))
572578 return true ;
573579
574580 return false ;
@@ -600,9 +606,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
600606 // even though that op just bufferizes to an allocation but does define
601607 // the contents of the buffer.
602608 SetVector<Value> definitionsOrLeaves =
603- state.findValueInReverseUseDefChain (
604- uConflictingWrite-> get (),
605- [&](Value v) { return state. bufferizesToMemoryWrite (v); });
609+ state.findValueInReverseUseDefChain (uConflictingWrite, [&](Value v) {
610+ return state. bufferizesToMemoryWrite (v);
611+ });
606612 assert (!definitionsOrLeaves.empty () &&
607613 " expected at least one definition or leaf" );
608614
@@ -641,8 +647,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
641647 // In the above example, if uRead is the OpOperand of reading_op, the
642648 // definition is %0. Note that operations that create an alias but do not
643649 // bufferize to a memory write (such as ExtractSliceOp) are skipped.
644- const SetVector<Value> &definitions =
645- state.findDefinitionsCached (uRead->get ());
650+ const SetVector<Value> &definitions = state.findDefinitionsCached (uRead);
646651 if (definitions.empty ()) {
647652 // Fast path: No conflict if there are no definitions.
648653 LLVM_DEBUG (llvm::dbgs ()
@@ -713,10 +718,10 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
713718 if (auto bufferizableOp = options.dynCastBufferizableOp (readingOp)) {
714719 if (bufferizableOp.bufferizesToElementwiseAccess (
715720 state, {uRead, uConflictingWrite})) {
716- if (hasEquivalentValueInReverseUseDefChain (
717- state, uRead-> get (), uConflictingWrite-> get () ) ||
721+ if (hasEquivalentValueInReverseUseDefChain (state, uRead,
722+ uConflictingWrite) ||
718723 hasEquivalentValueInReverseUseDefChain (
719- state, uConflictingWrite-> get () , uRead-> get () )) {
724+ state, uConflictingWrite, uRead)) {
720725 LLVM_DEBUG (
721726 llvm::dbgs ()
722727 << " no conflict: op bufferizes to element-wise access\n " );
@@ -965,11 +970,12 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
965970// Bufferization analyses.
966971// ===----------------------------------------------------------------------===//
967972
968- // Find the values that define the contents of the given value .
973+ // Find the values that define the contents of the given opOperand .
969974const llvm::SetVector<Value> &
970- OneShotAnalysisState::findDefinitionsCached (Value value) {
975+ OneShotAnalysisState::findDefinitionsCached (OpOperand *opOperand) {
976+ Value value = opOperand->get ();
971977 if (!cachedDefinitions.count (value))
972- cachedDefinitions[value] = findDefinitions (value );
978+ cachedDefinitions[value] = findDefinitions (opOperand );
973979 return cachedDefinitions[value];
974980}
975981
0 commit comments