12
12
#include " mlir/IR/BuiltinTypes.h"
13
13
#include " mlir/Support/LogicalResult.h"
14
14
#include " mlir/Support/MathExtras.h"
15
+ #include " llvm/ADT/STLExtras.h"
15
16
#include " llvm/ADT/SmallBitVector.h"
16
17
#include " llvm/ADT/SmallSet.h"
17
18
#include " llvm/ADT/StringRef.h"
18
19
#include " llvm/Support/raw_ostream.h"
19
20
#include < numeric>
20
21
#include < optional>
22
+ #include < type_traits>
21
23
22
24
using namespace mlir ;
23
25
@@ -569,32 +571,13 @@ AffineMap AffineMap::getMinorSubMap(unsigned numResults) const {
569
571
return getSliceMap (getNumResults () - numResults, numResults);
570
572
}
571
573
572
- AffineMap mlir::compressDims (AffineMap map,
573
- const llvm::SmallBitVector &unusedDims) {
574
- unsigned numDims = 0 ;
575
- SmallVector<AffineExpr> dimReplacements;
576
- dimReplacements.reserve (map.getNumDims ());
577
- MLIRContext *context = map.getContext ();
578
- for (unsigned dim = 0 , e = map.getNumDims (); dim < e; ++dim) {
579
- if (unusedDims.test (dim))
580
- dimReplacements.push_back (getAffineConstantExpr (0 , context));
581
- else
582
- dimReplacements.push_back (getAffineDimExpr (numDims++, context));
583
- }
584
- SmallVector<AffineExpr> resultExprs;
585
- resultExprs.reserve (map.getNumResults ());
586
- for (auto e : map.getResults ())
587
- resultExprs.push_back (e.replaceDims (dimReplacements));
588
- return AffineMap::get (numDims, map.getNumSymbols (), resultExprs, context);
589
- }
590
-
591
- AffineMap mlir::compressUnusedDims (AffineMap map) {
592
- return compressDims (map, getUnusedDimsBitVector ({map}));
593
- }
594
-
595
- static SmallVector<AffineMap>
596
- compressUnusedImpl (ArrayRef<AffineMap> maps,
597
- llvm::function_ref<AffineMap(AffineMap)> compressionFun) {
574
+ // / Implementation detail to compress multiple affine maps with a compressionFun
575
+ // / that is expected to be either compressUnusedDims or compressUnusedSymbols.
576
+ // / The implementation keeps track of num dims and symbols across the different
577
+ // / affine maps.
578
+ static SmallVector<AffineMap> compressUnusedListImpl (
579
+ ArrayRef<AffineMap> maps,
580
+ llvm::function_ref<AffineMap(AffineMap)> compressionFun) {
598
581
if (maps.empty ())
599
582
return SmallVector<AffineMap>();
600
583
SmallVector<AffineExpr> allExprs;
@@ -622,41 +605,31 @@ compressUnusedImpl(ArrayRef<AffineMap> maps,
622
605
return res;
623
606
}
624
607
608
+ AffineMap mlir::compressDims (AffineMap map,
609
+ const llvm::SmallBitVector &unusedDims) {
610
+ return projectDims (map, unusedDims, /* compressDimsFlag=*/ true );
611
+ }
612
+
613
+ AffineMap mlir::compressUnusedDims (AffineMap map) {
614
+ return compressDims (map, getUnusedDimsBitVector ({map}));
615
+ }
616
+
625
617
SmallVector<AffineMap> mlir::compressUnusedDims (ArrayRef<AffineMap> maps) {
626
- return compressUnusedImpl (maps,
627
- [](AffineMap m) { return compressUnusedDims (m); });
618
+ return compressUnusedListImpl (
619
+ maps, [](AffineMap m) { return compressUnusedDims (m); });
628
620
}
629
621
630
622
AffineMap mlir::compressSymbols (AffineMap map,
631
623
const llvm::SmallBitVector &unusedSymbols) {
632
- unsigned numSymbols = 0 ;
633
- SmallVector<AffineExpr> symReplacements;
634
- symReplacements.reserve (map.getNumSymbols ());
635
- MLIRContext *context = map.getContext ();
636
- for (unsigned sym = 0 , e = map.getNumSymbols (); sym < e; ++sym) {
637
- if (unusedSymbols.test (sym))
638
- symReplacements.push_back (getAffineConstantExpr (0 , context));
639
- else
640
- symReplacements.push_back (getAffineSymbolExpr (numSymbols++, context));
641
- }
642
- SmallVector<AffineExpr> resultExprs;
643
- resultExprs.reserve (map.getNumResults ());
644
- for (auto e : map.getResults ())
645
- resultExprs.push_back (e.replaceSymbols (symReplacements));
646
- return AffineMap::get (map.getNumDims (), numSymbols, resultExprs, context);
624
+ return projectSymbols (map, unusedSymbols, /* compressSymbolsFlag=*/ true );
647
625
}
648
626
649
627
AffineMap mlir::compressUnusedSymbols (AffineMap map) {
650
- llvm::SmallBitVector unusedSymbols (map.getNumSymbols (), true );
651
- map.walkExprs ([&](AffineExpr expr) {
652
- if (auto symExpr = expr.dyn_cast <AffineSymbolExpr>())
653
- unusedSymbols.reset (symExpr.getPosition ());
654
- });
655
- return compressSymbols (map, unusedSymbols);
628
+ return compressSymbols (map, getUnusedSymbolsBitVector ({map}));
656
629
}
657
630
658
631
SmallVector<AffineMap> mlir::compressUnusedSymbols (ArrayRef<AffineMap> maps) {
659
- return compressUnusedImpl (
632
+ return compressUnusedListImpl (
660
633
maps, [](AffineMap m) { return compressUnusedSymbols (m); });
661
634
}
662
635
@@ -741,15 +714,80 @@ AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
741
714
maps.front ().getContext ());
742
715
}
743
716
717
+ // / Common implementation to project out dimensions or symbols from an affine
718
+ // / map based on the template type.
719
+ // / Additionally, if 'compress' is true, the projected out dimensions or symbols
720
+ // / are also dropped from the resulting map.
721
+ template <typename AffineDimOrSymExpr>
722
+ static AffineMap projectCommonImpl (AffineMap map,
723
+ const llvm::SmallBitVector &toProject,
724
+ bool compress) {
725
+ static_assert (llvm::is_one_of<AffineDimOrSymExpr, AffineDimExpr,
726
+ AffineSymbolExpr>::value,
727
+ " expected AffineDimExpr or AffineSymbolExpr" );
728
+
729
+ constexpr bool isDim = std::is_same<AffineDimOrSymExpr, AffineDimExpr>::value;
730
+ int64_t numDimOrSym = (isDim) ? map.getNumDims () : map.getNumSymbols ();
731
+ SmallVector<AffineExpr> replacements;
732
+ replacements.reserve (numDimOrSym);
733
+
734
+ auto createNewDimOrSym = (isDim) ? getAffineDimExpr : getAffineSymbolExpr;
735
+ auto replaceDims = [](AffineExpr e, ArrayRef<AffineExpr> replacements) {
736
+ return e.replaceDims (replacements);
737
+ };
738
+ auto replaceSymbols = [](AffineExpr e, ArrayRef<AffineExpr> replacements) {
739
+ return e.replaceSymbols (replacements);
740
+ };
741
+ auto replaceNewDimOrSym = (isDim) ? replaceDims : replaceSymbols;
742
+
743
+ MLIRContext *context = map.getContext ();
744
+ int64_t newNumDimOrSym = 0 ;
745
+ for (unsigned dimOrSym = 0 ; dimOrSym < numDimOrSym; ++dimOrSym) {
746
+ if (toProject.test (dimOrSym)) {
747
+ replacements.push_back (getAffineConstantExpr (0 , context));
748
+ continue ;
749
+ }
750
+ int64_t newPos = compress ? newNumDimOrSym++ : dimOrSym;
751
+ replacements.push_back (createNewDimOrSym (newPos, context));
752
+ }
753
+ SmallVector<AffineExpr> resultExprs;
754
+ resultExprs.reserve (map.getNumResults ());
755
+ for (auto e : map.getResults ())
756
+ resultExprs.push_back (replaceNewDimOrSym (e, replacements));
757
+
758
+ int64_t numDims = (compress && isDim) ? newNumDimOrSym : map.getNumDims ();
759
+ int64_t numSyms = (compress && !isDim) ? newNumDimOrSym : map.getNumSymbols ();
760
+ return AffineMap::get (numDims, numSyms, resultExprs, context);
761
+ }
762
+
763
+ AffineMap mlir::projectDims (AffineMap map,
764
+ const llvm::SmallBitVector &projectedDimensions,
765
+ bool compressDimsFlag) {
766
+ return projectCommonImpl<AffineDimExpr>(map, projectedDimensions,
767
+ compressDimsFlag);
768
+ }
769
+
770
+ AffineMap mlir::projectSymbols (AffineMap map,
771
+ const llvm::SmallBitVector &projectedSymbols,
772
+ bool compressSymbolsFlag) {
773
+ return projectCommonImpl<AffineSymbolExpr>(map, projectedSymbols,
774
+ compressSymbolsFlag);
775
+ }
776
+
744
777
AffineMap mlir::getProjectedMap (AffineMap map,
745
- const llvm::SmallBitVector &unusedDims) {
746
- return compressUnusedSymbols (compressDims (map, unusedDims));
778
+ const llvm::SmallBitVector &projectedDimensions,
779
+ bool compressDimsFlag,
780
+ bool compressSymbolsFlag) {
781
+ map = projectDims (map, projectedDimensions, compressDimsFlag);
782
+ if (compressSymbolsFlag)
783
+ map = compressUnusedSymbols (map);
784
+ return map;
747
785
}
748
786
749
787
llvm::SmallBitVector mlir::getUnusedDimsBitVector (ArrayRef<AffineMap> maps) {
750
788
unsigned numDims = maps[0 ].getNumDims ();
751
789
llvm::SmallBitVector numDimsBitVector (numDims, true );
752
- for (const auto & m : maps) {
790
+ for (AffineMap m : maps) {
753
791
for (unsigned i = 0 ; i < numDims; ++i) {
754
792
if (m.isFunctionOfDim (i))
755
793
numDimsBitVector.reset (i);
@@ -758,6 +796,18 @@ llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef<AffineMap> maps) {
758
796
return numDimsBitVector;
759
797
}
760
798
799
+ llvm::SmallBitVector mlir::getUnusedSymbolsBitVector (ArrayRef<AffineMap> maps) {
800
+ unsigned numSymbols = maps[0 ].getNumSymbols ();
801
+ llvm::SmallBitVector numSymbolsBitVector (numSymbols, true );
802
+ for (AffineMap m : maps) {
803
+ for (unsigned i = 0 ; i < numSymbols; ++i) {
804
+ if (m.isFunctionOfSymbol (i))
805
+ numSymbolsBitVector.reset (i);
806
+ }
807
+ }
808
+ return numSymbolsBitVector;
809
+ }
810
+
761
811
// ===----------------------------------------------------------------------===//
762
812
// MutableAffineMap.
763
813
// ===----------------------------------------------------------------------===//
@@ -784,8 +834,8 @@ bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
784
834
return false ;
785
835
}
786
836
787
- // Simplifies the result affine expressions of this map. The expressions have to
788
- // be pure for the simplification implemented.
837
+ // Simplifies the result affine expressions of this map. The expressions
838
+ // have to be pure for the simplification implemented.
789
839
void MutableAffineMap::simplify () {
790
840
// Simplify each of the results if possible.
791
841
// TODO: functional-style map
0 commit comments