From f31cf5aae3fb03d2afbdbef5456717e83e8b209e Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Thu, 6 Mar 2025 21:10:21 -0800 Subject: [PATCH 1/2] [mlir][linalg] Allow fusing reshapes with parallel operands Signed-off-by: Ian Wood --- mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index a45b5c43f5d33..337fd8f3a0ac1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -566,7 +566,6 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, // - All the indexing maps for operands and results are projected // permutations. // - The fused tensor is not a scalar. - // - All the loops for the reshaped operand are parallel loops. SmallVector iteratorTypes = linalgOp.getIteratorTypesArray(); AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand); @@ -577,11 +576,7 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, .getValue() .isProjectedPermutation(); }) && - operandMap.getNumResults() > 0 && - llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) { - return isParallelIterator( - iteratorTypes[cast(expr).getPosition()]); - }); + operandMap.getNumResults() > 0; } namespace { From 0201daa3e0773993c3e5df21209971f2dc95b28e Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Fri, 7 Mar 2025 06:12:25 -0800 Subject: [PATCH 2/2] Add test to check reduction reshape fusion Signed-off-by: Ian Wood --- mlir/test/Dialect/Linalg/reshape_fusion.mlir | 25 ++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir index 4da9c0851ac70..c8720ebd98c09 100644 --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -482,6 +482,31 @@ func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor, // ----- +func.func @fuse_collapse_reduction(%arg0: tensor<10x10x20xf32>) -> tensor<100xf32> { + %c0 = arith.constant 0 : index + %c_0 = arith.constant 0.0 : f32 + %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<10x10x20xf32> into tensor<100x20xf32> + %2 = tensor.empty() : tensor<100xf32> + %3 = linalg.fill ins(%c_0 : f32) outs(%2 : tensor<100xf32>) -> tensor<100xf32> + %4 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%0 : tensor<100x20xf32>) outs(%3 : tensor<100xf32>) { + ^bb0(%arg1 : f32, %arg2: f32): + %4 = arith.addf %arg1, %arg2 : f32 + linalg.yield %4 : f32 + } -> tensor<100xf32> + return %4 : tensor<100xf32> +} + +// CHECK: func @fuse_collapse_reduction +// CHECK-SAME: %[[ARG0:.+]]: tensor<10x10x20xf32> +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : tensor<10x10x20xf32>) +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]] +// CHECK: return %[[COLLAPSE]] +// ----- + func.func @no_fuse_dynamic_dims(%arg0: tensor) -> tensor { %c0 = arith.constant 0 : index %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor into tensor