Skip to content
9 changes: 8 additions & 1 deletion mlir/include/mlir/IR/AsmState.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,10 @@ class ParserConfig {
/// `fallbackResourceMap` is an optional fallback handler that can be used to
/// parse external resources not explicitly handled by another parser.
ParserConfig(MLIRContext *context, bool verifyAfterParse = true,
FallbackAsmResourceMap *fallbackResourceMap = nullptr)
FallbackAsmResourceMap *fallbackResourceMap = nullptr,
bool retainIdentifierNames = false)
: context(context), verifyAfterParse(verifyAfterParse),
retainIdentifierNames(retainIdentifierNames),
fallbackResourceMap(fallbackResourceMap) {
assert(context && "expected valid MLIR context");
}
Expand All @@ -476,6 +478,10 @@ class ParserConfig {
/// Returns if the parser should verify the IR after parsing.
bool shouldVerifyAfterParse() const { return verifyAfterParse; }

/// Returns if the parser should retain identifier names collected using
/// parsing.
bool shouldRetainIdentifierNames() const { return retainIdentifierNames; }

/// Returns the parsing configurations associated to the bytecode read.
BytecodeReaderConfig &getBytecodeReaderConfig() const {
return const_cast<BytecodeReaderConfig &>(bytecodeReaderConfig);
Expand Down Expand Up @@ -513,6 +519,7 @@ class ParserConfig {
private:
MLIRContext *context;
bool verifyAfterParse;
bool retainIdentifierNames;
DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
FallbackAsmResourceMap *fallbackResourceMap;
BytecodeReaderConfig bytecodeReaderConfig;
Expand Down
10 changes: 10 additions & 0 deletions mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@ class MlirOptMainConfig {
/// Reproducer file generation (no crash required).
StringRef getReproducerFilename() const { return generateReproducerFileFlag; }

/// Print the pass-pipeline as text before executing.
MlirOptMainConfig &retainIdentifierNames(bool retain) {
retainIdentifierNamesFlag = retain;
return *this;
}
bool shouldRetainIdentifierNames() const { return retainIdentifierNamesFlag; }

protected:
/// Allow operation with no registered dialects.
/// This option is for convenience during testing only and discouraged in
Expand Down Expand Up @@ -226,6 +233,9 @@ class MlirOptMainConfig {
/// the corresponding line. This is meant for implementing diagnostic tests.
bool verifyDiagnosticsFlag = false;

/// Retain identifier names in the output (e.g., `%my_var` instead of `%0`).
bool retainIdentifierNamesFlag = false;

/// Run the verifier after each transformation pass.
bool verifyPassesFlag = true;

Expand Down
70 changes: 70 additions & 0 deletions mlir/lib/AsmParser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,10 @@ class OperationParser : public Parser {
/// an object of type 'OperationName'. Otherwise, failure is returned.
FailureOr<OperationName> parseCustomOperationName();

/// Store the SSA names for the current operation as attrs for debug purposes.
void storeSSANames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
DenseMap<Value, StringRef> argNames;

//===--------------------------------------------------------------------===//
// Region Parsing
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -1268,6 +1272,58 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
/*allowEmptyList=*/false);
}

/// Store the SSA names for the current operation as attrs for debug purposes.
void OperationParser::storeSSANames(Operation *&op,
ArrayRef<ResultRecord> resultIDs) {

// Store the name(s) of the result(s) of this operation.
if (op->getNumResults() > 0) {
llvm::SmallVector<llvm::StringRef, 1> resultNames;
for (const ResultRecord &resIt : resultIDs) {
resultNames.push_back(std::get<0>(resIt).drop_front(1));
// Insert empty string for sub-results/result groups
for (unsigned int i = 1; i < std::get<1>(resIt); ++i)
resultNames.push_back(llvm::StringRef());
}
op->setDiscardableAttr("mlir.resultNames",
builder.getStrArrayAttr(resultNames));
}

// Store the name information of the arguments of this operation.
if (op->getNumOperands() > 0) {
llvm::SmallVector<llvm::StringRef, 1> opArgNames;
for (auto &operand : op->getOpOperands()) {
auto it = argNames.find(operand.get());
if (it != argNames.end())
opArgNames.push_back(it->second.drop_front(1));
}
op->setDiscardableAttr("mlir.opArgNames",
builder.getStrArrayAttr(opArgNames));
}

// Store the name information of the block that contains this operation.
Block *blockPtr = op->getBlock();
for (const auto &map : blocksByName) {
for (const auto &entry : map) {
if (entry.second.block == blockPtr) {
op->setDiscardableAttr("mlir.blockName",
StringAttr::get(getContext(), entry.first));

// Store block arguments, if present
llvm::SmallVector<llvm::StringRef, 1> blockArgNames;

for (BlockArgument arg : blockPtr->getArguments()) {
auto it = argNames.find(arg);
if (it != argNames.end())
blockArgNames.push_back(it->second.drop_front(1));
}
op->setAttr("mlir.blockArgNames",
builder.getStrArrayAttr(blockArgNames));
}
}
}
}

namespace {
// RAII-style guard for cleaning up the regions in the operation state before
// deleting them. Within the parser, regions may get deleted if parsing failed,
Expand Down Expand Up @@ -1672,6 +1728,11 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
SmallVectorImpl<Value> &result) override {
if (auto value = parser.resolveSSAUse(operand, type)) {
result.push_back(value);

// Optionally store argument name for debug purposes
if (parser.getState().config.shouldRetainIdentifierNames())
parser.argNames.insert({value, operand.name});

return success();
}
return failure();
Expand Down Expand Up @@ -2031,6 +2092,11 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {

// Otherwise, create the operation and try to parse a location for it.
Operation *op = opBuilder.create(opState);

// If enabled, store the SSA name(s) for the operation
if (state.config.shouldRetainIdentifierNames())
storeSSANames(op, resultIDs);

if (parseTrailingLocationSpecifier(op))
return nullptr;

Expand Down Expand Up @@ -2355,6 +2421,10 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
} else {
auto loc = getEncodedSourceLocation(useInfo.location);
arg = owner->addArgument(type, loc);

// Optionally store argument name for debug purposes
if (state.config.shouldRetainIdentifierNames())
argNames.insert({arg, useInfo.name});
}

// If the argument has an explicit loc(...) specifier, parse and apply
Expand Down
78 changes: 72 additions & 6 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,12 @@ OpAsmParser::~OpAsmParser() = default;
MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }

/// Parse a type list.
/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
/// This is out-of-line to work-around
/// https://github.com/llvm/llvm-project/issues/62918
ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
return parseCommaSeparatedList(
[&]() { return parseType(result.emplace_back()); });
}


return parseCommaSeparatedList(
[&]() { return parseType(result.emplace_back()); });
}

//===----------------------------------------------------------------------===//
// DialectAsmPrinter
Expand Down Expand Up @@ -1299,6 +1298,11 @@ class SSANameState {
/// conflicts, it is automatically renamed.
StringRef uniqueValueName(StringRef name);

/// Set the original identifier names if available. Used in debugging with
/// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
void setRetainedIdentifierNames(Operation &op,
SmallVector<int, 2> &resultGroups);

/// This is the value ID for each SSA value. If this returns NameSentinel,
/// then the valueID has an entry in valueNames.
DenseMap<Value, unsigned> valueIDs;
Expand Down Expand Up @@ -1568,6 +1572,10 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
}

// Set the original identifier names if available. Used in debugging with
// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
setRetainedIdentifierNames(op, resultGroups);

unsigned numResults = op.getNumResults();
if (numResults == 0) {
// If value users should be printed, operations with no result need an id.
Expand All @@ -1590,6 +1598,64 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
}

void SSANameState::setRetainedIdentifierNames(
Operation &op, SmallVector<int, 2> &resultGroups) {
// Get the original names for the results if available
if (ArrayAttr resultNamesAttr =
op.getAttrOfType<ArrayAttr>("mlir.resultNames")) {
auto resultNames = resultNamesAttr.getValue();
auto results = op.getResults();
// Conservative in the case that the #results has changed
for (size_t i = 0; i < results.size() && i < resultNames.size(); ++i) {
auto resultName = resultNames[i].cast<StringAttr>().strref();
if (!resultName.empty()) {
if (!usedNames.count(resultName))
setValueName(results[i], resultName);
// If a result has a name, it is the start of a result group.
if (i > 0)
resultGroups.push_back(i);
}
}
op.removeDiscardableAttr("mlir.resultNames");
}

// Get the original name for the op args if available
if (ArrayAttr opArgNamesAttr =
op.getAttrOfType<ArrayAttr>("mlir.opArgNames")) {
auto opArgNames = opArgNamesAttr.getValue();
auto opArgs = op.getOperands();
// Conservative in the case that the #operands has changed
for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) {
auto opArgName = opArgNames[i].cast<StringAttr>().strref();
if (!usedNames.count(opArgName))
setValueName(opArgs[i], opArgName);
}
op.removeDiscardableAttr("mlir.opArgNames");
}

// Get the original name for the block if available
if (StringAttr blockNameAttr =
op.getAttrOfType<StringAttr>("mlir.blockName")) {
blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
op.removeDiscardableAttr("mlir.blockName");
}

// Get the original name for the block args if available
if (ArrayAttr blockArgNamesAttr =
op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
auto blockArgNames = blockArgNamesAttr.getValue();
auto blockArgs = op.getBlock()->getArguments();
// Conservative in the case that the #args has changed
for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
auto blockArgName = blockArgNames[i].cast<StringAttr>().strref();
if (!usedNames.count(blockArgName))
setValueName(blockArgs[i], blockArgName);
}
op.removeDiscardableAttr("mlir.blockArgNames");
}
return;
}

