From 6b4f5682647f746ee56112978c3e024003de6284 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Tue, 26 Nov 2024 08:20:17 -0800 Subject: [PATCH 1/3] Fix linalg crash during elementwise op fusion Signed-off-by: Ian Wood --- mlir/include/mlir/IR/AffineMap.h | 5 ++- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 3 +- mlir/lib/IR/AffineMap.cpp | 1 + .../Dialect/Linalg/fusion-elementwise.mlir | 31 +++++++++++++++++++ 4 files changed, 38 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h index e30950bbf292d..4cfc538a2fe72 100644 --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -595,7 +595,10 @@ AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map); /// potentially empty maps. Assumes each of the underlying map has 0 symbols. /// The resulting map has a number of dims equal to the max of `maps`' dims and /// the concatenated results as its results. -/// Returns an empty map if all input `maps` are empty. +/// +/// This method asserts when `maps` is empty. +/// TODO: this should return an empty map when `maps` is empty but there is no +/// way to get the MLIRContext needed to construct it. /// /// Example: /// When applied to the following list of 3 affine maps, diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index a934e47794051..6ddea53dfb997 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -93,7 +93,8 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs( // the bounds of the op can be still computed after dropping the selected // operand. inversePermutation returns an empty AffineMap in case the // concatanated indexing maps are not invertible. - return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap(); + return !indexingMaps.empty() && + inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap(); } /// Returns a set of indices of the producer's results which would diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index ea3c0723b0775..8056058511ebe 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -834,6 +834,7 @@ AffineMap mlir::inverseAndBroadcastProjectedPermutation(AffineMap map) { } AffineMap mlir::concatAffineMaps(ArrayRef maps) { + assert(maps.size()); unsigned numResults = 0, numDims = 0, numSymbols = 0; for (auto m : maps) numResults += m.getNumResults(); diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir index 8131e4054cc6b..bd9977f1410b9 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir @@ -28,3 +28,34 @@ func.func @drop_unused_producer_result(%arg0 : tensor, // CHECK: %[[FUSED_OP:.+]] = linalg.generic // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : // CHECK: return %[[FUSED_OP]] + +// ----- + +#map = affine_map<(d0) -> (d0)> +func.func @handle_unused_operands(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> { + %cst_0 = arith.constant 0.000000e+00 : f32 + %cst_1 = arith.constant 1.000000e+00 : f32 + %0:2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} outs(%arg0, %arg1 : tensor<8xf32>, tensor<8xf32>) { + ^bb0(%out: f32, %out_2: f32): + %1 = linalg.index 0 : index + %2 = arith.index_cast %1 : index to i64 + %3 = arith.sitofp %2 : i64 to f32 + %4 = arith.divf %3, %cst_0 : f32 + linalg.yield %3, %4 : f32, f32 + } -> (tensor<8xf32>, tensor<8xf32>) + linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} ins(%0#1 : tensor<8xf32>) { + ^bb0(%in: f32): + %2 = arith.cmpf one, %in, %cst_1 : f32 + cf.assert %2, "Side effect op" + linalg.yield + } + func.return %arg1 : tensor<8xf32> +} + +// CHECK-LABEL: func @handle_unused_operands +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: outs(%[[EMPTY]] : +// CHECK-NOT: linalg.generic From 5c96ba141f6587d8dd7ac5e3bbd21f25e65348cf Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Tue, 26 Nov 2024 23:26:35 -0800 Subject: [PATCH 2/3] Add MLIRContext to concatAffineMaps Signed-off-by: Ian Wood --- .../mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 2 +- mlir/include/mlir/IR/AffineMap.h | 7 ++----- mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 3 ++- mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp | 6 ++++-- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 12 +++++++++--- mlir/lib/IR/AffineMap.cpp | 9 +++++---- 6 files changed, 23 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index c0eff99c85075..244db23925ab3 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -741,7 +741,7 @@ def LinalgStructuredInterface /*methodBody=*/"", /*defaultImplementation=*/[{ auto maps = $_op.getIndexingMapsArray(); - return concatAffineMaps(maps); + return concatAffineMaps(maps, $_op.getContext()); }] >, InterfaceMethod< diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h index 4cfc538a2fe72..4bc40a7d4091a 100644 --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -595,10 +595,7 @@ AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map); /// potentially empty maps. Assumes each of the underlying map has 0 symbols. /// The resulting map has a number of dims equal to the max of `maps`' dims and /// the concatenated results as its results. -/// -/// This method asserts when `maps` is empty. -/// TODO: this should return an empty map when `maps` is empty but there is no -/// way to get the MLIRContext needed to construct it. +/// Returns an empty map if all input `maps` are empty. /// /// Example: /// When applied to the following list of 3 affine maps, @@ -616,7 +613,7 @@ AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map); /// ```mlir /// (i, j, k) -> (i, k, k, j, i, j) /// ``` -AffineMap concatAffineMaps(ArrayRef maps); +AffineMap concatAffineMaps(ArrayRef maps, MLIRContext *context); /// Returns the map that results from projecting out the dimensions specified in /// `projectedDimensions`. The projected dimensions are set to 0. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index 0cffadf8fb64a..caf9cdb3a3eb4 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -54,7 +54,8 @@ bool linalg::detail::canOpOperandsBeDroppedImpl( // if the op has no loops. return linalgOp.getNumLoops() == 0; } - return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap(); + return inversePermutation(concatAffineMaps( + indexingMaps, linalgOp.getContext())) != AffineMap(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index bacc634f5ee55..bb50347596910 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -392,7 +392,8 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, // 1. Check if any of the iteration dimensions are unit-trip count. They will // end up being unit-trip count if they are used to index into a unit-dim // tensor/memref. - AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps)); + AffineMap invertedMap = + inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext())); if (!invertedMap) { return rewriter.notifyMatchFailure(genericOp, "invalid indexing maps for operation"); @@ -486,7 +487,8 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, // Abort if the indexing maps of the result operation are not invertible // (i.e. not legal) or if no dimension was reduced. if (newIndexingMaps == indexingMaps || - !inversePermutation(concatAffineMaps(newIndexingMaps))) + !inversePermutation( + concatAffineMaps(newIndexingMaps, rewriter.getContext()))) return failure(); Location loc = genericOp.getLoc(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 6ddea53dfb997..5e1d4b76edce8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -88,13 +88,18 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs( indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand)); } } + if (indexingMaps.empty()) { + // If there are no indexing maps, the operand can only be dropped + // if neither ops op have loops. + return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0; + } // The concatanation of the remained indexing maps must be invertible, so // the bounds of the op can be still computed after dropping the selected // operand. inversePermutation returns an empty AffineMap in case the // concatanated indexing maps are not invertible. - return !indexingMaps.empty() && - inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap(); + return inversePermutation(concatAffineMaps( + indexingMaps, producer.getContext())) != AffineMap(); } /// Returns a set of indices of the producer's results which would @@ -1996,7 +2001,8 @@ class FoldScalarOrSplatConstant : public OpRewritePattern { genericOp.getMatchingIndexingMap(&outputOperand)); // Check if the operation shapes to loops map is computable. - if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { + if (!inversePermutation( + concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) { return rewriter.notifyMatchFailure( genericOp, "fused op loop bound computation failed"); } diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index 8056058511ebe..719a81ec057f9 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -833,8 +833,10 @@ AffineMap mlir::inverseAndBroadcastProjectedPermutation(AffineMap map) { return AffineMap::get(map.getNumResults(), /*symbolCount=*/0, exprs, context); } -AffineMap mlir::concatAffineMaps(ArrayRef maps) { - assert(maps.size()); +AffineMap mlir::concatAffineMaps(ArrayRef maps, + MLIRContext *context) { + if (maps.empty()) + return AffineMap::get(context); unsigned numResults = 0, numDims = 0, numSymbols = 0; for (auto m : maps) numResults += m.getNumResults(); @@ -847,8 +849,7 @@ AffineMap mlir::concatAffineMaps(ArrayRef maps) { numSymbols += m.getNumSymbols(); numDims = std::max(m.getNumDims(), numDims); } - return AffineMap::get(numDims, numSymbols, results, - maps.front().getContext()); + return AffineMap::get(numDims, numSymbols, results, context); } /// Common implementation to project out dimensions or symbols from an affine From 362f726f56782b36ad1d4af5b18114c87015d2cf Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Tue, 26 Nov 2024 12:34:00 -0800 Subject: [PATCH 3/3] Update mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp Co-authored-by: Quinn Dawkins --- mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 5e1d4b76edce8..c44194a123158 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -90,7 +90,7 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs( } if (indexingMaps.empty()) { // If there are no indexing maps, the operand can only be dropped - // if neither ops op have loops. + // if neither op has loops. return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0; }