|
11 | 11 | //===----------------------------------------------------------------------===// |
12 | 12 |
|
13 | 13 | #include "circt/Dialect/RTG/IR/RTGOps.h" |
| 14 | +#include "circt/Support/ParsingUtils.h" |
14 | 15 | #include "mlir/IR/Builders.h" |
15 | 16 | #include "mlir/IR/DialectImplementation.h" |
| 17 | +#include "llvm/ADT/SmallString.h" |
16 | 18 |
|
17 | 19 | using namespace mlir; |
18 | 20 | using namespace circt; |
@@ -399,6 +401,115 @@ LogicalResult TestOp::verifyRegions() { |
399 | 401 | return success(); |
400 | 402 | } |
401 | 403 |
|
| 404 | +ParseResult TestOp::parse(OpAsmParser &parser, OperationState &result) { |
| 405 | + // Parse the name as a symbol. |
| 406 | + if (parser.parseSymbolName( |
| 407 | + result.getOrAddProperties<TestOp::Properties>().sym_name)) |
| 408 | + return failure(); |
| 409 | + |
| 410 | + // Parse the function signature. |
| 411 | + SmallVector<OpAsmParser::Argument> arguments; |
| 412 | + SmallVector<StringAttr> names; |
| 413 | + |
| 414 | + auto parseOneArgument = [&]() -> ParseResult { |
| 415 | + std::string name; |
| 416 | + auto res = |
| 417 | + parser.parseOptionalKeywordOrString(&name) || parser.parseColon(); |
| 418 | + |
| 419 | + auto argLoc = parser.getCurrentLocation(); |
| 420 | + if (failed(parser.parseArgument(arguments.emplace_back(), |
| 421 | + /*allowType=*/true, /*allowAttrs=*/true))) |
| 422 | + return failure(); |
| 423 | + |
| 424 | + // If no explicit name was provided, try to use the SSA name. |
| 425 | + if (res) { |
| 426 | + auto inferredName = parsing_util::getNameFromSSA( |
| 427 | + result.getContext(), arguments.back().ssaName.name); |
| 428 | + if (inferredName.empty()) |
| 429 | + return parser.emitError(argLoc, "invalid SSA name for test argument"); |
| 430 | + names.push_back(inferredName); |
| 431 | + } else { |
| 432 | + names.push_back(StringAttr::get(result.getContext(), name)); |
| 433 | + } |
| 434 | + |
| 435 | + return success(); |
| 436 | + }; |
| 437 | + if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, |
| 438 | + parseOneArgument, " in argument list")) |
| 439 | + return failure(); |
| 440 | + |
| 441 | + SmallVector<Type> argTypes; |
| 442 | + SmallVector<DictEntry> entries; |
| 443 | + SmallVector<Location> argLocs; |
| 444 | + argTypes.reserve(arguments.size()); |
| 445 | + argLocs.reserve(arguments.size()); |
| 446 | + for (auto [name, arg] : llvm::zip(names, arguments)) { |
| 447 | + argTypes.push_back(arg.type); |
| 448 | + argLocs.push_back(arg.sourceLoc ? *arg.sourceLoc : result.location); |
| 449 | + entries.push_back({name, arg.type}); |
| 450 | + } |
| 451 | + auto emitError = [&]() -> InFlightDiagnostic { |
| 452 | + return parser.emitError(parser.getCurrentLocation()); |
| 453 | + }; |
| 454 | + Type type = DictType::getChecked(emitError, result.getContext(), |
| 455 | + ArrayRef<DictEntry>(entries)); |
| 456 | + if (!type) |
| 457 | + return failure(); |
| 458 | + result.getOrAddProperties<TestOp::Properties>().target = TypeAttr::get(type); |
| 459 | + |
| 460 | + auto loc = parser.getCurrentLocation(); |
| 461 | + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) |
| 462 | + return failure(); |
| 463 | + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { |
| 464 | + return parser.emitError(loc) |
| 465 | + << "'" << result.name.getStringRef() << "' op "; |
| 466 | + }))) |
| 467 | + return failure(); |
| 468 | + |
| 469 | + std::unique_ptr<Region> bodyRegionRegion = std::make_unique<Region>(); |
| 470 | + if (parser.parseRegion(*bodyRegionRegion, arguments)) |
| 471 | + return failure(); |
| 472 | + |
| 473 | + if (bodyRegionRegion->empty()) { |
| 474 | + bodyRegionRegion->emplaceBlock(); |
| 475 | + bodyRegionRegion->addArguments(argTypes, argLocs); |
| 476 | + } |
| 477 | + result.addRegion(std::move(bodyRegionRegion)); |
| 478 | + |
| 479 | + return success(); |
| 480 | +} |
| 481 | + |
| 482 | +void TestOp::print(OpAsmPrinter &p) { |
| 483 | + p << ' '; |
| 484 | + p.printSymbolName(getSymNameAttr().getValue()); |
| 485 | + p << "("; |
| 486 | + SmallString<32> resultNameStr; |
| 487 | + llvm::interleaveComma( |
| 488 | + llvm::zip(getTarget().getEntries(), getBody()->getArguments()), p, |
| 489 | + [&](auto entryAndArg) { |
| 490 | + auto [entry, arg] = entryAndArg; |
| 491 | + |
| 492 | + resultNameStr.clear(); |
| 493 | + llvm::raw_svector_ostream tmpStream(resultNameStr); |
| 494 | + p.printOperand(arg, tmpStream); |
| 495 | + if (tmpStream.str().drop_front() != entry.name) |
| 496 | + p << entry.name.getValue() << ": "; |
| 497 | + p.printRegionArgument(arg); |
| 498 | + }); |
| 499 | + p << ")"; |
| 500 | + p.printOptionalAttrDictWithKeyword( |
| 501 | + (*this)->getAttrs(), {getSymNameAttrName(), getTargetAttrName()}); |
| 502 | + p << ' '; |
| 503 | + p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false); |
| 504 | +} |
| 505 | + |
| 506 | +void TestOp::getAsmBlockArgumentNames(Region ®ion, |
| 507 | + OpAsmSetValueNameFn setNameFn) { |
| 508 | + for (auto [entry, arg] : |
| 509 | + llvm::zip(getTarget().getEntries(), region.getArguments())) |
| 510 | + setNameFn(arg, entry.name.getValue()); |
| 511 | +} |
| 512 | + |
402 | 513 | //===----------------------------------------------------------------------===// |
403 | 514 | // TargetOp |
404 | 515 | //===----------------------------------------------------------------------===// |
|
0 commit comments