diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h index 42cbedcf9f883..36f80712efb6a 100644 --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -464,8 +464,10 @@ class ParserConfig { /// `fallbackResourceMap` is an optional fallback handler that can be used to /// parse external resources not explicitly handled by another parser. ParserConfig(MLIRContext *context, bool verifyAfterParse = true, - FallbackAsmResourceMap *fallbackResourceMap = nullptr) + FallbackAsmResourceMap *fallbackResourceMap = nullptr, + bool retainIdentifierNames = false) : context(context), verifyAfterParse(verifyAfterParse), + retainIdentifierNames(retainIdentifierNames), fallbackResourceMap(fallbackResourceMap) { assert(context && "expected valid MLIR context"); } @@ -476,6 +478,10 @@ class ParserConfig { /// Returns if the parser should verify the IR after parsing. bool shouldVerifyAfterParse() const { return verifyAfterParse; } + /// Returns if the parser should retain identifier names collected using + /// parsing. + bool shouldRetainIdentifierNames() const { return retainIdentifierNames; } + /// Returns the parsing configurations associated to the bytecode read. BytecodeReaderConfig &getBytecodeReaderConfig() const { return const_cast(bytecodeReaderConfig); @@ -513,6 +519,7 @@ class ParserConfig { private: MLIRContext *context; bool verifyAfterParse; + bool retainIdentifierNames; DenseMap> resourceParsers; FallbackAsmResourceMap *fallbackResourceMap; BytecodeReaderConfig bytecodeReaderConfig; diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h index 6e90fad1618d2..a85dca186a4f3 100644 --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -176,6 +176,13 @@ class MlirOptMainConfig { /// Reproducer file generation (no crash required). StringRef getReproducerFilename() const { return generateReproducerFileFlag; } + /// Print the pass-pipeline as text before executing. + MlirOptMainConfig &retainIdentifierNames(bool retain) { + retainIdentifierNamesFlag = retain; + return *this; + } + bool shouldRetainIdentifierNames() const { return retainIdentifierNamesFlag; } + protected: /// Allow operation with no registered dialects. /// This option is for convenience during testing only and discouraged in @@ -226,6 +233,9 @@ class MlirOptMainConfig { /// the corresponding line. This is meant for implementing diagnostic tests. bool verifyDiagnosticsFlag = false; + /// Retain identifier names in the output (e.g., `%my_var` instead of `%0`). + bool retainIdentifierNamesFlag = false; + /// Run the verifier after each transformation pass. bool verifyPassesFlag = true; diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index 00f2b0c0c2f12..6f0c5fa30ffa5 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -611,6 +611,11 @@ class OperationParser : public Parser { /// an object of type 'OperationName'. Otherwise, failure is returned. FailureOr parseCustomOperationName(); + /// Store the identifier names for the current operation as attrs for debug + /// purposes. + void storeIdentifierNames(Operation *&op, ArrayRef resultIDs); + DenseMap argNames; + //===--------------------------------------------------------------------===// // Region Parsing //===--------------------------------------------------------------------===// @@ -1268,6 +1273,70 @@ OperationParser::parseSuccessors(SmallVectorImpl &destinations) { /*allowEmptyList=*/false); } +/// Store the SSA names for the current operation as attrs for debug purposes. +void OperationParser::storeIdentifierNames(Operation *&op, + ArrayRef resultIDs) { + + // Store the name(s) of the result(s) of this operation. + if (op->getNumResults() > 0) { + llvm::SmallVector resultNames; + for (const ResultRecord &resIt : resultIDs) { + resultNames.push_back(std::get<0>(resIt).drop_front(1)); + // Insert empty string for sub-results/result groups + for (unsigned int i = 1; i < std::get<1>(resIt); ++i) + resultNames.push_back(llvm::StringRef()); + } + op->setDiscardableAttr("mlir.resultNames", + builder.getStrArrayAttr(resultNames)); + } + + // Store the name information of the arguments of this operation. + if (op->getNumOperands() > 0) { + llvm::SmallVector opArgNames; + for (auto &operand : op->getOpOperands()) { + auto it = argNames.find(operand.get()); + if (it != argNames.end()) + opArgNames.push_back(it->second.drop_front(1)); + } + op->setDiscardableAttr("mlir.opArgNames", + builder.getStrArrayAttr(opArgNames)); + } + + // Store the name information of the block that contains this operation. + Block *blockPtr = op->getBlock(); + for (const auto &map : blocksByName) { + for (const auto &entry : map) { + if (entry.second.block == blockPtr) { + op->setDiscardableAttr("mlir.blockName", + StringAttr::get(getContext(), entry.first)); + + // Store block arguments, if present + llvm::SmallVector blockArgNames; + + for (BlockArgument arg : blockPtr->getArguments()) { + auto it = argNames.find(arg); + if (it != argNames.end()) + blockArgNames.push_back(it->second.drop_front(1)); + } + op->setAttr("mlir.blockArgNames", + builder.getStrArrayAttr(blockArgNames)); + } + } + } + + // Store names of region arguments (e.g., for FuncOps) + if (op->getNumRegions() > 0 && op->getRegion(0).getNumArguments() > 0) { + llvm::SmallVector regionArgNames; + for (BlockArgument arg : op->getRegion(0).getArguments()) { + auto it = argNames.find(arg); + if (it != argNames.end()) { + regionArgNames.push_back(it->second.drop_front(1)); + } + } + op->setAttr("mlir.regionArgNames", builder.getStrArrayAttr(regionArgNames)); + } +} + namespace { // RAII-style guard for cleaning up the regions in the operation state before // deleting them. Within the parser, regions may get deleted if parsing failed, @@ -1672,6 +1741,11 @@ class CustomOpAsmParser : public AsmParserImpl { SmallVectorImpl &result) override { if (auto value = parser.resolveSSAUse(operand, type)) { result.push_back(value); + + // Optionally store argument name for debug purposes + if (parser.getState().config.shouldRetainIdentifierNames()) + parser.argNames.insert({value, operand.name}); + return success(); } return failure(); @@ -2031,6 +2105,11 @@ OperationParser::parseCustomOperation(ArrayRef resultIDs) { // Otherwise, create the operation and try to parse a location for it. Operation *op = opBuilder.create(opState); + + // If enabled, store the original identifier name(s) for the operation + if (state.config.shouldRetainIdentifierNames()) + storeIdentifierNames(op, resultIDs); + if (parseTrailingLocationSpecifier(op)) return nullptr; @@ -2180,6 +2259,9 @@ ParseResult OperationParser::parseRegionBody(Region ®ion, SMLoc startLoc, if (state.asmState) state.asmState->addDefinition(arg, argInfo.location); + if (state.config.shouldRetainIdentifierNames()) + argNames.insert({arg, argInfo.name}); + // Record the definition for this argument. if (addDefinition(argInfo, arg)) return failure(); @@ -2355,6 +2437,10 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) { } else { auto loc = getEncodedSourceLocation(useInfo.location); arg = owner->addArgument(type, loc); + + // Optionally store argument name for debug purposes + if (state.config.shouldRetainIdentifierNames()) + argNames.insert({arg, useInfo.name}); } // If the argument has an explicit loc(...) specifier, parse and apply diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 6b8b7473bf0f8..51f4bb66a8414 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -72,13 +72,12 @@ OpAsmParser::~OpAsmParser() = default; MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); } /// Parse a type list. -/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918 +/// This is out-of-line to work-around +/// https://github.com/llvm/llvm-project/issues/62918 ParseResult AsmParser::parseTypeList(SmallVectorImpl &result) { - return parseCommaSeparatedList( - [&]() { return parseType(result.emplace_back()); }); - } - - + return parseCommaSeparatedList( + [&]() { return parseType(result.emplace_back()); }); +} //===----------------------------------------------------------------------===// // DialectAsmPrinter @@ -982,7 +981,8 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter { /// store the new copy, static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, StringRef allowedPunctChars = "$._-", - bool allowTrailingDigit = true) { + bool allowTrailingDigit = true, + bool allowNumeric = false) { assert(!name.empty() && "Shouldn't have an empty name here"); auto copyNameToBuffer = [&] { @@ -998,16 +998,17 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, // Check to see if this name is valid. If it starts with a digit, then it // could conflict with the autogenerated numeric ID's, so add an underscore - // prefix to avoid problems. - if (isdigit(name[0])) { + // prefix to avoid problems. This can be overridden by setting allowNumeric. + if (isdigit(name[0]) && !allowNumeric) { buffer.push_back('_'); copyNameToBuffer(); return buffer; } // If the name ends with a trailing digit, add a '_' to avoid potential - // conflicts with autogenerated ID's. - if (!allowTrailingDigit && isdigit(name.back())) { + // conflicts with autogenerated ID's. This can be overridden by setting + // allowNumeric. + if (!allowTrailingDigit && isdigit(name.back()) && !allowNumeric) { copyNameToBuffer(); buffer.push_back('_'); return buffer; @@ -1293,11 +1294,18 @@ class SSANameState { std::optional &lookupResultNo) const; /// Set a special value name for the given value. - void setValueName(Value value, StringRef name); + void setValueName(Value value, StringRef name, bool allowNumeric = false); /// Uniques the given value name within the printer. If the given name /// conflicts, it is automatically renamed. - StringRef uniqueValueName(StringRef name); + StringRef uniqueValueName(StringRef name, bool allowNumeric = false); + + /// Set the original identifier names if available. Used in debugging with + /// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig + void setRetainedIdentifierNames(Operation &op, + SmallVector &resultGroups, + bool hasRegion = false); + void setRetainedIdentifierNames(Region ®ion); /// This is the value ID for each SSA value. If this returns NameSentinel, /// then the valueID has an entry in valueNames. @@ -1486,6 +1494,9 @@ void SSANameState::numberValuesInRegion(Region ®ion) { setValueName(arg, name); }; + // Use manually specified region arg names if available + setRetainedIdentifierNames(region); + if (!printerFlags.shouldPrintGenericOpForm()) { if (Operation *op = region.getParentOp()) { if (auto asmInterface = dyn_cast(op)) @@ -1537,7 +1548,10 @@ void SSANameState::numberValuesInOp(Operation &op) { // Function used to set the special result names for the operation. SmallVector resultGroups(/*Size=*/1, /*Value=*/0); auto setResultNameFn = [&](Value result, StringRef name) { - assert(!valueIDs.count(result) && "result numbered multiple times"); + // Case where the result has already been named + if (valueIDs.count(result)) + return; + // assert(!valueIDs.count(result) && "result numbered multiple times"); assert(result.getDefiningOp() == &op && "result not defined by 'op'"); setValueName(result, name); @@ -1561,6 +1575,10 @@ void SSANameState::numberValuesInOp(Operation &op) { blockNames[block] = {-1, name}; }; + // Set the original identifier names if available. Used in debugging with + // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig + setRetainedIdentifierNames(op, resultGroups); + if (!printerFlags.shouldPrintGenericOpForm()) { if (OpAsmOpInterface asmInterface = dyn_cast(&op)) { asmInterface.getAsmBlockNames(setBlockNameFn); @@ -1590,6 +1608,75 @@ void SSANameState::numberValuesInOp(Operation &op) { } } +void SSANameState::setRetainedIdentifierNames(Operation &op, + SmallVector &resultGroups, + bool hasRegion) { + + // Lambda which fetches the list of relevant attributes (e.g., + // mlir.resultNames) and associates them with the relevant values + auto handleNamedAttributes = + [this](Operation &op, const Twine &attrName, auto getValuesFunc, + std::optional> customAction = + std::nullopt) { + if (ArrayAttr namesAttr = op.getAttrOfType(attrName.str())) { + auto names = namesAttr.getValue(); + auto values = getValuesFunc(); + // Conservative in case the number of values has changed + for (size_t i = 0; i < values.size() && i < names.size(); ++i) { + auto name = names[i].cast().strref(); + if (!name.empty()) { + if (!this->usedNames.count(name)) + this->setValueName(values[i], name, true); + if (customAction.has_value()) + customAction.value()(i); + } + } + op.removeDiscardableAttr(attrName.str()); + } + }; + + if (hasRegion) { + // Get the original name(s) for the region arg(s) if available (e.g., for + // FuncOp args). Requires hasRegion flag to ensure scoping is correct + if (hasRegion && op.getNumRegions() > 0 && + op.getRegion(0).getNumArguments() > 0) { + handleNamedAttributes(op, "mlir.regionArgNames", + [&]() { return op.getRegion(0).getArguments(); }); + } + } else { + // Get the original names for the results if available + handleNamedAttributes( + op, "mlir.resultNames", [&]() { return op.getResults(); }, + [&resultGroups](int i) { /*handles result groups*/ + if (i > 0) + resultGroups.push_back(i); + }); + + // Get the original name for the op args if available + handleNamedAttributes(op, "mlir.opArgNames", + [&]() { return op.getOperands(); }); + + // Get the original name for the block if available + if (StringAttr blockNameAttr = + op.getAttrOfType("mlir.blockName")) { + blockNames[op.getBlock()] = {-1, blockNameAttr.strref()}; + op.removeDiscardableAttr("mlir.blockName"); + } + + // Get the original name(s) for the block arg(s) if available + handleNamedAttributes(op, "mlir.blockArgNames", + [&]() { return op.getBlock()->getArguments(); }); + } + return; +} + +void SSANameState::setRetainedIdentifierNames(Region ®ion) { + if (Operation *op = region.getParentOp()) { + SmallVector resultGroups; + setRetainedIdentifierNames(*op, resultGroups, true); + } +} + void SSANameState::getResultIDAndNumber( OpResult result, Value &lookupValue, std::optional &lookupResultNo) const { @@ -1629,7 +1716,8 @@ void SSANameState::getResultIDAndNumber( lookupValue = owner->getResult(groupResultNo); } -void SSANameState::setValueName(Value value, StringRef name) { +void SSANameState::setValueName(Value value, StringRef name, + bool allowNumeric) { // If the name is empty, the value uses the default numbering. if (name.empty()) { valueIDs[value] = nextValueID++; @@ -1637,12 +1725,13 @@ void SSANameState::setValueName(Value value, StringRef name) { } valueIDs[value] = NameSentinel; - valueNames[value] = uniqueValueName(name); + valueNames[value] = uniqueValueName(name, allowNumeric); } -StringRef SSANameState::uniqueValueName(StringRef name) { +StringRef SSANameState::uniqueValueName(StringRef name, bool allowNumeric) { SmallString<16> tmpBuffer; - name = sanitizeIdentifier(name, tmpBuffer); + name = sanitizeIdentifier(name, tmpBuffer, /*allowedPunctChars=*/"$._-", + /*allowTrailingDigit=*/true, allowNumeric); // Check to see if this name is already unique. if (!usedNames.count(name)) { diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp index 5395aa2b502d7..c448243586159 100644 --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -149,6 +149,11 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig { cl::desc("Round-trip the IR after parsing and ensure it succeeds"), cl::location(verifyRoundtripFlag), cl::init(false)); + static cl::opt retainIdentifierNames( + "retain-identifier-names", + cl::desc("Retain the original names of identifiers when printing"), + cl::location(retainIdentifierNamesFlag), cl::init(false)); + static cl::list passPlugins( "load-pass-plugin", cl::desc("Load passes from plugin library")); @@ -359,8 +364,9 @@ performActions(raw_ostream &os, // untouched. PassReproducerOptions reproOptions; FallbackAsmResourceMap fallbackResourceMap; - ParserConfig parseConfig(context, /*verifyAfterParse=*/true, - &fallbackResourceMap); + ParserConfig parseConfig( + context, /*verifyAfterParse=*/true, &fallbackResourceMap, + /*retainIdentifierName=*/config.shouldRetainIdentifierNames()); if (config.shouldRunReproducer()) reproOptions.attachResourceParser(parseConfig); diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir new file mode 100644 index 0000000000000..65aa0507d205f --- /dev/null +++ b/mlir/test/IR/print-retain-identifiers.mlir @@ -0,0 +1,113 @@ +// RUN: mlir-opt -retain-identifier-names %s | FileCheck %s + + +//===----------------------------------------------------------------------===// +// Test SSA results (with single return values) +//===----------------------------------------------------------------------===// + +// CHECK: func.func @add_one(%my_input: f64) -> f64 { +func.func @add_one(%my_input: f64) -> f64 { + // CHECK: %my_constant = arith.constant 1.000000e+00 : f64 + %my_constant = arith.constant 1.000000e+00 : f64 + // CHECK: %my_output = arith.addf %my_input, %my_constant : f64 + %my_output = arith.addf %my_input, %my_constant : f64 + // CHECK: return %my_output : f64 + return %my_output : f64 +} + + +// ----- + +//===----------------------------------------------------------------------===// +// Test basic blocks and their arguments +//===----------------------------------------------------------------------===// + +func.func @simple(i64, i1) -> i64 { +^bb_alpha(%a: i64, %cond: i1): + // CHECK: cf.cond_br %cond, ^bb_beta, ^bb_gamma + cf.cond_br %cond, ^bb_beta, ^bb_gamma + +// CHECK: ^bb_beta: // pred: ^bb_alpha +^bb_beta: + // CHECK: cf.br ^bb_delta(%a : i64) + cf.br ^bb_delta(%a: i64) + +// CHECK: ^bb_gamma: // pred: ^bb_alpha +^bb_gamma: + // CHECK: %b = arith.addi %a, %a : i64 + %b = arith.addi %a, %a : i64 + // CHECK: cf.br ^bb_delta(%b : i64) + cf.br ^bb_delta(%b: i64) + +// CHECK: ^bb_delta(%c: i64): // 2 preds: ^bb_gamma, ^bb_beta +^bb_delta(%c: i64): + // CHECK: cf.br ^bb_eps(%c, %a : i64, i64) + cf.br ^bb_eps(%c, %a : i64, i64) + +// CHECK: ^bb_eps(%d: i64, %e: i64): // pred: ^bb_delta +^bb_eps(%d : i64, %e : i64): + // CHECK: %f = arith.addi %d, %e : i64 + %f = arith.addi %d, %e : i64 + return %f : i64 +} + +// ----- + +//===----------------------------------------------------------------------===// +// Test multiple return values +//===----------------------------------------------------------------------===// + +func.func @select_min_max(%a: f64, %b: f64) -> (f64, f64) { + %gt = arith.cmpf "ogt", %a, %b : f64 + // CHECK: %min, %max = scf.if %gt -> (f64, f64) { + %min, %max = scf.if %gt -> (f64, f64) { + scf.yield %b, %a : f64, f64 + } else { + scf.yield %a, %b : f64, f64 + } + // CHECK: return %min, %max : f64, f64 + return %min, %max : f64, f64 +} + +// ----- + +//===----------------------------------------------------------------------===// +// Test multiple return values, with a grouped value tuple +//===----------------------------------------------------------------------===// + +func.func @select_max(%a: f64, %b: f64, %c: f64, %d: f64) -> (f64, f64, f64, f64) { + // Find the max between %a and %b, + // with %c and %d being other values that are returned. + %gt = arith.cmpf "ogt", %a, %b : f64 + // CHECK: %max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) { + %max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) { + scf.yield %b, %a, %c, %d : f64, f64, f64, f64 + } else { + scf.yield %a, %b, %d, %c : f64, f64, f64, f64 + } + // CHECK: return %max, %others#0, %others#1, %alt : f64, f64, f64, f64 + return %max, %others#0, %others#1, %alt : f64, f64, f64, f64 +} + +// ----- + +//===----------------------------------------------------------------------===// +// Test identifiers which may clash with OpAsmOpInterface names (e.g., cst, %1, etc) +//===----------------------------------------------------------------------===// + +// CHECK: func.func @clash(%arg1: f64, %arg0: f64, %arg2: f64) -> f64 { +func.func @clash(%arg1: f64, %arg0: f64, %arg2: f64) -> f64 { + %my_constant = arith.constant 1.000000e+00 : f64 + // CHECK: %cst = arith.constant 2.000000e+00 : f64 + %cst = arith.constant 2.000000e+00 : f64 + // CHECK: %cst_1 = arith.constant 3.000000e+00 : f64 + %cst_1 = arith.constant 3.000000e+00 : f64 + // CHECK: %1 = arith.addf %arg1, %cst : f64 + %1 = arith.addf %arg1, %cst : f64 + // CHECK: %0 = arith.addf %arg1, %cst_1 : f64 + %0 = arith.addf %arg1, %cst_1 : f64 + // CHECK: return %1 : f64 + return %1 : f64 +} + +// -----