diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index c0eff99c85075..244db23925ab3 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -741,7 +741,7 @@ def LinalgStructuredInterface /*methodBody=*/"", /*defaultImplementation=*/[{ auto maps = $_op.getIndexingMapsArray(); - return concatAffineMaps(maps); + return concatAffineMaps(maps, $_op.getContext()); }] >, InterfaceMethod< diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h index e30950bbf292d..4bc40a7d4091a 100644 --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -613,7 +613,7 @@ AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map); /// ```mlir /// (i, j, k) -> (i, k, k, j, i, j) /// ``` -AffineMap concatAffineMaps(ArrayRef maps); +AffineMap concatAffineMaps(ArrayRef maps, MLIRContext *context); /// Returns the map that results from projecting out the dimensions specified in /// `projectedDimensions`. The projected dimensions are set to 0. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index 0cffadf8fb64a..caf9cdb3a3eb4 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -54,7 +54,8 @@ bool linalg::detail::canOpOperandsBeDroppedImpl( // if the op has no loops. return linalgOp.getNumLoops() == 0; } - return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap(); + return inversePermutation(concatAffineMaps( + indexingMaps, linalgOp.getContext())) != AffineMap(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index bacc634f5ee55..bb50347596910 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -392,7 +392,8 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, // 1. Check if any of the iteration dimensions are unit-trip count. They will // end up being unit-trip count if they are used to index into a unit-dim // tensor/memref. - AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps)); + AffineMap invertedMap = + inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext())); if (!invertedMap) { return rewriter.notifyMatchFailure(genericOp, "invalid indexing maps for operation"); @@ -486,7 +487,8 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, // Abort if the indexing maps of the result operation are not invertible // (i.e. not legal) or if no dimension was reduced. if (newIndexingMaps == indexingMaps || - !inversePermutation(concatAffineMaps(newIndexingMaps))) + !inversePermutation( + concatAffineMaps(newIndexingMaps, rewriter.getContext()))) return failure(); Location loc = genericOp.getLoc(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index a934e47794051..c44194a123158 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -88,12 +88,18 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs( indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand)); } } + if (indexingMaps.empty()) { + // If there are no indexing maps, the operand can only be dropped + // if neither op has loops. + return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0; + } // The concatanation of the remained indexing maps must be invertible, so // the bounds of the op can be still computed after dropping the selected // operand. inversePermutation returns an empty AffineMap in case the // concatanated indexing maps are not invertible. - return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap(); + return inversePermutation(concatAffineMaps( + indexingMaps, producer.getContext())) != AffineMap(); } /// Returns a set of indices of the producer's results which would @@ -1995,7 +2001,8 @@ class FoldScalarOrSplatConstant : public OpRewritePattern { genericOp.getMatchingIndexingMap(&outputOperand)); // Check if the operation shapes to loops map is computable. - if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { + if (!inversePermutation( + concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) { return rewriter.notifyMatchFailure( genericOp, "fused op loop bound computation failed"); } diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index ea3c0723b0775..719a81ec057f9 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -833,7 +833,10 @@ AffineMap mlir::inverseAndBroadcastProjectedPermutation(AffineMap map) { return AffineMap::get(map.getNumResults(), /*symbolCount=*/0, exprs, context); } -AffineMap mlir::concatAffineMaps(ArrayRef maps) { +AffineMap mlir::concatAffineMaps(ArrayRef maps, + MLIRContext *context) { + if (maps.empty()) + return AffineMap::get(context); unsigned numResults = 0, numDims = 0, numSymbols = 0; for (auto m : maps) numResults += m.getNumResults(); @@ -846,8 +849,7 @@ AffineMap mlir::concatAffineMaps(ArrayRef maps) { numSymbols += m.getNumSymbols(); numDims = std::max(m.getNumDims(), numDims); } - return AffineMap::get(numDims, numSymbols, results, - maps.front().getContext()); + return AffineMap::get(numDims, numSymbols, results, context); } /// Common implementation to project out dimensions or symbols from an affine diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir index 8131e4054cc6b..bd9977f1410b9 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir @@ -28,3 +28,34 @@ func.func @drop_unused_producer_result(%arg0 : tensor, // CHECK: %[[FUSED_OP:.+]] = linalg.generic // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : // CHECK: return %[[FUSED_OP]] + +// ----- + +#map = affine_map<(d0) -> (d0)> +func.func @handle_unused_operands(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> { + %cst_0 = arith.constant 0.000000e+00 : f32 + %cst_1 = arith.constant 1.000000e+00 : f32 + %0:2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} outs(%arg0, %arg1 : tensor<8xf32>, tensor<8xf32>) { + ^bb0(%out: f32, %out_2: f32): + %1 = linalg.index 0 : index + %2 = arith.index_cast %1 : index to i64 + %3 = arith.sitofp %2 : i64 to f32 + %4 = arith.divf %3, %cst_0 : f32 + linalg.yield %3, %4 : f32, f32 + } -> (tensor<8xf32>, tensor<8xf32>) + linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} ins(%0#1 : tensor<8xf32>) { + ^bb0(%in: f32): + %2 = arith.cmpf one, %in, %cst_1 : f32 + cf.assert %2, "Side effect op" + linalg.yield + } + func.return %arg1 : tensor<8xf32> +} + +// CHECK-LABEL: func @handle_unused_operands +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: outs(%[[EMPTY]] : +// CHECK-NOT: linalg.generic