-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][linalg] unfold projected permutation. #114704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
cdf865c
ce58238
b9094dc
3b238c6
e3373b8
296f805
6f61f9a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,243 @@ | ||||||||||||||||||||||
| //===- UnfoldProjectedPermutation.cpp - extract projected projections ---===// | ||||||||||||||||||||||
| // | ||||||||||||||||||||||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||||||||||||||||||||
| // See https://llvm.org/LICENSE.txt for license information. | ||||||||||||||||||||||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||||||||||||||||||||
| // | ||||||||||||||||||||||
| //===----------------------------------------------------------------------===// | ||||||||||||||||||||||
| // | ||||||||||||||||||||||
| #include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||||||||||||||||||||||
| #include <utility> | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| #include "mlir/Dialect/Affine/IR/AffineOps.h" | ||||||||||||||||||||||
| #include "mlir/Dialect/Linalg/IR/Linalg.h" | ||||||||||||||||||||||
| #include <map> | ||||||||||||||||||||||
| #include <optional> | ||||||||||||||||||||||
| #include <vector> | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| using namespace mlir; | ||||||||||||||||||||||
| using namespace mlir::linalg; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| namespace { | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /// This file implements pattern to decompose the input operand(s) of a | ||||||||||||||||||||||
|
||||||||||||||||||||||
| /// This file implements pattern to decompose the input operand(s) of a | |
| /// This pattern decomposes the input operand(s) of a |
? Isn't this comment attached to the pattern itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just followed the style of [1] which i really liked as right at the beginning it is explains the algorithm the file implements.
[1] https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| /// linalg.generic that has a `transpose`, `broadcast` or a mixture of two, | |
| /// linalg.generic that has a `transpose`, `broadcast`, or a mixture of the two, |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| /// linalg.generic is a good optimization but sometimes we may want to unwrap | |
| /// i.e. `unfold` them as explicit transpose and broadcast. This rewrite | |
| /// linalg.generic is a good optimization but sometimes we may want to unwrap, | |
| /// i.e., `unfold` them as explicit transpose and broadcast. This rewrite |
banach-space marked this conversation as resolved.
Show resolved
Hide resolved
banach-space marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
banach-space marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
banach-space marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
banach-space marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Intuitively this makes sense, but ... why? 😅 Which part would break?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ping
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this work at all for dynamic shapes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For a start this will assert when trying to create tensor.empty with dynamic shape. https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp#L874
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, rather than documenting what the code does, could you add a comment saying "why"? Or what's missing? From what you are saying, we'd need to add logic to compute dynamic sizes of the input tensors for ops like EmptyOp? And probably sth else as well?
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| for (auto &operand : op->getOpOperands()) { | |
| auto opType = cast<RankedTensorType>(operand.get().getType()); | |
| for (auto size : opType.getShape()) | |
| if (size == ShapedType::kDynamic) | |
| return failure(); | |
| } | |
| if (!llvm::all_of(packOp->getOpOperands(), [](OpOperand &oper) { | |
| auto opType = cast<RankedTensorType>(oper.get().getType()); | |
| return !ShapedType::isDynamicShape(opType.getShape()); | |
| }) return failure(); |
This way it's easier to emphasis the key logic rather than all the loops and if stmts. Please double check syntax 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ping
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (!llvm::all_of(packOp->getOpOperands(), [](OpOperand &oper) {
auto opType = cast(oper.get().getType());
return !ShapedType::isDynamicShape(opType.getShape());
}) return failure();
Thanks for this more elegant code. I guess you mean 'any_of'.
banach-space marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| // RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s | ||
|
|
||
| #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)> | ||
| #identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> | ||
|
|
||
| func.func @transpose_and_broadcast(%x : tensor<7x8x9xf32>, %y: tensor<5x9x7x8x10xf32>, %z : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> { | ||
| %res = linalg.generic | ||
| { indexing_maps = [#projection, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} | ||
| ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>) outs(%z : tensor<5x9x7x8x10xf32>) { | ||
| ^bb0(%in: f32, %in_1: f32, %out: f32): | ||
| %div = arith.divf %in, %in_1 : f32 | ||
| linalg.yield %div : f32 | ||
| } -> tensor<5x9x7x8x10xf32> | ||
| return %res : tensor<5x9x7x8x10xf32> | ||
| } | ||
|
|
||
| // CHECK-LABEL: transpose_and_broadcast | ||
| // CHECK-SAME: %[[X:.+]]: tensor<7x8x9xf32>, %[[Y:.+]]: tensor<5x9x7x8x10xf32>, %[[Z:.+]]: tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> { | ||
| // CHECK: %[[E0:.+]] = tensor.empty() : tensor<9x7x8xf32> | ||
| // CHECK: %[[X_trans:.+]] = linalg.transpose ins(%[[X]] : tensor<7x8x9xf32>) outs(%[[E0]] : tensor<9x7x8xf32>) permutation = [2, 0, 1] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would have expected this to be
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, Now, I managed convince myself that this is correct, but please double check for yourself 😅 @MaheshRavishankar , you might be skewed by: #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>I think this is the trick (IIUC, this is the actual mapping here):
Whereas you assume that:
Does it make sense?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes of course you are right! Two parts to this.
Therefore, for input
|
||
| // CHECK: %[[E1:.+]] = tensor.empty() : tensor<5x9x7x8x10xf32> | ||
| // CHECK: %[[X_trans_bc:.+]] = linalg.broadcast ins(%[[X_trans]] : tensor<9x7x8xf32>) outs(%[[E1]] : tensor<5x9x7x8x10xf32>) dimensions = [0, 4] | ||
| // CHECK: {{.*}} = linalg.div ins(%[[X_trans_bc]], %[[Y]] : tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>) outs(%[[Z]] : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> | ||
| // CHECK-NOT: linalg.generic | ||
|
|
||
| // ----- | ||
|
|
||
| #identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)> | ||
| #transposed = affine_map<(d0, d1, d2) -> (d2, d0, d1)> | ||
|
|
||
| func.func @transpose_only(%x : tensor<32x2x16xf32>, %y: tensor<2x16x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { | ||
| %res = linalg.generic | ||
| { indexing_maps = [#transposed, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel"]} | ||
| ins(%x, %y : tensor<32x2x16xf32>, tensor<2x16x32xf32>) | ||
| outs(%z : tensor<2x16x32xf32>) { | ||
| ^bb0(%in: f32, %in_1: f32, %out: f32): | ||
| %div = arith.divf %in, %in_1 : f32 | ||
| linalg.yield %div : f32 | ||
| } -> tensor<2x16x32xf32> | ||
| return %res : tensor<2x16x32xf32> | ||
| } | ||
|
|
||
| // CHECK-LABEL: transpose_only | ||
| // CHECK-SAME: %[[X:.+]]: tensor<32x2x16xf32>, %[[Y:.+]]: tensor<2x16x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { | ||
| // CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32> | ||
| // CHECK: %[[X_trans:.+]] = linalg.transpose ins(%[[X]] : tensor<32x2x16xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) permutation = [1, 2, 0] | ||
| // CHECK: {{.*}} = linalg.div ins(%[[X_trans]], %[[Y]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%[[Z]] : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> | ||
| // CHECK-NOT: linalg.generic | ||
|
|
||
| // ----- | ||
|
|
||
| #identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)> | ||
| #broadcast = affine_map<(d0, d1, d2) -> (d0, d2)> | ||
| func.func @broadcast_only(%x : tensor<2x16x32xf32>, %y: tensor<2x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { | ||
| %res = linalg.generic | ||
| { indexing_maps = [#identity, #broadcast, #identity], iterator_types = ["parallel", "parallel", "parallel"]} | ||
| ins(%x, %y : tensor<2x16x32xf32>, tensor<2x32xf32>) | ||
| outs(%z : tensor<2x16x32xf32>) { | ||
| ^bb0(%in: f32, %in_1: f32, %out: f32): | ||
| %div = arith.divf %in, %in_1 : f32 | ||
| linalg.yield %div : f32 | ||
| } -> tensor<2x16x32xf32> | ||
| return %res : tensor<2x16x32xf32> | ||
| } | ||
|
|
||
| // CHECK-LABEL: broadcast_only | ||
| // CHECK-SAME: %[[X:.+]]: tensor<2x16x32xf32>, %[[Y:.+]]: tensor<2x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { | ||
| // CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32> | ||
| // CHECK: %[[X_bc:.+]] = linalg.broadcast ins(%[[Y]] : tensor<2x32xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) dimensions = [1] | ||
| // CHECK: {{.*}} = linalg.div ins(%[[X]], %[[X_bc]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%arg2 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> | ||
| // CHECK-NOT: linalg.generic | ||
Uh oh!
There was an error while loading. Please reload this page.