@@ -314,19 +314,20 @@ class CollapseInfo {
314314 void dump () const ;
315315
316316 // Update CollapseInfo to ensure that all dimensions collapsable in `this` are
317- // also collapsable in `consumerInfo `. This means:
318- // 1. Any dimension not collapsable in `consumerInfo ` should not be
317+ // also collapsable in `otherInfo `. This means:
318+ // 1. Any dimension not collapsable in `otherInfo ` should not be
319319 // collapsable in `this`
320320 // 2. For any pair of dimensions in `this`, if they are collapsable in
321- // `consumerInfo `, they must be collapsable into the same dimension in
322- // `consumerInfo ` to be collapsable into the same dimension in `this`.
321+ // `otherInfo `, they must be collapsable into the same dimension in
322+ // `otherInfo ` to be collapsable into the same dimension in `this`.
323323 // Returns true if the operation modified the number of collapsable loops.
324- bool updateFromConsumer (OpOperand *operand, const CollapseInfo &consumerInfo);
324+ bool updateFromOther (FailureOr<AffineMap> otherToThisMap,
325+ const CollapseInfo &otherInfo);
325326
326- // Update `collapsableLoops` by subtracting `uncollapsable` and update the
327- // reassociation indicies accordingly.
328- // Returns true if the operation modified the number of collapsable loops .
329- bool updateCollapseViaSubtract ( const CollapsableLoopsSet &uncollapsable );
327+ // Update `this` (which is the info for `op`) when either a producer or
328+ // consumer is not collapsible. This is done by considering all the dims
329+ // accessed by other to be uncollapsible .
330+ bool updateFromUncollapsible (Operation *op, OpOperand *operand );
330331
331332 // Get `collapsableLoops` after applying the transformation provided by `map`.
332333 // Note: doesn't modify `collapsableLoops`, the tranformation is applied to a
@@ -460,48 +461,56 @@ CollapseInfo::getTransformedReassociation(AffineMap map) const {
460461 return transformedReassociation;
461462}
462463
463- bool CollapseInfo::updateFromConsumer (OpOperand *operand,
464- const CollapseInfo &consumerInfo) {
465- FailureOr<AffineMap> consumerToProducerMap =
466- getConsumerLoopToProducerLoopsMap (*operand);
467- if (failed (consumerToProducerMap)) {
464+ bool CollapseInfo::updateFromOther (FailureOr<AffineMap> otherToThisMap,
465+ const CollapseInfo &otherInfo) {
466+ if (failed (otherToThisMap)) {
468467 return this ->clear ();
469468 }
470469
471- CollapsableLoopsSet consumerCollapsable =
472- consumerInfo.getTransformedCollapsableLoops (
473- consumerToProducerMap.value ());
470+ CollapsableLoopsSet otherCollapsible =
471+ otherInfo.getTransformedCollapsableLoops (otherToThisMap.value ());
474472
475- SmallVector<ReassociationIndices> consumerReassoc =
476- consumerInfo .getTransformedReassociation (consumerToProducerMap .value ());
473+ SmallVector<ReassociationIndices> otherReassoc =
474+ otherInfo .getTransformedReassociation (otherToThisMap .value ());
477475
478476 // Get a map from original index to the index it gets collapsed into
479- llvm::DenseMap<long , long > consumerCollapseMap ;
480- for (const auto &[idx, indicies] : llvm::enumerate (consumerReassoc )) {
477+ llvm::DenseMap<long , long > otherCollapseMap ;
478+ for (const auto &[idx, indicies] : llvm::enumerate (otherReassoc )) {
481479 for (const auto elem : indicies) {
482- consumerCollapseMap [elem] = idx;
480+ otherCollapseMap [elem] = idx;
483481 }
484482 }
485483
486- // Remove all collapsable loops in `producer` that are not collapsable in
487- // `consumer` (set intersect)
488- bool didChange = collapsableLoops.remove_if (
489- [&](long elem) -> bool { return !consumerCollapsable.contains (elem); });
484+ // Remove all collapsable loops in `this` that both exist and are not
485+ // collapsable in `other` (set intersect)
486+ bool didChange = collapsableLoops.remove_if ([&](long elem) -> bool {
487+ // Exists and is collapsable
488+ if (otherCollapsible.contains (elem)) {
489+ return false ;
490+ }
491+
492+ // Does not exist in `other`.
493+ if (!otherToThisMap->isFunctionOfDim (elem)) {
494+ return false ;
495+ }
496+
497+ return true ;
498+ });
490499
491500 // Now update the reassociation indicies given the updated `collapsableLoops`
492- // and `consumerCollapsableMap `.
501+ // and `otherCollapsableMap `.
493502 // The idea is to reconstruct the reassociation indicies, and at each index:
494503 // (1) If `index` IS NOT in `collapsableLoops`, split `indicies` and don't add
495504 // `index` to either.
496505 //
497- // (2) If `index` IS in `collapsableLoops` but `consumerCollapseMap ` maps
506+ // (2) If `index` IS in `collapsableLoops` but `otherCollapseMap ` maps
498507 // `index` to a different collapsed loop then the other indicies, split
499508 // `indicies` and insert `index` into the new one.
500509 //
501510 // For example:
502- // producer reassociation = [[0, 1], [2, 3]]
503- // consumer reassociation = [0, 1, 2, 3]
504- // then, consumer reassociation gets updated to [[0, 1], [2, 3]] because
511+ // `this` reassociation = [[0, 1], [2, 3]]
512+ // `other` reassociation = [0, 1, 2, 3]
513+ // then, `other` reassociation gets updated to [[0, 1], [2, 3]] because
505514 // [0, 1] and [2, 3] get collapsed into different loops
506515 //
507516 // (3) Otherwise, keep the index
@@ -525,22 +534,25 @@ bool CollapseInfo::updateFromConsumer(OpOperand *operand,
525534 }
526535 newIndicies.clear ();
527536 collapseIntoIdx = kUninitialized ;
537+ } else if (!otherCollapseMap.contains (index)) {
538+ // (2) `index` does not exist in `other`.
539+ newIndicies.push_back (index);
528540 } else if (collapseIntoIdx == kUninitialized ) {
529- // (2 ) First occurance of collapsable loop, set collapseIntoIdx.
530- collapseIntoIdx = consumerCollapseMap .at (index);
541+ // (3 ) First occurance of collapsable loop, set collapseIntoIdx.
542+ collapseIntoIdx = otherCollapseMap .at (index);
531543 newIndicies.push_back (index);
532- } else if (consumerCollapseMap .at (index) != collapseIntoIdx) {
533- // (3 ) `index` is collapsable but not collapsable into the other loops.
544+ } else if (otherCollapseMap .at (index) != collapseIntoIdx) {
545+ // (4 ) `index` is collapsable but not collapsable into the other loops.
534546 // So, split them and look for other loops to collapse `index` into.
535547 didChange = true ;
536548 if (newIndicies.size () > 1 ) {
537549 newReassociation.push_back (std::move (newIndicies));
538550 }
539551 newIndicies.clear ();
540- collapseIntoIdx = consumerCollapseMap [index];
552+ collapseIntoIdx = otherCollapseMap [index];
541553 newIndicies.push_back (index);
542554 } else {
543- // (4 ) `index` is collapsable and can be collapsed into
555+ // (5 ) `index` is collapsable and can be collapsed into
544556 // `collapseIntoIndex`.
545557 newIndicies.push_back (index);
546558 }
@@ -554,10 +566,17 @@ bool CollapseInfo::updateFromConsumer(OpOperand *operand,
554566 return didChange;
555567}
556568
557- // Update `collapsableLoops` by subtracting `uncollapsable` and update the
558- // reassociation indicies accordingly.
559- bool CollapseInfo::updateCollapseViaSubtract (
560- const CollapsableLoopsSet &uncollapsable) {
569+ bool CollapseInfo::updateFromUncollapsible (Operation *op, OpOperand *operand) {
570+ auto fusionOp = cast<LinalgFusionOpInterface>(op);
571+ AffineMap map = operand->getOwner () == op
572+ ? fusionOp.getMatchingIndexingMap (operand)
573+ : fusionOp.getIndexingMapMatchingResult (
574+ cast<OpResult>(operand->get ()));
575+
576+ CollapseInfo::CollapsableLoopsSet uncollapsable;
577+ for (auto expr : map.getResults ()) {
578+ uncollapsable.insert (cast<AffineDimExpr>(expr).getPosition ());
579+ }
561580 auto initialSize = collapsableLoops.size ();
562581 collapsableLoops.set_subtract (uncollapsable);
563582 updateReassociation ();
@@ -791,35 +810,18 @@ updateConsumersFromProducers(ArrayRef<Operation *> slice,
791810 continue ;
792811 }
793812
794- // Track the dimensions that are not collapsable by this current op.
795- // Initialize this with all loops in thel producer. Note: the dims are
796- // relative to the consumers iteration space, not the producers. This
797- // cannot be done via union of producer and consumer collapsable loops
798- // because the consumer may have loops that the producer does not.
799- CollapseInfo::CollapsableLoopsSet producerUncollapsable;
800- for (auto expr :
801- consumerOp.getMatchingIndexingMap (operand).getResults ()) {
802- producerUncollapsable.insert (cast<AffineDimExpr>(expr).getPosition ());
803- }
804-
805- FailureOr<AffineMap> mapping =
806- getProducerLoopToConsumerLoopsMap (*operand);
807-
808- // If there is no mapping or we can't find the op, the tensor is
809- // not collapsable. So, all dimensions of the producer are uncollapsable.
810- if (!opMap.contains (producerOp) || failed (mapping)) {
811- didChange |=
812- consumerInfo.updateCollapseViaSubtract (producerUncollapsable);
813+ // If we can't find the op, the tensor is not collapsable. So, consider
814+ // all the dimensions of the producer to be uncollapsable.
815+ if (!opMap.contains (producerOp)) {
816+ didChange |= consumerInfo.updateFromUncollapsible (consumerOp, operand);
813817 continue ;
814818 }
815819
816820 const CollapseInfo &producerInfo = opMap.at (producerOp);
817- CollapseInfo::CollapsableLoopsSet producerCollapsable =
818- producerInfo.getTransformedCollapsableLoops (mapping.value ());
819- producerUncollapsable.set_subtract (producerCollapsable);
820-
821+ FailureOr<AffineMap> consumerToProducerMap =
822+ getProducerLoopToConsumerLoopsMap (*operand);
821823 didChange |=
822- consumerInfo.updateCollapseViaSubtract (producerUncollapsable );
824+ consumerInfo.updateFromOther (consumerToProducerMap, producerInfo );
823825 }
824826 }
825827 return didChange;
@@ -837,21 +839,31 @@ updateProducersFromConsumers(ArrayRef<Operation *> slice,
837839 // Iterate over `slice` in reverse so that we visit each `op` 's consumer
838840 // before visiting `op`.
839841 for (auto op : llvm::reverse (slice)) {
840- auto consumerOp = cast<DestinationStyleOpInterface >(op);
841- const CollapseInfo &consumerInfo = opMap.at (consumerOp) ;
842+ auto producerOp = cast<LinalgFusionOpInterface >(op);
843+ CollapseInfo &producerInfo = opMap.find (producerOp)-> second ;
842844
843- for (auto *operand : consumerOp.getDpsInputOperands ()) {
844- auto definingOp = operand->get ().getDefiningOp ();
845- if (!definingOp || !opMap.contains (definingOp)) {
845+ for (auto &operand : producerOp->getUses ()) {
846+ auto *consumerOp = operand.getOwner ();
847+ if (consumerOp->hasTrait <OpTrait::IsTerminator>()) {
848+ continue ;
849+ }
850+
851+ // If we can't find the op, the tensor is not collapsable. So, consider
852+ // all the dimensions of the consumer to be uncollapsable.
853+ if (!opMap.contains (consumerOp)) {
854+ didChange |= producerInfo.updateFromUncollapsible (producerOp, &operand);
846855 continue ;
847856 }
848857
849858 // Get a mapping from the consumer's iteration space to the producer's.
850- CollapseInfo &producerInfo = opMap.find (definingOp)-> second ;
859+ const CollapseInfo &consumerInfo = opMap.at (consumerOp) ;
851860
852861 // Only loops collapsable in both the consumer and producer may be
853862 // collapsed.
854- didChange |= producerInfo.updateFromConsumer (operand, consumerInfo);
863+ FailureOr<AffineMap> consumerToProducerMap =
864+ getConsumerLoopToProducerLoopsMap (operand);
865+ didChange |=
866+ producerInfo.updateFromOther (consumerToProducerMap, consumerInfo);
855867 }
856868 }
857869 return didChange;
0 commit comments