diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index defac8308b909..659eabd2e9388 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -4337,11 +4337,16 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { dest = rewriter.create(loc, newDestType, packOp.getDest()); } - Value newOp = rewriter.create( - loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(), - packOp.getPaddingValue(), packOp.getOuterDimsPerm()); + auto clonedPackOp = cast(rewriter.clone(*packOp)); + Value res = clonedPackOp.getResult(); + rewriter.startOpModification(clonedPackOp); + clonedPackOp.getSourceMutable().assign(source); + clonedPackOp.getDestMutable().assign(dest); + res.setType(dest.getType()); + rewriter.finalizeOpModification(clonedPackOp); + rewriter.replaceOpWithNewOp( - packOp, packOp.getResult().getType(), newOp); + packOp, packOp.getResult().getType(), clonedPackOp); return success(); } diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 86754c1c37536..0aa2d33ef17ed 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2357,7 +2357,7 @@ func.func @unpack_pack_with_padding_no_canonicalization(%t: tensor<256x512xbf16> %tensor_empty = tensor.empty() : tensor<4x16x64x32xbf16> %tensor_empty1 = tensor.empty() : tensor<224x512xbf16> %packed = tensor.pack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty : tensor<256x512xbf16> -> tensor<4x16x64x32xbf16> - %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty1 : tensor<4x16x64x32xbf16> -> tensor<224x512xbf16> + %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty1 : tensor<4x16x64x32xbf16> -> tensor<224x512xbf16> return %unpacked : tensor<224x512xbf16> } @@ -2707,3 +2707,14 @@ func.func @test_destination_multiple_result(%arg0: tensor<2x2xf32>, %arg1: tenso %0:2 = test.destination_style_op ins(%cast : tensor) outs(%cast_0 : tensor) -> tensor, index return %0#1 : index } + +// ----- + +// CHECK-LABEL: func.func @pack_dont_drop_attributes( +// CHECK: tensor.pack {{.*}} {test_attr} +func.func @pack_dont_drop_attributes(%arg0: tensor, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> { + %c32_i64 = arith.constant 32 : i64 + %cst = arith.constant 0.000000e+00 : f16 + %pack = tensor.pack %arg0 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %arg1 {test_attr} : tensor -> tensor<128x?x100x16x1xf16> + return %pack : tensor<128x?x100x16x1xf16> +}