1010#include " iree/compiler/DispatchCreation/Passes.h"
1111#include " llvm/ADT/STLExtras.h"
1212#include " llvm/ADT/SetVector.h"
13+ #include " llvm/ADT/SmallVectorExtras.h"
1314#include " llvm/Support/Debug.h"
15+ #include " llvm/Support/LogicalResult.h"
1416#include " llvm/Support/raw_ostream.h"
1517#include " mlir/Analysis/SliceAnalysis.h"
1618#include " mlir/Dialect/Affine/Utils.h"
@@ -280,10 +282,15 @@ class CollapseInfo {
280282 // Debug print the current operation & reassociation indicies
281283 void dump () const ;
282284
283- // Update `collapsableLoops` by taking the set intersection with
284- // `otherCollapsable` and update the reassociation indicies accordingly.
285+ // Update CollapseInfo to ensure that all dimensions collapsable in `this` are
286+ // also collapsable in `consumerInfo`. This means:
287+ // 1. Any dimension not collapsable in `consumerInfo` should not be
288+ // collapsable in `this`
289+ // 2. For any pair of dimensions in `this`, if they are collapsable in
290+ // `consumerInfo`, they must be collapsable into the same dimension in
291+ // `consumerInfo` to be collapsable into the same dimension in `this`.
285292 // Returns true if the operation modified the number of collapsable loops.
286- bool updateCollapseViaIntersect ( const CollapsableLoopsSet &otherCollapsable );
293+ bool updateFromConsumer (OpOperand *operand, const CollapseInfo &consumerInfo );
287294
288295 // Update `collapsableLoops` by subtracting `uncollapsable` and update the
289296 // reassociation indicies accordingly.
@@ -293,13 +300,18 @@ class CollapseInfo {
293300 // Get `collapsableLoops` after applying the transformation provided by `map`.
294301 // Note: doesn't modify `collapsableLoops`, the tranformation is applied to a
295302 // copy.
296- FailureOr<CollapsableLoopsSet>
297- getTransformedCollapsableLoops (AffineMap map) const ;
303+ CollapsableLoopsSet getTransformedCollapsableLoops (AffineMap map) const ;
298304
299- // Clear internal data
300- void clear () {
305+ // Get `reassociation` after applying the transformation provided by `map`.
306+ SmallVector<ReassociationIndices>
307+ getTransformedReassociation (AffineMap map) const ;
308+
309+ // Clear internal data and returns if anything changed.
310+ bool clear () {
311+ bool isNotEmpty = reassociation.empty () || collapsableLoops.empty ();
301312 reassociation.clear ();
302313 collapsableLoops.clear ();
314+ return isNotEmpty;
303315 }
304316
305317 const CollapsableLoopsSet &getCollapsibleLoops () const {
@@ -386,12 +398,8 @@ void CollapseInfo::updateReassociation() {
386398// map = affine_map<(d0, d1, d2) -> (d1, d2, d5)>
387399//
388400// Therefore, the collapsable loops with respect to the consumer is {1, 2, 5}.
389- FailureOr< CollapseInfo::CollapsableLoopsSet>
401+ CollapseInfo::CollapsableLoopsSet
390402CollapseInfo::getTransformedCollapsableLoops (AffineMap map) const {
391- if (!map) {
392- return failure ();
393- }
394-
395403 CollapsableLoopsSet transformedLoops;
396404 for (auto index : collapsableLoops) {
397405 assert (index < map.getNumResults () && " index has no valid mapping" );
@@ -405,19 +413,114 @@ CollapseInfo::getTransformedCollapsableLoops(AffineMap map) const {
405413 return transformedLoops;
406414}
407415
408- // Update `collapsableLoops` by taking the set intersection with
409- // `otherCollapsable` and update the reassociation indicies accordingly.
410- bool CollapseInfo::updateCollapseViaIntersect (
411- const CollapsableLoopsSet &otherCollapsable) {
412- CollapsableLoopsSet toRemove;
413- for (auto elem : collapsableLoops) {
414- if (!otherCollapsable.contains (elem)) {
415- toRemove.insert (elem);
416+ SmallVector<ReassociationIndices>
417+ CollapseInfo::getTransformedReassociation (AffineMap map) const {
418+ SmallVector<ReassociationIndices> transformedReassociation (
419+ reassociation.size ());
420+ for (const auto &[i, indicies] : llvm::enumerate (reassociation)) {
421+ for (auto elem : indicies) {
422+ auto dimExpr = dyn_cast<AffineDimExpr>(map.getResult (elem));
423+ if (!dimExpr) {
424+ break ;
425+ }
426+ transformedReassociation[i].push_back (dimExpr.getPosition ());
416427 }
417428 }
418- collapsableLoops.set_subtract (toRemove);
419- updateReassociation ();
420- return toRemove.size ();
429+ return transformedReassociation;
430+ }
431+
432+ bool CollapseInfo::updateFromConsumer (OpOperand *operand,
433+ const CollapseInfo &consumerInfo) {
434+ FailureOr<AffineMap> consumerToProducerMap =
435+ getConsumerLoopToProducerLoopsMap (*operand);
436+ if (failed (consumerToProducerMap)) {
437+ return this ->clear ();
438+ }
439+
440+ CollapsableLoopsSet consumerCollapsable =
441+ consumerInfo.getTransformedCollapsableLoops (
442+ consumerToProducerMap.value ());
443+
444+ SmallVector<ReassociationIndices> consumerReassoc =
445+ consumerInfo.getTransformedReassociation (consumerToProducerMap.value ());
446+
447+ // Get a map from original index to the index it gets collapsed into
448+ llvm::DenseMap<long , long > consumerCollapseMap;
449+ for (const auto &[idx, indicies] : llvm::enumerate (consumerReassoc)) {
450+ for (const auto elem : indicies) {
451+ consumerCollapseMap[elem] = idx;
452+ }
453+ }
454+
455+ // Remove all collapsable loops in `producer` that are not collapsable in
456+ // `consumer` (set intersect)
457+ bool didChange = collapsableLoops.remove_if (
458+ [&](long elem) -> bool { return !consumerCollapsable.contains (elem); });
459+
460+ // Now update the reassociation indicies given the updated `collapsableLoops`
461+ // and `consumerCollapsableMap`.
462+ // The idea is to reconstruct the reassociation indicies, and at each index:
463+ // (1) If `index` IS NOT in `collapsableLoops`, split `indicies` and don't add
464+ // `index` to either.
465+ //
466+ // (2) If `index` IS in `collapsableLoops` but `consumerCollapseMap` maps
467+ // `index` to a different collapsed loop then the other indicies, split
468+ // `indicies` and insert `index` into the new one.
469+ //
470+ // For example:
471+ // producer reassociation = [[0, 1], [2, 3]]
472+ // consumer reassociation = [0, 1, 2, 3]
473+ // then, consumer reassociation gets updated to [[0, 1], [2, 3]] because
474+ // [0, 1] and [2, 3] get collapsed into different loops
475+ //
476+ // (3) Otherwise, keep the index
477+ constexpr long kUninitialized = -1 ;
478+ SmallVector<ReassociationIndices> newReassociation;
479+ for (ReassociationIndicesRef indicies : reassociation) {
480+ // Track the loop index that `indicies` get collapsed into.
481+ long collapseIntoIdx = kUninitialized ;
482+
483+ // Holds dimensions that should be collapsed together
484+ ReassociationIndices newIndicies;
485+ for (int64_t index : indicies) {
486+ if (!collapsableLoops.contains (index)) {
487+ // (1) Because `index` isn't collapsable, the indicies in `newIndicies`
488+ // are no longer adjacent to the upcoming indicies. If there is >1 index
489+ // to collapse, add it to the new reassociation. Otherwise, discard it
490+ // because there is no dimension to collapse with.
491+ didChange = true ;
492+ if (newIndicies.size () > 1 ) {
493+ newReassociation.push_back (std::move (newIndicies));
494+ }
495+ newIndicies.clear ();
496+ collapseIntoIdx = kUninitialized ;
497+ } else if (collapseIntoIdx == kUninitialized ) {
498+ // (2) First occurance of collapsable loop, set collapseIntoIdx.
499+ collapseIntoIdx = consumerCollapseMap.at (index);
500+ newIndicies.push_back (index);
501+ } else if (consumerCollapseMap.at (index) != collapseIntoIdx) {
502+ // (3) `index` is collapsable but not collapsable into the other loops.
503+ // So, split them and look for other loops to collapse `index` into.
504+ didChange = true ;
505+ if (newIndicies.size () > 1 ) {
506+ newReassociation.push_back (std::move (newIndicies));
507+ }
508+ newIndicies.clear ();
509+ collapseIntoIdx = consumerCollapseMap[index];
510+ newIndicies.push_back (index);
511+ } else {
512+ // (4) `index` is collapsable and can be collapsed into
513+ // `collapseIntoIndex`.
514+ newIndicies.push_back (index);
515+ }
516+ }
517+
518+ if (newIndicies.size () > 1 ) {
519+ newReassociation.push_back (newIndicies);
520+ }
521+ }
522+ reassociation = std::move (newReassociation);
523+ return didChange;
421524}
422525
423526// Update `collapsableLoops` by subtracting `uncollapsable` and update the
@@ -679,12 +782,10 @@ static bool updateConsumersFromProducers(
679782 continue ;
680783 }
681784
682- CollapseInfo &producerInfo = opMap.find (producerOp)-> second ;
683- FailureOr< CollapseInfo::CollapsableLoopsSet> producerCollapsable =
785+ const CollapseInfo &producerInfo = opMap.at (producerOp);
786+ CollapseInfo::CollapsableLoopsSet producerCollapsable =
684787 producerInfo.getTransformedCollapsableLoops (mapping.value ());
685- if (!failed (producerCollapsable)) {
686- producerUncollapsable.set_subtract (producerCollapsable.value ());
687- }
788+ producerUncollapsable.set_subtract (producerCollapsable);
688789
689790 didChange |=
690791 consumerInfo.updateCollapseViaSubtract (producerUncollapsable);
@@ -707,7 +808,7 @@ static bool updateProducersFromConsumers(
707808 for (auto op : llvm::reverse (slice)) {
708809 auto genericConsumer = cast<linalg::GenericOp>(op);
709810 assert (opMap.contains (genericConsumer));
710- const CollapseInfo &consumerInfo = opMap.find (genericConsumer)-> second ;
811+ const CollapseInfo &consumerInfo = opMap.at (genericConsumer);
711812
712813 for (auto operand : genericConsumer.getDpsInputOperands ()) {
713814 auto definingOp = operand->get ().getDefiningOp ();
@@ -721,26 +822,10 @@ static bool updateProducersFromConsumers(
721822
722823 // Get a mapping from the consumer's iteration space to the producer's.
723824 CollapseInfo &producerInfo = opMap.find (genericProducer)->second ;
724- FailureOr<AffineMap> consumerToProducerMap =
725- getConsumerLoopToProducerLoopsMap (*operand);
726- if (failed (consumerToProducerMap)) {
727- didChange |= !producerInfo.getCollapsibleLoops ().empty ();
728- producerInfo.clear ();
729- continue ;
730- }
731825
732- // Use the map to get the consumer's collapsable loops in terms of the
733- // producer.
734- auto consumerCollapsable = consumerInfo.getTransformedCollapsableLoops (
735- consumerToProducerMap.value ());
736- if (failed (consumerCollapsable)) {
737- producerInfo.clear ();
738- continue ;
739- }
740826 // Only loops collapsable in both the consumer and producer may be
741827 // collapsed.
742- didChange |=
743- producerInfo.updateCollapseViaIntersect (consumerCollapsable.value ());
828+ didChange |= producerInfo.updateFromConsumer (operand, consumerInfo);
744829 }
745830 }
746831 return didChange;
0 commit comments