Skip to content

Commit 5c96ba1

Browse files
committed
Add MLIRContext to concatAffineMaps
Signed-off-by: Ian Wood <[email protected]>
1 parent 6b4f568 commit 5c96ba1

File tree

6 files changed

+23
-16
lines changed

6 files changed

+23
-16
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ def LinalgStructuredInterface
741741
/*methodBody=*/"",
742742
/*defaultImplementation=*/[{
743743
auto maps = $_op.getIndexingMapsArray();
744-
return concatAffineMaps(maps);
744+
return concatAffineMaps(maps, $_op.getContext());
745745
}]
746746
>,
747747
InterfaceMethod<

mlir/include/mlir/IR/AffineMap.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -595,10 +595,7 @@ AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map);
595595
/// potentially empty maps. Assumes each of the underlying map has 0 symbols.
596596
/// The resulting map has a number of dims equal to the max of `maps`' dims and
597597
/// the concatenated results as its results.
598-
///
599-
/// This method asserts when `maps` is empty.
600-
/// TODO: this should return an empty map when `maps` is empty but there is no
601-
/// way to get the MLIRContext needed to construct it.
598+
/// Returns an empty map if all input `maps` are empty.
602599
///
603600
/// Example:
604601
/// When applied to the following list of 3 affine maps,
@@ -616,7 +613,7 @@ AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map);
616613
/// ```mlir
617614
/// (i, j, k) -> (i, k, k, j, i, j)
618615
/// ```
619-
AffineMap concatAffineMaps(ArrayRef<AffineMap> maps);
616+
AffineMap concatAffineMaps(ArrayRef<AffineMap> maps, MLIRContext *context);
620617

621618
/// Returns the map that results from projecting out the dimensions specified in
622619
/// `projectedDimensions`. The projected dimensions are set to 0.

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
5454
// if the op has no loops.
5555
return linalgOp.getNumLoops() == 0;
5656
}
57-
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
57+
return inversePermutation(concatAffineMaps(
58+
indexingMaps, linalgOp.getContext())) != AffineMap();
5859
}
5960

6061
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,8 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
392392
// 1. Check if any of the iteration dimensions are unit-trip count. They will
393393
// end up being unit-trip count if they are used to index into a unit-dim
394394
// tensor/memref.
395-
AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
395+
AffineMap invertedMap =
396+
inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext()));
396397
if (!invertedMap) {
397398
return rewriter.notifyMatchFailure(genericOp,
398399
"invalid indexing maps for operation");
@@ -486,7 +487,8 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
486487
// Abort if the indexing maps of the result operation are not invertible
487488
// (i.e. not legal) or if no dimension was reduced.
488489
if (newIndexingMaps == indexingMaps ||
489-
!inversePermutation(concatAffineMaps(newIndexingMaps)))
490+
!inversePermutation(
491+
concatAffineMaps(newIndexingMaps, rewriter.getContext())))
490492
return failure();
491493

492494
Location loc = genericOp.getLoc();

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,18 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
8888
indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
8989
}
9090
}
91+
if (indexingMaps.empty()) {
92+
// If there are no indexing maps, the operand can only be dropped
93+
// if neither ops op have loops.
94+
return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
95+
}
9196

9297
// The concatanation of the remained indexing maps must be invertible, so
9398
// the bounds of the op can be still computed after dropping the selected
9499
// operand. inversePermutation returns an empty AffineMap in case the
95100
// concatanated indexing maps are not invertible.
96-
return !indexingMaps.empty() &&
97-
inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
101+
return inversePermutation(concatAffineMaps(
102+
indexingMaps, producer.getContext())) != AffineMap();
98103
}
99104

100105
/// Returns a set of indices of the producer's results which would
@@ -1996,7 +2001,8 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
19962001
genericOp.getMatchingIndexingMap(&outputOperand));
19972002

19982003
// Check if the operation shapes to loops map is computable.
1999-
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
2004+
if (!inversePermutation(
2005+
concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) {
20002006
return rewriter.notifyMatchFailure(
20012007
genericOp, "fused op loop bound computation failed");
20022008
}

mlir/lib/IR/AffineMap.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -833,8 +833,10 @@ AffineMap mlir::inverseAndBroadcastProjectedPermutation(AffineMap map) {
833833
return AffineMap::get(map.getNumResults(), /*symbolCount=*/0, exprs, context);
834834
}
835835

836-
AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
837-
assert(maps.size());
836+
AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps,
837+
MLIRContext *context) {
838+
if (maps.empty())
839+
return AffineMap::get(context);
838840
unsigned numResults = 0, numDims = 0, numSymbols = 0;
839841
for (auto m : maps)
840842
numResults += m.getNumResults();
@@ -847,8 +849,7 @@ AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
847849
numSymbols += m.getNumSymbols();
848850
numDims = std::max(m.getNumDims(), numDims);
849851
}
850-
return AffineMap::get(numDims, numSymbols, results,
851-
maps.front().getContext());
852+
return AffineMap::get(numDims, numSymbols, results, context);
852853
}
853854

854855
/// Common implementation to project out dimensions or symbols from an affine

0 commit comments

Comments
 (0)