void SSANameState::getResultIDAndNumber(
OpResult result, Value &lookupValue,
std::optional<int> &lookupResultNo) const {
Expand Down
10 changes: 8 additions & 2 deletions mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
cl::desc("Round-trip the IR after parsing and ensure it succeeds"),
cl::location(verifyRoundtripFlag), cl::init(false));

static cl::opt<bool, /*ExternalStorage=*/true> retainIdentifierNames(
"retain-identifier-names",
cl::desc("Retain the original names of identifiers when printing"),
cl::location(retainIdentifierNamesFlag), cl::init(false));

static cl::list<std::string> passPlugins(
"load-pass-plugin", cl::desc("Load passes from plugin library"));

Expand Down Expand Up @@ -359,8 +364,9 @@ performActions(raw_ostream &os,
// untouched.
PassReproducerOptions reproOptions;
FallbackAsmResourceMap fallbackResourceMap;
ParserConfig parseConfig(context, /*verifyAfterParse=*/true,
&fallbackResourceMap);
ParserConfig parseConfig(
context, /*verifyAfterParse=*/true, &fallbackResourceMap,
/*retainIdentifierName=*/config.shouldRetainIdentifierNames());
if (config.shouldRunReproducer())
reproOptions.attachResourceParser(parseConfig);

Expand Down
92 changes: 92 additions & 0 deletions mlir/test/IR/print-retain-identifiers.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// RUN: mlir-opt -retain-identifier-names %s | FileCheck %s


//===----------------------------------------------------------------------===//
// Test SSA results (with single return values)
//===----------------------------------------------------------------------===//

// CHECK: func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
// CHECK: %my_constant = arith.constant 1.000000e+00 : f64
%my_constant = arith.constant 1.000000e+00 : f64
// CHECK: %my_output = arith.addf %my_input, %my_constant : f64
%my_output = arith.addf %my_input, %my_constant : f64
// CHECK: return %my_output : f64
return %my_output : f64
}


