Skip to content

Commit 4885aad

Browse files
committed
Unconditionally fold pack(unpack) for push down unpack pass
1 parent 19c26c0 commit 4885aad

File tree

2 files changed

+29
-56
lines changed

2 files changed

+29
-56
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -298,55 +298,37 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
298298
return std::make_tuple(packedOperand, indexingMap);
299299
}
300300

301-
static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
302-
int numDpsOuts = genericOp.getNumDpsInits();
303-
Block *block = genericOp.getBody();
304-
int numBlockArgs = block->getNumArguments();
305-
int initArgStartIndex = numBlockArgs - numDpsOuts;
306-
for (int i = 0; i < numDpsOuts; ++i) {
307-
int matchingInitArgIndex = initArgStartIndex + i;
308-
return block->getArgument(matchingInitArgIndex).use_empty();
309-
}
310-
return true;
311-
}
312-
313-
/// Pack a genericOp and return it.
301+
/// This function is a helper subroutine to pack a genericOp and return it. It
302+
/// will create a new generic op with the packed operand and the packed output
303+
/// according to packInfo when we attempt to push down unpack or bubble up pack
304+
/// around it. Implicitly this will only work when a packInfo can be obtained.
305+
/// This make sure that we are only using this function on parallel permuted
306+
/// dimensions.
314307
static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
315308
Value dest, AffineMap packedOutIndexingMap,
316309
const PackInfo &packInfo,
317-
bool canUnpackPackFold) {
310+
bool isFoldableUnpackPack) {
318311
Location loc = genericOp.getLoc();
319312
SmallVector<Value> inputOperands;
320313
SmallVector<Value> inputOperandsFromUnpackedSource;
321314
SmallVector<AffineMap> indexingMaps;
322-
323315
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
324316
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
325317
rewriter, loc, packInfo, genericOp, inputOperand);
326-
327318
if (auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>()) {
328319
inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
329320
} else {
330321
inputOperandsFromUnpackedSource.push_back(packedOperand);
331322
}
332-
333323
inputOperands.push_back(packedOperand);
334324
indexingMaps.push_back(packedIndexingMap);
335325
}
336326

337-
// Note: Whether or not the unpack pack sequence can fold also depends on
338-
// the caller of this routine.
339-
// 1) In push down unpack op pattern, this is true because the pack op is
340-
// generated and we can guarantee they are compatible.
341-
// 2) In bubble up pack op pattern, this is not true because the unpack op
342-
// can be from an arbitrary domain so we need to keep both.
343-
canUnpackPackFold = canUnpackPackFold && isGenericOutsNotUsed(genericOp) &&
344-
!hasGatherSemantics(genericOp);
345327
// If The pack and unpack op can be folded:
346328
// 1) use unpack op source op for operand to fold unpack -> pack sequence.
347329
// 2) init tensor of the generic op can be replaced by the destination of the
348330
// pack op.
349-
if (canUnpackPackFold) {
331+
if (isFoldableUnpackPack) {
350332
inputOperands = inputOperandsFromUnpackedSource;
351333
if (auto destPack = dest.getDefiningOp<linalg::PackOp>())
352334
dest = destPack.getDest();
@@ -487,8 +469,10 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
487469
.getDefiningOp<tensor::EmptyOp>()) {
488470
dest = packOpDest;
489471
}
472+
// Here pack(unpack) isn't naively foldable because the unpack op can be from
473+
// an arbitrary domain so we need to keep both.
490474
return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
491-
*packInfo, /*canUnpackPackFold=*/false);
475+
*packInfo, /*isFoldableUnpackPack=*/false);
492476
}
493477

494478
/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
@@ -1125,9 +1109,12 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
11251109
}
11261110

