From 012f6d46ff6c187470d6ca102be513e7a5a78a21 Mon Sep 17 00:00:00 2001 From: jerryyin Date: Fri, 22 Nov 2024 15:52:12 +0000 Subject: [PATCH 1/7] [NFC] Add allowInsertSliceLowering to packOp and allowExtractSliceLowering to UnPackOp --- .../Linalg/TransformOps/LinalgTransformOps.td | 6 ++++-- .../mlir/Dialect/Linalg/Transforms/Transforms.h | 8 +++++--- .../Linalg/TransformOps/LinalgTransformOps.cpp | 8 ++++++-- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 12 +++++++----- 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index e3084530bd11b..ea96da77b6c33 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -548,7 +548,8 @@ def LowerPackOp : Op:$target); + let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target, + DefaultValuedAttr:$allowInsertSliceLowering); let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op, Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op, Transform_ConcreteOpType<"linalg.transpose">:$transpose_op); @@ -588,7 +589,8 @@ def LowerUnPackOp : Op:$target); + let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target, + DefaultValuedAttr:$allowExtractSliceLowering); let results = (outs Transform_ConcreteOpType<"tensor.empty">:$empty_op, Transform_ConcreteOpType<"linalg.transpose">:$transpose_op, Transform_ConcreteOpType<"tensor.collapse_shape">:$collapse_shape_op, diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 51967f83fee37..fd27e7929764d 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1121,7 +1121,8 @@ struct LowerPackResult { /// Rewrite pack as pad + reshape + transpose. FailureOr lowerPack(RewriterBase &rewriter, - tensor::PackOp packOp); + tensor::PackOp packOp, + bool allowInsertSliceLowering = true); struct LowerUnPackOpResult { tensor::EmptyOp emptyOp; @@ -1131,8 +1132,9 @@ struct LowerUnPackOpResult { }; /// Rewrite pack as empty + transpose + reshape + extract_slice. -FailureOr lowerUnPack(RewriterBase &rewriter, - tensor::UnPackOp unPackOp); +FailureOr +lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp, + bool allowExtractSliceLowering = true); /// Struct to hold the result of a `pack` call. struct PackResult { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index ada80deacfdbf..5117a5c58c381 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1171,7 +1171,9 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne( transform::ApplyToEachResultList &transformResults, transform::TransformState &state) { rewriter.setInsertionPoint(target); - FailureOr res = lowerPack(rewriter, target); + bool allowInsertSliceLowering = getAllowInsertSliceLowering(); + FailureOr res = + lowerPack(rewriter, target, allowInsertSliceLowering); if (failed(res)) { return mlir::emitSilenceableFailure(target->getLoc()) << "cannot lower to pad + expand + transpose"; @@ -1191,7 +1193,9 @@ DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne( transform::ApplyToEachResultList &transformResults, transform::TransformState &state) { rewriter.setInsertionPoint(target); - FailureOr res = lowerUnPack(rewriter, target); + bool allowExtractSliceLowering = getAllowExtractSliceLowering(); + FailureOr res = + lowerUnPack(rewriter, target, allowExtractSliceLowering); if (failed(res)) { DiagnosedSilenceableFailure diag = emitSilenceableError() diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index d92543d726462..0717dad4c2852 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -217,7 +217,8 @@ struct PackedOperandsDimList { } // namespace FailureOr linalg::lowerPack(RewriterBase &rewriter, - tensor::PackOp packOp) { + tensor::PackOp packOp, + bool allowInsertSliceLowering) { // 1. Filter out NYI cases. auto packedTensorType = cast(packOp->getResultTypes().front()); @@ -295,7 +296,7 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL();); - if (packOp.isLikePad()) { + if (allowInsertSliceLowering && packOp.isLikePad()) { // Pack ops which operate as simple pads may not produce legal // tensor.insert_slice operations when the packed type does not rank reduce // to the padded type. @@ -351,8 +352,9 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, return LowerPackResult{padOp, reshapeOp, transposeOp}; } -FailureOr linalg::lowerUnPack(RewriterBase &rewriter, - tensor::UnPackOp unPackOp) { +FailureOr +linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp, + bool allowExtractSliceLowering) { Location loc = unPackOp->getLoc(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(unPackOp); @@ -362,7 +364,7 @@ FailureOr linalg::lowerUnPack(RewriterBase &rewriter, OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1); auto destTensorType = cast(unPackOp.getDest().getType()); - if (unPackOp.isLikeUnPad()) { + if (allowExtractSliceLowering && unPackOp.isLikeUnPad()) { // This unpack is just a plain unpad. // Just extract the slice from the higher ranked tensor. ArrayRef destShape = destTensorType.getShape(); From 46b72028918f13f8faf7ee474d6da14f15a246ef Mon Sep 17 00:00:00 2001 From: jerryyin Date: Fri, 22 Nov 2024 17:51:40 +0000 Subject: [PATCH 2/7] This commit add test cases to allowInsertSliceLowering and allowExtractSliceLowering --- .../Dialect/Linalg/transform-lower-pack.mlir | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir index 7aadf19069563..2e6a5ea97aaa3 100644 --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -96,6 +96,34 @@ module attributes {transform.with_named_sequence} { // ----- +// This is same as pack_as_pad but since we explicitly added {allowInsertSliceLowering = false}, it should not +// be lowered to insert_slice. +// CHECK-LABEL: func.func @pack_disallowed_as_pad( +// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>, +// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>) +func.func @pack_disallowed_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> { + %cst_0 = arith.constant 0.0 : f32 + // tensor.pack is lowered to tensor.pad + tensor.expand_shape + tensor.insert_slice + // CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0] + // CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32> + // CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]] + %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1 + : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32> + return %pack : tensor<1x1x1x1x136x64x16x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %pack = transform.structured.match ops{["tensor.pack"]} in %module_op + : (!transform.any_op) -> !transform.op<"tensor.pack"> + transform.structured.lower_pack %pack {allowInsertSliceLowering = false}: (!transform.op<"tensor.pack">) + -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">) + transform.yield + } +} + +// ----- + // Check that we don't lower the following pack as a pad. // Although all the outer most dimensions in the resulting shape are 1s, // some of the original dimensions are not part of the inner_dims_pos, hence @@ -233,6 +261,34 @@ module attributes {transform.with_named_sequence} { // ----- +// This is same as upack_as_pad but since we explicitly added {allowExtractSlicelowering = false}, it should not +// be lowered to extract_slice. +// CHECK-LABEL: func.func @unpack_disallowed_as_pad( +func.func @unpack_disallowed_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> { + %cst_0 = arith.constant 0.0 : f32 + + // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32> + // CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]] + %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1 + : tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32> + return %pack : tensor<129x47x16x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op + : (!transform.any_op) -> !transform.op<"tensor.unpack"> + transform.structured.lower_unpack %unpack {allowExtractSliceLowering = false}: (!transform.op<"tensor.unpack">) + -> (!transform.op<"tensor.empty">, + !transform.op<"linalg.transpose">, + !transform.op<"tensor.collapse_shape">, + !transform.op<"tensor.extract_slice">) + transform.yield + } +} + +// ----- + // CHECK-LABEL: func.func @pack_with_outer_dims_perm( func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>, %dest: tensor<200x4x16x100x16x32xi32>) From 0fa54017fd955f7637f9c8289896b4691518537f Mon Sep 17 00:00:00 2001 From: jerryyin Date: Mon, 25 Nov 2024 15:57:32 +0000 Subject: [PATCH 3/7] Address review requests - Renamed allowInsertSliceLowering to lowerPadLikeWithInsertSlice - Renamed allowExtractSliceLowering to lowerUnpadLikeWithExtractSlice - Removed the redundant unit test since this is NFC change This reverts commit 46b72028918f13f8faf7ee474d6da14f15a246ef. --- .../Linalg/TransformOps/LinalgTransformOps.td | 4 +- .../Dialect/Linalg/Transforms/Transforms.h | 4 +- .../TransformOps/LinalgTransformOps.cpp | 8 +-- .../Dialect/Linalg/Transforms/Transforms.cpp | 8 +-- .../Dialect/Linalg/transform-lower-pack.mlir | 56 ------------------- 5 files changed, 12 insertions(+), 68 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index ea96da77b6c33..675a766ec98b3 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -549,7 +549,7 @@ def LowerPackOp : Op:$target, - DefaultValuedAttr:$allowInsertSliceLowering); + DefaultValuedAttr:$lowerPadLikeWithInsertSlice); let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op, Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op, Transform_ConcreteOpType<"linalg.transpose">:$transpose_op); @@ -590,7 +590,7 @@ def LowerUnPackOp : Op:$target, - DefaultValuedAttr:$allowExtractSliceLowering); + DefaultValuedAttr:$lowerUnpadLikeWithExtractSlice); let results = (outs Transform_ConcreteOpType<"tensor.empty">:$empty_op, Transform_ConcreteOpType<"linalg.transpose">:$transpose_op, Transform_ConcreteOpType<"tensor.collapse_shape">:$collapse_shape_op, diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index fd27e7929764d..82558de0fbfe6 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1122,7 +1122,7 @@ struct LowerPackResult { /// Rewrite pack as pad + reshape + transpose. FailureOr lowerPack(RewriterBase &rewriter, tensor::PackOp packOp, - bool allowInsertSliceLowering = true); + bool lowerPadLikeWithInsertSlice = true); struct LowerUnPackOpResult { tensor::EmptyOp emptyOp; @@ -1134,7 +1134,7 @@ struct LowerUnPackOpResult { /// Rewrite pack as empty + transpose + reshape + extract_slice. FailureOr lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp, - bool allowExtractSliceLowering = true); + bool lowerUnpadLikeWithExtractSlice = true); /// Struct to hold the result of a `pack` call. struct PackResult { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 5117a5c58c381..06f58d4943394 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1171,9 +1171,9 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne( transform::ApplyToEachResultList &transformResults, transform::TransformState &state) { rewriter.setInsertionPoint(target); - bool allowInsertSliceLowering = getAllowInsertSliceLowering(); + bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice(); FailureOr res = - lowerPack(rewriter, target, allowInsertSliceLowering); + lowerPack(rewriter, target, lowerPadLikeWithInsertSlice); if (failed(res)) { return mlir::emitSilenceableFailure(target->getLoc()) << "cannot lower to pad + expand + transpose"; @@ -1193,9 +1193,9 @@ DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne( transform::ApplyToEachResultList &transformResults, transform::TransformState &state) { rewriter.setInsertionPoint(target); - bool allowExtractSliceLowering = getAllowExtractSliceLowering(); + bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice(); FailureOr res = - lowerUnPack(rewriter, target, allowExtractSliceLowering); + lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice); if (failed(res)) { DiagnosedSilenceableFailure diag = emitSilenceableError() diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 0717dad4c2852..f597faa16cf60 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -218,7 +218,7 @@ struct PackedOperandsDimList { FailureOr linalg::lowerPack(RewriterBase &rewriter, tensor::PackOp packOp, - bool allowInsertSliceLowering) { + bool lowerPadLikeWithInsertSlice) { // 1. Filter out NYI cases. auto packedTensorType = cast(packOp->getResultTypes().front()); @@ -296,7 +296,7 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL();); - if (allowInsertSliceLowering && packOp.isLikePad()) { + if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) { // Pack ops which operate as simple pads may not produce legal // tensor.insert_slice operations when the packed type does not rank reduce // to the padded type. @@ -354,7 +354,7 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, FailureOr linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp, - bool allowExtractSliceLowering) { + bool lowerUnpadLikeWithExtractSlice) { Location loc = unPackOp->getLoc(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(unPackOp); @@ -364,7 +364,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp, OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1); auto destTensorType = cast(unPackOp.getDest().getType()); - if (allowExtractSliceLowering && unPackOp.isLikeUnPad()) { + if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) { // This unpack is just a plain unpad. // Just extract the slice from the higher ranked tensor. ArrayRef destShape = destTensorType.getShape(); diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir index 2e6a5ea97aaa3..7aadf19069563 100644 --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -96,34 +96,6 @@ module attributes {transform.with_named_sequence} { // ----- -// This is same as pack_as_pad but since we explicitly added {allowInsertSliceLowering = false}, it should not -// be lowered to insert_slice. -// CHECK-LABEL: func.func @pack_disallowed_as_pad( -// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>, -// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>) -func.func @pack_disallowed_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> { - %cst_0 = arith.constant 0.0 : f32 - // tensor.pack is lowered to tensor.pad + tensor.expand_shape + tensor.insert_slice - // CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0] - // CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32> - // CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]] - %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1 - : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32> - return %pack : tensor<1x1x1x1x136x64x16x16xf32> -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { - %pack = transform.structured.match ops{["tensor.pack"]} in %module_op - : (!transform.any_op) -> !transform.op<"tensor.pack"> - transform.structured.lower_pack %pack {allowInsertSliceLowering = false}: (!transform.op<"tensor.pack">) - -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">) - transform.yield - } -} - -// ----- - // Check that we don't lower the following pack as a pad. // Although all the outer most dimensions in the resulting shape are 1s, // some of the original dimensions are not part of the inner_dims_pos, hence @@ -261,34 +233,6 @@ module attributes {transform.with_named_sequence} { // ----- -// This is same as upack_as_pad but since we explicitly added {allowExtractSlicelowering = false}, it should not -// be lowered to extract_slice. -// CHECK-LABEL: func.func @unpack_disallowed_as_pad( -func.func @unpack_disallowed_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> { - %cst_0 = arith.constant 0.0 : f32 - - // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32> - // CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]] - %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1 - : tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32> - return %pack : tensor<129x47x16x16xf32> -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { - %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op - : (!transform.any_op) -> !transform.op<"tensor.unpack"> - transform.structured.lower_unpack %unpack {allowExtractSliceLowering = false}: (!transform.op<"tensor.unpack">) - -> (!transform.op<"tensor.empty">, - !transform.op<"linalg.transpose">, - !transform.op<"tensor.collapse_shape">, - !transform.op<"tensor.extract_slice">) - transform.yield - } -} - -// ----- - // CHECK-LABEL: func.func @pack_with_outer_dims_perm( func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>, %dest: tensor<200x4x16x100x16x32xi32>) From 671829ee1c89b1d5ff82c866f7074a969414698f Mon Sep 17 00:00:00 2001 From: jerryyin Date: Mon, 2 Dec 2024 21:44:20 +0000 Subject: [PATCH 4/7] Adding test cases to allowInsertSliceLowering and allowExtractSliceLowering --- .../Dialect/Linalg/transform-lower-pack.mlir | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir index 7aadf19069563..2f7f2ff5211bf 100644 --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -96,6 +96,34 @@ module attributes {transform.with_named_sequence} { // ----- +// This is same as pack_as_pad but since we explicitly added {lowerPadLikeWithInsertSlice = false}, it should not +// be lowered to insert_slice. +// CHECK-LABEL: func.func @pack_disallowed_as_pad( +func.func @pack_disallowed_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> { + %cst_0 = arith.constant 0.0 : f32 + // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose + // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<129x47x16x16xf32> + // CHECK: %[[PAD:.*]] = tensor.pad %[[ARG0]] + // CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]] + // CHECK: %[[PAD_EXPANDED:.*]] = tensor.expand_shape %[[PAD]] + // CHECK: %[[RES:.*]] = linalg.transpose ins(%[[PAD_EXPANDED]] + %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1 + : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32> + return %pack : tensor<1x1x1x1x136x64x16x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %pack = transform.structured.match ops{["tensor.pack"]} in %module_op + : (!transform.any_op) -> !transform.op<"tensor.pack"> + transform.structured.lower_pack %pack {lowerPadLikeWithInsertSlice = false}: (!transform.op<"tensor.pack">) + -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">) + transform.yield + } +} + +// ----- + // Check that we don't lower the following pack as a pad. // Although all the outer most dimensions in the resulting shape are 1s, // some of the original dimensions are not part of the inner_dims_pos, hence @@ -233,6 +261,38 @@ module attributes {transform.with_named_sequence} { // ----- +// This is same as upack_as_pad but since we explicitly added {lowerUnpadLikeWithExtractSlice = false}, it should not +// be lowered to extract_slice. +// CHECK-LABEL: func.func @unpack_disallowed_as_pad( +func.func @unpack_disallowed_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> { + %cst_0 = arith.constant 0.0 : f32 + + // tensor.unpack is lowered to tensor.extract_slice + linalg.transpose + tensor.collapse_shape + // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32> + // CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[ARG0]] + // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[TRANSPOSED]] + // CHECK: %[[RES:.*]] = tensor.extract_slice %[[COLLAPSED]] + %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1 + : tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32> + return %pack : tensor<129x47x16x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op + : (!transform.any_op) -> !transform.op<"tensor.unpack"> + transform.structured.lower_unpack %unpack {lowerUnpadLikeWithExtractSlice = false}: (!transform.op<"tensor.unpack">) + -> (!transform.op<"tensor.empty">, + !transform.op<"linalg.transpose">, + !transform.op<"tensor.collapse_shape">, + !transform.op<"tensor.extract_slice">) + transform.yield + } +} + +// ----- + // CHECK-LABEL: func.func @pack_with_outer_dims_perm( func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>, %dest: tensor<200x4x16x100x16x32xi32>) From 3ad8cd64f687afa5cadc57e40aa228d6a587ed03 Mon Sep 17 00:00:00 2001 From: jerryyin Date: Thu, 5 Dec 2024 15:42:03 +0000 Subject: [PATCH 5/7] Add test to verify pack/producer unpack/consumer fusion --- .../transform-tile-and-fuse-pack-unpack.mlir | 121 ++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir new file mode 100644 index 0000000000000..31c28a852eef2 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir @@ -0,0 +1,121 @@ +// RUN: mlir-opt %s --transform-interpreter --split-input-file -canonicalize | FileCheck %s + +// For pack op, we use lowerPadLikeWithInsertSlice = false to ensure no insert_slice is generated. +// This allows linalg.transpose to be fused as a producer operation. Alternatively, without this attribute +// insert_slice will be generated and fusion blocked. + +module { + // CHECK-label: func @fuse_pack_as_producer + // CHECK: scf.forall {{.*}} { + // CHECK: linalg.transpose + // CHECK: linalg.generic + // CHECK: scf.forall.in_parallel + // CHECK: } + func.func @fuse_pack_as_producer(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>) + -> tensor<4x4x128x256xf32> { + %dest = tensor.empty() : tensor<1x1x128x256xf32> + %pack = tensor.pack %src inner_dims_pos = [0, 1] inner_tiles = [128, 256] + into %dest : tensor<128x256xf32> -> tensor<1x1x128x256xf32> + + %out = tensor.empty() : tensor<4x4x128x256xf32> + %res = linalg.generic + {indexing_maps = [affine_map<(i, j, k, l) -> (0, 0, k, l)>, + affine_map<(i, j, k, l) -> (i, j, k, l)>, + affine_map<(i, j, k, l) -> (i, j, k, l)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%pack, %other: tensor<1x1x128x256xf32>, tensor<4x4x128x256xf32>) + outs(%out: tensor<4x4x128x256xf32>) { + ^bb0(%pack_elem: f32, %other_elem: f32, %out_elem: f32): + %r = arith.addf %pack_elem, %other_elem : f32 + linalg.yield %r : f32 + } -> tensor<4x4x128x256xf32> + + return %res : tensor<4x4x128x256xf32> + } + + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + // Find and lower pack operation. + %pack = transform.structured.match ops{["tensor.pack"]} in %arg1 + : (!transform.any_op) -> !transform.op<"tensor.pack"> + %paded, %expanded, %transpose = transform.structured.lower_pack %pack {lowerPadLikeWithInsertSlice = false} + : (!transform.op<"tensor.pack">) + -> (!transform.op<"tensor.pad">, + !transform.op<"tensor.expand_shape">, + !transform.op<"linalg.transpose">) + + %root = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + // Tile the lialg operation with parallel forall loop tiling [4, 4]. + %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Fuse the transpose operation into the tiled loop. + transform.structured.fuse_into_containing_op %transpose into %forall_op + : (!transform.op<"linalg.transpose">, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } + } +} + +// ----- +// For unpack op, we use lowerUnpadLikeWithExtractSlice = false to ensure no extract_slice is generated. +// This allows linalg.transpose to be fused as a consumer operation. Alternatively, without this attribute +// extract_slice will be generated and fusion blocked. + +module { + // CHECK-label: func @fuse_unpack_as_consumer + // CHECK: scf.forall {{.*}} { + // CHECK: linalg.generic + // CHECK: linalg.transpose + // CHECK: scf.forall.in_parallel + // CHECK: } + func.func @fuse_unpack_as_consumer(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>) + -> tensor<128x256xf32> { + %out = tensor.empty() : tensor<1x1x128x256xf32> + %res = linalg.generic + {indexing_maps = [affine_map<(i, j, k, l) -> (i, j, k, l)>, + affine_map<(i, j, k, l) -> (i, j, k, l)>, + affine_map<(i, j, k, l) -> (0, 0, k, l)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%src, %other: tensor<4x4x128x256xf32>, tensor<4x4x128x256xf32>) + outs(%out: tensor<1x1x128x256xf32>) { + ^bb0(%unpack_elem: f32, %other_elem: f32, %out_elem: f32): + %r = arith.addf %unpack_elem, %other_elem : f32 + linalg.yield %r : f32 + } -> tensor<1x1x128x256xf32> + + %dest = tensor.empty() : tensor<128x256xf32> + %unpack = tensor.unpack %res inner_dims_pos = [0, 1] inner_tiles = [128, 256] + into %dest : tensor<1x1x128x256xf32> -> tensor<128x256xf32> + + return %unpack : tensor<128x256xf32> + } + + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + // Find and lower unpack operation. + %unpack = transform.structured.match ops{["tensor.unpack"]} in %arg1 + : (!transform.any_op) -> !transform.op<"tensor.unpack"> + transform.structured.lower_unpack %unpack {lowerUnpadLikeWithExtractSlice = false} + : (!transform.op<"tensor.unpack">) + -> (!transform.op<"tensor.empty">, + !transform.op<"linalg.transpose">, + !transform.op<"tensor.collapse_shape">, + !transform.op<"tensor.extract_slice">) + + %root = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + // Tile the lialg operation with parallel forall loop tiling [4, 4]. + %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Fuse the consumer operation into the tiled loop. + %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op + : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> + transform.test.fuse_consumer %slice_op + : (!transform.op<"tensor.parallel_insert_slice">) -> (!transform.any_op, !transform.any_op) + transform.yield + } + } +} From 68b53288b08777ff1e895a7d0f54cdd97373bc25 Mon Sep 17 00:00:00 2001 From: jerryyin Date: Mon, 9 Dec 2024 16:12:35 +0000 Subject: [PATCH 6/7] Adding additional negative test cases - Added additional test cases to demonstrate insert/extract slice will block producer/consumer fusion - Readability enahncements --- .../Dialect/Linalg/transform-lower-pack.mlir | 34 ++--- .../transform-tile-and-fuse-pack-unpack.mlir | 117 ++++++++++++++++++ 2 files changed, 134 insertions(+), 17 deletions(-) diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir index 2f7f2ff5211bf..5f8ff36a16578 100644 --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -98,15 +98,15 @@ module attributes {transform.with_named_sequence} { // This is same as pack_as_pad but since we explicitly added {lowerPadLikeWithInsertSlice = false}, it should not // be lowered to insert_slice. -// CHECK-LABEL: func.func @pack_disallowed_as_pad( -func.func @pack_disallowed_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> { +// CHECK-LABEL: func.func @pack_as_pad_disabled_insert_slice( +func.func @pack_as_pad_disabled_insert_slice(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> { %cst_0 = arith.constant 0.0 : f32 // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<129x47x16x16xf32> - // CHECK: %[[PAD:.*]] = tensor.pad %[[ARG0]] + // CHECK-DAG: %[[PAD:.*]] = tensor.pad %[[ARG0]] // CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]] // CHECK: %[[PAD_EXPANDED:.*]] = tensor.expand_shape %[[PAD]] - // CHECK: %[[RES:.*]] = linalg.transpose ins(%[[PAD_EXPANDED]] + // CHECK-DAG: %[[RES:.*]] = linalg.transpose ins(%[[PAD_EXPANDED]] %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1 : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32> return %pack : tensor<1x1x1x1x136x64x16x16xf32> @@ -261,18 +261,18 @@ module attributes {transform.with_named_sequence} { // ----- -// This is same as upack_as_pad but since we explicitly added {lowerUnpadLikeWithExtractSlice = false}, it should not +// This is same as upack_as_pad but since we explicitly added {lowerUnpadLikeWithExtractSlice = false}, it should not // be lowered to extract_slice. -// CHECK-LABEL: func.func @unpack_disallowed_as_pad( -func.func @unpack_disallowed_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> { +// CHECK-LABEL: func.func @unpack_as_pad_disabled_extract_slice( +func.func @unpack_as_pad_disabled_extract_slice(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> { %cst_0 = arith.constant 0.0 : f32 // tensor.unpack is lowered to tensor.extract_slice + linalg.transpose + tensor.collapse_shape - // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32> - // CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]] - // CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[ARG0]] - // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[TRANSPOSED]] - // CHECK: %[[RES:.*]] = tensor.extract_slice %[[COLLAPSED]] + // CHECK-DAG: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32> + // CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[ARG0]] + // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[TRANSPOSED]] + // CHECK-DAG: %[[RES:.*]] = tensor.extract_slice %[[COLLAPSED]] %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1 : tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32> return %pack : tensor<129x47x16x16xf32> @@ -632,7 +632,7 @@ func.func @unpack_fully_dynamic(%source: tensor, %dest: tensor !transform.op<"tensor.unpack"> + : (!transform.any_op) -> !transform.op<"tensor.unpack"> transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">) -> (!transform.op<"tensor.empty">, !transform.op<"linalg.transpose">, @@ -687,9 +687,9 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: @unpack_with_outer_dims_perm // CHECK-SAME: %[[ARG0:.*]]: tensor<32x64xf32>, %[[ARG1:.*]]: tensor<2x4x32x8xf32> // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<4x8x2x32xf32> -// CHECK: %[[TRAN:.*]] = linalg.transpose -// CHECK-SAME: ins(%[[ARG1]] : tensor<2x4x32x8xf32>) -// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x8x2x32xf32>) +// CHECK: %[[TRAN:.*]] = linalg.transpose +// CHECK-SAME: ins(%[[ARG1]] : tensor<2x4x32x8xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x8x2x32xf32>) // CHECK-SAME: permutation = [1, 3, 0, 2] // CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]] // CHECK-SAME: : tensor<4x8x2x32xf32> into tensor<32x64xf32> @@ -698,7 +698,7 @@ module attributes {transform.with_named_sequence} { // CHECK: linalg.copy ins(%[[SLICE]] // CHECK-SAME: : tensor<32x64xf32>) outs(%[[ARG0]] : tensor<32x64xf32>) -> tensor<32x64xf32> func.func @unpack_with_outer_dims_perm(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> { - %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0] + %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg0 : tensor<2x4x32x8xf32> -> tensor<32x64xf32> return %unpack : tensor<32x64xf32> } diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir index 31c28a852eef2..ffed9ab6e0653 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir @@ -58,6 +58,62 @@ module { } } +// ----- +// For pack op, by default lowerPadLikeWithInsertSlice = true, which generates insert_slice and blocks fusion. + +module { + // CHECK-label: func @fuse_pack_as_producer_blocked_by_insert_slice + // CHECK: tensor.insert_slice + // CHECK: scf.forall {{.*}} { + // CHECK: scf.forall.in_parallel + // CHECK: } + func.func @fuse_pack_as_producer_blocked_by_insert_slice(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>) + -> tensor<4x4x128x256xf32> { + %dest = tensor.empty() : tensor<1x1x128x256xf32> + %pack = tensor.pack %src inner_dims_pos = [0, 1] inner_tiles = [128, 256] + into %dest : tensor<128x256xf32> -> tensor<1x1x128x256xf32> + + %out = tensor.empty() : tensor<4x4x128x256xf32> + %res = linalg.generic + {indexing_maps = [affine_map<(i, j, k, l) -> (0, 0, k, l)>, + affine_map<(i, j, k, l) -> (i, j, k, l)>, + affine_map<(i, j, k, l) -> (i, j, k, l)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%pack, %other: tensor<1x1x128x256xf32>, tensor<4x4x128x256xf32>) + outs(%out: tensor<4x4x128x256xf32>) { + ^bb0(%pack_elem: f32, %other_elem: f32, %out_elem: f32): + %r = arith.addf %pack_elem, %other_elem : f32 + linalg.yield %r : f32 + } -> tensor<4x4x128x256xf32> + + return %res : tensor<4x4x128x256xf32> + } + + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + // Find and lower pack operation. + %pack = transform.structured.match ops{["tensor.pack"]} in %arg1 + : (!transform.any_op) -> !transform.op<"tensor.pack"> + %paded, %expanded, %transpose = transform.structured.lower_pack %pack + : (!transform.op<"tensor.pack">) + -> (!transform.op<"tensor.pad">, + !transform.op<"tensor.expand_shape">, + !transform.op<"linalg.transpose">) + + %root = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + // Tile the lialg operation with parallel forall loop tiling [4, 4]. + %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Fuse the transpose operation into the tiled loop. + transform.structured.fuse_into_containing_op %transpose into %forall_op + : (!transform.op<"linalg.transpose">, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } + } +} + // ----- // For unpack op, we use lowerUnpadLikeWithExtractSlice = false to ensure no extract_slice is generated. // This allows linalg.transpose to be fused as a consumer operation. Alternatively, without this attribute @@ -119,3 +175,64 @@ module { } } } + +// ----- +// For unpack op, by default lowerUnpadLikeWithExtractSlice = true, which generates extract_slice and blocks fusion. + +module { + // CHECK-label: func @fuse_unpack_as_consumer_blocked_by_extract_slice + // CHECK: scf.forall {{.*}} { + // CHECK: linalg.generic + // CHECK: scf.forall.in_parallel + // CHECK: } + // CHECK: tensor.extract_slice + func.func @fuse_unpack_as_consumer_blocked_by_extract_slice(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>) + -> tensor<128x256xf32> { + %out = tensor.empty() : tensor<1x1x128x256xf32> + %res = linalg.generic + {indexing_maps = [affine_map<(i, j, k, l) -> (i, j, k, l)>, + affine_map<(i, j, k, l) -> (i, j, k, l)>, + affine_map<(i, j, k, l) -> (0, 0, k, l)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%src, %other: tensor<4x4x128x256xf32>, tensor<4x4x128x256xf32>) + outs(%out: tensor<1x1x128x256xf32>) { + ^bb0(%unpack_elem: f32, %other_elem: f32, %out_elem: f32): + %r = arith.addf %unpack_elem, %other_elem : f32 + linalg.yield %r : f32 + } -> tensor<1x1x128x256xf32> + + %dest = tensor.empty() : tensor<128x256xf32> + %unpack = tensor.unpack %res inner_dims_pos = [0, 1] inner_tiles = [128, 256] + into %dest : tensor<1x1x128x256xf32> -> tensor<128x256xf32> + + return %unpack : tensor<128x256xf32> + } + + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + // Find and lower unpack operation. + %unpack = transform.structured.match ops{["tensor.unpack"]} in %arg1 + : (!transform.any_op) -> !transform.op<"tensor.unpack"> + transform.structured.lower_unpack %unpack + : (!transform.op<"tensor.unpack">) + -> (!transform.op<"tensor.empty">, + !transform.op<"linalg.transpose">, + !transform.op<"tensor.collapse_shape">, + !transform.op<"tensor.extract_slice">) + + %root = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + // Tile the lialg operation with parallel forall loop tiling [4, 4]. + %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Fuse the consumer operation into the tiled loop. + %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op + : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> + // Note that we cannot apply transform.test.fuse_consumer here because the extract_slice + // is not qualified consumer operation. Forcing this will yeild "could not fetch consumer + // to fuse" error. + transform.yield + } + } +} From f6c54e82d01aa6ec026cc8eeae6dc5da3aa8f7d4 Mon Sep 17 00:00:00 2001 From: jerryyin Date: Mon, 9 Dec 2024 22:00:44 +0000 Subject: [PATCH 7/7] Add additional LIT variable This help to clearly demonstrate the produer fusion in pack case and consumer fusion in unpack case. --- .../transform-tile-and-fuse-pack-unpack.mlir | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir index ffed9ab6e0653..faf7ff9ad7ed0 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir @@ -1,14 +1,14 @@ // RUN: mlir-opt %s --transform-interpreter --split-input-file -canonicalize | FileCheck %s // For pack op, we use lowerPadLikeWithInsertSlice = false to ensure no insert_slice is generated. -// This allows linalg.transpose to be fused as a producer operation. Alternatively, without this attribute -// insert_slice will be generated and fusion blocked. +// This allows linalg.transpose to be fused as a producer operation. In below testcase, linalg.transpose +// as a producer operation is fused into the scf.forall loop. module { // CHECK-label: func @fuse_pack_as_producer // CHECK: scf.forall {{.*}} { - // CHECK: linalg.transpose - // CHECK: linalg.generic + // CHECK: %[[PRODUCER:.*]] = linalg.transpose + // CHECK: linalg.generic {{.*}} ins(%[[PRODUCER]] // CHECK: scf.forall.in_parallel // CHECK: } func.func @fuse_pack_as_producer(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>) @@ -60,11 +60,13 @@ module { // ----- // For pack op, by default lowerPadLikeWithInsertSlice = true, which generates insert_slice and blocks fusion. +// In below testcase, tensor.insert_slice as a producer operation cannot be fused into the scf.forall loop. module { // CHECK-label: func @fuse_pack_as_producer_blocked_by_insert_slice - // CHECK: tensor.insert_slice + // CHECK: %[[PRODUCER:.*]] = tensor.insert_slice // CHECK: scf.forall {{.*}} { + // CHECK: linalg.generic {{.*}} ins(%[[PRODUCER]] // CHECK: scf.forall.in_parallel // CHECK: } func.func @fuse_pack_as_producer_blocked_by_insert_slice(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>) @@ -116,14 +118,13 @@ module { // ----- // For unpack op, we use lowerUnpadLikeWithExtractSlice = false to ensure no extract_slice is generated. -// This allows linalg.transpose to be fused as a consumer operation. Alternatively, without this attribute -// extract_slice will be generated and fusion blocked. - +// This allows linalg.transpose to be fused as a consumer operation. In below testcase, linalg.transpose +// as a consumer operation is fused into the scf.forall loop. module { // CHECK-label: func @fuse_unpack_as_consumer // CHECK: scf.forall {{.*}} { - // CHECK: linalg.generic - // CHECK: linalg.transpose + // CHECK: %[[CONSUMER:.*]] = linalg.generic + // CHECK: linalg.transpose ins(%[[CONSUMER]] // CHECK: scf.forall.in_parallel // CHECK: } func.func @fuse_unpack_as_consumer(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>) @@ -178,14 +179,15 @@ module { // ----- // For unpack op, by default lowerUnpadLikeWithExtractSlice = true, which generates extract_slice and blocks fusion. - +// In below testcase, tensor.extract_slice as a consumer operation cannot be fused into the scf.forall loop. module { // CHECK-label: func @fuse_unpack_as_consumer_blocked_by_extract_slice - // CHECK: scf.forall {{.*}} { - // CHECK: linalg.generic + // CHECK: %[[CONSUMER:.*]] = scf.forall {{.*}} { + // CHECK: %[[ADDF:.*]] = linalg.generic // CHECK: scf.forall.in_parallel + // CHECK: tensor.parallel_insert_slice %[[ADDF]] // CHECK: } - // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice %[[CONSUMER]] func.func @fuse_unpack_as_consumer_blocked_by_extract_slice(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>) -> tensor<128x256xf32> { %out = tensor.empty() : tensor<1x1x128x256xf32>