From 67da59818a7bcce97898b3f6aadff11262c65b95 Mon Sep 17 00:00:00 2001 From: jerryyin Date: Fri, 2 May 2025 19:46:53 +0000 Subject: [PATCH 1/4] Folding unpack and pack sequence --- .../Transforms/DataLayoutPropagation.cpp | 36 ++++++++++ .../Linalg/data-layout-propagation.mlir | 68 +++++++++---------- 2 files changed, 69 insertions(+), 35 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index f2a64f5bf38a3..893f9314396c8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -298,20 +298,56 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, return std::make_tuple(packedOperand, indexingMap); } +static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) { + int numDpsOuts = genericOp.getNumDpsInits(); + for (int i = 0; i < numDpsOuts; ++i) { + Block *block = genericOp.getBody(); + int numBlockArgs = block->getNumArguments(); + int matchingInitArgIndex = numBlockArgs - numDpsOuts + i; + return block->getArgument(matchingInitArgIndex).use_empty(); + } + return true; +} + /// Pack a genericOp and return it. static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest, AffineMap packedOutIndexingMap, const PackInfo &packInfo) { Location loc = genericOp.getLoc(); SmallVector inputOperands; + SmallVector inputOperandsFromUnpackedSource; SmallVector indexingMaps; + + // Note: canUnpackPackFold needs to also guarantee the generic body + // doesn't have gather semantics. Since such scenarios has been + // rejected by both BubbleUpPackOpThroughGenericOp and + // PushDownUnPackOpThroughGenericOp, we can safely assume + // canUnpackPackFold is as long as init is not used. + bool canUnpackPackFold = isGenericOutsNotUsed(genericOp); for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( rewriter, loc, packInfo, genericOp, inputOperand); + + if (auto unpackOp = inputOperand->get().getDefiningOp()) { + inputOperandsFromUnpackedSource.push_back(unpackOp.getSource()); + } else { + inputOperandsFromUnpackedSource.push_back(packedOperand); + } + inputOperands.push_back(packedOperand); indexingMaps.push_back(packedIndexingMap); } + // If The pack and unpack op can be folded: + // 1) use unpack op source op for operand to fold unpack -> pack sequence + // 2) init tensor of the generic op can be replaced by the new tensor.empty + // as the generic out. + if (canUnpackPackFold) { + inputOperands = inputOperandsFromUnpackedSource; + if (auto destPack = dest.getDefiningOp()) + dest = destPack.getDest(); + } + int64_t numInnerLoops = packInfo.getNumTiledLoops(); SmallVector iterTypes = genericOp.getIteratorTypesArray(); diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index 19d4524a2ec06..fde1c40fb3c12 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -524,22 +524,11 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t // CHECK-LABEL: func.func @unpack_element_type_change // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32> -// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] -// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16> -// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG1_PACK_EMPTY]] -// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> -// CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG0_PACK_EMPTY]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16> // CHECK: %[[RES:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]] -// CHECK-SAME: ins(%[[ARG0_PACK]] -// CHECK-SAME: outs(%[[ARG1_PACK]] +// CHECK-SAME: ins(%[[ARG0]] +// CHECK-SAME: outs(%[[EMPTY]] // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]] // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[ARG1]] @@ -564,19 +553,11 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5 // CHECK-LABEL: func.func @forward_tensor_empty // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32> -// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32> -// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] -// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> -// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> -// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG0_PACK_EMPTY]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> // CHECK: %[[RES:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]] -// CHECK-SAME: ins(%[[PACKED_ARG0]] -// CHECK-SAME: outs(%[[DEST]] +// CHECK-SAME: ins(%[[ARG0]] +// CHECK-SAME: outs(%[[EMPTY]] // CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]] // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[FINAL_RES]] @@ -810,12 +791,9 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x } // CHECK-LABEL: func.func @unpack_empty_inner_dims -// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] -// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x64x56x56xf32>) // CHECK: %[[RES:.+]] = linalg.generic -// CHECK-SAME: ins(%[[PACKED_ARG0]] +// CHECK-SAME: ins(%[[ARG0]] // CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]] // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] @@ -943,14 +921,10 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] // CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32> // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32> -// CHECK: %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32> -// CHECK: %[[PACK_ARG0:.+]] = linalg.pack -// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16] -// CHECK-SAME: into %[[PACK_EMPTY]] // CHECK: %[[POOL:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] -// CHECK-SAME: ins(%[[PACK_ARG0]], %[[ARG1]] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] // CHECK-SAME: outs(%[[INIT]] // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[POOL]] // CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16] @@ -1421,3 +1395,27 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32> // CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32> + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @fold_unpack_pack_after_bubble_up(%arg0: tensor<8x8x4x8xf32>) -> tensor<8x8x4x8xf32> { + %empty = tensor.empty() : tensor<32x64xf32> + %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty : tensor<8x8x4x8xf32> -> tensor<32x64xf32> + %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<32x64xf32>) outs(%empty : tensor<32x64xf32>) { + ^bb0(%in: f32, %out: f32): + %2 = arith.addf %in, %in : f32 + linalg.yield %2 : f32 + } -> tensor<32x64xf32> + %empty1 = tensor.empty() : tensor<8x8x4x8xf32> + %pack = linalg.pack %1 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty1 : tensor<32x64xf32> -> tensor<8x8x4x8xf32> + return %pack : tensor<8x8x4x8xf32> +} + +// CHECK-LABEL: func.func @fold_unpack_pack_after_bubble_up +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32> +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x8x4x8xf32>) +// CHECK: return %[[GENERIC]] : tensor<8x8x4x8xf32> From 19c26c02190877ada859715fbaa266c73e33edcd Mon Sep 17 00:00:00 2001 From: jerryyin Date: Mon, 5 May 2025 18:14:50 +0000 Subject: [PATCH 2/4] Addressing review feedbacks --- .../Transforms/DataLayoutPropagation.cpp | 35 +++++++----- .../Linalg/data-layout-propagation.mlir | 55 +++++++++++++------ 2 files changed, 58 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 893f9314396c8..19b590b0d10e6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -300,10 +300,11 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) { int numDpsOuts = genericOp.getNumDpsInits(); + Block *block = genericOp.getBody(); + int numBlockArgs = block->getNumArguments(); + int initArgStartIndex = numBlockArgs - numDpsOuts; for (int i = 0; i < numDpsOuts; ++i) { - Block *block = genericOp.getBody(); - int numBlockArgs = block->getNumArguments(); - int matchingInitArgIndex = numBlockArgs - numDpsOuts + i; + int matchingInitArgIndex = initArgStartIndex + i; return block->getArgument(matchingInitArgIndex).use_empty(); } return true; @@ -312,18 +313,13 @@ static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) { /// Pack a genericOp and return it. static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest, AffineMap packedOutIndexingMap, - const PackInfo &packInfo) { + const PackInfo &packInfo, + bool canUnpackPackFold) { Location loc = genericOp.getLoc(); SmallVector inputOperands; SmallVector inputOperandsFromUnpackedSource; SmallVector indexingMaps; - // Note: canUnpackPackFold needs to also guarantee the generic body - // doesn't have gather semantics. Since such scenarios has been - // rejected by both BubbleUpPackOpThroughGenericOp and - // PushDownUnPackOpThroughGenericOp, we can safely assume - // canUnpackPackFold is as long as init is not used. - bool canUnpackPackFold = isGenericOutsNotUsed(genericOp); for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( rewriter, loc, packInfo, genericOp, inputOperand); @@ -338,10 +334,18 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, indexingMaps.push_back(packedIndexingMap); } + // Note: Whether or not the unpack pack sequence can fold also depends on + // the caller of this routine. + // 1) In push down unpack op pattern, this is true because the pack op is + // generated and we can guarantee they are compatible. + // 2) In bubble up pack op pattern, this is not true because the unpack op + // can be from an arbitrary domain so we need to keep both. + canUnpackPackFold = canUnpackPackFold && isGenericOutsNotUsed(genericOp) && + !hasGatherSemantics(genericOp); // If The pack and unpack op can be folded: - // 1) use unpack op source op for operand to fold unpack -> pack sequence - // 2) init tensor of the generic op can be replaced by the new tensor.empty - // as the generic out. + // 1) use unpack op source op for operand to fold unpack -> pack sequence. + // 2) init tensor of the generic op can be replaced by the destination of the + // pack op. if (canUnpackPackFold) { inputOperands = inputOperandsFromUnpackedSource; if (auto destPack = dest.getDefiningOp()) @@ -484,7 +488,7 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, dest = packOpDest; } return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, - *packInfo); + *packInfo, /*canUnpackPackFold=*/false); } /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. @@ -1122,7 +1126,8 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, // Pack the genericOp. GenericOp newGenericOp = - packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo); + packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo, + /*canUnpackPackFold=*/true); Value newResult = newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index fde1c40fb3c12..a7749e7b2034f 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -1398,24 +1398,45 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x // ----- -#map = affine_map<(d0, d1) -> (d0, d1)> -func.func @fold_unpack_pack_after_bubble_up(%arg0: tensor<8x8x4x8xf32>) -> tensor<8x8x4x8xf32> { - %empty = tensor.empty() : tensor<32x64xf32> - %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty : tensor<8x8x4x8xf32> -> tensor<32x64xf32> - %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<32x64xf32>) outs(%empty : tensor<32x64xf32>) { - ^bb0(%in: f32, %out: f32): - %2 = arith.addf %in, %in : f32 - linalg.yield %2 : f32 - } -> tensor<32x64xf32> - %empty1 = tensor.empty() : tensor<8x8x4x8xf32> - %pack = linalg.pack %1 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty1 : tensor<32x64xf32> -> tensor<8x8x4x8xf32> - return %pack : tensor<8x8x4x8xf32> +func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %dest: tensor, %arg1: tensor) -> tensor { + %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %dest : tensor<8x8x4x8xf32> -> tensor + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor) outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + linalg.yield %1 : bf16 + } -> tensor + return %0 : tensor } -// CHECK-LABEL: func.func @fold_unpack_pack_after_bubble_up +// CHECK-LABEL: func.func @push_unpack_in_padded_domain_foldable // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32> -// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK: %[[EMPTY:.+]] = tensor.empty +// CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>) -// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x8x4x8xf32>) -// CHECK: return %[[GENERIC]] : tensor<8x8x4x8xf32> +// CHECK-SAME: outs(%[[EMPTY]] : tensor) +// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]] +// CHECK: return %[[UNPACK]] : tensor + +// ----- + +func.func @push_unpack_in_padded_domain_not_foldable(%arg0: tensor<8x8x4x8xf32>, %arg1: tensor) -> tensor { + %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %arg1 : tensor<8x8x4x8xf32> -> tensor + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor) outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @push_unpack_in_padded_domain_not_foldable +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] +// CHECK: %[[PACK:.+]] = linalg.pack %[[ARG1]] +// CHECK: %[[UNPACK1:.+]] = linalg.pack %[[UNPACK]] +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[UNPACK1]] : tensor) +// CHECK-SAME: outs(%[[PACK]] : tensor) +// CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]] +// CHECK: return %[[UNPACK2]] : tensor From 4885aad9745e74cb652f50ccf90f5a69596c7b39 Mon Sep 17 00:00:00 2001 From: jerryyin Date: Tue, 6 May 2025 20:34:39 +0000 Subject: [PATCH 3/4] Unconditionally fold pack(unpack) for push down unpack pass --- .../Transforms/DataLayoutPropagation.cpp | 43 +++++++------------ .../Linalg/data-layout-propagation.mlir | 42 ++++++------------ 2 files changed, 29 insertions(+), 56 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 19b590b0d10e6..7b0abffb41f6d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -298,55 +298,37 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, return std::make_tuple(packedOperand, indexingMap); } -static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) { - int numDpsOuts = genericOp.getNumDpsInits(); - Block *block = genericOp.getBody(); - int numBlockArgs = block->getNumArguments(); - int initArgStartIndex = numBlockArgs - numDpsOuts; - for (int i = 0; i < numDpsOuts; ++i) { - int matchingInitArgIndex = initArgStartIndex + i; - return block->getArgument(matchingInitArgIndex).use_empty(); - } - return true; -} - -/// Pack a genericOp and return it. +/// This function is a helper subroutine to pack a genericOp and return it. It +/// will create a new generic op with the packed operand and the packed output +/// according to packInfo when we attempt to push down unpack or bubble up pack +/// around it. Implicitly this will only work when a packInfo can be obtained. +/// This make sure that we are only using this function on parallel permuted +/// dimensions. static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest, AffineMap packedOutIndexingMap, const PackInfo &packInfo, - bool canUnpackPackFold) { + bool isFoldableUnpackPack) { Location loc = genericOp.getLoc(); SmallVector inputOperands; SmallVector inputOperandsFromUnpackedSource; SmallVector indexingMaps; - for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( rewriter, loc, packInfo, genericOp, inputOperand); - if (auto unpackOp = inputOperand->get().getDefiningOp()) { inputOperandsFromUnpackedSource.push_back(unpackOp.getSource()); } else { inputOperandsFromUnpackedSource.push_back(packedOperand); } - inputOperands.push_back(packedOperand); indexingMaps.push_back(packedIndexingMap); } - // Note: Whether or not the unpack pack sequence can fold also depends on - // the caller of this routine. - // 1) In push down unpack op pattern, this is true because the pack op is - // generated and we can guarantee they are compatible. - // 2) In bubble up pack op pattern, this is not true because the unpack op - // can be from an arbitrary domain so we need to keep both. - canUnpackPackFold = canUnpackPackFold && isGenericOutsNotUsed(genericOp) && - !hasGatherSemantics(genericOp); // If The pack and unpack op can be folded: // 1) use unpack op source op for operand to fold unpack -> pack sequence. // 2) init tensor of the generic op can be replaced by the destination of the // pack op. - if (canUnpackPackFold) { + if (isFoldableUnpackPack) { inputOperands = inputOperandsFromUnpackedSource; if (auto destPack = dest.getDefiningOp()) dest = destPack.getDest(); @@ -487,8 +469,10 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, .getDefiningOp()) { dest = packOpDest; } + // Here pack(unpack) isn't naively foldable because the unpack op can be from + // an arbitrary domain so we need to keep both. return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, - *packInfo, /*canUnpackPackFold=*/false); + *packInfo, /*isFoldableUnpackPack=*/false); } /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. @@ -1125,9 +1109,12 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, } // Pack the genericOp. + // pack(unpack) is foldable in this case. This is because in pushing down the + // unpack, by default we will populate an additional pack op after the unpack. + // This guarantees them to be foldable. GenericOp newGenericOp = packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo, - /*canUnpackPackFold=*/true); + /*isFoldableUnpackPack=*/true); Value newResult = newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index a7749e7b2034f..63f068d3f8681 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -455,13 +455,10 @@ func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56 // CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]] // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]] -// CHECK: %[[ARG0_EMPTY_PACK:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> -// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG0_EMPTY_PACK]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> // CHECK: %[[RES:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP]]] -// CHECK-SAME: outs(%[[PACKED_ARG0]] +// CHECK-SAME: outs(%[[EMPTY]] // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]] // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[UNPACKED_ARG0]] @@ -485,22 +482,11 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56 // CHECK-LABEL: func.func @unpack_on_input // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32> -// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] -// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> -// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG1_PACK_EMPTY]] -// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> -// CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG0_PACK_EMPTY]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> // CHECK: %[[RES:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]] -// CHECK-SAME: ins(%[[ARG0_PACK]] -// CHECK-SAME: outs(%[[ARG1_PACK]] +// CHECK-SAME: ins(%[[ARG0]] +// CHECK-SAME: outs(%[[EMPTY]] // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]] // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[ARG1]] @@ -1407,19 +1393,21 @@ func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %de } -> tensor return %0 : tensor } - // CHECK-LABEL: func.func @push_unpack_in_padded_domain_foldable // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] // CHECK: %[[EMPTY:.+]] = tensor.empty // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>) // CHECK-SAME: outs(%[[EMPTY]] : tensor) // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]] +// CHECK-SAME: into %[[ARG2]] // CHECK: return %[[UNPACK]] : tensor // ----- -func.func @push_unpack_in_padded_domain_not_foldable(%arg0: tensor<8x8x4x8xf32>, %arg1: tensor) -> tensor { +func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %arg1: tensor) -> tensor { %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %arg1 : tensor<8x8x4x8xf32> -> tensor %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor) outs(%arg1 : tensor) { ^bb0(%in: f32, %out: f32): @@ -1428,15 +1416,13 @@ func.func @push_unpack_in_padded_domain_not_foldable(%arg0: tensor<8x8x4x8xf32>, } -> tensor return %0 : tensor } - -// CHECK-LABEL: func.func @push_unpack_in_padded_domain_not_foldable +// CHECK-LABEL: func.func @push_unpack_in_padded_domain_out_used // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] -// CHECK: %[[PACK:.+]] = linalg.pack %[[ARG1]] -// CHECK: %[[UNPACK1:.+]] = linalg.pack %[[UNPACK]] +// CHECK: %[[EMPTY:.+]] = tensor.empty // CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK-SAME: ins(%[[UNPACK1]] : tensor) -// CHECK-SAME: outs(%[[PACK]] : tensor) +// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor) // CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]] +// CHECK-SAME: into %[[ARG1]] // CHECK: return %[[UNPACK2]] : tensor From a9c1dccc3f73793bdd9e1f51ab3a6e15403a8338 Mon Sep 17 00:00:00 2001 From: jerryyin Date: Wed, 7 May 2025 13:24:14 +0000 Subject: [PATCH 4/4] Fixing comments --- mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 7b0abffb41f6d..26904f1f40d12 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -324,7 +324,7 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, indexingMaps.push_back(packedIndexingMap); } - // If The pack and unpack op can be folded: + // If the pack and unpack op can be folded: // 1) use unpack op source op for operand to fold unpack -> pack sequence. // 2) init tensor of the generic op can be replaced by the destination of the // pack op. @@ -469,7 +469,7 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, .getDefiningOp()) { dest = packOpDest; } - // Here pack(unpack) isn't naively foldable because the unpack op can be from + // pack(unpack) isn't naively foldable because the unpack op can be from // an arbitrary domain so we need to keep both. return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo, /*isFoldableUnpackPack=*/false);