Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
Expand Down Expand Up @@ -352,9 +353,13 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {

// Conservatively handle remaining BlockArguments as non-valid symbols.
// E.g. scf.for iterArgs.
if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
return false;

if (auto blockArgument =
llvm::dyn_cast<BlockArgument>(dimOp.getShapedValue())) {
if (!llvm::isa<FunctionOpInterface>(
blockArgument.getParentRegion()->getParentOp())) {
Comment on lines +358 to +359
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functions may have blocks other than the entry block. Not all block arguments are function arguments, so this change looks suspicious to me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right.Thanks for the advice, I probably already know how to do it.Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can parse the following IR with mlir-opt, which will depart the bug, I found that I can depart the bug via generic IR.
In that case, I can write tests too.

#map = affine_map<()[s0] -> (s0)>
"builtin.module"() ({
  "gpu.module"() <{sym_name = "gpu"}> ({
    "gpu.func"() <{function_type = (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()}> ({
    ^bb0(%arg3: memref<?x?xf32>, %arg4: memref<?x?xf32>, %arg5: memref<?x?xf32>):
      %16 = "arith.constant"() <{value = 1 : index}> : () -> index
      %17 = "memref.dim"(%arg3, %16) : (memref<?x?xf32>, index) -> index
      %18 = "arith.constant"() <{value = 0 : index}> : () -> index
      "affine.for"(%18, %17) <{lowerBoundMap = #map, operandSegmentSizes = array<i32: 1, 1, 0>, step = 32 : index, upperBoundMap = #map}> ({
      ^bb0(%arg6: index):
        "affine.yield"() : () -> ()
      }) : (index, index) -> ()
      "gpu.return"() : () -> ()
    }) {gpu.kernel, sym_name = "gemm", workgroup_attributions = 0 : i64} : () -> ()
  }) : () -> ()
  "func.func"() <{function_type = (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> f32, sym_name = "main"}> ({
  ^bb0(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>):
    %0 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
    %1 = "arith.constant"() <{value = 1.000000e+00 : f32}> : () -> f32
    %2 = "arith.constant"() <{value = 2.000000e+00 : f32}> : () -> f32
    %3 = "arith.constant"() <{value = 0 : index}> : () -> index
    %4 = "memref.dim"(%arg0, %3) : (memref<?x?xf32>, index) -> index
    %5 = "arith.constant"() <{value = 1 : index}> : () -> index
    %6 = "memref.dim"(%arg0, %5) : (memref<?x?xf32>, index) -> index
    %7 = "arith.constant"() <{value = 1 : index}> : () -> index
    %8 = "memref.dim"(%arg1, %7) : (memref<?x?xf32>, index) -> index
    %9 = "arith.constant"() <{value = 128 : index}> : () -> index
    %10 = "arith.ceildivui"(%4, %9) : (index, index) -> index
    %11 = "arith.constant"() <{value = 64 : index}> : () -> index
    %12 = "arith.ceildivsi"(%6, %11) : (index, index) -> index
    %13 = "arith.constant"() <{value = 256 : index}> : () -> index
    %14 = "arith.constant"() <{value = 262144 : i32}> : () -> i32
    %15 = "arith.constant"() <{value = 1 : index}> : () -> index
    "gpu.launch_func"(%12, %10, %15, %13, %15, %15, %14, %arg0, %arg1, %arg2) <{kernel = @gpu::@gemm, operandSegmentSizes = array<i32: 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 3, 0>}> : (index, index, index, index, index, index, i32, memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
    "func.return"(%0) : (f32) -> ()
  }) : () -> ()
}) {gpu.container_module} : () -> ()

But in that case, there is another question I'd like to ask, which I'm not thinking about very clearly.

Copy link
Member Author

@linuxlonelyeagle linuxlonelyeagle Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused, for the block bb0, if it's parameter is a memref, then it's dimensions can change as well, but it shouldn't cause an effect like the one inside the example you gave, I'm not very sure. I think this needs to be confirmed.I'm not quite sure how to fix this.I'd appreciate some guidance on this.Thanks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I've figured it out, and I'll modify the patch later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a new development on this issue, I found the real problem because gpu.func doesn't have AffineScope Traits.I'm going to have to look further on this issue. @ftynse Thank you for the guidance you've given me. I think I'm still making progress.

return false;
}
}
// The dim op is also okay if its operand memref is a view/subview whose
// corresponding size is a valid symbol.
std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
Expand All @@ -365,6 +370,11 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {

// Skip over all memref.cast ops (if any).
Operation *op = dimOp.getShapedValue().getDefiningOp();

// the ShapedValue of the dim is the function block argument.
if (!op)
return true;

while (auto castOp = dyn_cast<memref::CastOp>(op)) {
// Bail on unranked memrefs.
if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
Expand Down