Skip to content

Commit 5b9a4e1

Browse files
fix scf.for parser
1 parent 6812fc0 commit 5b9a4e1

File tree

3 files changed

+33
-12
lines changed

3 files changed

+33
-12
lines changed

mlir/lib/AsmParser/Parser.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,12 @@ class OperationParser : public Parser {
820820
/// their first reference, to allow checking for use of undefined values.
821821
DenseMap<Value, SMLoc> forwardRefPlaceholders;
822822

823+
/// Operations that define the placeholders. These are kept until the end of
824+
/// of the lifetime of the parser because some custom parsers may store
825+
/// references to them in local state and use them after forward references
826+
/// have been resolved.
827+
DenseSet<Operation *> forwardRefOps;
828+
823829
/// Deffered locations: when parsing `loc(#loc42)` we add an entry to this
824830
/// map. After parsing the definition `#loc42 = ...` we'll patch back users
825831
/// of this location.
@@ -847,11 +853,11 @@ OperationParser::OperationParser(ParserState &state, ModuleOp topLevelOp)
847853
}
848854

849855
OperationParser::~OperationParser() {
850-
for (auto &fwd : forwardRefPlaceholders) {
856+
for (Operation *op : forwardRefOps) {
851857
// Drop all uses of undefined forward declared reference and destroy
852858
// defining operation.
853-
fwd.first.dropAllUses();
854-
fwd.first.getDefiningOp()->destroy();
859+
op->dropAllUses();
860+
op->destroy();
855861
}
856862
for (const auto &scope : forwardRef) {
857863
for (const auto &fwd : scope) {
@@ -1007,7 +1013,6 @@ ParseResult OperationParser::addDefinition(UnresolvedOperand useInfo,
10071013
// the actual definition instead, delete the forward ref, and remove it
10081014
// from our set of forward references we track.
10091015
existing.replaceAllUsesWith(value);
1010-
existing.getDefiningOp()->destroy();
10111016
forwardRefPlaceholders.erase(existing);
10121017

10131018
// If a definition of the value already exists, replace it in the assembly
@@ -1194,6 +1199,7 @@ Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
11941199
/*attributes=*/std::nullopt, /*properties=*/nullptr, /*successors=*/{},
11951200
/*numRegions=*/0);
11961201
forwardRefPlaceholders[op->getResult(0)] = loc;
1202+
forwardRefOps.insert(op);
11971203
return op->getResult(0);
11981204
}
11991205

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -499,8 +499,20 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
499499
else if (parser.parseType(type))
500500
return failure();
501501

502-
// Resolve input operands.
502+
// Set block argument types, so that they are known when parsing the region.
503503
regionArgs.front().type = type;
504+
for (auto [iterArg, type] :
505+
llvm::zip(llvm::drop_begin(regionArgs), result.types))
506+
iterArg.type = type;
507+
508+
// Parse the body region.
509+
Region *body = result.addRegion();
510+
if (parser.parseRegion(*body, regionArgs))
511+
return failure();
512+
ForOp::ensureTerminator(*body, builder, result.location);
513+
514+
// Resolve input operands. This should be done after parsing the region to
515+
// catch invalid IR where operands were defined inside of the region.
504516
if (parser.resolveOperand(lb, type, result.operands) ||
505517
parser.resolveOperand(ub, type, result.operands) ||
506518
parser.resolveOperand(step, type, result.operands))
@@ -516,13 +528,6 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
516528
}
517529
}
518530

519-
// Parse the body region.
520-
Region *body = result.addRegion();
521-
if (parser.parseRegion(*body, regionArgs))
522-
return failure();
523-
524-
ForOp::ensureTerminator(*body, builder, result.location);
525-
526531
// Parse the optional attribute list.
527532
if (parser.parseOptionalAttrDict(result.attributes))
528533
return failure();

mlir/test/Dialect/SCF/invalid.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,3 +747,13 @@ func.func @parallel_missing_terminator(%0 : index) {
747747
return
748748
}
749749

750+
// -----
751+
752+
func.func @invalid_reference(%a: index) {
753+
// expected-error @below{{use of undeclared SSA value name}}
754+
scf.for %x = %a to %a step %a iter_args(%var = %foo) -> tensor<?xf32> {
755+
%foo = "test.inner"() : () -> (tensor<?xf32>)
756+
scf.yield %foo : tensor<?xf32>
757+
}
758+
return
759+
}

0 commit comments

Comments
 (0)