Skip to content

Commit c069ddd

Browse files
fix create affine.for bug.
1 parent db6f627 commit c069ddd

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/IR/Matchers.h"
1818
#include "mlir/IR/OpDefinition.h"
1919
#include "mlir/IR/PatternMatch.h"
20+
#include "mlir/Interfaces/FunctionInterfaces.h"
2021
#include "mlir/Interfaces/ShapedOpInterfaces.h"
2122
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
2223
#include "mlir/Transforms/InliningUtils.h"
@@ -352,9 +353,13 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
352353

353354
// Conservatively handle remaining BlockArguments as non-valid symbols.
354355
// E.g. scf.for iterArgs.
355-
if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
356-
return false;
357-
356+
if (auto blockArgument =
357+
llvm::dyn_cast<BlockArgument>(dimOp.getShapedValue())) {
358+
if (!llvm::isa<FunctionOpInterface>(
359+
blockArgument.getParentRegion()->getParentOp())) {
360+
return false;
361+
}
362+
}
358363
// The dim op is also okay if its operand memref is a view/subview whose
359364
// corresponding size is a valid symbol.
360365
std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
@@ -365,6 +370,11 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
365370

366371
// Skip over all memref.cast ops (if any).
367372
Operation *op = dimOp.getShapedValue().getDefiningOp();
373+
374+
// the ShapedValue of the dim is the function block argument.
375+
if (!op)
376+
return true;
377+
368378
while (auto castOp = dyn_cast<memref::CastOp>(op)) {
369379
// Bail on unranked memrefs.
370380
if (isa<UnrankedMemRefType>(castOp.getSource().getType()))

0 commit comments

Comments
 (0)