diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index e3084530bd11b..675a766ec98b3 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -548,7 +548,8 @@ def LowerPackOp : Op:$target); + let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target, + DefaultValuedAttr:$lowerPadLikeWithInsertSlice); let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op, Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op, Transform_ConcreteOpType<"linalg.transpose">:$transpose_op); @@ -588,7 +589,8 @@ def LowerUnPackOp : Op:$target); + let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target, + DefaultValuedAttr:$lowerUnpadLikeWithExtractSlice); let results = (outs Transform_ConcreteOpType<"tensor.empty">:$empty_op, Transform_ConcreteOpType<"linalg.transpose">:$transpose_op, Transform_ConcreteOpType<"tensor.collapse_shape">:$collapse_shape_op, diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 51967f83fee37..82558de0fbfe6 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1121,7 +1121,8 @@ struct LowerPackResult { /// Rewrite pack as pad + reshape + transpose. FailureOr lowerPack(RewriterBase &rewriter, - tensor::PackOp packOp); + tensor::PackOp packOp, + bool lowerPadLikeWithInsertSlice = true); struct LowerUnPackOpResult { tensor::EmptyOp emptyOp; @@ -1131,8 +1132,9 @@ struct LowerUnPackOpResult { }; /// Rewrite pack as empty + transpose + reshape + extract_slice. -FailureOr lowerUnPack(RewriterBase &rewriter, - tensor::UnPackOp unPackOp); +FailureOr +lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp, + bool lowerUnpadLikeWithExtractSlice = true); /// Struct to hold the result of a `pack` call. struct PackResult { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index ada80deacfdbf..06f58d4943394 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1171,7 +1171,9 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne( transform::ApplyToEachResultList &transformResults, transform::TransformState &state) { rewriter.setInsertionPoint(target); - FailureOr res = lowerPack(rewriter, target); + bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice(); + FailureOr res = + lowerPack(rewriter, target, lowerPadLikeWithInsertSlice); if (failed(res)) { return mlir::emitSilenceableFailure(target->getLoc()) << "cannot lower to pad + expand + transpose"; @@ -1191,7 +1193,9 @@ DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne( transform::ApplyToEachResultList &transformResults, transform::TransformState &state) { rewriter.setInsertionPoint(target); - FailureOr res = lowerUnPack(rewriter, target); + bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice(); + FailureOr res = + lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice); if (failed(res)) { DiagnosedSilenceableFailure diag = emitSilenceableError() diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index d92543d726462..f597faa16cf60 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -217,7 +217,8 @@ struct PackedOperandsDimList { } // namespace FailureOr linalg::lowerPack(RewriterBase &rewriter, - tensor::PackOp packOp) { + tensor::PackOp packOp, + bool lowerPadLikeWithInsertSlice) { // 1. Filter out NYI cases. auto packedTensorType = cast(packOp->getResultTypes().front()); @@ -295,7 +296,7 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL();); - if (packOp.isLikePad()) { + if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) { // Pack ops which operate as simple pads may not produce legal // tensor.insert_slice operations when the packed type does not rank reduce // to the padded type. @@ -351,8 +352,9 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, return LowerPackResult{padOp, reshapeOp, transposeOp}; } -FailureOr linalg::lowerUnPack(RewriterBase &rewriter, - tensor::UnPackOp unPackOp) { +FailureOr +linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp, + bool lowerUnpadLikeWithExtractSlice) { Location loc = unPackOp->getLoc(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(unPackOp); @@ -362,7 +364,7 @@ FailureOr linalg::lowerUnPack(RewriterBase &rewriter, OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1); auto destTensorType = cast(unPackOp.getDest().getType()); - if (unPackOp.isLikeUnPad()) { + if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) { // This unpack is just a plain unpad. // Just extract the slice from the higher ranked tensor. ArrayRef destShape = destTensorType.getShape(); diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir index 7aadf19069563..5f8ff36a16578 100644 --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -96,6 +96,34 @@ module attributes {transform.with_named_sequence} { // ----- +// This is same as pack_as_pad but since we explicitly added {lowerPadLikeWithInsertSlice = false}, it should not +// be lowered to insert_slice. +// CHECK-LABEL: func.func @pack_as_pad_disabled_insert_slice( +func.func @pack_as_pad_disabled_insert_slice(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> { + %cst_0 = arith.constant 0.0 : f32 + // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose + // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<129x47x16x16xf32> + // CHECK-DAG: %[[PAD:.*]] = tensor.pad %[[ARG0]] + // CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]] + // CHECK: %[[PAD_EXPANDED:.*]] = tensor.expand_shape %[[PAD]] + // CHECK-DAG: %[[RES:.*]] = linalg.transpose ins(%[[PAD_EXPANDED]] + %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1 + : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32> + return %pack : tensor<1x1x1x1x136x64x16x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %pack = transform.structured.match ops{["tensor.pack"]} in %module_op + : (!transform.any_op) -> !transform.op<"tensor.pack"> + transform.structured.lower_pack %pack {lowerPadLikeWithInsertSlice = false}: (!transform.op<"tensor.pack">) + -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">) + transform.yield + } +} + +// ----- + // Check that we don't lower the following pack as a pad. // Although all the outer most dimensions in the resulting shape are 1s, // some of the original dimensions are not part of the inner_dims_pos, hence @@ -233,6 +261,38 @@ module attributes {transform.with_named_sequence} { // ----- +// This is same as upack_as_pad but since we explicitly added {lowerUnpadLikeWithExtractSlice = false}, it should not +// be lowered to extract_slice. +// CHECK-LABEL: func.func @unpack_as_pad_disabled_extract_slice( +func.func @unpack_as_pad_disabled_extract_slice(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> { + %cst_0 = arith.constant 0.0 : f32 + + // tensor.unpack is lowered to tensor.extract_slice + linalg.transpose + tensor.collapse_shape + // CHECK-DAG: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32> + // CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[ARG0]] + // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[TRANSPOSED]] + // CHECK-DAG: %[[RES:.*]] = tensor.extract_slice %[[COLLAPSED]] + %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1 + : tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32> + return %pack : tensor<129x47x16x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op + : (!transform.any_op) -> !transform.op<"tensor.unpack"> + transform.structured.lower_unpack %unpack {lowerUnpadLikeWithExtractSlice = false}: (!transform.op<"tensor.unpack">) + -> (!transform.op<"tensor.empty">, + !transform.op<"linalg.transpose">, + !transform.op<"tensor.collapse_shape">, + !transform.op<"tensor.extract_slice">) + transform.yield + } +} + +// ----- + // CHECK-LABEL: func.func @pack_with_outer_dims_perm( func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>, %dest: tensor<200x4x16x100x16x32xi32>) @@ -572,7 +632,7 @@ func.func @unpack_fully_dynamic(%source: tensor, %dest: tensor !transform.op<"tensor.unpack"> + : (!transform.any_op) -> !transform.op<"tensor.unpack"> transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">) -> (!transform.op<"tensor.empty">, !transform.op<"linalg.transpose">, @@ -627,9 +687,9 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: @unpack_with_outer_dims_perm // CHECK-SAME: %[[ARG0:.*]]: tensor<32x64xf32>, %[[ARG1:.*]]: tensor<2x4x32x8xf32> // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<4x8x2x32xf32> -// CHECK: %[[TRAN:.*]] = linalg.transpose -// CHECK-SAME: ins(%[[ARG1]] : tensor<2x4x32x8xf32>) -// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x8x2x32xf32>) +// CHECK: %[[TRAN:.*]] = linalg.transpose +// CHECK-SAME: ins(%[[ARG1]] : tensor<2x4x32x8xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x8x2x32xf32>) // CHECK-SAME: permutation = [1, 3, 0, 2] // CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]] // CHECK-SAME: : tensor<4x8x2x32xf32> into tensor<32x64xf32> @@ -638,7 +698,7 @@ module attributes {transform.with_named_sequence} { // CHECK: linalg.copy ins(%[[SLICE]] // CHECK-SAME: : tensor<32x64xf32>) outs(%[[ARG0]] : tensor<32x64xf32>) -> tensor<32x64xf32> func.func @unpack_with_outer_dims_perm(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> { - %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0] + %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg0 : tensor<2x4x32x8xf32> -> tensor<32x64xf32> return %unpack : tensor<32x64xf32> } diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir new file mode 100644 index 0000000000000..faf7ff9ad7ed0 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir @@ -0,0 +1,240 @@ +// RUN: mlir-opt %s --transform-interpreter --split-input-file -canonicalize | FileCheck %s + +// For pack op, we use lowerPadLikeWithInsertSlice = false to ensure no insert_slice is generated. +// This allows linalg.transpose to be fused as a producer operation. In below testcase, linalg.transpose +// as a producer operation is fused into the scf.forall loop. + +module { + // CHECK-label: func @fuse_pack_as_producer + // CHECK: scf.forall {{.*}} { + // CHECK: %[[PRODUCER:.*]] = linalg.transpose + // CHECK: linalg.generic {{.*}} ins(%[[PRODUCER]] + // CHECK: scf.forall.in_parallel + // CHECK: } + func.func @fuse_pack_as_producer(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>) + -> tensor<4x4x128x256xf32> { + %dest = tensor.empty() : tensor<1x1x128x256xf32> + %pack = tensor.pack %src inner_dims_pos = [0, 1] inner_tiles = [128, 256] + into %dest : tensor<128x256xf32> -> tensor<1x1x128x256xf32> + + %out = tensor.empty() : tensor<4x4x128x256xf32> + %res = linalg.generic + {indexing_maps = [affine_map<(i, j, k, l) -> (0, 0, k, l)>, + affine_map<(i, j, k, l) -> (i, j, k, l)>, + affine_map<(i, j, k, l) -> (i, j, k, l)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%pack, %other: tensor<1x1x128x256xf32>, tensor<4x4x128x256xf32>) + outs(%out: tensor<4x4x128x256xf32>) { + ^bb0(%pack_elem: f32, %other_elem: f32, %out_elem: f32): + %r = arith.addf %pack_elem, %other_elem : f32 + linalg.yield %r : f32 + } -> tensor<4x4x128x256xf32> + + return %res : tensor<4x4x128x256xf32> + } + + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + // Find and lower pack operation. + %pack = transform.structured.match ops{["tensor.pack"]} in %arg1 + : (!transform.any_op) -> !transform.op<"tensor.pack"> + %paded, %expanded, %transpose = transform.structured.lower_pack %pack {lowerPadLikeWithInsertSlice = false} + : (!transform.op<"tensor.pack">) + -> (!transform.op<"tensor.pad">, + !transform.op<"tensor.expand_shape">, + !transform.op<"linalg.transpose">) + + %root = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + // Tile the lialg operation with parallel forall loop tiling [4, 4]. + %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Fuse the transpose operation into the tiled loop. + transform.structured.fuse_into_containing_op %transpose into %forall_op + : (!transform.op<"linalg.transpose">, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } + } +} + +// ----- +// For pack op, by default lowerPadLikeWithInsertSlice = true, which generates insert_slice and blocks fusion. +// In below testcase, tensor.insert_slice as a producer operation cannot be fused into the scf.forall loop. + +module { + // CHECK-label: func @fuse_pack_as_producer_blocked_by_insert_slice + // CHECK: %[[PRODUCER:.*]] = tensor.insert_slice + // CHECK: scf.forall {{.*}} { + // CHECK: linalg.generic {{.*}} ins(%[[PRODUCER]] + // CHECK: scf.forall.in_parallel + // CHECK: } + func.func @fuse_pack_as_producer_blocked_by_insert_slice(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>) + -> tensor<4x4x128x256xf32> { + %dest = tensor.empty() : tensor<1x1x128x256xf32> + %pack = tensor.pack %src inner_dims_pos = [0, 1] inner_tiles = [128, 256] + into %dest : tensor<128x256xf32> -> tensor<1x1x128x256xf32> + + %out = tensor.empty() : tensor<4x4x128x256xf32> + %res = linalg.generic + {indexing_maps = [affine_map<(i, j, k, l) -> (0, 0, k, l)>, + affine_map<(i, j, k, l) -> (i, j, k, l)>, + affine_map<(i, j, k, l) -> (i, j, k, l)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%pack, %other: tensor<1x1x128x256xf32>, tensor<4x4x128x256xf32>) + outs(%out: tensor<4x4x128x256xf32>) { + ^bb0(%pack_elem: f32, %other_elem: f32, %out_elem: f32): + %r = arith.addf %pack_elem, %other_elem : f32 + linalg.yield %r : f32 + } -> tensor<4x4x128x256xf32> + + return %res : tensor<4x4x128x256xf32> + } + + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + // Find and lower pack operation. + %pack = transform.structured.match ops{["tensor.pack"]} in %arg1 + : (!transform.any_op) -> !transform.op<"tensor.pack"> + %paded, %expanded, %transpose = transform.structured.lower_pack %pack + : (!transform.op<"tensor.pack">) + -> (!transform.op<"tensor.pad">, + !transform.op<"tensor.expand_shape">, + !transform.op<"linalg.transpose">) + + %root = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + // Tile the lialg operation with parallel forall loop tiling [4, 4]. + %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Fuse the transpose operation into the tiled loop. + transform.structured.fuse_into_containing_op %transpose into %forall_op + : (!transform.op<"linalg.transpose">, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } + } +} + +// ----- +// For unpack op, we use lowerUnpadLikeWithExtractSlice = false to ensure no extract_slice is generated. +// This allows linalg.transpose to be fused as a consumer operation. In below testcase, linalg.transpose +// as a consumer operation is fused into the scf.forall loop. +module { + // CHECK-label: func @fuse_unpack_as_consumer + // CHECK: scf.forall {{.*}} { + // CHECK: %[[CONSUMER:.*]] = linalg.generic + // CHECK: linalg.transpose ins(%[[CONSUMER]] + // CHECK: scf.forall.in_parallel + // CHECK: } + func.func @fuse_unpack_as_consumer(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>) + -> tensor<128x256xf32> { + %out = tensor.empty() : tensor<1x1x128x256xf32> + %res = linalg.generic + {indexing_maps = [affine_map<(i, j, k, l) -> (i, j, k, l)>, + affine_map<(i, j, k, l) -> (i, j, k, l)>, + affine_map<(i, j, k, l) -> (0, 0, k, l)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%src, %other: tensor<4x4x128x256xf32>, tensor<4x4x128x256xf32>) + outs(%out: tensor<1x1x128x256xf32>) { + ^bb0(%unpack_elem: f32, %other_elem: f32, %out_elem: f32): + %r = arith.addf %unpack_elem, %other_elem : f32 + linalg.yield %r : f32 + } -> tensor<1x1x128x256xf32> + + %dest = tensor.empty() : tensor<128x256xf32> + %unpack = tensor.unpack %res inner_dims_pos = [0, 1] inner_tiles = [128, 256] + into %dest : tensor<1x1x128x256xf32> -> tensor<128x256xf32> + + return %unpack : tensor<128x256xf32> + } + + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + // Find and lower unpack operation. + %unpack = transform.structured.match ops{["tensor.unpack"]} in %arg1 + : (!transform.any_op) -> !transform.op<"tensor.unpack"> + transform.structured.lower_unpack %unpack {lowerUnpadLikeWithExtractSlice = false} + : (!transform.op<"tensor.unpack">) + -> (!transform.op<"tensor.empty">, + !transform.op<"linalg.transpose">, + !transform.op<"tensor.collapse_shape">, + !transform.op<"tensor.extract_slice">) + + %root = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + // Tile the lialg operation with parallel forall loop tiling [4, 4]. + %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Fuse the consumer operation into the tiled loop. + %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op + : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> + transform.test.fuse_consumer %slice_op + : (!transform.op<"tensor.parallel_insert_slice">) -> (!transform.any_op, !transform.any_op) + transform.yield + } + } +} + +// ----- +// For unpack op, by default lowerUnpadLikeWithExtractSlice = true, which generates extract_slice and blocks fusion. +// In below testcase, tensor.extract_slice as a consumer operation cannot be fused into the scf.forall loop. +module { + // CHECK-label: func @fuse_unpack_as_consumer_blocked_by_extract_slice + // CHECK: %[[CONSUMER:.*]] = scf.forall {{.*}} { + // CHECK: %[[ADDF:.*]] = linalg.generic + // CHECK: scf.forall.in_parallel + // CHECK: tensor.parallel_insert_slice %[[ADDF]] + // CHECK: } + // CHECK: tensor.extract_slice %[[CONSUMER]] + func.func @fuse_unpack_as_consumer_blocked_by_extract_slice(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>) + -> tensor<128x256xf32> { + %out = tensor.empty() : tensor<1x1x128x256xf32> + %res = linalg.generic + {indexing_maps = [affine_map<(i, j, k, l) -> (i, j, k, l)>, + affine_map<(i, j, k, l) -> (i, j, k, l)>, + affine_map<(i, j, k, l) -> (0, 0, k, l)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%src, %other: tensor<4x4x128x256xf32>, tensor<4x4x128x256xf32>) + outs(%out: tensor<1x1x128x256xf32>) { + ^bb0(%unpack_elem: f32, %other_elem: f32, %out_elem: f32): + %r = arith.addf %unpack_elem, %other_elem : f32 + linalg.yield %r : f32 + } -> tensor<1x1x128x256xf32> + + %dest = tensor.empty() : tensor<128x256xf32> + %unpack = tensor.unpack %res inner_dims_pos = [0, 1] inner_tiles = [128, 256] + into %dest : tensor<1x1x128x256xf32> -> tensor<128x256xf32> + + return %unpack : tensor<128x256xf32> + } + + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + // Find and lower unpack operation. + %unpack = transform.structured.match ops{["tensor.unpack"]} in %arg1 + : (!transform.any_op) -> !transform.op<"tensor.unpack"> + transform.structured.lower_unpack %unpack + : (!transform.op<"tensor.unpack">) + -> (!transform.op<"tensor.empty">, + !transform.op<"linalg.transpose">, + !transform.op<"tensor.collapse_shape">, + !transform.op<"tensor.extract_slice">) + + %root = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + // Tile the lialg operation with parallel forall loop tiling [4, 4]. + %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Fuse the consumer operation into the tiled loop. + %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op + : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> + // Note that we cannot apply transform.test.fuse_consumer here because the extract_slice + // is not qualified consumer operation. Forcing this will yeild "could not fetch consumer + // to fuse" error. + transform.yield + } + } +}