Skip to content

Commit 255ba1c

Browse files
[mlir][AffineMap] NFC - Refactor getProjectedMap and split into projectDims and projectSymbols
The default behavior of getProjectedMap may be surprising as it implicitly compresses the dims and the unused symbols. Make these explicit in the API and refactor to more idiomatic implementations with better reuse. Differential Revision: https://reviews.llvm.org/D146611
1 parent c8117eb commit 255ba1c

File tree

2 files changed

+144
-74
lines changed

2 files changed

+144
-74
lines changed

mlir/include/mlir/IR/AffineMap.h

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,9 @@ struct MutableAffineMap {
403403
/// Simplifies an affine map by simplifying its underlying AffineExpr results.
404404
AffineMap simplifyAffineMap(AffineMap map);
405405

406+
/// Drop the dims that are listed in `unusedDims`.
407+
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims);
408+
406409
/// Drop the dims that are not used.
407410
AffineMap compressUnusedDims(AffineMap map);
408411

@@ -411,8 +414,9 @@ AffineMap compressUnusedDims(AffineMap map);
411414
/// dims and symbols.
412415
SmallVector<AffineMap> compressUnusedDims(ArrayRef<AffineMap> maps);
413416

414-
/// Drop the dims that are not listed in `unusedDims`.
415-
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims);
417+
/// Drop the symbols that are listed in `unusedSymbols`.
418+
AffineMap compressSymbols(AffineMap map,
419+
const llvm::SmallBitVector &unusedSymbols);
416420

417421
/// Drop the symbols that are not used.
418422
AffineMap compressUnusedSymbols(AffineMap map);
@@ -422,10 +426,6 @@ AffineMap compressUnusedSymbols(AffineMap map);
422426
/// dims and symbols.
423427
SmallVector<AffineMap> compressUnusedSymbols(ArrayRef<AffineMap> maps);
424428

425-
/// Drop the symbols that are not listed in `unusedSymbols`.
426-
AffineMap compressSymbols(AffineMap map,
427-
const llvm::SmallBitVector &unusedSymbols);
428-
429429
/// Returns a map with the same dimension and symbol count as `map`, but whose
430430
/// results are the unique affine expressions of `map`.
431431
AffineMap removeDuplicateExprs(AffineMap map);
@@ -469,7 +469,7 @@ AffineMap inversePermutation(AffineMap map);
469469
/// Return the reverse map of a projected permutation where the projected
470470
/// dimensions are transformed into 0s.
471471
///
472-
/// Prerequisites: `map` must be a projected permuation.
472+
/// Prerequisites: `map` must be a projected permutation.
473473
///
474474
/// Example 1:
475475
///
@@ -559,9 +559,38 @@ AffineMap concatAffineMaps(ArrayRef<AffineMap> maps);
559559
/// projected_dimensions : {1}
560560
/// result : affine_map<(d0, d1) -> (d0, 0)>
561561
///
562-
/// This function also compresses unused symbols away.
562+
/// This function also compresses the dims when the boolean flag is true.
563+
AffineMap projectDims(AffineMap map,
564+
const llvm::SmallBitVector &projectedDimensions,
565+
bool compressDimsFlag = false);
566+
/// Symbol counterpart of `projectDims`.
567+
/// This function also compresses the symbols when the boolean flag is true.
568+
AffineMap projectSymbols(AffineMap map,
569+
const llvm::SmallBitVector &projectedSymbols,
570+
bool compressSymbolsFlag = false);
571+
/// Calls `projectDims(map, projectedDimensions, compressDimsFlag)`.
572+
/// If `compressSymbolsFlag` is true, additionally call `compressUnusedSymbols`.
563573
AffineMap getProjectedMap(AffineMap map,
564-
const llvm::SmallBitVector &projectedDimensions);
574+
const llvm::SmallBitVector &projectedDimensions,
575+
bool compressDimsFlag = true,
576+
bool compressSymbolsFlag = true);
577+
578+
// Return a bitvector where each bit set indicates a dimension that is not used
579+
// by any of the maps in the input array `maps`.
580+
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef<AffineMap> maps);
581+
582+
// Return a bitvector where each bit set indicates a symbol that is not used
583+
// by any of the maps in the input array `maps`.
584+
llvm::SmallBitVector getUnusedSymbolsBitVector(ArrayRef<AffineMap> maps);
585+
586+
inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
587+
map.print(os);
588+
return os;
589+
}
590+
591+
//===----------------------------------------------------------------------===//
592+
// Templated helper functions.
593+
//===----------------------------------------------------------------------===//
565594

