-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][linalg] unfold projected permutation. #114704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
cdf865c
ce58238
b9094dc
3b238c6
e3373b8
296f805
6f61f9a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,240 @@ | ||||||||||||||||||||||
| //===- DecomposeGenericByUnfoldingPermutation.cpp -------===// | ||||||||||||||||||||||
| // | ||||||||||||||||||||||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||||||||||||||||||||
| // See https://llvm.org/LICENSE.txt for license information. | ||||||||||||||||||||||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||||||||||||||||||||
| // | ||||||||||||||||||||||
| //===----------------------------------------------------------------------===// | ||||||||||||||||||||||
| // | ||||||||||||||||||||||
| #include "mlir/Dialect/Affine/IR/AffineOps.h" | ||||||||||||||||||||||
| #include "mlir/Dialect/Linalg/IR/Linalg.h" | ||||||||||||||||||||||
| #include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||||||||||||||||||||||
| #include <map> | ||||||||||||||||||||||
| #include <optional> | ||||||||||||||||||||||
| #include <utility> | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| using namespace mlir; | ||||||||||||||||||||||
| using namespace mlir::linalg; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| namespace { | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /// This pattern decomposes the input operand(s) of a linalg.generic that has | ||||||||||||||||||||||
| /// a `transpose`, `broadcast`, or a mixture of two, into explicit transpose | ||||||||||||||||||||||
| /// and broadcast. Having them folded into the linalg.generic is a good | ||||||||||||||||||||||
| /// optimization but sometimes we may want to unwrap, i.e., `unfold` them as | ||||||||||||||||||||||
| /// explicit transpose and broadcast. This rewrite pattern helps do it for | ||||||||||||||||||||||
| /// each input operand. This is useful for instance when trying to recognize | ||||||||||||||||||||||
| /// named ops. | ||||||||||||||||||||||
| /// | ||||||||||||||||||||||
| /// The transpose, broadcast, or mixture of both, are expressed in the affine | ||||||||||||||||||||||
| /// map of the operand. Technically it is essentially `projected permutation`. | ||||||||||||||||||||||
| /// | ||||||||||||||||||||||
| /// Example | ||||||||||||||||||||||
| /// | ||||||||||||||||||||||
| /// ```mlir | ||||||||||||||||||||||
| /// | ||||||||||||||||||||||
| /// #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)> | ||||||||||||||||||||||
| /// #identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> | ||||||||||||||||||||||
| /// ... | ||||||||||||||||||||||
| /// %res = linalg.generic | ||||||||||||||||||||||
| /// { indexing_maps = [#projection, #identity, #identity], | ||||||||||||||||||||||
| /// iterator_types = ["parallel", "parallel", "parallel", | ||||||||||||||||||||||
| /// "parallel", "parallel"]} | ||||||||||||||||||||||
| /// ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>) | ||||||||||||||||||||||
| /// outs(%z : tensor<5x9x7x8x10xf32>) { | ||||||||||||||||||||||
| /// ^bb0(%in: f32, %in_1: f32, %out: f32): | ||||||||||||||||||||||
| /// %div = arith.divf %in, %in_1 : f32 | ||||||||||||||||||||||
| /// linalg.yield %div : f32 | ||||||||||||||||||||||
| /// } -> tensor<5x9x7x8x10xf32> | ||||||||||||||||||||||
| /// ``` | ||||||||||||||||||||||
| /// | ||||||||||||||||||||||
| /// In the above IR operand `%x` map is a projected-permutation. This can be | ||||||||||||||||||||||
| /// unfolded as: | ||||||||||||||||||||||
| /// | ||||||||||||||||||||||
| /// ```mlir | ||||||||||||||||||||||
| /// ... | ||||||||||||||||||||||
| /// %x_trans = linalg.transpose | ||||||||||||||||||||||
| /// ins(%x : tensor<7x8x9xf32>) | ||||||||||||||||||||||
| /// outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1] | ||||||||||||||||||||||
| /// ... | ||||||||||||||||||||||
| /// %x_trans_bc = linalg.broadcast | ||||||||||||||||||||||
| /// ins(%x_trans : tensor<9x7x8xf32>) | ||||||||||||||||||||||
| /// outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4] | ||||||||||||||||||||||
| /// %2 = linalg.div | ||||||||||||||||||||||
| /// ins(%x_trans_bc, %y : | ||||||||||||||||||||||
| /// tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>) | ||||||||||||||||||||||
| /// outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> | ||||||||||||||||||||||
| /// | ||||||||||||||||||||||
| /// Note that linalg.generic has been 'specialized' to linalg.div. | ||||||||||||||||||||||
| /// | ||||||||||||||||||||||
| /// To unfold it, it is more optimal to transpose first and then do the | ||||||||||||||||||||||
| /// broadcast. However, if transpose is done first, the permutation map needs | ||||||||||||||||||||||
| /// to be expressed in terms of reduced dimension as broadcast hasn't happened | ||||||||||||||||||||||
| /// yet. Also, the broadcast dimensions in a linalg.generic come from other | ||||||||||||||||||||||
| /// operands (those not broadcasted along that particular dimension). We work | ||||||||||||||||||||||
| /// this out by computing the convex-polyhedron shape of the linalg.generic | ||||||||||||||||||||||
| /// iteration space from shapes of all the operands, both inputs and outputs. | ||||||||||||||||||||||
| /// | ||||||||||||||||||||||
| struct DecomposeProjectedPermutation : public OpRewritePattern<GenericOp> { | ||||||||||||||||||||||
| using OpRewritePattern<GenericOp>::OpRewritePattern; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| LogicalResult matchAndRewrite(GenericOp genericOp, | ||||||||||||||||||||||
| PatternRewriter &rewriter) const override; | ||||||||||||||||||||||
| }; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /// For the given `map`, determine what dimensions are transposed and what | ||||||||||||||||||||||
| /// dimensions are broadcasted. | ||||||||||||||||||||||
| /// Returns : | ||||||||||||||||||||||
| /// transpose-permutation, broadcast-dimensions` (empty if not needed) | ||||||||||||||||||||||
| /// | ||||||||||||||||||||||
| std::pair<SmallVector<int64_t>, SmallVector<int64_t>> | ||||||||||||||||||||||
| computeTransposeBroadcast(AffineMap &map) { | ||||||||||||||||||||||
| assert(map.isProjectedPermutation(false) && "not a projection"); | ||||||||||||||||||||||
| int64_t minorSize = map.getNumResults(); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| SmallVector<int64_t> minorResult; | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added comments |
||||||||||||||||||||||
| for (int64_t i = 0; i < minorSize; ++i) { | ||||||||||||||||||||||
| auto expr = cast<AffineDimExpr>(map.getResults()[i]); | ||||||||||||||||||||||
| minorResult.push_back(expr.getPosition()); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // If dims are not monotonically increasing then transpose is present. | ||||||||||||||||||||||
| SmallVector<int64_t> sortedResMap(minorResult); | ||||||||||||||||||||||
| std::sort(sortedResMap.begin(), sortedResMap.end()); | ||||||||||||||||||||||
| bool hasTranspose = !std::equal(minorResult.begin(), minorResult.end(), | ||||||||||||||||||||||
| sortedResMap.begin(), sortedResMap.end()); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Walk the sorted map result to determine which dimensions are broadcasted. | ||||||||||||||||||||||
| SmallVector<int64_t> broadcast; | ||||||||||||||||||||||
| for (int64_t i = 0, j = 0; i < map.getNumInputs(); ++i) { | ||||||||||||||||||||||
| if (j < minorSize && sortedResMap[j] == i) { | ||||||||||||||||||||||
| j++; | ||||||||||||||||||||||
| continue; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| broadcast.push_back(i); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| SmallVector<int64_t> permutation; | ||||||||||||||||||||||
| if (hasTranspose) { | ||||||||||||||||||||||
| /// Consider an operand `x : tensor<7x8x9>` of a genericOp that has | ||||||||||||||||||||||
|
||||||||||||||||||||||
| /// Consider an operand `x : tensor<7x8x9>` of a genericOp that has | |
| // Consider an operand `x : tensor<7x8x9>` of a genericOp that has |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit]
| /// `x`s access is both transposed and brodcast. But when specifying | |
| /// `x`s access is both transposed and broadcast. But when specifying |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| for (auto &operand : op->getOpOperands()) { | |
| auto opType = cast<RankedTensorType>(operand.get().getType()); | |
| for (auto size : opType.getShape()) | |
| if (size == ShapedType::kDynamic) | |
| return failure(); | |
| } | |
| if (!llvm::all_of(packOp->getOpOperands(), [](OpOperand &oper) { | |
| auto opType = cast<RankedTensorType>(oper.get().getType()); | |
| return !ShapedType::isDynamicShape(opType.getShape()); | |
| }) return failure(); |
Originally posted here:
GitHub is really good at hiding these 😓
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| // RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s | ||
|
|
||
| #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)> | ||
| #identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> | ||
|
|
||
| func.func @transpose_and_broadcast(%x : tensor<7x8x9xf32>, %y: tensor<5x9x7x8x10xf32>, %z : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> { | ||
| %res = linalg.generic | ||
| { indexing_maps = [#projection, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} | ||
| ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>) outs(%z : tensor<5x9x7x8x10xf32>) { | ||
| ^bb0(%in: f32, %in_1: f32, %out: f32): | ||
| %div = arith.divf %in, %in_1 : f32 | ||
| linalg.yield %div : f32 | ||
| } -> tensor<5x9x7x8x10xf32> | ||
| return %res : tensor<5x9x7x8x10xf32> | ||
| } | ||
|
|
||
| // CHECK-LABEL: transpose_and_broadcast | ||
| // CHECK-SAME: %[[X:.+]]: tensor<7x8x9xf32>, %[[Y:.+]]: tensor<5x9x7x8x10xf32>, %[[Z:.+]]: tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> { | ||
| // CHECK: %[[E0:.+]] = tensor.empty() : tensor<9x7x8xf32> | ||
| // CHECK: %[[X_trans:.+]] = linalg.transpose ins(%[[X]] : tensor<7x8x9xf32>) outs(%[[E0]] : tensor<9x7x8xf32>) permutation = [2, 0, 1] | ||
| // CHECK: %[[E1:.+]] = tensor.empty() : tensor<5x9x7x8x10xf32> | ||
| // CHECK: %[[X_trans_bc:.+]] = linalg.broadcast ins(%[[X_trans]] : tensor<9x7x8xf32>) outs(%[[E1]] : tensor<5x9x7x8x10xf32>) dimensions = [0, 4] | ||
| // CHECK: {{.*}} = linalg.div ins(%[[X_trans_bc]], %[[Y]] : tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>) outs(%[[Z]] : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> | ||
| // CHECK-NOT: linalg.generic | ||
|
|
||
| // ----- | ||
|
|
||
| #identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)> | ||
| #transposed = affine_map<(d0, d1, d2) -> (d2, d0, d1)> | ||
|
|
||
| func.func @transpose_only(%x : tensor<32x2x16xf32>, %y: tensor<2x16x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { | ||
| %res = linalg.generic | ||
| { indexing_maps = [#transposed, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel"]} | ||
| ins(%x, %y : tensor<32x2x16xf32>, tensor<2x16x32xf32>) | ||
| outs(%z : tensor<2x16x32xf32>) { | ||
| ^bb0(%in: f32, %in_1: f32, %out: f32): | ||
| %div = arith.divf %in, %in_1 : f32 | ||
| linalg.yield %div : f32 | ||
| } -> tensor<2x16x32xf32> | ||
| return %res : tensor<2x16x32xf32> | ||
| } | ||
|
|
||
| // CHECK-LABEL: transpose_only | ||
| // CHECK-SAME: %[[X:.+]]: tensor<32x2x16xf32>, %[[Y:.+]]: tensor<2x16x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { | ||
| // CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32> | ||
| // CHECK: %[[X_trans:.+]] = linalg.transpose ins(%[[X]] : tensor<32x2x16xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) permutation = [1, 2, 0] | ||
| // CHECK: {{.*}} = linalg.div ins(%[[X_trans]], %[[Y]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%[[Z]] : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> | ||
| // CHECK-NOT: linalg.generic | ||
|
|
||
| // ----- | ||
|
|
||
| #identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)> | ||
| #broadcast = affine_map<(d0, d1, d2) -> (d0, d2)> | ||
| func.func @broadcast_only(%x : tensor<2x16x32xf32>, %y: tensor<2x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { | ||
| %res = linalg.generic | ||
| { indexing_maps = [#identity, #broadcast, #identity], iterator_types = ["parallel", "parallel", "parallel"]} | ||
| ins(%x, %y : tensor<2x16x32xf32>, tensor<2x32xf32>) | ||
| outs(%z : tensor<2x16x32xf32>) { | ||
| ^bb0(%in: f32, %in_1: f32, %out: f32): | ||
| %div = arith.divf %in, %in_1 : f32 | ||
| linalg.yield %div : f32 | ||
| } -> tensor<2x16x32xf32> | ||
| return %res : tensor<2x16x32xf32> | ||
| } | ||
|
|
||
| // CHECK-LABEL: broadcast_only | ||
| // CHECK-SAME: %[[X:.+]]: tensor<2x16x32xf32>, %[[Y:.+]]: tensor<2x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { | ||
| // CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32> | ||
| // CHECK: %[[X_bc:.+]] = linalg.broadcast ins(%[[Y]] : tensor<2x32xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) dimensions = [1] | ||
| // CHECK: {{.*}} = linalg.div ins(%[[X]], %[[X_bc]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%arg2 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> | ||
| // CHECK-NOT: linalg.generic |
Uh oh!
There was an error while loading. Please reload this page.