diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index f2a64f5bf38a3..26904f1f40d12 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -298,20 +298,42 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, return std::make_tuple(packedOperand, indexingMap); } -/// Pack a genericOp and return it. +/// This function is a helper subroutine to pack a genericOp and return it. It +/// will create a new generic op with the packed operand and the packed output +/// according to packInfo when we attempt to push down unpack or bubble up pack +/// around it. Implicitly this will only work when a packInfo can be obtained. +/// This make sure that we are only using this function on parallel permuted +/// dimensions. static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest, AffineMap packedOutIndexingMap, - const PackInfo &packInfo) { + const PackInfo &packInfo, + bool isFoldableUnpackPack) { Location loc = genericOp.getLoc(); SmallVector inputOperands; + SmallVector inputOperandsFromUnpackedSource; SmallVector indexingMaps; for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( rewriter, loc, packInfo, genericOp, inputOperand); + if (auto unpackOp = inputOperand->get().getDefiningOp()) { + inputOperandsFromUnpackedSource.push_back(unpackOp.getSource()); + } else { + inputOperandsFromUnpackedSource.push_back(packedOperand); + } inputOperands.push_back(packedOperand); indexingMaps.push_back(packedIndexingMap); } + // If the pack and unpack op can be folded: + // 1) use unpack op source op for operand to fold unpack -> pack sequence. + // 2) init tensor of the generic op can be replaced by the destination of the + // pack op. + if (isFoldableUnpackPack) { + inputOperands = inputOperandsFromUnpackedSource; + if (auto destPack = dest.getDefiningOp()) + dest = destPack.getDest(); + } + int64_t numInnerLoops = packInfo.getNumTiledLoops(); SmallVector iterTypes = genericOp.getIteratorTypesArray(); @@ -447,8 +469,10 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, .getDefiningOp()) { dest = packOpDest; } + // pack(unpack) isn't naively foldable because the unpack op can be from + // an arbitrary domain so we need to keep both. return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, - *packInfo); + *packInfo, /*isFoldableUnpackPack=*/false); } /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. @@ -1085,8 +1109,12 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, } // Pack the genericOp. + // pack(unpack) is foldable in this case. This is because in pushing down the + // unpack, by default we will populate an additional pack op after the unpack. + // This guarantees them to be foldable. GenericOp newGenericOp = - packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo); + packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo, + /*isFoldableUnpackPack=*/true); Value newResult = newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index 19d4524a2ec06..63f068d3f8681 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -455,13 +455,10 @@ func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56 // CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]] // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]] -// CHECK: %[[ARG0_EMPTY_PACK:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> -// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG0_EMPTY_PACK]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> // CHECK: %[[RES:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP]]] -// CHECK-SAME: outs(%[[PACKED_ARG0]] +// CHECK-SAME: outs(%[[EMPTY]] // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]] // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[UNPACKED_ARG0]] @@ -485,22 +482,11 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56 // CHECK-LABEL: func.func @unpack_on_input // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32> -// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] -// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> -// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG1_PACK_EMPTY]] -// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> -// CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG0_PACK_EMPTY]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> // CHECK: %[[RES:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]] -// CHECK-SAME: ins(%[[ARG0_PACK]] -// CHECK-SAME: outs(%[[ARG1_PACK]] +// CHECK-SAME: ins(%[[ARG0]] +// CHECK-SAME: outs(%[[EMPTY]] // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]] // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[ARG1]] @@ -524,22 +510,11 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t // CHECK-LABEL: func.func @unpack_element_type_change // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32> -// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] -// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16> -// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG1_PACK_EMPTY]] -// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> -// CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG0_PACK_EMPTY]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16> // CHECK: %[[RES:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]] -// CHECK-SAME: ins(%[[ARG0_PACK]] -// CHECK-SAME: outs(%[[ARG1_PACK]] +// CHECK-SAME: ins(%[[ARG0]] +// CHECK-SAME: outs(%[[EMPTY]] // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]] // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[ARG1]] @@ -564,19 +539,11 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5 // CHECK-LABEL: func.func @forward_tensor_empty // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32> -// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32> -// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]] -// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> -// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> -// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG0_PACK_EMPTY]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32> // CHECK: %[[RES:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]] -// CHECK-SAME: ins(%[[PACKED_ARG0]] -// CHECK-SAME: outs(%[[DEST]] +// CHECK-SAME: ins(%[[ARG0]] +// CHECK-SAME: outs(%[[EMPTY]] // CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]] // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[FINAL_RES]] @@ -810,12 +777,9 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x } // CHECK-LABEL: func.func @unpack_empty_inner_dims -// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] -// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x64x56x56xf32>) // CHECK: %[[RES:.+]] = linalg.generic -// CHECK-SAME: ins(%[[PACKED_ARG0]] +// CHECK-SAME: ins(%[[ARG0]] // CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]] // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] @@ -943,14 +907,10 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] // CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32> // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32> -// CHECK: %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32> -// CHECK: %[[PACK_ARG0:.+]] = linalg.pack -// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16] -// CHECK-SAME: into %[[PACK_EMPTY]] // CHECK: %[[POOL:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] -// CHECK-SAME: ins(%[[PACK_ARG0]], %[[ARG1]] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] // CHECK-SAME: outs(%[[INIT]] // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[POOL]] // CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16] @@ -1421,3 +1381,48 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32> // CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32> + +// ----- + +func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %dest: tensor, %arg1: tensor) -> tensor { + %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %dest : tensor<8x8x4x8xf32> -> tensor + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor) outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + linalg.yield %1 : bf16 + } -> tensor + return %0 : tensor +} +// CHECK-LABEL: func.func @push_unpack_in_padded_domain_foldable +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor) +// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]] +// CHECK-SAME: into %[[ARG2]] +// CHECK: return %[[UNPACK]] : tensor + +// ----- + +func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %arg1: tensor) -> tensor { + %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %arg1 : tensor<8x8x4x8xf32> -> tensor + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor) outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor + return %0 : tensor +} +// CHECK-LABEL: func.func @push_unpack_in_padded_domain_out_used +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor) +// CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]] +// CHECK-SAME: into %[[ARG1]] +// CHECK: return %[[UNPACK2]] : tensor