diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 7f9ba1bdd2692..bf66ed01ef158 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -637,6 +637,7 @@ struct DropPadUnitDims : public OpRewritePattern { } ArrayRef sourceShape = padOp.getSourceType().getShape(); + ArrayRef resultShape = padOp.getResultType().getShape(); int64_t padRank = sourceShape.size(); auto isStaticZero = [](OpFoldResult f) { @@ -647,16 +648,18 @@ struct DropPadUnitDims : public OpRewritePattern { allowedUnitDims.end()); llvm::SmallDenseSet unitDims; SmallVector newShape; + SmallVector newResultShape; SmallVector newLowPad; SmallVector newHighPad; - for (const auto [dim, size, low, high] : - zip_equal(llvm::seq(static_cast(0), padRank), sourceShape, - padOp.getMixedLowPad(), padOp.getMixedHighPad())) { + for (const auto [dim, size, outSize, low, high] : zip_equal( + llvm::seq(static_cast(0), padRank), sourceShape, + resultShape, padOp.getMixedLowPad(), padOp.getMixedHighPad())) { if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) && isStaticZero(high)) { unitDims.insert(dim); } else { newShape.push_back(size); + newResultShape.push_back(outSize); newLowPad.push_back(low); newHighPad.push_back(high); } @@ -686,8 +689,10 @@ struct DropPadUnitDims : public OpRewritePattern { collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape, reassociationMap, options.rankReductionStrategy); - auto newPadOp = tensor::PadOp::create( - rewriter, padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad, + auto newResultType = RankedTensorType::get( + newResultShape, padOp.getResultType().getElementType()); + auto newPadOp = rewriter.create( + padOp.getLoc(), /*result=*/newResultType, collapsedSource, newLowPad, newHighPad, paddingVal, padOp.getNofold()); Value dest = padOp.getResult(); diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir index a00c798197e5a..5f42938244db6 100644 --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -1076,6 +1076,44 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te // ----- +func.func @drop_unit_dim_mixed_static_dynamic(%arg0: tensor<1x?xf32>) -> tensor<1x16xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f32 + %padded = tensor.pad %arg0 low[%c0, %c1] high[%c0, %c0] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %cst : f32 + } : tensor<1x?xf32> to tensor<1x16xf32> + return %padded : tensor<1x16xf32> +} +// CHECK-LABEL: func @drop_unit_dim_mixed_static_dynamic +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARGS:.*]] : tensor<1x?xf32> into tensor +// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSE]] low[1] high[0] { +// CHECK: ^bb0(%[[IDX:.*]]: index): +// CHECK: tensor.yield %[[CST]] : f32 +// CHECK: } : tensor to tensor<16xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[PADDED]] {{\[\[}}0, 1]] output_shape [1, 16] : tensor<16xf32> into tensor<1x16xf32> +// CHECK: return %[[EXPANDED]] : tensor<1x16xf32> + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +module { + func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor { + %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32> + %0 = tensor.empty(%arg1) : tensor + %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %2 = arith.mulf %in, %in_0 : f32 + %3 = arith.addf %out, %2 : f32 + linalg.yield %3 : f32 + } -> tensor + return %1 : tensor + } +} // CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (0, d0)> // CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ()> @@ -1097,23 +1135,6 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te // CHECK: return %[[VAL_14]] : tensor // CHECK: } -#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> -module { - func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor { - %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32> - %0 = tensor.empty(%arg1) : tensor - %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %2 = arith.mulf %in, %in_0 : f32 - %3 = arith.addf %out, %2 : f32 - linalg.yield %3 : f32 - } -> tensor - return %1 : tensor - } -} - // ----- func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> {