-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][affine]fix create affine.for bug. #117721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-affine @llvm/pr-subscribers-mlir Author: lonely eagle (linuxlonelyeagle) ChangesI encountered this in the pass I wrote. This is because affinemap is used in the lower-bound or upper-bound of create affine.for, and the symbol for affinemap comes from a memref.dim whose memref is a function argument, affine.for check will be failed. Full diff: https://github.com/llvm/llvm-project/pull/117721.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 1c5466730a5589..0d24e434328419 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
@@ -352,9 +353,13 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
// Conservatively handle remaining BlockArguments as non-valid symbols.
// E.g. scf.for iterArgs.
- if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
- return false;
-
+ if (auto blockArgument =
+ llvm::dyn_cast<BlockArgument>(dimOp.getShapedValue())) {
+ if (!llvm::isa<FunctionOpInterface>(
+ blockArgument.getParentRegion()->getParentOp())) {
+ return false;
+ }
+ }
// The dim op is also okay if its operand memref is a view/subview whose
// corresponding size is a valid symbol.
std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
@@ -365,6 +370,11 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
// Skip over all memref.cast ops (if any).
Operation *op = dimOp.getShapedValue().getDefiningOp();
+
+ // the ShapedValue of the dim is the function block argument.
+ if (!op)
+ return true;
+
while (auto castOp = dyn_cast<memref::CastOp>(op)) {
// Bail on unranked memrefs.
if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
|
|
I believe this issue could be made even clearer.Below are the results after I fixed this bug.If you have any questions, welcome to tell me. |
ftynse
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. Let's first understand whether this is a bug or the intended behavior. This starts with adding a test. The test should show that something that wasn't accepted as a symbol becomes accepted as a symbol, i.e. does not emit an error, after the patch.
Also consider the fact not all block arguments are function arguments. One can perfectly well have a
func.func @foo(...) {
cf.br ^bb1(...)
^bb1(%bbarg: memref<?xf32>):
%dim = memref.dim %bbarg, %c0
%new = call @memref_realloc(%bbarg, 2 * %dim)
cf.cond_br ^bb1(%new), ^bb2
^bb2:
return
}
where %bbarg is a block argument, but it's dim cannot be used as a symbol because it changes.
| if (!llvm::isa<FunctionOpInterface>( | ||
| blockArgument.getParentRegion()->getParentOp())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Functions may have blocks other than the entry block. Not all block arguments are function arguments, so this change looks suspicious to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right.Thanks for the advice, I probably already know how to do it.Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can parse the following IR with mlir-opt, which will depart the bug, I found that I can depart the bug via generic IR.
In that case, I can write tests too.
#map = affine_map<()[s0] -> (s0)>
"builtin.module"() ({
"gpu.module"() <{sym_name = "gpu"}> ({
"gpu.func"() <{function_type = (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()}> ({
^bb0(%arg3: memref<?x?xf32>, %arg4: memref<?x?xf32>, %arg5: memref<?x?xf32>):
%16 = "arith.constant"() <{value = 1 : index}> : () -> index
%17 = "memref.dim"(%arg3, %16) : (memref<?x?xf32>, index) -> index
%18 = "arith.constant"() <{value = 0 : index}> : () -> index
"affine.for"(%18, %17) <{lowerBoundMap = #map, operandSegmentSizes = array<i32: 1, 1, 0>, step = 32 : index, upperBoundMap = #map}> ({
^bb0(%arg6: index):
"affine.yield"() : () -> ()
}) : (index, index) -> ()
"gpu.return"() : () -> ()
}) {gpu.kernel, sym_name = "gemm", workgroup_attributions = 0 : i64} : () -> ()
}) : () -> ()
"func.func"() <{function_type = (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> f32, sym_name = "main"}> ({
^bb0(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>):
%0 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
%1 = "arith.constant"() <{value = 1.000000e+00 : f32}> : () -> f32
%2 = "arith.constant"() <{value = 2.000000e+00 : f32}> : () -> f32
%3 = "arith.constant"() <{value = 0 : index}> : () -> index
%4 = "memref.dim"(%arg0, %3) : (memref<?x?xf32>, index) -> index
%5 = "arith.constant"() <{value = 1 : index}> : () -> index
%6 = "memref.dim"(%arg0, %5) : (memref<?x?xf32>, index) -> index
%7 = "arith.constant"() <{value = 1 : index}> : () -> index
%8 = "memref.dim"(%arg1, %7) : (memref<?x?xf32>, index) -> index
%9 = "arith.constant"() <{value = 128 : index}> : () -> index
%10 = "arith.ceildivui"(%4, %9) : (index, index) -> index
%11 = "arith.constant"() <{value = 64 : index}> : () -> index
%12 = "arith.ceildivsi"(%6, %11) : (index, index) -> index
%13 = "arith.constant"() <{value = 256 : index}> : () -> index
%14 = "arith.constant"() <{value = 262144 : i32}> : () -> i32
%15 = "arith.constant"() <{value = 1 : index}> : () -> index
"gpu.launch_func"(%12, %10, %15, %13, %15, %15, %14, %arg0, %arg1, %arg2) <{kernel = @gpu::@gemm, operandSegmentSizes = array<i32: 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 3, 0>}> : (index, index, index, index, index, index, i32, memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
"func.return"(%0) : (f32) -> ()
}) : () -> ()
}) {gpu.container_module} : () -> ()
But in that case, there is another question I'd like to ask, which I'm not thinking about very clearly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused, for the block bb0, if it's parameter is a memref, then it's dimensions can change as well, but it shouldn't cause an effect like the one inside the example you gave, I'm not very sure. I think this needs to be confirmed.I'm not quite sure how to fix this.I'd appreciate some guidance on this.Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I've figured it out, and I'll modify the patch later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a new development on this issue, I found the real problem because gpu.func doesn't have AffineScope Traits.I'm going to have to look further on this issue. @ftynse Thank you for the guidance you've given me. I think I'm still making progress.
This PR in order to solve the following problem. #117721. To efficiently implement the thread-to-data mapping relationship, I introduced AffineScope in gpu.func(Data or thread layout).
I encountered this in the pass I wrote.
This is because affinemap is used in the lower-bound or upper-bound of create affine.for, and the symbol for affinemap comes from a memref.dim whose memref is a function argument, affine.for check will be failed.
Something like the following, but the code below doesn't make sense. What I'm trying to say is that I created such affine.for in pass encountered the above bug. but it's worth mentioning that if you write the following IR by hand, there is no problem. So I didn't add a test.