// -----

//===----------------------------------------------------------------------===//
// Test basic blocks and their arguments
//===----------------------------------------------------------------------===//

func.func @simple(i64, i1) -> i64 {
^bb_alpha(%a: i64, %cond: i1):
// CHECK: cf.cond_br %cond, ^bb_beta, ^bb_gamma
cf.cond_br %cond, ^bb_beta, ^bb_gamma

// CHECK: ^bb_beta: // pred: ^bb_alpha
^bb_beta:
// CHECK: cf.br ^bb_delta(%a : i64)
cf.br ^bb_delta(%a: i64)

// CHECK: ^bb_gamma: // pred: ^bb_alpha
^bb_gamma:
// CHECK: %b = arith.addi %a, %a : i64
%b = arith.addi %a, %a : i64
// CHECK: cf.br ^bb_delta(%b : i64)
cf.br ^bb_delta(%b: i64)

// CHECK: ^bb_delta(%c: i64): // 2 preds: ^bb_gamma, ^bb_beta
^bb_delta(%c: i64):
// CHECK: cf.br ^bb_eps(%c, %a : i64, i64)
cf.br ^bb_eps(%c, %a : i64, i64)

// CHECK: ^bb_eps(%d: i64, %e: i64): // pred: ^bb_delta
^bb_eps(%d : i64, %e : i64):
// CHECK: %f = arith.addi %d, %e : i64
%f = arith.addi %d, %e : i64
return %f : i64
}

// -----

//===----------------------------------------------------------------------===//
// Test multiple return values
//===----------------------------------------------------------------------===//

func.func @select_min_max(%a: f64, %b: f64) -> (f64, f64) {
%gt = arith.cmpf "ogt", %a, %b : f64
// CHECK: %min, %max = scf.if %gt -> (f64, f64) {
%min, %max = scf.if %gt -> (f64, f64) {
scf.yield %b, %a : f64, f64
} else {
scf.yield %a, %b : f64, f64
}
// CHECK: return %min, %max : f64, f64
return %min, %max : f64, f64
}

// -----

////===----------------------------------------------------------------------===//
// Test multiple return values, with a grouped value tuple
//===----------------------------------------------------------------------===//

func.func @select_max(%a: f64, %b: f64, %c: f64, %d: f64) -> (f64, f64, f64, f64) {
// Find the max between %a and %b,
// with %c and %d being other values that are returned.
%gt = arith.cmpf "ogt", %a, %b : f64
// CHECK: %max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) {
%max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) {
scf.yield %b, %a, %c, %d : f64, f64, f64, f64
} else {
scf.yield %a, %b, %d, %c : f64, f64, f64, f64
}
// CHECK: return %max, %others#0, %others#1, %alt : f64, f64, f64, f64
return %max, %others#0, %others#1, %alt : f64, f64, f64, f64
}

// -----