-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir][TD] Support padding with poison #152003
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: James Newling (newling) ChangesTestingIn this file In IREE, I can see this works via a less targeted test using Full diff: https://github.com/llvm/llvm-project/pull/152003.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 2e6252336dfeb..3d12bc397813b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -230,13 +231,18 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
Value paddingValue;
if (auto complexTy =
dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
- auto complexAttr = cast<ArrayAttr>(paddingValueAttr);
- paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
- complexTy, complexAttr);
- } else {
- paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(),
- cast<TypedAttr>(paddingValueAttr));
+ if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) {
+ paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
+ complexTy, complexAttr);
+ }
+ } else if (isa<ub::PoisonAttr>(paddingValueAttr)) {
+ paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(),
+ getElementTypeOrSelf(v.getType()));
+ } else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) {
+ paddingValue =
+ arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr);
}
+ assert(paddingValue && "failed to create value from padding attribute");
// Pad the operand to the bounding box defined by `paddedShape`.
SmallVector<int64_t> tensorShape;
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
index f7418769f79ca..2857b53103779 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -4,6 +4,7 @@
// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
func.func @pad_fill(%value: f32, %output: tensor<24x25xf32>) -> tensor<24x25xf32>
{
+ // %goo = ub.poison : f32
%0 = linalg.fill ins(%value : f32) outs(%output : tensor<24x25xf32>) -> tensor<24x25xf32>
func.return %0 : tensor<24x25xf32>
}
@@ -18,7 +19,8 @@ module attributes {transform.with_named_sequence} {
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] {
- padding_values=[0.0 : f32, 0.0 : f32]
+ // padding_values= [poison, 0.0 : f32]
+ padding_values= [0.0 : f32, 0.0 : f32]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
|
@Hardcode84 a poison related question you might have an idea about -- is an array attribute like |
You can try |
Thanks you so much @Hardcode84 ! |
Signed-off-by: James Newling <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
[nit] Isn't this PR updating TD rather than Tensor? |
Maybe Linalg? At least, all the files touched are in a directory /Dialect/Linalg/. But definitely not Tensor, you're right. |
No description provided.