11271111
// Pack the genericOp.
1112+
// pack(unpack) is foldable in this case. This is because in pushing down the
1113+
// unpack, by default we will populate an additional pack op after the unpack.
1114+
// This guarantees them to be foldable.
11281115
GenericOp newGenericOp =
11291116
packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
1130-
/*canUnpackPackFold=*/true);
1117+
/*isFoldableUnpackPack=*/true);
11311118
Value newResult =
11321119
newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
11331120

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -455,13 +455,10 @@ func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56
455455
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
456456
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
457457
// CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]]
458-
// CHECK: %[[ARG0_EMPTY_PACK:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
459-
// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
460-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
461-
// CHECK-SAME: into %[[ARG0_EMPTY_PACK]]
458+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
462459
// CHECK: %[[RES:.+]] = linalg.generic
463460
// CHECK-SAME: indexing_maps = [#[[$MAP]]]
464-
// CHECK-SAME: outs(%[[PACKED_ARG0]]
461+
// CHECK-SAME: outs(%[[EMPTY]]
465462
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
466463
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
467464
// CHECK-SAME: into %[[UNPACKED_ARG0]]
@@ -485,22 +482,11 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
485482
// CHECK-LABEL: func.func @unpack_on_input
486483
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
487484
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
488-
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
489-
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
490-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
491-
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
492-
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
493-
// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
494-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
495-
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
496-
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
497-
// CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
498-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
499-
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
485+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
500486
// CHECK: %[[RES:.+]] = linalg.generic
501487
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
502-
// CHECK-SAME: ins(%[[ARG0_PACK]]
503-
// CHECK-SAME: outs(%[[ARG1_PACK]]
488+
// CHECK-SAME: ins(%[[ARG0]]
489+
// CHECK-SAME: outs(%[[EMPTY]]
504490
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
505491
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
506492
// CHECK-SAME: into %[[ARG1]]
@@ -1407,19 +1393,21 @@ func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %de
14071393
} -> tensor<?x64xbf16>
14081394
return %0 : tensor<?x64xbf16>
14091395
}
1410-
14111396
// CHECK-LABEL: func.func @push_unpack_in_padded_domain_foldable
14121397
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1398+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1399+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
14131400
// CHECK: %[[EMPTY:.+]] = tensor.empty
14141401
// CHECK: %[[GENERIC:.+]] = linalg.generic
14151402
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
14161403
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
14171404
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
1405+
// CHECK-SAME: into %[[ARG2]]
14181406
// CHECK: return %[[UNPACK]] : tensor<?x64xbf16>
14191407

14201408
// -----
14211409

1422-
func.func @push_unpack_in_padded_domain_not_foldable(%arg0: tensor<8x8x4x8xf32>, %arg1: tensor<?x64xf32>) -> tensor<?x64xf32> {
1410+
func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %arg1: tensor<?x64xf32>) -> tensor<?x64xf32> {
14231411
%unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %arg1 : tensor<8x8x4x8xf32> -> tensor<?x64xf32>
14241412
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<?x64xf32>) outs(%arg1 : tensor<?x64xf32>) {
14251413
^bb0(%in: f32, %out: f32):
@@ -1428,15 +1416,13 @@ func.func @push_unpack_in_padded_domain_not_foldable(%arg0: tensor<8x8x4x8xf32>,
14281416
} -> tensor<?x64xf32>
14291417
return %0 : tensor<?x64xf32>
14301418
}
1431-
1432-
// CHECK-LABEL: func.func @push_unpack_in_padded_domain_not_foldable
1419+
// CHECK-LABEL: func.func @push_unpack_in_padded_domain_out_used
14331420
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
14341421
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1435-
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
1436-
// CHECK: %[[PACK:.+]] = linalg.pack %[[ARG1]]
1437-
// CHECK: %[[UNPACK1:.+]] = linalg.pack %[[UNPACK]]
1422+
// CHECK: %[[EMPTY:.+]] = tensor.empty
14381423
// CHECK: %[[GENERIC:.+]] = linalg.generic
1439-
// CHECK-SAME: ins(%[[UNPACK1]] : tensor<?x8x4x8xf32>)
1440-
// CHECK-SAME: outs(%[[PACK]] : tensor<?x8x4x8xf32>)
1424+
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
1425+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xf32>)
14411426
// CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
1427+
// CHECK-SAME: into %[[ARG1]]
14421428
// CHECK: return %[[UNPACK2]] : tensor<?x64xf32>

0 commit comments

Comments
 (0)