566595
/// Apply a permutation from `map` to `source` and return the result.
567596
template <typename T>
@@ -584,7 +613,7 @@ SmallVector<T> applyPermutationMap(AffineMap map, llvm::ArrayRef<T> source) {
584613
return result;
585614
}
586615

587-
/// Calculates maxmimum dimension and symbol positions from the expressions
616+
/// Calculates maximum dimension and symbol positions from the expressions
588617
/// in `exprsLists` and stores them in `maxDim` and `maxSym` respectively.
589618
template <typename AffineExprContainer>
590619
static void getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList,
@@ -601,15 +630,6 @@ static void getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList,
601630
}
602631
}
603632

604-
inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
605-
map.print(os);
606-
return os;
607-
}
608-
609-
// Return a bitvector where each bit set indicates a dimension that is not used
610-
// by any of the maps in the input array `maps`.
611-
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef<AffineMap> maps);
612-
613633
} // namespace mlir
614634

615635
namespace llvm {

mlir/lib/IR/AffineMap.cpp

Lines changed: 105 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
#include "mlir/IR/BuiltinTypes.h"
1313
#include "mlir/Support/LogicalResult.h"
1414
#include "mlir/Support/MathExtras.h"
15+
#include "llvm/ADT/STLExtras.h"
1516
#include "llvm/ADT/SmallBitVector.h"
1617
#include "llvm/ADT/SmallSet.h"
1718
#include "llvm/ADT/StringRef.h"
1819
#include "llvm/Support/raw_ostream.h"
1920
#include <numeric>
2021
#include <optional>
22+
#include <type_traits>
2123

2224
using namespace mlir;
2325

@@ -569,32 +571,13 @@ AffineMap AffineMap::getMinorSubMap(unsigned numResults) const {
569571
return getSliceMap(getNumResults() - numResults, numResults);
570572
}
571573

572-
AffineMap mlir::compressDims(AffineMap map,
573-
const llvm::SmallBitVector &unusedDims) {
574-
unsigned numDims = 0;
575-
SmallVector<AffineExpr> dimReplacements;
576-
dimReplacements.reserve(map.getNumDims());
577-
MLIRContext *context = map.getContext();
578-
for (unsigned dim = 0, e = map.getNumDims(); dim < e; ++dim) {
579-
if (unusedDims.test(dim))
580-
dimReplacements.push_back(getAffineConstantExpr(0, context));
581-
else
582-
dimReplacements.push_back(getAffineDimExpr(numDims++, context));
583-
}
584-
SmallVector<AffineExpr> resultExprs;
585-
resultExprs.reserve(map.getNumResults());
586-
for (auto e : map.getResults())
587-
resultExprs.push_back(e.replaceDims(dimReplacements));
588-
return AffineMap::get(numDims, map.getNumSymbols(), resultExprs, context);
589-
}
590-
591-
AffineMap mlir::compressUnusedDims(AffineMap map) {
592-
return compressDims(map, getUnusedDimsBitVector({map}));
593-
}
594-
595-
static SmallVector<AffineMap>
596-
compressUnusedImpl(ArrayRef<AffineMap> maps,
597-
llvm::function_ref<AffineMap(AffineMap)> compressionFun) {
574+
/// Implementation detail to compress multiple affine maps with a compressionFun
575+
/// that is expected to be either compressUnusedDims or compressUnusedSymbols.
576+
/// The implementation keeps track of num dims and symbols across the different
577+
/// affine maps.
578+
static SmallVector<AffineMap> compressUnusedListImpl(
579+
ArrayRef<AffineMap> maps,
580+
llvm::function_ref<AffineMap(AffineMap)> compressionFun) {
598581
if (maps.empty())
599582
return SmallVector<AffineMap>();
600583
SmallVector<AffineExpr> allExprs;
@@ -622,41 +605,31 @@ compressUnusedImpl(ArrayRef<AffineMap> maps,
622605
return res;
623606
}
624607

608+
AffineMap mlir::compressDims(AffineMap map,
609+
const llvm::SmallBitVector &unusedDims) {
610+
return projectDims(map, unusedDims, /*compressDimsFlag=*/true);
611+
}
612+
613+
AffineMap mlir::compressUnusedDims(AffineMap map) {
614+
return compressDims(map, getUnusedDimsBitVector({map}));
615+
}
616+
625617
SmallVector<AffineMap> mlir::compressUnusedDims(ArrayRef<AffineMap> maps) {
626-
return compressUnusedImpl(maps,
627-
[](AffineMap m) { return compressUnusedDims(m); });
618+
return compressUnusedListImpl(
619+
maps, [](AffineMap m) { return compressUnusedDims(m); });
628620
}
629621

630622
AffineMap mlir::compressSymbols(AffineMap map,
631623
const llvm::SmallBitVector &unusedSymbols) {
632-
unsigned numSymbols = 0;
633-
SmallVector<AffineExpr> symReplacements;
634-
symReplacements.reserve(map.getNumSymbols());
635-
MLIRContext *context = map.getContext();
636-
for (unsigned sym = 0, e = map.getNumSymbols(); sym < e; ++sym) {
637-
if (unusedSymbols.test(sym))
638-
symReplacements.push_back(getAffineConstantExpr(0, context));
639-
else
640-
symReplacements.push_back(getAffineSymbolExpr(numSymbols++, context));
641-
}
642-
SmallVector<AffineExpr> resultExprs;
643-
resultExprs.reserve(map.getNumResults());
644-
for (auto e : map.getResults())
645-
resultExprs.push_back(e.replaceSymbols(symReplacements));
646-
return AffineMap::get(map.getNumDims(), numSymbols, resultExprs, context);
624+
return projectSymbols(map, unusedSymbols, /*compressSymbolsFlag=*/true);
647625
}
648626

649627
AffineMap mlir::compressUnusedSymbols(AffineMap map) {
650-
llvm::SmallBitVector unusedSymbols(map.getNumSymbols(), true);
651-
map.walkExprs([&](AffineExpr expr) {
652-
if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
653-
unusedSymbols.reset(symExpr.getPosition());
654-
});
655-
return compressSymbols(map, unusedSymbols);
628+
return compressSymbols(map, getUnusedSymbolsBitVector({map}));
656629
}
657630

658631
SmallVector<AffineMap> mlir::compressUnusedSymbols(ArrayRef<AffineMap> maps) {
659-
return compressUnusedImpl(
632+
return compressUnusedListImpl(
660633
maps, [](AffineMap m) { return compressUnusedSymbols(m); });
661634
}
662635

@@ -741,15 +714,80 @@ AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
741714
maps.front().getContext());
742715
}
743716

717+
/// Common implementation to project out dimensions or symbols from an affine
718+
/// map based on the template type.
719+
/// Additionally, if 'compress' is true, the projected out dimensions or symbols
720+
/// are also dropped from the resulting map.
721+
template <typename AffineDimOrSymExpr>
722+
static AffineMap projectCommonImpl(AffineMap map,
723+
const llvm::SmallBitVector &toProject,
724+
bool compress) {
725+
static_assert(llvm::is_one_of<AffineDimOrSymExpr, AffineDimExpr,
726+
AffineSymbolExpr>::value,
727+
"expected AffineDimExpr or AffineSymbolExpr");
728+
729+
constexpr bool isDim = std::is_same<AffineDimOrSymExpr, AffineDimExpr>::value;
730+
int64_t numDimOrSym = (isDim) ? map.getNumDims() : map.getNumSymbols();
731+
SmallVector<AffineExpr> replacements;
732+
replacements.reserve(numDimOrSym);
733+
734+
auto createNewDimOrSym = (isDim) ? getAffineDimExpr : getAffineSymbolExpr;
735+
auto replaceDims = [](AffineExpr e, ArrayRef<AffineExpr> replacements) {
736+
return e.replaceDims(replacements);
737+
};
738+
auto replaceSymbols = [](AffineExpr e, ArrayRef<AffineExpr> replacements) {
739+
return e.replaceSymbols(replacements);
740+
};
741+
auto replaceNewDimOrSym = (isDim) ? replaceDims : replaceSymbols;
742+
743+
MLIRContext *context = map.getContext();
744+
int64_t newNumDimOrSym = 0;
745+
for (unsigned dimOrSym = 0; dimOrSym < numDimOrSym; ++dimOrSym) {
746+
if (toProject.test(dimOrSym)) {
747+
replacements.push_back(getAffineConstantExpr(0, context));
748+
continue;
749+
}
750+
int64_t newPos = compress ? newNumDimOrSym++ : dimOrSym;
751+
replacements.push_back(createNewDimOrSym(newPos, context));
752+
}
753+
SmallVector<AffineExpr> resultExprs;
754+
resultExprs.reserve(map.getNumResults());
755+
for (auto e : map.getResults())
756+
resultExprs.push_back(replaceNewDimOrSym(e, replacements));
757+
758+
int64_t numDims = (compress && isDim) ? newNumDimOrSym : map.getNumDims();
759+
int64_t numSyms = (compress && !isDim) ? newNumDimOrSym : map.getNumSymbols();
760+
return AffineMap::get(numDims, numSyms, resultExprs, context);
761+
}
762+
763+
AffineMap mlir::projectDims(AffineMap map,
764+
const llvm::SmallBitVector &projectedDimensions,
765+
bool compressDimsFlag) {
766+
return projectCommonImpl<AffineDimExpr>(map, projectedDimensions,
767+
compressDimsFlag);
768+
}
769+
770+
AffineMap mlir::projectSymbols(AffineMap map,
771+
const llvm::SmallBitVector &projectedSymbols,
772+
bool compressSymbolsFlag) {
773+
return projectCommonImpl<AffineSymbolExpr>(map, projectedSymbols,
774+
compressSymbolsFlag);
775+
}
776+
744777
AffineMap mlir::getProjectedMap(AffineMap map,
745-
const llvm::SmallBitVector &unusedDims) {
746-
return compressUnusedSymbols(compressDims(map, unusedDims));
778+
const llvm::SmallBitVector &projectedDimensions,
779+
bool compressDimsFlag,
780+
bool compressSymbolsFlag) {
781+
map = projectDims(map, projectedDimensions, compressDimsFlag);
782+
if (compressSymbolsFlag)
783+
map = compressUnusedSymbols(map);
784+
return map;
747785
}
748786

749787
llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef<AffineMap> maps) {
750788
unsigned numDims = maps[0].getNumDims();
751789
llvm::SmallBitVector numDimsBitVector(numDims, true);
752-
for (const auto &m : maps) {
790+
for (AffineMap m : maps) {
753791
for (unsigned i = 0; i < numDims; ++i) {
754792
if (m.isFunctionOfDim(i))
755793
numDimsBitVector.reset(i);
@@ -758,6 +796,18 @@ llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef<AffineMap> maps) {
758796
return numDimsBitVector;
759797
}
760798

799+
llvm::SmallBitVector mlir::getUnusedSymbolsBitVector(ArrayRef<AffineMap> maps) {
800+
unsigned numSymbols = maps[0].getNumSymbols();
801+
llvm::SmallBitVector numSymbolsBitVector(numSymbols, true);
802+
for (AffineMap m : maps) {
803+
for (unsigned i = 0; i < numSymbols; ++i) {
804+
if (m.isFunctionOfSymbol(i))
805+
numSymbolsBitVector.reset(i);
806+
}
807+
}
808+
return numSymbolsBitVector;
809+
}
810+
761811
//===----------------------------------------------------------------------===//
762812
// MutableAffineMap.
763813
//===----------------------------------------------------------------------===//
@@ -784,8 +834,8 @@ bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
784834
return false;
785835
}
786836

787-
// Simplifies the result affine expressions of this map. The expressions have to
788-
// be pure for the simplification implemented.
837+
// Simplifies the result affine expressions of this map. The expressions
838+
// have to be pure for the simplification implemented.
789839
void MutableAffineMap::simplify() {
790840
// Simplify each of the results if possible.
791841
// TODO: functional-style map

0 commit comments

Comments
 (0)