diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index d79399b6588be..c906f3bdcc632 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -399,6 +399,18 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp, if (!genericOp->getResult(0).hasOneUse()) return failure(); + // TODO: Add an option for allowing padding values. It could introduce + // undefined behavior if we unconditionally propagate pack op through all + // the ops. E.g., if the padding value is zero and there are division ops in + // a generic op. Some values of padding area could be NaN (0/0). + if (packOp.getPaddingValue()) + return failure(); + + OpOperand *opOperand = genericOp.getDpsInitOperand(0); + auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp); + if (failed(packInfo)) + return failure(); + // We want to move the pack not the generic. OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(genericOp); @@ -422,18 +434,6 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp, return failure(); } - // TODO: Add an option for allowing padding values. It could introduce - // undefined behavior if we unconditionally propagate pack op through all - // the ops. E.g., if the padding value is zero and there are division ops in - // a generic op. Some values of padding area could be NaN (0/0). - if (packOp.getPaddingValue()) - return failure(); - - OpOperand *opOperand = genericOp.getDpsInitOperand(0); - auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp); - if (failed(packInfo)) - return failure(); - // Rebuild the indexing map for the corresponding init operand. auto [packedOutOperand, packedOutIndexingMap] = getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index cb8064411bbae..b2b29b2b2fee2 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -46,6 +46,34 @@ func.func @dynamic_elem_pack(%arg0: tensor, %dest: tensor) // ----- +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func.func @dynamic_elem_pack_padding_value(%arg0: tensor, %dest: tensor) -> tensor +{ + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 3.000000e+00 : f32 + %0 = tensor.dim %arg0, %c0 : tensor + %1 = tensor.dim %arg0, %c1 : tensor + %2 = tensor.empty(%0, %1) : tensor + %3 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor) + outs(%2 : tensor) { + ^bb0(%arg3: f32, %arg4: f32): + %4 = arith.addf %arg3, %arg3 : f32 + linalg.yield %4 : f32 + } -> tensor + %4 = tensor.pack %3 padding_value(%cst : f32) + inner_dims_pos = [0, 1] + inner_tiles = [8, 2] + into %dest : tensor -> tensor + return %4 : tensor +} +// CHECK-LABEL: func.func @dynamic_elem_pack_padding_value +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK: tensor.pack %[[GENERIC]] + +// ----- + #map0 = affine_map<(d0, d1) -> (d0, d1)> func.func @elem_pack_transpose_inner_dims(%arg0: tensor<128x256xi32>, %dest: tensor<4x16x16x32xi32>) -> tensor<4x16x16x32xi32>{ %init = tensor.empty() : tensor<128x256xi32>