From 7e4a6a7ce6018cf2ee8da7138bd84a3c66ee82a8 Mon Sep 17 00:00:00 2001 From: Perry Gibson Date: Fri, 26 Jan 2024 14:34:14 +0000 Subject: [PATCH 01/11] Added single result SSA processing --- mlir/include/mlir/IR/AsmState.h | 12 +++++++++--- mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h | 10 ++++++++++ mlir/lib/AsmParser/Parser.cpp | 17 +++++++++++++++++ mlir/lib/IR/AsmPrinter.cpp | 17 +++++++++++------ mlir/lib/Tools/mlir-opt/MlirOptMain.cpp | 10 ++++++++-- 5 files changed, 55 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h index 42cbedcf9f883..9c4eadb04cdf2 100644 --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -144,8 +144,7 @@ class AsmResourceBlob { /// Return the underlying data as an array of the given type. This is an /// inherrently unsafe operation, and should only be used when the data is /// known to be of the correct type. - template - ArrayRef getDataAs() const { + template ArrayRef getDataAs() const { return llvm::ArrayRef((const T *)data.data(), data.size() / sizeof(T)); } @@ -464,8 +463,10 @@ class ParserConfig { /// `fallbackResourceMap` is an optional fallback handler that can be used to /// parse external resources not explicitly handled by another parser. ParserConfig(MLIRContext *context, bool verifyAfterParse = true, - FallbackAsmResourceMap *fallbackResourceMap = nullptr) + FallbackAsmResourceMap *fallbackResourceMap = nullptr, + bool retainIdentifierNames = false) : context(context), verifyAfterParse(verifyAfterParse), + retainIdentifierNames(retainIdentifierNames), fallbackResourceMap(fallbackResourceMap) { assert(context && "expected valid MLIR context"); } @@ -476,6 +477,10 @@ class ParserConfig { /// Returns if the parser should verify the IR after parsing. bool shouldVerifyAfterParse() const { return verifyAfterParse; } + /// Returns if the parser should retain identifier names collected using + /// parsing. + bool shouldRetainIdentifierNames() const { return retainIdentifierNames; } + /// Returns the parsing configurations associated to the bytecode read. BytecodeReaderConfig &getBytecodeReaderConfig() const { return const_cast(bytecodeReaderConfig); @@ -513,6 +518,7 @@ class ParserConfig { private: MLIRContext *context; bool verifyAfterParse; + bool retainIdentifierNames; DenseMap> resourceParsers; FallbackAsmResourceMap *fallbackResourceMap; BytecodeReaderConfig bytecodeReaderConfig; diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h index 6e90fad1618d2..a85dca186a4f3 100644 --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -176,6 +176,13 @@ class MlirOptMainConfig { /// Reproducer file generation (no crash required). StringRef getReproducerFilename() const { return generateReproducerFileFlag; } + /// Print the pass-pipeline as text before executing. + MlirOptMainConfig &retainIdentifierNames(bool retain) { + retainIdentifierNamesFlag = retain; + return *this; + } + bool shouldRetainIdentifierNames() const { return retainIdentifierNamesFlag; } + protected: /// Allow operation with no registered dialects. /// This option is for convenience during testing only and discouraged in @@ -226,6 +233,9 @@ class MlirOptMainConfig { /// the corresponding line. This is meant for implementing diagnostic tests. bool verifyDiagnosticsFlag = false; + /// Retain identifier names in the output (e.g., `%my_var` instead of `%0`). + bool retainIdentifierNamesFlag = false; + /// Run the verifier after each transformation pass. bool verifyPassesFlag = true; diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index 00f2b0c0c2f12..8d3861a2f018d 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -1223,6 +1223,23 @@ ParseResult OperationParser::parseOperation() { } } + // If enabled, store the SSA name(s) for the operation + if (state.config.shouldRetainIdentifierNames()) { + if (opResI == 1) { + for (ResultRecord &resIt : resultIDs) { + for (unsigned subRes : llvm::seq(0, std::get<1>(resIt))) { + op->setDiscardableAttr( + "mlir.ssaName", + StringAttr::get(getContext(), + std::get<0>(resIt).drop_front(1))); + } + } + } else if (opResI > 1) { + emitError( + "have not yet implemented support for multiple return values"); + } + } + // Add this operation to the assembly state if it was provided to populate. } else if (state.asmState) { state.asmState->finalizeOperationDefinition( diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 6b8b7473bf0f8..164b0f97fc1cd 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -72,13 +72,12 @@ OpAsmParser::~OpAsmParser() = default; MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); } /// Parse a type list. -/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918 +/// This is out-of-line to work-around +/// https://github.com/llvm/llvm-project/issues/62918 ParseResult AsmParser::parseTypeList(SmallVectorImpl &result) { - return parseCommaSeparatedList( - [&]() { return parseType(result.emplace_back()); }); - } - - + return parseCommaSeparatedList( + [&]() { return parseType(result.emplace_back()); }); +} //===----------------------------------------------------------------------===// // DialectAsmPrinter @@ -1579,6 +1578,12 @@ void SSANameState::numberValuesInOp(Operation &op) { } Value resultBegin = op.getResult(0); + // Get the original SSA for the result if available + if (StringAttr ssaNameAttr = op.getAttrOfType("mlir.ssaName")) { + setValueName(resultBegin, ssaNameAttr.strref()); + op.removeDiscardableAttr("mlir.ssaName"); + } + // If the first result wasn't numbered, give it a default number. if (valueIDs.try_emplace(resultBegin, nextValueID).second) ++nextValueID; diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp index 5395aa2b502d7..c448243586159 100644 --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -149,6 +149,11 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig { cl::desc("Round-trip the IR after parsing and ensure it succeeds"), cl::location(verifyRoundtripFlag), cl::init(false)); + static cl::opt retainIdentifierNames( + "retain-identifier-names", + cl::desc("Retain the original names of identifiers when printing"), + cl::location(retainIdentifierNamesFlag), cl::init(false)); + static cl::list passPlugins( "load-pass-plugin", cl::desc("Load passes from plugin library")); @@ -359,8 +364,9 @@ performActions(raw_ostream &os, // untouched. PassReproducerOptions reproOptions; FallbackAsmResourceMap fallbackResourceMap; - ParserConfig parseConfig(context, /*verifyAfterParse=*/true, - &fallbackResourceMap); + ParserConfig parseConfig( + context, /*verifyAfterParse=*/true, &fallbackResourceMap, + /*retainIdentifierName=*/config.shouldRetainIdentifierNames()); if (config.shouldRunReproducer()) reproOptions.attachResourceParser(parseConfig); From 5a629d2cd3e6f736af20c808d53206f086ea5b05 Mon Sep 17 00:00:00 2001 From: Perry Gibson Date: Fri, 26 Jan 2024 14:52:10 +0000 Subject: [PATCH 02/11] Moved SSA name stored to function --- mlir/lib/AsmParser/Parser.cpp | 40 ++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index 8d3861a2f018d..282077865c665 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -611,6 +611,10 @@ class OperationParser : public Parser { /// an object of type 'OperationName'. Otherwise, failure is returned. FailureOr parseCustomOperationName(); + /// Store the SSA names for the current operation as attrs for debug purposes. + ParseResult storeSSANames(Operation *&op, + SmallVector resultIDs); + //===--------------------------------------------------------------------===// // Region Parsing //===--------------------------------------------------------------------===// @@ -1224,21 +1228,8 @@ ParseResult OperationParser::parseOperation() { } // If enabled, store the SSA name(s) for the operation - if (state.config.shouldRetainIdentifierNames()) { - if (opResI == 1) { - for (ResultRecord &resIt : resultIDs) { - for (unsigned subRes : llvm::seq(0, std::get<1>(resIt))) { - op->setDiscardableAttr( - "mlir.ssaName", - StringAttr::get(getContext(), - std::get<0>(resIt).drop_front(1))); - } - } - } else if (opResI > 1) { - emitError( - "have not yet implemented support for multiple return values"); - } - } + if (state.config.shouldRetainIdentifierNames()) + storeSSANames(op, resultIDs); // Add this operation to the assembly state if it was provided to populate. } else if (state.asmState) { @@ -1285,6 +1276,25 @@ OperationParser::parseSuccessors(SmallVectorImpl &destinations) { /*allowEmptyList=*/false); } +/// Store the SSA names for the current operation as attrs for debug purposes. +ParseResult +OperationParser::storeSSANames(Operation *&op, + SmallVector resultIDs) { + if (op->getNumResults() == 0) + emitError("Operation has no results\n"); + else if (op->getNumResults() > 1) + emitError("have not yet implemented support for multiple return values\n"); + + for (ResultRecord &resIt : resultIDs) { + for (unsigned subRes : llvm::seq(0, std::get<1>(resIt))) { + op->setDiscardableAttr( + "mlir.ssaName", + StringAttr::get(getContext(), std::get<0>(resIt).drop_front(1))); + } + } + return success(); +} + namespace { // RAII-style guard for cleaning up the regions in the operation state before // deleting them. Within the parser, regions may get deleted if parsing failed, From f2f8047bd80df0f6da6a8db4bfe976491157993e Mon Sep 17 00:00:00 2001 From: Perry Gibson Date: Fri, 26 Jan 2024 15:39:18 +0000 Subject: [PATCH 03/11] Added initial block name handling --- mlir/lib/AsmParser/Parser.cpp | 13 +++++++++++++ mlir/lib/IR/AsmPrinter.cpp | 33 ++++++++++++++++++++++++++++----- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index 282077865c665..bb3fd5d872940 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -1292,6 +1292,19 @@ OperationParser::storeSSANames(Operation *&op, StringAttr::get(getContext(), std::get<0>(resIt).drop_front(1))); } } + + // Find the name of the block that contains this operation. + Block *blockPtr = op->getBlock(); + for (const auto &map : blocksByName) { + for (const auto &entry : map) { + if (entry.second.block == blockPtr) { + op->setDiscardableAttr("mlir.blockName", + StringAttr::get(getContext(), entry.first)); + llvm::outs() << "Block name: " << entry.first << "\n"; + } + } + } + return success(); } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 164b0f97fc1cd..70d38c9bd8b4f 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1298,6 +1298,10 @@ class SSANameState { /// conflicts, it is automatically renamed. StringRef uniqueValueName(StringRef name); + /// Set the original identifier names if available. Used in debugging with + /// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig + void setRetainedIdentifierNames(Operation &op); + /// This is the value ID for each SSA value. If this returns NameSentinel, /// then the valueID has an entry in valueNames. DenseMap valueIDs; @@ -1578,11 +1582,9 @@ void SSANameState::numberValuesInOp(Operation &op) { } Value resultBegin = op.getResult(0); - // Get the original SSA for the result if available - if (StringAttr ssaNameAttr = op.getAttrOfType("mlir.ssaName")) { - setValueName(resultBegin, ssaNameAttr.strref()); - op.removeDiscardableAttr("mlir.ssaName"); - } + // Set the original identifier names if available. Used in debugging with + // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig + setRetainedIdentifierNames(op); // If the first result wasn't numbered, give it a default number. if (valueIDs.try_emplace(resultBegin, nextValueID).second) @@ -1595,6 +1597,27 @@ void SSANameState::numberValuesInOp(Operation &op) { } } +void SSANameState::setRetainedIdentifierNames(Operation &op) { + // Get the original SSA for the result(s) if available + Value resultBegin = op.getResult(0); + if (StringAttr ssaNameAttr = op.getAttrOfType("mlir.ssaName")) { + setValueName(resultBegin, ssaNameAttr.strref()); + op.removeDiscardableAttr("mlir.ssaName"); + } + unsigned numResults = op.getNumResults(); + if (numResults > 1) + llvm::outs() + << "have not yet implemented support for multiple return values\n"; + + // Get the original SSA name for the block if available + if (StringAttr blockNameAttr = + op.getAttrOfType("mlir.blockName")) { + blockNames[op.getBlock()] = {-1, blockNameAttr.strref()}; + op.removeDiscardableAttr("mlir.blockName"); + } + return; +} + void SSANameState::getResultIDAndNumber( OpResult result, Value &lookupValue, std::optional &lookupResultNo) const { From 1aa3814f6eb47ec8d51d2eb68e40e469194e2744 Mon Sep 17 00:00:00 2001 From: Perry Gibson Date: Fri, 26 Jan 2024 16:18:27 +0000 Subject: [PATCH 04/11] Added block arg name handling --- mlir/lib/AsmParser/Parser.cpp | 25 ++++++++++++++++++++----- mlir/lib/IR/AsmPrinter.cpp | 14 +++++++++++++- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index bb3fd5d872940..95b1b2fe1f8f0 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -614,6 +614,7 @@ class OperationParser : public Parser { /// Store the SSA names for the current operation as attrs for debug purposes. ParseResult storeSSANames(Operation *&op, SmallVector resultIDs); + DenseMap blockArgNames; //===--------------------------------------------------------------------===// // Region Parsing @@ -1203,6 +1204,11 @@ ParseResult OperationParser::parseOperation() { << op->getNumResults() << " results but was provided " << numExpectedResults << " to bind"; + // If enabled, store the SSA name(s) for the operation + llvm::outs() << "parsing operation: " << op->getName() << "\n"; + if (state.config.shouldRetainIdentifierNames()) + storeSSANames(op, resultIDs); + // Add this operation to the assembly state if it was provided to populate. if (state.asmState) { unsigned resultIt = 0; @@ -1227,10 +1233,6 @@ ParseResult OperationParser::parseOperation() { } } - // If enabled, store the SSA name(s) for the operation - if (state.config.shouldRetainIdentifierNames()) - storeSSANames(op, resultIDs); - // Add this operation to the assembly state if it was provided to populate. } else if (state.asmState) { state.asmState->finalizeOperationDefinition( @@ -1300,7 +1302,16 @@ OperationParser::storeSSANames(Operation *&op, if (entry.second.block == blockPtr) { op->setDiscardableAttr("mlir.blockName", StringAttr::get(getContext(), entry.first)); - llvm::outs() << "Block name: " << entry.first << "\n"; + + // Store block arguments, if present + llvm::SmallVector argNames; + + for (BlockArgument arg : blockPtr->getArguments()) { + auto it = blockArgNames.find(arg); + if (it != blockArgNames.end()) + argNames.push_back(it->second.drop_front(1)); + } + op->setAttr("mlir.blockArgNames", builder.getStrArrayAttr(argNames)); } } } @@ -2395,6 +2406,10 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) { } else { auto loc = getEncodedSourceLocation(useInfo.location); arg = owner->addArgument(type, loc); + + // Optionally store argument name for debug purposes + if (state.config.shouldRetainIdentifierNames()) + blockArgNames.insert({arg, useInfo.name}); } // If the argument has an explicit loc(...) specifier, parse and apply diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 70d38c9bd8b4f..34e0bdfccda94 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1609,12 +1609,24 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) { llvm::outs() << "have not yet implemented support for multiple return values\n"; - // Get the original SSA name for the block if available + // Get the original name for the block if available if (StringAttr blockNameAttr = op.getAttrOfType("mlir.blockName")) { blockNames[op.getBlock()] = {-1, blockNameAttr.strref()}; op.removeDiscardableAttr("mlir.blockName"); } + + // Get the original name for the block args if available + if (ArrayAttr blockArgNamesAttr = + op.getAttrOfType("mlir.blockArgNames")) { + auto blockArgNames = blockArgNamesAttr.getValue(); + auto blockArgs = op.getBlock()->getArguments(); + for (int i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) { + auto blockArgName = blockArgNames[i].cast(); + setValueName(blockArgs[i], cast(blockArgNames[i]).strref()); + } + op.removeDiscardableAttr("mlir.blockArgNames"); + } return; } From c0ec451d9ba155478ee5bded93cfaa912480e0ce Mon Sep 17 00:00:00 2001 From: Perry Gibson Date: Sat, 27 Jan 2024 11:24:53 +0000 Subject: [PATCH 05/11] Added support for operations with no arguments --- mlir/lib/AsmParser/Parser.cpp | 26 ++++++++++---------------- mlir/lib/IR/AsmPrinter.cpp | 27 +++++++++++++++------------ 2 files changed, 25 insertions(+), 28 deletions(-) diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index 95b1b2fe1f8f0..36266be5af033 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -612,8 +612,7 @@ class OperationParser : public Parser { FailureOr parseCustomOperationName(); /// Store the SSA names for the current operation as attrs for debug purposes. - ParseResult storeSSANames(Operation *&op, - SmallVector resultIDs); + void storeSSANames(Operation *&op, ArrayRef resultIDs); DenseMap blockArgNames; //===--------------------------------------------------------------------===// @@ -1204,11 +1203,6 @@ ParseResult OperationParser::parseOperation() { << op->getNumResults() << " results but was provided " << numExpectedResults << " to bind"; - // If enabled, store the SSA name(s) for the operation - llvm::outs() << "parsing operation: " << op->getName() << "\n"; - if (state.config.shouldRetainIdentifierNames()) - storeSSANames(op, resultIDs); - // Add this operation to the assembly state if it was provided to populate. if (state.asmState) { unsigned resultIt = 0; @@ -1279,15 +1273,12 @@ OperationParser::parseSuccessors(SmallVectorImpl &destinations) { } /// Store the SSA names for the current operation as attrs for debug purposes. -ParseResult -OperationParser::storeSSANames(Operation *&op, - SmallVector resultIDs) { - if (op->getNumResults() == 0) - emitError("Operation has no results\n"); - else if (op->getNumResults() > 1) +void OperationParser::storeSSANames(Operation *&op, + ArrayRef resultIDs) { + if (op->getNumResults() > 1) emitError("have not yet implemented support for multiple return values\n"); - for (ResultRecord &resIt : resultIDs) { + for (const ResultRecord &resIt : resultIDs) { for (unsigned subRes : llvm::seq(0, std::get<1>(resIt))) { op->setDiscardableAttr( "mlir.ssaName", @@ -1315,8 +1306,6 @@ OperationParser::storeSSANames(Operation *&op, } } } - - return success(); } namespace { @@ -2082,6 +2071,11 @@ OperationParser::parseCustomOperation(ArrayRef resultIDs) { // Otherwise, create the operation and try to parse a location for it. Operation *op = opBuilder.create(opState); + + // If enabled, store the SSA name(s) for the operation + if (state.config.shouldRetainIdentifierNames()) + storeSSANames(op, resultIDs); + if (parseTrailingLocationSpecifier(op)) return nullptr; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 34e0bdfccda94..d4be3b7802e68 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1571,6 +1571,10 @@ void SSANameState::numberValuesInOp(Operation &op) { } } + // Set the original identifier names if available. Used in debugging with + // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig + setRetainedIdentifierNames(op); + unsigned numResults = op.getNumResults(); if (numResults == 0) { // If value users should be printed, operations with no result need an id. @@ -1582,10 +1586,6 @@ void SSANameState::numberValuesInOp(Operation &op) { } Value resultBegin = op.getResult(0); - // Set the original identifier names if available. Used in debugging with - // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig - setRetainedIdentifierNames(op); - // If the first result wasn't numbered, give it a default number. if (valueIDs.try_emplace(resultBegin, nextValueID).second) ++nextValueID; @@ -1599,15 +1599,17 @@ void SSANameState::numberValuesInOp(Operation &op) { void SSANameState::setRetainedIdentifierNames(Operation &op) { // Get the original SSA for the result(s) if available - Value resultBegin = op.getResult(0); - if (StringAttr ssaNameAttr = op.getAttrOfType("mlir.ssaName")) { - setValueName(resultBegin, ssaNameAttr.strref()); - op.removeDiscardableAttr("mlir.ssaName"); - } unsigned numResults = op.getNumResults(); if (numResults > 1) llvm::outs() << "have not yet implemented support for multiple return values\n"; + else if (numResults == 1) { + Value resultBegin = op.getResult(0); + if (StringAttr ssaNameAttr = op.getAttrOfType("mlir.ssaName")) { + setValueName(resultBegin, ssaNameAttr.strref()); + op.removeDiscardableAttr("mlir.ssaName"); + } + } // Get the original name for the block if available if (StringAttr blockNameAttr = @@ -1621,9 +1623,10 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) { op.getAttrOfType("mlir.blockArgNames")) { auto blockArgNames = blockArgNamesAttr.getValue(); auto blockArgs = op.getBlock()->getArguments(); - for (int i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) { - auto blockArgName = blockArgNames[i].cast(); - setValueName(blockArgs[i], cast(blockArgNames[i]).strref()); + for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) { + auto blockArgName = blockArgNames[i].cast().strref(); + if (!usedNames.count(blockArgName)) + setValueName(blockArgs[i], blockArgName); } op.removeDiscardableAttr("mlir.blockArgNames"); } From a9e3891f4ed0511de6797cfe986fe8d3dcbf8f08 Mon Sep 17 00:00:00 2001 From: Perry Gibson Date: Sat, 27 Jan 2024 12:21:39 +0000 Subject: [PATCH 06/11] Added initial unit tests --- mlir/test/IR/print-retain-identifiers.mlir | 54 ++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 mlir/test/IR/print-retain-identifiers.mlir diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir new file mode 100644 index 0000000000000..663f7301b817b --- /dev/null +++ b/mlir/test/IR/print-retain-identifiers.mlir @@ -0,0 +1,54 @@ +// RUN: mlir-opt -retain-identifier-names %s | FileCheck %s + + +//===----------------------------------------------------------------------===// +// Test SSA results (with single return values) +//===----------------------------------------------------------------------===// + +// CHECK: func.func @add_one(%arg0: f64, %arg1: f64) -> f64 { +func.func @add_one(%arg0: f64, %arg1: f64) -> f64 { + // CHECK: %my_constant = arith.constant 1.000000e+00 : f64 + %my_constant = arith.constant 1.000000e+00 : f64 + // CHECK: %my_output = arith.addf %arg0, %my_constant : f64 + %my_output = arith.addf %arg0, %my_constant : f64 + // CHECK: return %my_output : f64 + return %my_output : f64 +} + + +// ----- + +//===----------------------------------------------------------------------===// +// Test basic blocks and their arguments +//===----------------------------------------------------------------------===// + +func.func @simple(i64, i1) -> i64 { +^bb_alpha(%a: i64, %cond: i1): + // CHECK: cf.cond_br %cond, ^bb_beta, ^bb_gamma + cf.cond_br %cond, ^bb_beta, ^bb_gamma + +// CHECK: ^bb_beta: // pred: ^bb_alpha +^bb_beta: + // CHECK: cf.br ^bb_delta(%a : i64) + cf.br ^bb_delta(%a: i64) + +// CHECK: ^bb_gamma: // pred: ^bb_alpha +^bb_gamma: + // CHECK: %b = arith.addi %a, %a : i64 + %b = arith.addi %a, %a : i64 + // CHECK: cf.br ^bb_delta(%b : i64) + cf.br ^bb_delta(%b: i64) + +// CHECK: ^bb_delta(%c: i64): // 2 preds: ^bb_gamma, ^bb_beta +^bb_delta(%c: i64): + // CHECK: cf.br ^bb_eps(%c, %a : i64, i64) + cf.br ^bb_eps(%c, %a : i64, i64) + +// CHECK: ^bb_eps(%d: i64, %e: i64): // pred: ^bb_delta +^bb_eps(%d : i64, %e : i64): + // CHECK: %f = arith.addi %d, %e : i64 + %f = arith.addi %d, %e : i64 + return %f : i64 +} + +// ----- From 467c3a38abfbbc22c25b0874a3943ee7ad8bc78b Mon Sep 17 00:00:00 2001 From: Perry Gibson Date: Sat, 27 Jan 2024 14:49:16 +0000 Subject: [PATCH 07/11] Added operation operand name preservation --- mlir/lib/AsmParser/Parser.cpp | 34 +++++++++++++++++----- mlir/lib/IR/AsmPrinter.cpp | 13 +++++++++ mlir/test/IR/print-retain-identifiers.mlir | 8 ++--- 3 files changed, 43 insertions(+), 12 deletions(-) diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index 36266be5af033..7327776ae5a55 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -613,7 +613,7 @@ class OperationParser : public Parser { /// Store the SSA names for the current operation as attrs for debug purposes. void storeSSANames(Operation *&op, ArrayRef resultIDs); - DenseMap blockArgNames; + DenseMap argNames; //===--------------------------------------------------------------------===// // Region Parsing @@ -1286,7 +1286,19 @@ void OperationParser::storeSSANames(Operation *&op, } } - // Find the name of the block that contains this operation. + // Store the name information of the arguments of this operation. + if (op->getNumOperands() > 0) { + llvm::SmallVector opArgNames; + for (auto &operand : op->getOpOperands()) { + auto it = argNames.find(operand.get()); + if (it != argNames.end()) + opArgNames.push_back(it->second.drop_front(1)); + } + op->setDiscardableAttr("mlir.opArgNames", + builder.getStrArrayAttr(opArgNames)); + } + + // Store the name information of the block that contains this operation. Block *blockPtr = op->getBlock(); for (const auto &map : blocksByName) { for (const auto &entry : map) { @@ -1295,14 +1307,15 @@ void OperationParser::storeSSANames(Operation *&op, StringAttr::get(getContext(), entry.first)); // Store block arguments, if present - llvm::SmallVector argNames; + llvm::SmallVector blockArgNames; for (BlockArgument arg : blockPtr->getArguments()) { - auto it = blockArgNames.find(arg); - if (it != blockArgNames.end()) - argNames.push_back(it->second.drop_front(1)); + auto it = argNames.find(arg); + if (it != argNames.end()) + blockArgNames.push_back(it->second.drop_front(1)); } - op->setAttr("mlir.blockArgNames", builder.getStrArrayAttr(argNames)); + op->setAttr("mlir.blockArgNames", + builder.getStrArrayAttr(blockArgNames)); } } } @@ -1712,6 +1725,11 @@ class CustomOpAsmParser : public AsmParserImpl { SmallVectorImpl &result) override { if (auto value = parser.resolveSSAUse(operand, type)) { result.push_back(value); + + // Optionally store argument name for debug purposes + if (parser.getState().config.shouldRetainIdentifierNames()) + parser.argNames.insert({value, operand.name}); + return success(); } return failure(); @@ -2403,7 +2421,7 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) { // Optionally store argument name for debug purposes if (state.config.shouldRetainIdentifierNames()) - blockArgNames.insert({arg, useInfo.name}); + argNames.insert({arg, useInfo.name}); } // If the argument has an explicit loc(...) specifier, parse and apply diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index d4be3b7802e68..5755b9021dbad 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1611,6 +1611,19 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) { } } + // Get the original name for the op args if available + if (ArrayAttr opArgNamesAttr = + op.getAttrOfType("mlir.opArgNames")) { + auto opArgNames = opArgNamesAttr.getValue(); + auto opArgs = op.getOperands(); + for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) { + auto opArgName = opArgNames[i].cast().strref(); + if (!usedNames.count(opArgName)) + setValueName(opArgs[i], opArgName); + } + op.removeDiscardableAttr("mlir.opArgNames"); + } + // Get the original name for the block if available if (StringAttr blockNameAttr = op.getAttrOfType("mlir.blockName")) { diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir index 663f7301b817b..05a6b9b42d8e8 100644 --- a/mlir/test/IR/print-retain-identifiers.mlir +++ b/mlir/test/IR/print-retain-identifiers.mlir @@ -5,12 +5,12 @@ // Test SSA results (with single return values) //===----------------------------------------------------------------------===// -// CHECK: func.func @add_one(%arg0: f64, %arg1: f64) -> f64 { -func.func @add_one(%arg0: f64, %arg1: f64) -> f64 { +// CHECK: func.func @add_one(%my_input: f64, %arg1: f64) -> f64 { +func.func @add_one(%my_input: f64, %arg1: f64) -> f64 { // CHECK: %my_constant = arith.constant 1.000000e+00 : f64 %my_constant = arith.constant 1.000000e+00 : f64 - // CHECK: %my_output = arith.addf %arg0, %my_constant : f64 - %my_output = arith.addf %arg0, %my_constant : f64 + // CHECK: %my_output = arith.addf %my_input, %my_constant : f64 + %my_output = arith.addf %my_input, %my_constant : f64 // CHECK: return %my_output : f64 return %my_output : f64 } From d062cb861c283594413c8edc7dca57f9494b1f3e Mon Sep 17 00:00:00 2001 From: Perry Gibson Date: Sat, 27 Jan 2024 18:27:53 +0000 Subject: [PATCH 08/11] Added support for result groups --- mlir/lib/AsmParser/Parser.cpp | 19 ++++++----- mlir/lib/IR/AsmPrinter.cpp | 36 ++++++++++++-------- mlir/test/IR/print-retain-identifiers.mlir | 38 ++++++++++++++++++++++ 3 files changed, 72 insertions(+), 21 deletions(-) diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index 7327776ae5a55..247e99e61c2c0 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -1275,15 +1275,18 @@ OperationParser::parseSuccessors(SmallVectorImpl &destinations) { /// Store the SSA names for the current operation as attrs for debug purposes. void OperationParser::storeSSANames(Operation *&op, ArrayRef resultIDs) { - if (op->getNumResults() > 1) - emitError("have not yet implemented support for multiple return values\n"); - - for (const ResultRecord &resIt : resultIDs) { - for (unsigned subRes : llvm::seq(0, std::get<1>(resIt))) { - op->setDiscardableAttr( - "mlir.ssaName", - StringAttr::get(getContext(), std::get<0>(resIt).drop_front(1))); + + // Store the name(s) of the result(s) of this operation. + if (op->getNumResults() > 0) { + llvm::SmallVector resultNames; + for (const ResultRecord &resIt : resultIDs) { + resultNames.push_back(std::get<0>(resIt).drop_front(1)); + // Insert empty string for sub-results/result groups + for (unsigned int i = 1; i < std::get<1>(resIt); ++i) + resultNames.push_back(llvm::StringRef()); } + op->setDiscardableAttr("mlir.resultNames", + builder.getStrArrayAttr(resultNames)); } // Store the name information of the arguments of this operation. diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 5755b9021dbad..84603bb6ebfba 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1300,7 +1300,8 @@ class SSANameState { /// Set the original identifier names if available. Used in debugging with /// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig - void setRetainedIdentifierNames(Operation &op); + void setRetainedIdentifierNames(Operation &op, + SmallVector &resultGroups); /// This is the value ID for each SSA value. If this returns NameSentinel, /// then the valueID has an entry in valueNames. @@ -1573,7 +1574,7 @@ void SSANameState::numberValuesInOp(Operation &op) { // Set the original identifier names if available. Used in debugging with // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig - setRetainedIdentifierNames(op); + setRetainedIdentifierNames(op, resultGroups); unsigned numResults = op.getNumResults(); if (numResults == 0) { @@ -1597,18 +1598,25 @@ void SSANameState::numberValuesInOp(Operation &op) { } } -void SSANameState::setRetainedIdentifierNames(Operation &op) { - // Get the original SSA for the result(s) if available - unsigned numResults = op.getNumResults(); - if (numResults > 1) - llvm::outs() - << "have not yet implemented support for multiple return values\n"; - else if (numResults == 1) { - Value resultBegin = op.getResult(0); - if (StringAttr ssaNameAttr = op.getAttrOfType("mlir.ssaName")) { - setValueName(resultBegin, ssaNameAttr.strref()); - op.removeDiscardableAttr("mlir.ssaName"); +void SSANameState::setRetainedIdentifierNames( + Operation &op, SmallVector &resultGroups) { + // Get the original names for the results if available + if (ArrayAttr resultNamesAttr = + op.getAttrOfType("mlir.resultNames")) { + auto resultNames = resultNamesAttr.getValue(); + auto results = op.getResults(); + // Conservative in the case that the #results has changed + for (size_t i = 0; i < results.size() && i < resultNames.size(); ++i) { + auto resultName = resultNames[i].cast().strref(); + if (!resultName.empty()) { + if (!usedNames.count(resultName)) + setValueName(results[i], resultName); + // If a result has a name, it is the start of a result group. + if (i > 0) + resultGroups.push_back(i); + } } + op.removeDiscardableAttr("mlir.resultNames"); } // Get the original name for the op args if available @@ -1616,6 +1624,7 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) { op.getAttrOfType("mlir.opArgNames")) { auto opArgNames = opArgNamesAttr.getValue(); auto opArgs = op.getOperands(); + // Conservative in the case that the #operands has changed for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) { auto opArgName = opArgNames[i].cast().strref(); if (!usedNames.count(opArgName)) @@ -1636,6 +1645,7 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) { op.getAttrOfType("mlir.blockArgNames")) { auto blockArgNames = blockArgNamesAttr.getValue(); auto blockArgs = op.getBlock()->getArguments(); + // Conservative in the case that the #args has changed for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) { auto blockArgName = blockArgNames[i].cast().strref(); if (!usedNames.count(blockArgName)) diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir index 05a6b9b42d8e8..b3e4f075b3936 100644 --- a/mlir/test/IR/print-retain-identifiers.mlir +++ b/mlir/test/IR/print-retain-identifiers.mlir @@ -52,3 +52,41 @@ func.func @simple(i64, i1) -> i64 { } // ----- + +//===----------------------------------------------------------------------===// +// Test multiple return values +//===----------------------------------------------------------------------===// + +func.func @select_min_max(%a: f64, %b: f64) -> (f64, f64) { + %gt = arith.cmpf "ogt", %a, %b : f64 + // CHECK: %min, %max = scf.if %gt -> (f64, f64) { + %min, %max = scf.if %gt -> (f64, f64) { + scf.yield %b, %a : f64, f64 + } else { + scf.yield %a, %b : f64, f64 + } + // CHECK: return %min, %max : f64, f64 + return %min, %max : f64, f64 +} + +// ----- + +////===----------------------------------------------------------------------===// +// Test multiple return values, with a grouped value tuple +//===----------------------------------------------------------------------===// + +func.func @select_max(%a: f64, %b: f64, %c: f64, %d: f64) -> (f64, f64, f64, f64) { + // Find the max between %a and %b, + // with %c and %d being other values that are returned. + %gt = arith.cmpf "ogt", %a, %b : f64 + // CHECK: %max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) { + %max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) { + scf.yield %b, %a, %c, %d : f64, f64, f64, f64 + } else { + scf.yield %a, %b, %d, %c : f64, f64, f64, f64 + } + // CHECK: return %max, %others#0, %others#1, %alt : f64, f64, f64, f64 + return %max, %others#0, %others#1, %alt : f64, f64, f64, f64 +} + +// ----- From fb557df2c4afec5aaaf876428769b70d29306244 Mon Sep 17 00:00:00 2001 From: Perry Gibson Date: Sat, 27 Jan 2024 18:58:35 +0000 Subject: [PATCH 09/11] Fix clang-format issue in AsmState.h --- mlir/include/mlir/IR/AsmState.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h index 9c4eadb04cdf2..36f80712efb6a 100644 --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -144,7 +144,8 @@ class AsmResourceBlob { /// Return the underlying data as an array of the given type. This is an /// inherrently unsafe operation, and should only be used when the data is /// known to be of the correct type. - template ArrayRef getDataAs() const { + template + ArrayRef getDataAs() const { return llvm::ArrayRef((const T *)data.data(), data.size() / sizeof(T)); } From 233a987408742eb47c9d55bbe1013af4c0b0af10 Mon Sep 17 00:00:00 2001 From: Perry Gibson Date: Sun, 28 Jan 2024 16:22:52 +0000 Subject: [PATCH 10/11] Added system to handle when we use default names --- mlir/lib/IR/AsmPrinter.cpp | 45 ++++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 84603bb6ebfba..e3bbe75441d58 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -981,7 +981,8 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter { /// store the new copy, static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, StringRef allowedPunctChars = "$._-", - bool allowTrailingDigit = true) { + bool allowTrailingDigit = true, + bool allowNumeric = false) { assert(!name.empty() && "Shouldn't have an empty name here"); auto copyNameToBuffer = [&] { @@ -997,16 +998,17 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, // Check to see if this name is valid. If it starts with a digit, then it // could conflict with the autogenerated numeric ID's, so add an underscore - // prefix to avoid problems. - if (isdigit(name[0])) { + // prefix to avoid problems. This can be overridden by setting allowNumeric. + if (isdigit(name[0]) && !allowNumeric) { buffer.push_back('_'); copyNameToBuffer(); return buffer; } // If the name ends with a trailing digit, add a '_' to avoid potential - // conflicts with autogenerated ID's. - if (!allowTrailingDigit && isdigit(name.back())) { + // conflicts with autogenerated ID's. This can be overridden by setting + // allowNumeric. + if (!allowTrailingDigit && isdigit(name.back()) && !allowNumeric) { copyNameToBuffer(); buffer.push_back('_'); return buffer; @@ -1292,11 +1294,11 @@ class SSANameState { std::optional &lookupResultNo) const; /// Set a special value name for the given value. - void setValueName(Value value, StringRef name); + void setValueName(Value value, StringRef name, bool allowNumeric = false); /// Uniques the given value name within the printer. If the given name /// conflicts, it is automatically renamed. - StringRef uniqueValueName(StringRef name); + StringRef uniqueValueName(StringRef name, bool allowNumeric = false); /// Set the original identifier names if available. Used in debugging with /// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig @@ -1541,7 +1543,10 @@ void SSANameState::numberValuesInOp(Operation &op) { // Function used to set the special result names for the operation. SmallVector resultGroups(/*Size=*/1, /*Value=*/0); auto setResultNameFn = [&](Value result, StringRef name) { - assert(!valueIDs.count(result) && "result numbered multiple times"); + // Case where the result has already been named + if (valueIDs.count(result)) + return; + // assert(!valueIDs.count(result) && "result numbered multiple times"); assert(result.getDefiningOp() == &op && "result not defined by 'op'"); setValueName(result, name); @@ -1565,6 +1570,10 @@ void SSANameState::numberValuesInOp(Operation &op) { blockNames[block] = {-1, name}; }; + // Set the original identifier names if available. Used in debugging with + // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig + setRetainedIdentifierNames(op, resultGroups); + if (!printerFlags.shouldPrintGenericOpForm()) { if (OpAsmOpInterface asmInterface = dyn_cast(&op)) { asmInterface.getAsmBlockNames(setBlockNameFn); @@ -1572,10 +1581,6 @@ void SSANameState::numberValuesInOp(Operation &op) { } } - // Set the original identifier names if available. Used in debugging with - // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig - setRetainedIdentifierNames(op, resultGroups); - unsigned numResults = op.getNumResults(); if (numResults == 0) { // If value users should be printed, operations with no result need an id. @@ -1610,7 +1615,7 @@ void SSANameState::setRetainedIdentifierNames( auto resultName = resultNames[i].cast().strref(); if (!resultName.empty()) { if (!usedNames.count(resultName)) - setValueName(results[i], resultName); + setValueName(results[i], resultName, /*allowNumeric=*/true); // If a result has a name, it is the start of a result group. if (i > 0) resultGroups.push_back(i); @@ -1628,7 +1633,7 @@ void SSANameState::setRetainedIdentifierNames( for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) { auto opArgName = opArgNames[i].cast().strref(); if (!usedNames.count(opArgName)) - setValueName(opArgs[i], opArgName); + setValueName(opArgs[i], opArgName, /*allowNumeric=*/true); } op.removeDiscardableAttr("mlir.opArgNames"); } @@ -1649,7 +1654,7 @@ void SSANameState::setRetainedIdentifierNames( for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) { auto blockArgName = blockArgNames[i].cast().strref(); if (!usedNames.count(blockArgName)) - setValueName(blockArgs[i], blockArgName); + setValueName(blockArgs[i], blockArgName, /*allowNumeric=*/true); } op.removeDiscardableAttr("mlir.blockArgNames"); } @@ -1695,7 +1700,8 @@ void SSANameState::getResultIDAndNumber( lookupValue = owner->getResult(groupResultNo); } -void SSANameState::setValueName(Value value, StringRef name) { +void SSANameState::setValueName(Value value, StringRef name, + bool allowNumeric) { // If the name is empty, the value uses the default numbering. if (name.empty()) { valueIDs[value] = nextValueID++; @@ -1703,12 +1709,13 @@ void SSANameState::setValueName(Value value, StringRef name) { } valueIDs[value] = NameSentinel; - valueNames[value] = uniqueValueName(name); + valueNames[value] = uniqueValueName(name, allowNumeric); } -StringRef SSANameState::uniqueValueName(StringRef name) { +StringRef SSANameState::uniqueValueName(StringRef name, bool allowNumeric) { SmallString<16> tmpBuffer; - name = sanitizeIdentifier(name, tmpBuffer); + name = sanitizeIdentifier(name, tmpBuffer, /*allowedPunctChars=*/"$._-", + /*allowTrailingDigit=*/true, allowNumeric); // Check to see if this name is already unique. if (!usedNames.count(name)) { From b257ab53e508e5fb00e16a1c73e7fceeea4d0923 Mon Sep 17 00:00:00 2001 From: Perry Gibson Date: Mon, 29 Jan 2024 13:03:04 +0000 Subject: [PATCH 11/11] Added region arg support & ambiguous name test --- mlir/lib/AsmParser/Parser.cpp | 28 ++++- mlir/lib/IR/AsmPrinter.cpp | 122 ++++++++++++--------- mlir/test/IR/print-retain-identifiers.mlir | 27 ++++- 3 files changed, 115 insertions(+), 62 deletions(-) diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index 247e99e61c2c0..6f0c5fa30ffa5 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -611,8 +611,9 @@ class OperationParser : public Parser { /// an object of type 'OperationName'. Otherwise, failure is returned. FailureOr parseCustomOperationName(); - /// Store the SSA names for the current operation as attrs for debug purposes. - void storeSSANames(Operation *&op, ArrayRef resultIDs); + /// Store the identifier names for the current operation as attrs for debug + /// purposes. + void storeIdentifierNames(Operation *&op, ArrayRef resultIDs); DenseMap argNames; //===--------------------------------------------------------------------===// @@ -1273,8 +1274,8 @@ OperationParser::parseSuccessors(SmallVectorImpl &destinations) { } /// Store the SSA names for the current operation as attrs for debug purposes. -void OperationParser::storeSSANames(Operation *&op, - ArrayRef resultIDs) { +void OperationParser::storeIdentifierNames(Operation *&op, + ArrayRef resultIDs) { // Store the name(s) of the result(s) of this operation. if (op->getNumResults() > 0) { @@ -1322,6 +1323,18 @@ void OperationParser::storeSSANames(Operation *&op, } } } + + // Store names of region arguments (e.g., for FuncOps) + if (op->getNumRegions() > 0 && op->getRegion(0).getNumArguments() > 0) { + llvm::SmallVector regionArgNames; + for (BlockArgument arg : op->getRegion(0).getArguments()) { + auto it = argNames.find(arg); + if (it != argNames.end()) { + regionArgNames.push_back(it->second.drop_front(1)); + } + } + op->setAttr("mlir.regionArgNames", builder.getStrArrayAttr(regionArgNames)); + } } namespace { @@ -2093,9 +2106,9 @@ OperationParser::parseCustomOperation(ArrayRef resultIDs) { // Otherwise, create the operation and try to parse a location for it. Operation *op = opBuilder.create(opState); - // If enabled, store the SSA name(s) for the operation + // If enabled, store the original identifier name(s) for the operation if (state.config.shouldRetainIdentifierNames()) - storeSSANames(op, resultIDs); + storeIdentifierNames(op, resultIDs); if (parseTrailingLocationSpecifier(op)) return nullptr; @@ -2246,6 +2259,9 @@ ParseResult OperationParser::parseRegionBody(Region ®ion, SMLoc startLoc, if (state.asmState) state.asmState->addDefinition(arg, argInfo.location); + if (state.config.shouldRetainIdentifierNames()) + argNames.insert({arg, argInfo.name}); + // Record the definition for this argument. if (addDefinition(argInfo, arg)) return failure(); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index e3bbe75441d58..51f4bb66a8414 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1303,7 +1303,9 @@ class SSANameState { /// Set the original identifier names if available. Used in debugging with /// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig void setRetainedIdentifierNames(Operation &op, - SmallVector &resultGroups); + SmallVector &resultGroups, + bool hasRegion = false); + void setRetainedIdentifierNames(Region ®ion); /// This is the value ID for each SSA value. If this returns NameSentinel, /// then the valueID has an entry in valueNames. @@ -1492,6 +1494,9 @@ void SSANameState::numberValuesInRegion(Region ®ion) { setValueName(arg, name); }; + // Use manually specified region arg names if available + setRetainedIdentifierNames(region); + if (!printerFlags.shouldPrintGenericOpForm()) { if (Operation *op = region.getParentOp()) { if (auto asmInterface = dyn_cast(op)) @@ -1603,64 +1608,75 @@ void SSANameState::numberValuesInOp(Operation &op) { } } -void SSANameState::setRetainedIdentifierNames( - Operation &op, SmallVector &resultGroups) { - // Get the original names for the results if available - if (ArrayAttr resultNamesAttr = - op.getAttrOfType("mlir.resultNames")) { - auto resultNames = resultNamesAttr.getValue(); - auto results = op.getResults(); - // Conservative in the case that the #results has changed - for (size_t i = 0; i < results.size() && i < resultNames.size(); ++i) { - auto resultName = resultNames[i].cast().strref(); - if (!resultName.empty()) { - if (!usedNames.count(resultName)) - setValueName(results[i], resultName, /*allowNumeric=*/true); - // If a result has a name, it is the start of a result group. - if (i > 0) - resultGroups.push_back(i); - } - } - op.removeDiscardableAttr("mlir.resultNames"); - } - - // Get the original name for the op args if available - if (ArrayAttr opArgNamesAttr = - op.getAttrOfType("mlir.opArgNames")) { - auto opArgNames = opArgNamesAttr.getValue(); - auto opArgs = op.getOperands(); - // Conservative in the case that the #operands has changed - for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) { - auto opArgName = opArgNames[i].cast().strref(); - if (!usedNames.count(opArgName)) - setValueName(opArgs[i], opArgName, /*allowNumeric=*/true); +void SSANameState::setRetainedIdentifierNames(Operation &op, + SmallVector &resultGroups, + bool hasRegion) { + + // Lambda which fetches the list of relevant attributes (e.g., + // mlir.resultNames) and associates them with the relevant values + auto handleNamedAttributes = + [this](Operation &op, const Twine &attrName, auto getValuesFunc, + std::optional> customAction = + std::nullopt) { + if (ArrayAttr namesAttr = op.getAttrOfType(attrName.str())) { + auto names = namesAttr.getValue(); + auto values = getValuesFunc(); + // Conservative in case the number of values has changed + for (size_t i = 0; i < values.size() && i < names.size(); ++i) { + auto name = names[i].cast().strref(); + if (!name.empty()) { + if (!this->usedNames.count(name)) + this->setValueName(values[i], name, true); + if (customAction.has_value()) + customAction.value()(i); + } + } + op.removeDiscardableAttr(attrName.str()); + } + }; + + if (hasRegion) { + // Get the original name(s) for the region arg(s) if available (e.g., for + // FuncOp args). Requires hasRegion flag to ensure scoping is correct + if (hasRegion && op.getNumRegions() > 0 && + op.getRegion(0).getNumArguments() > 0) { + handleNamedAttributes(op, "mlir.regionArgNames", + [&]() { return op.getRegion(0).getArguments(); }); } - op.removeDiscardableAttr("mlir.opArgNames"); - } - - // Get the original name for the block if available - if (StringAttr blockNameAttr = - op.getAttrOfType("mlir.blockName")) { - blockNames[op.getBlock()] = {-1, blockNameAttr.strref()}; - op.removeDiscardableAttr("mlir.blockName"); - } - - // Get the original name for the block args if available - if (ArrayAttr blockArgNamesAttr = - op.getAttrOfType("mlir.blockArgNames")) { - auto blockArgNames = blockArgNamesAttr.getValue(); - auto blockArgs = op.getBlock()->getArguments(); - // Conservative in the case that the #args has changed - for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) { - auto blockArgName = blockArgNames[i].cast().strref(); - if (!usedNames.count(blockArgName)) - setValueName(blockArgs[i], blockArgName, /*allowNumeric=*/true); + } else { + // Get the original names for the results if available + handleNamedAttributes( + op, "mlir.resultNames", [&]() { return op.getResults(); }, + [&resultGroups](int i) { /*handles result groups*/ + if (i > 0) + resultGroups.push_back(i); + }); + + // Get the original name for the op args if available + handleNamedAttributes(op, "mlir.opArgNames", + [&]() { return op.getOperands(); }); + + // Get the original name for the block if available + if (StringAttr blockNameAttr = + op.getAttrOfType("mlir.blockName")) { + blockNames[op.getBlock()] = {-1, blockNameAttr.strref()}; + op.removeDiscardableAttr("mlir.blockName"); } - op.removeDiscardableAttr("mlir.blockArgNames"); + + // Get the original name(s) for the block arg(s) if available + handleNamedAttributes(op, "mlir.blockArgNames", + [&]() { return op.getBlock()->getArguments(); }); } return; } +void SSANameState::setRetainedIdentifierNames(Region ®ion) { + if (Operation *op = region.getParentOp()) { + SmallVector resultGroups; + setRetainedIdentifierNames(*op, resultGroups, true); + } +} + void SSANameState::getResultIDAndNumber( OpResult result, Value &lookupValue, std::optional &lookupResultNo) const { diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir index b3e4f075b3936..65aa0507d205f 100644 --- a/mlir/test/IR/print-retain-identifiers.mlir +++ b/mlir/test/IR/print-retain-identifiers.mlir @@ -5,8 +5,8 @@ // Test SSA results (with single return values) //===----------------------------------------------------------------------===// -// CHECK: func.func @add_one(%my_input: f64, %arg1: f64) -> f64 { -func.func @add_one(%my_input: f64, %arg1: f64) -> f64 { +// CHECK: func.func @add_one(%my_input: f64) -> f64 { +func.func @add_one(%my_input: f64) -> f64 { // CHECK: %my_constant = arith.constant 1.000000e+00 : f64 %my_constant = arith.constant 1.000000e+00 : f64 // CHECK: %my_output = arith.addf %my_input, %my_constant : f64 @@ -71,7 +71,7 @@ func.func @select_min_max(%a: f64, %b: f64) -> (f64, f64) { // ----- -////===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// // Test multiple return values, with a grouped value tuple //===----------------------------------------------------------------------===// @@ -90,3 +90,24 @@ func.func @select_max(%a: f64, %b: f64, %c: f64, %d: f64) -> (f64, f64, f64, f64 } // ----- + +//===----------------------------------------------------------------------===// +// Test identifiers which may clash with OpAsmOpInterface names (e.g., cst, %1, etc) +//===----------------------------------------------------------------------===// + +// CHECK: func.func @clash(%arg1: f64, %arg0: f64, %arg2: f64) -> f64 { +func.func @clash(%arg1: f64, %arg0: f64, %arg2: f64) -> f64 { + %my_constant = arith.constant 1.000000e+00 : f64 + // CHECK: %cst = arith.constant 2.000000e+00 : f64 + %cst = arith.constant 2.000000e+00 : f64 + // CHECK: %cst_1 = arith.constant 3.000000e+00 : f64 + %cst_1 = arith.constant 3.000000e+00 : f64 + // CHECK: %1 = arith.addf %arg1, %cst : f64 + %1 = arith.addf %arg1, %cst : f64 + // CHECK: %0 = arith.addf %arg1, %cst_1 : f64 + %0 = arith.addf %arg1, %cst_1 : f64 + // CHECK: return %1 : f64 + return %1 : f64 +} + +// -----