Skip to content

Commit 19c26c0

Browse files
committed
Addressing review feedbacks
1 parent 67da598 commit 19c26c0

File tree

2 files changed

+58
-32
lines changed

2 files changed

+58
-32
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -300,10 +300,11 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
300300

301301
static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
302302
int numDpsOuts = genericOp.getNumDpsInits();
303+
Block *block = genericOp.getBody();
304+
int numBlockArgs = block->getNumArguments();
305+
int initArgStartIndex = numBlockArgs - numDpsOuts;
303306
for (int i = 0; i < numDpsOuts; ++i) {
304-
Block *block = genericOp.getBody();
305-
int numBlockArgs = block->getNumArguments();
306-
int matchingInitArgIndex = numBlockArgs - numDpsOuts + i;
307+
int matchingInitArgIndex = initArgStartIndex + i;
307308
return block->getArgument(matchingInitArgIndex).use_empty();
308309
}
309310
return true;
@@ -312,18 +313,13 @@ static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
312313
/// Pack a genericOp and return it.
313314
static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
314315
Value dest, AffineMap packedOutIndexingMap,
315-
const PackInfo &packInfo) {
316+
const PackInfo &packInfo,
317+
bool canUnpackPackFold) {
316318
Location loc = genericOp.getLoc();
317319
SmallVector<Value> inputOperands;
318320
SmallVector<Value> inputOperandsFromUnpackedSource;
319321
SmallVector<AffineMap> indexingMaps;
320322

321-
// Note: canUnpackPackFold needs to also guarantee the generic body
322-
// doesn't have gather semantics. Since such scenarios has been
323-
// rejected by both BubbleUpPackOpThroughGenericOp and
324-
// PushDownUnPackOpThroughGenericOp, we can safely assume
325-
// canUnpackPackFold is as long as init is not used.
326-
bool canUnpackPackFold = isGenericOutsNotUsed(genericOp);
327323
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
328324
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
329325
rewriter, loc, packInfo, genericOp, inputOperand);
@@ -338,10 +334,18 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
338334
indexingMaps.push_back(packedIndexingMap);
339335
}
340336

337+
// Note: Whether or not the unpack pack sequence can fold also depends on
338+
// the caller of this routine.
339+
// 1) In push down unpack op pattern, this is true because the pack op is
340+
// generated and we can guarantee they are compatible.
341+
// 2) In bubble up pack op pattern, this is not true because the unpack op
342+
// can be from an arbitrary domain so we need to keep both.
343+
canUnpackPackFold = canUnpackPackFold && isGenericOutsNotUsed(genericOp) &&
344+
!hasGatherSemantics(genericOp);
341345
// If The pack and unpack op can be folded:
342-
// 1) use unpack op source op for operand to fold unpack -> pack sequence
343-
// 2) init tensor of the generic op can be replaced by the new tensor.empty
344-
// as the generic out.
346+
// 1) use unpack op source op for operand to fold unpack -> pack sequence.
347+
// 2) init tensor of the generic op can be replaced by the destination of the
348+
// pack op.
345349
if (canUnpackPackFold) {
346350
inputOperands = inputOperandsFromUnpackedSource;
347351
if (auto destPack = dest.getDefiningOp<linalg::PackOp>())
@@ -484,7 +488,7 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
484488
dest = packOpDest;
485489
}
486490
return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
487-
*packInfo);
491+
*packInfo, /*canUnpackPackFold=*/false);
488492
}
489493

490494
/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
@@ -1122,7 +1126,8 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
11221126

11231127
// Pack the genericOp.
11241128
GenericOp newGenericOp =
1125-
packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
1129+
packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
1130+
/*canUnpackPackFold=*/true);
11261131
Value newResult =
11271132
newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
11281133

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,24 +1398,45 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
13981398

13991399
// -----
14001400

1401-
#map = affine_map<(d0, d1) -> (d0, d1)>
1402-
func.func @fold_unpack_pack_after_bubble_up(%arg0: tensor<8x8x4x8xf32>) -> tensor<8x8x4x8xf32> {
1403-
%empty = tensor.empty() : tensor<32x64xf32>
1404-
%unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty : tensor<8x8x4x8xf32> -> tensor<32x64xf32>
1405-
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<32x64xf32>) outs(%empty : tensor<32x64xf32>) {
1406-
^bb0(%in: f32, %out: f32):
1407-
%2 = arith.addf %in, %in : f32
1408-
linalg.yield %2 : f32
1409-
} -> tensor<32x64xf32>
1410-
%empty1 = tensor.empty() : tensor<8x8x4x8xf32>
1411-
%pack = linalg.pack %1 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty1 : tensor<32x64xf32> -> tensor<8x8x4x8xf32>
1412-
return %pack : tensor<8x8x4x8xf32>
1401+
func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %dest: tensor<?x64xf32>, %arg1: tensor<?x64xbf16>) -> tensor<?x64xbf16> {
1402+
%unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %dest : tensor<8x8x4x8xf32> -> tensor<?x64xf32>
1403+
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<?x64xf32>) outs(%arg1 : tensor<?x64xbf16>) {
1404+
^bb0(%in: f32, %out: bf16):
1405+
%1 = arith.truncf %in : f32 to bf16
1406+
linalg.yield %1 : bf16
1407+
} -> tensor<?x64xbf16>
1408+
return %0 : tensor<?x64xbf16>
14131409
}
14141410

1415-
// CHECK-LABEL: func.func @fold_unpack_pack_after_bubble_up
1411+
// CHECK-LABEL: func.func @push_unpack_in_padded_domain_foldable
14161412
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1417-
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32>
1418-
// CHECK: %[[GENERIC:.+]] = linalg.generic
1413+
// CHECK: %[[EMPTY:.+]] = tensor.empty
1414+
// CHECK: %[[GENERIC:.+]] = linalg.generic
14191415
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
1420-
// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x8x4x8xf32>)
1421-
// CHECK: return %[[GENERIC]] : tensor<8x8x4x8xf32>
1416+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
1417+
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
1418+
// CHECK: return %[[UNPACK]] : tensor<?x64xbf16>
1419+
1420+
// -----
1421+
1422+
func.func @push_unpack_in_padded_domain_not_foldable(%arg0: tensor<8x8x4x8xf32>, %arg1: tensor<?x64xf32>) -> tensor<?x64xf32> {
1423+
%unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %arg1 : tensor<8x8x4x8xf32> -> tensor<?x64xf32>
1424+
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<?x64xf32>) outs(%arg1 : tensor<?x64xf32>) {
1425+
^bb0(%in: f32, %out: f32):
1426+
%1 = arith.addf %in, %out : f32
1427+
linalg.yield %1 : f32
1428+
} -> tensor<?x64xf32>
1429+
return %0 : tensor<?x64xf32>
1430+
}
1431+
1432+
// CHECK-LABEL: func.func @push_unpack_in_padded_domain_not_foldable
1433+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1434+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1435+
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
1436+
// CHECK: %[[PACK:.+]] = linalg.pack %[[ARG1]]
1437+
// CHECK: %[[UNPACK1:.+]] = linalg.pack %[[UNPACK]]
1438+
// CHECK: %[[GENERIC:.+]] = linalg.generic
1439+
// CHECK-SAME: ins(%[[UNPACK1]] : tensor<?x8x4x8xf32>)
1440+
// CHECK-SAME: outs(%[[PACK]] : tensor<?x8x4x8xf32>)
1441+
// CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
1442+
// CHECK: return %[[UNPACK2]] : tensor<?x64xf32>

0 commit comments

Comments
 (0)