1717#include " mlir/IR/Matchers.h"
1818#include " mlir/IR/OpDefinition.h"
1919#include " mlir/IR/PatternMatch.h"
20+ #include " mlir/Interfaces/FunctionInterfaces.h"
2021#include " mlir/Interfaces/ShapedOpInterfaces.h"
2122#include " mlir/Interfaces/ValueBoundsOpInterface.h"
2223#include " mlir/Transforms/InliningUtils.h"
@@ -352,9 +353,13 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
352353
353354 // Conservatively handle remaining BlockArguments as non-valid symbols.
354355 // E.g. scf.for iterArgs.
355- if (llvm::isa<BlockArgument>(dimOp.getShapedValue ()))
356- return false ;
357-
356+ if (auto blockArgument =
357+ llvm::dyn_cast<BlockArgument>(dimOp.getShapedValue ())) {
358+ if (!llvm::isa<FunctionOpInterface>(
359+ blockArgument.getParentRegion ()->getParentOp ())) {
360+ return false ;
361+ }
362+ }
358363 // The dim op is also okay if its operand memref is a view/subview whose
359364 // corresponding size is a valid symbol.
360365 std::optional<int64_t > index = getConstantIntValue (dimOp.getDimension ());
@@ -365,6 +370,11 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
365370
366371 // Skip over all memref.cast ops (if any).
367372 Operation *op = dimOp.getShapedValue ().getDefiningOp ();
373+
374+ // the ShapedValue of the dim is the function block argument.
375+ if (!op)
376+ return true ;
377+
368378 while (auto castOp = dyn_cast<memref::CastOp>(op)) {
369379 // Bail on unranked memrefs.
370380 if (isa<UnrankedMemRefType>(castOp.getSource ().getType ()))
0 commit comments