Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,20 +298,60 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
return std::make_tuple(packedOperand, indexingMap);
}

static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
int numDpsOuts = genericOp.getNumDpsInits();
Block *block = genericOp.getBody();
int numBlockArgs = block->getNumArguments();
int initArgStartIndex = numBlockArgs - numDpsOuts;
for (int i = 0; i < numDpsOuts; ++i) {
int matchingInitArgIndex = initArgStartIndex + i;
return block->getArgument(matchingInitArgIndex).use_empty();
}
return true;
}

/// Pack a genericOp and return it.
static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
Value dest, AffineMap packedOutIndexingMap,
const PackInfo &packInfo) {
const PackInfo &packInfo,
bool canUnpackPackFold) {
Location loc = genericOp.getLoc();
SmallVector<Value> inputOperands;
SmallVector<Value> inputOperandsFromUnpackedSource;
SmallVector<AffineMap> indexingMaps;

for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
rewriter, loc, packInfo, genericOp, inputOperand);

if (auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>()) {
inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
} else {
inputOperandsFromUnpackedSource.push_back(packedOperand);
}

inputOperands.push_back(packedOperand);
indexingMaps.push_back(packedIndexingMap);
}

// Note: Whether or not the unpack pack sequence can fold also depends on
// the caller of this routine.
// 1) In push down unpack op pattern, this is true because the pack op is
// generated and we can guarantee they are compatible.
// 2) In bubble up pack op pattern, this is not true because the unpack op
// can be from an arbitrary domain so we need to keep both.
canUnpackPackFold = canUnpackPackFold && isGenericOutsNotUsed(genericOp) &&
!hasGatherSemantics(genericOp);
// If The pack and unpack op can be folded:
// 1) use unpack op source op for operand to fold unpack -> pack sequence.
// 2) init tensor of the generic op can be replaced by the destination of the
// pack op.
if (canUnpackPackFold) {
inputOperands = inputOperandsFromUnpackedSource;
if (auto destPack = dest.getDefiningOp<linalg::PackOp>())
dest = destPack.getDest();
}

int64_t numInnerLoops = packInfo.getNumTiledLoops();
SmallVector<utils::IteratorType> iterTypes =
genericOp.getIteratorTypesArray();
Expand Down Expand Up @@ -448,7 +488,7 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
dest = packOpDest;
}
return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
*packInfo);
*packInfo, /*canUnpackPackFold=*/false);
}

/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
Expand Down Expand Up @@ -1086,7 +1126,8 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,

// Pack the genericOp.
GenericOp newGenericOp =
packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
/*canUnpackPackFold=*/true);
Value newResult =
newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));

Expand Down
89 changes: 54 additions & 35 deletions mlir/test/Dialect/Linalg/data-layout-propagation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -524,22 +524,11 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
// CHECK-LABEL: func.func @unpack_element_type_change
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
// CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ins(%[[ARG0_PACK]]
// CHECK-SAME: outs(%[[ARG1_PACK]]
// CHECK-SAME: ins(%[[ARG0]]
// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1]]
Expand All @@ -564,19 +553,11 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
// CHECK-LABEL: func.func @forward_tensor_empty
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ins(%[[PACKED_ARG0]]
// CHECK-SAME: outs(%[[DEST]]
// CHECK-SAME: ins(%[[ARG0]]
// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[FINAL_RES]]
Expand Down Expand Up @@ -810,12 +791,9 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
}

// CHECK-LABEL: func.func @unpack_empty_inner_dims
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x64x56x56xf32>)
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: ins(%[[PACKED_ARG0]]
// CHECK-SAME: ins(%[[ARG0]]
// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []

Expand Down Expand Up @@ -943,14 +921,10 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
// CHECK: %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
// CHECK: %[[PACK_ARG0:.+]] = linalg.pack
// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16]
// CHECK-SAME: into %[[PACK_EMPTY]]
// CHECK: %[[POOL:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
// CHECK-SAME: ins(%[[PACK_ARG0]], %[[ARG1]]
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
// CHECK-SAME: outs(%[[INIT]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[POOL]]
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
Expand Down Expand Up @@ -1421,3 +1395,48 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
// CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32>

// -----

func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %dest: tensor<?x64xf32>, %arg1: tensor<?x64xbf16>) -> tensor<?x64xbf16> {
%unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %dest : tensor<8x8x4x8xf32> -> tensor<?x64xf32>
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<?x64xf32>) outs(%arg1 : tensor<?x64xbf16>) {
^bb0(%in: f32, %out: bf16):
%1 = arith.truncf %in : f32 to bf16
linalg.yield %1 : bf16
} -> tensor<?x64xbf16>
return %0 : tensor<?x64xbf16>
}

// CHECK-LABEL: func.func @push_unpack_in_padded_domain_foldable
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[EMPTY:.+]] = tensor.empty
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
// CHECK: return %[[UNPACK]] : tensor<?x64xbf16>

// -----

func.func @push_unpack_in_padded_domain_not_foldable(%arg0: tensor<8x8x4x8xf32>, %arg1: tensor<?x64xf32>) -> tensor<?x64xf32> {
%unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %arg1 : tensor<8x8x4x8xf32> -> tensor<?x64xf32>
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<?x64xf32>) outs(%arg1 : tensor<?x64xf32>) {
^bb0(%in: f32, %out: f32):
%1 = arith.addf %in, %out : f32
linalg.yield %1 : f32
} -> tensor<?x64xf32>
return %0 : tensor<?x64xf32>
}

// CHECK-LABEL: func.func @push_unpack_in_padded_domain_not_foldable
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
// CHECK: %[[PACK:.+]] = linalg.pack %[[ARG1]]
// CHECK: %[[UNPACK1:.+]] = linalg.pack %[[UNPACK]]
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[UNPACK1]] : tensor<?x8x4x8xf32>)
// CHECK-SAME: outs(%[[PACK]] : tensor<?x8x4x8xf32>)
// CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
// CHECK: return %[[UNPACK2]] : tensor<?x64xf32>