Skip to content

Commit 47bd0c7

Browse files
Max191Max Dawkins
authored andcommitted
[mlir] Match before rewrite in BubbleUpPackOpThroughGenericOp
Signed-off-by: Max Dawkins <[email protected]>
1 parent a3e2075 commit 47bd0c7

File tree

2 files changed

+40
-12
lines changed

2 files changed

+40
-12
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,18 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
399399
if (!genericOp->getResult(0).hasOneUse())
400400
return failure();
401401

402+
// TODO: Add an option for allowing padding values. It could introduce
403+
// undefined behavior if we unconditionally propagate pack op through all
404+
// the ops. E.g., if the padding value is zero and there are division ops in
405+
// a generic op. Some values of padding area could be NaN (0/0).
406+
if (packOp.getPaddingValue())
407+
return failure();
408+
409+
OpOperand *opOperand = genericOp.getDpsInitOperand(0);
410+
auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
411+
if (failed(packInfo))
412+
return failure();
413+
402414
// We want to move the pack not the generic.
403415
OpBuilder::InsertionGuard guard(rewriter);
404416
rewriter.setInsertionPoint(genericOp);
@@ -422,18 +434,6 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
422434
return failure();
423435
}
424436

425-
// TODO: Add an option for allowing padding values. It could introduce
426-
// undefined behavior if we unconditionally propagate pack op through all
427-
// the ops. E.g., if the padding value is zero and there are division ops in
428-
// a generic op. Some values of padding area could be NaN (0/0).
429-
if (packOp.getPaddingValue())
430-
return failure();
431-
432-
OpOperand *opOperand = genericOp.getDpsInitOperand(0);
433-
auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
434-
if (failed(packInfo))
435-
return failure();
436-
437437
// Rebuild the indexing map for the corresponding init operand.
438438
auto [packedOutOperand, packedOutIndexingMap] =
439439
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,34 @@ func.func @dynamic_elem_pack(%arg0: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>)
4646

4747
// -----
4848

49+
#map0 = affine_map<(d0, d1) -> (d0, d1)>
50+
func.func @dynamic_elem_pack_padding_value(%arg0: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32>
51+
{
52+
%c0 = arith.constant 0 : index
53+
%c1 = arith.constant 1 : index
54+
%cst = arith.constant 3.000000e+00 : f32
55+
%0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
56+
%1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
57+
%2 = tensor.empty(%0, %1) : tensor<?x?xf32>
58+
%3 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
59+
ins(%arg0 : tensor<?x?xf32>)
60+
outs(%2 : tensor<?x?xf32>) {
61+
^bb0(%arg3: f32, %arg4: f32):
62+
%4 = arith.addf %arg3, %arg3 : f32
63+
linalg.yield %4 : f32
64+
} -> tensor<?x?xf32>
65+
%4 = tensor.pack %3 padding_value(%cst : f32)
66+
inner_dims_pos = [0, 1]
67+
inner_tiles = [8, 2]
68+
into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
69+
return %4 : tensor<?x?x8x2xf32>
70+
}
71+
// CHECK-LABEL: func.func @dynamic_elem_pack_padding_value
72+
// CHECK: %[[GENERIC:.+]] = linalg.generic
73+
// CHECK: tensor.pack %[[GENERIC]]
74+
75+
// -----
76+
4977
#map0 = affine_map<(d0, d1) -> (d0, d1)>
5078
func.func @elem_pack_transpose_inner_dims(%arg0: tensor<128x256xi32>, %dest: tensor<4x16x16x32xi32>) -> tensor<4x16x16x32xi32>{
5179
%init = tensor.empty() : tensor<128x256xi32>

0 commit comments

Comments
 (0)