Skip to content

Commit 68ae2fa

Browse files
committed
Add support for type alias via ad-hoc parser
1 parent 3eebcfd commit 68ae2fa

File tree

2 files changed

+247
-49
lines changed

2 files changed

+247
-49
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: mlir-format %s --mlir-use-nameloc-as-prefix | FileCheck %s
2+
3+
// CHECK: !funky64 = f64
4+
!funky64 = f64
5+
// CHECK: !fancy64 = f64
6+
!fancy64 = f64
7+
8+
// CHECK: func.func @add_one(%b: f643) -> (f64, !funky64, !fancy64) {
9+
func.func @add_one(%b: f64) -> (f64, !funky64, !fancy64) {
10+
// CHECK: %c = arith.constant 1.00000e+00 : !funky64
11+
%c = arith.constant 1.00000e+00 : !funky64
12+
// CHECK: %x1 = arith.addf %b, %c : f64
13+
%x1 = arith.addf %b,
14+
%c : f64
15+
// CHECK: %x2 = arith.addf %b, %b : !funky64
16+
%x2 = arith.addf %b, %b : !funky64
17+
// CHECK: %x3 = arith.addf %x2, %b : !fancy64
18+
%x3 = arith.addf %x2, %b : !fancy64
19+
// CHECK: return %x1, %x2, %x3 : f64, !funky64, !fancy64
20+
return %x1, %x2, %x3 : f64, !funky64, !fancy64
21+
}

mlir/tools/mlir-format/mlir-format.cpp

Lines changed: 226 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,157 @@ class CombinedOpDefIterator {
7373
BaseIterator fmtIter;
7474
};
7575

76+
// Given the scopeLoc of an operation, extract src locations of the input and
77+
// output type
78+
std::pair<SmallVector<llvm::SMRange>, SmallVector<llvm::SMRange>>
79+
getOpTypeLoc(llvm::SMRange op_loc) {
80+
SmallVector<llvm::SMRange> inputTypeRanges;
81+
SmallVector<llvm::SMRange> outputTypeRanges;
82+
83+
// Extract the string from the range
84+
const char *startPtr = op_loc.Start.getPointer();
85+
const char *endPtr = op_loc.End.getPointer();
86+
StringRef opString(startPtr, endPtr - startPtr);
87+
88+
// Find the position of the last ':' in the string
89+
size_t colonPos = opString.rfind(':');
90+
if (colonPos == StringRef::npos) {
91+
// No ':' found, return empty vectors
92+
return {inputTypeRanges, outputTypeRanges};
93+
}
94+
95+
// Extract the type definition substring
96+
StringRef typeDefStr = opString.substr(colonPos + 1).trim();
97+
98+
// Check if the type definition substring contains '->' (input -> output
99+
// types)
100+
size_t arrowPos = typeDefStr.find("->");
101+
102+
if (arrowPos != StringRef::npos) {
103+
// Split into input and output type strings
104+
StringRef inputTypeStr = typeDefStr.substr(0, arrowPos).trim();
105+
StringRef outputTypeStr = typeDefStr.substr(arrowPos + 2).trim();
106+
107+
// Parse input type ranges (if any)
108+
if (!inputTypeStr.empty() && inputTypeStr != "()") {
109+
SmallVector<StringRef> inputTypeParts;
110+
inputTypeStr
111+
.drop_front() // Remove '('
112+
.drop_back() // Remove ')'
113+
.split(inputTypeParts, ',');
114+
115+
for (const auto &typeStr : inputTypeParts) {
116+
const char *start = typeStr.trim().data();
117+
const char *end = start + typeStr.trim().size();
118+
inputTypeRanges.push_back(
119+
llvm::SMRange(llvm::SMLoc::getFromPointer(start),
120+
llvm::SMLoc::getFromPointer(end)));
121+
}
122+
}
123+
124+
// Parse output type ranges (if any)
125+
if (!outputTypeStr.empty() && outputTypeStr != "()") {
126+
SmallVector<StringRef> outputTypeParts;
127+
outputTypeStr.split(outputTypeParts, ',');
128+
129+
for (const auto &typeStr : outputTypeParts) {
130+
const char *start = typeStr.trim().data();
131+
const char *end = start + typeStr.trim().size();
132+
outputTypeRanges.push_back(
133+
llvm::SMRange(llvm::SMLoc::getFromPointer(start),
134+
llvm::SMLoc::getFromPointer(end)));
135+
}
136+
}
137+
} else {
138+
// Single type definition (no '->'), assume it's an output type
139+
SmallVector<StringRef> typeParts;
140+
typeDefStr.split(typeParts, ',');
141+
142+
for (const auto &typeStr : typeParts) {
143+
const char *start = typeStr.trim().data();
144+
const char *end = start + typeStr.trim().size();
145+
outputTypeRanges.push_back(
146+
llvm::SMRange(llvm::SMLoc::getFromPointer(start),
147+
llvm::SMLoc::getFromPointer(end)));
148+
}
149+
}
150+
151+
return {inputTypeRanges, outputTypeRanges};
152+
}
153+
154+
llvm::SMRange getSMRangeFromString(const std::string &str) {
155+
const char *startPtr = str.data();
156+
const char *endPtr = startPtr + str.size();
157+
return llvm::SMRange(llvm::SMLoc::getFromPointer(startPtr),
158+
llvm::SMLoc::getFromPointer(endPtr));
159+
}
160+
161+
void replaceTypesInString(std::string &formattedStr,
162+
const SmallVector<llvm::SMRange> &inputTypes,
163+
const SmallVector<llvm::SMRange> &outputTypes) {
164+
// Get type locations from the formatted string
165+
llvm::SMRange formattedLoc = getSMRangeFromString(formattedStr);
166+
auto formattedTypes = getOpTypeLoc(formattedLoc);
167+
168+
// Ensure the number of types matches
169+
if (inputTypes.size() != formattedTypes.first.size() ||
170+
outputTypes.size() != formattedTypes.second.size()) {
171+
llvm::errs() << "Error: Mismatched number of input/output types in "
172+
"replacement operation.\n";
173+
return;
174+
}
175+
176+
// Perform input type replacements backwards to avoid index issues
177+
for (size_t i = inputTypes.size(); i-- > 0;) {
178+
const llvm::SMRange &formattedRange = formattedTypes.first[i];
179+
const llvm::SMRange &inputRange = inputTypes[i];
180+
181+
const char *formattedStart = formattedRange.Start.getPointer();
182+
const char *formattedEnd = formattedRange.End.getPointer();
183+
184+
const char *inputStart = inputRange.Start.getPointer();
185+
const char *inputEnd = inputRange.End.getPointer();
186+
187+
llvm::StringRef formattedType(formattedStart,
188+
formattedEnd - formattedStart);
189+
llvm::StringRef inputType(inputStart, inputEnd - inputStart);
190+
191+
// Replace in the formatted string
192+
size_t pos = formattedStr.find(formattedType.str());
193+
if (pos != std::string::npos) {
194+
formattedStr.replace(pos, formattedType.size(), inputType.str());
195+
} else {
196+
llvm::errs() << "Warning: Input type not found in formatted string: "
197+
<< formattedType << "\n";
198+
}
199+
}
200+
201+
// Perform output type replacements backwards to avoid index issues
202+
for (size_t i = outputTypes.size(); i-- > 0;) {
203+
const llvm::SMRange &formattedRange = formattedTypes.second[i];
204+
const llvm::SMRange &outputRange = outputTypes[i];
205+
206+
const char *formattedStart = formattedRange.Start.getPointer();
207+
const char *formattedEnd = formattedRange.End.getPointer();
208+
209+
const char *outputStart = outputRange.Start.getPointer();
210+
const char *outputEnd = outputRange.End.getPointer();
211+
212+
llvm::StringRef formattedType(formattedStart,
213+
formattedEnd - formattedStart);
214+
llvm::StringRef outputType(outputStart, outputEnd - outputStart);
215+
216+
// Replace in the formatted string
217+
size_t pos = formattedStr.find(formattedType.str());
218+
if (pos != std::string::npos) {
219+
formattedStr.replace(pos, formattedType.size(), outputType.str());
220+
} else {
221+
llvm::errs() << "Warning: Output type not found in formatted string: "
222+
<< formattedType << "\n";
223+
}
224+
}
225+
}
226+
76227
// Function to find the character before the previous comma
77228
const char *findPrevComma(const char *start, const char *stop_point) {
78229
if (!start) {
@@ -256,13 +407,11 @@ void Formatter::formatOps() {
256407
ParserConfig parseConfig(&context, /*verifyAfterParse=*/true,
257408
&fallbackResourceMap);
258409

259-
// Write the rewriteBuffer to a stream, that we can then parse
260410
std::string bufferContent;
261411
llvm::raw_string_ostream stream(bufferContent);
262412
rewriteBuffer.write(stream);
263413
stream.flush();
264414

265-
// Print the bufferContent to llvm::outs() for debugging.
266415
fmtSourceMgr.AddNewSourceBuffer(
267416
llvm::MemoryBuffer::getMemBufferCopy(bufferContent), SMLoc());
268417

@@ -285,67 +434,93 @@ void Formatter::formatOps() {
285434
continue;
286435

287436
// Print the fmtDef op and store as a string.
288-
// Replace the opDef with this formatted string.
289437
std::string formattedStr;
290438
llvm::raw_string_ostream stream(formattedStr);
291439
fmtDef.op->print(stream);
292440

293-
// Replacing the range:
441+
// Use the original type aliases
442+
auto orig_types = getOpTypeLoc(opDef.scopeLoc);
443+
replaceTypesInString(formattedStr, orig_types.first, orig_types.second);
444+
445+
// Replace the opDef with this formatted string.
294446
replaceRangeFmt({startOp, endOp}, formattedStr);
447+
448+
// Write the updated buffer to llvm::outs()
449+
writeFmt(llvm::outs());
295450
}
296451

297-
// Write the updated buffer to llvm::outs()
298-
writeFmt(llvm::outs());
299-
}
452+
std::string getNamedLoc(
453+
const OperationDefinition::ResultGroupDefinition &resultGroup) {
454+
auto sm_range = resultGroup.definition.loc;
455+
const char *start = sm_range.Start.getPointer();
456+
const int len = sm_range.End.getPointer() - start;
300457

301-
void markNames(Formatter &formatState, raw_ostream &os) {
302-
// Get the operation definitions from the AsmParserState.
303-
for (OperationDefinition &it : formatState.getOpDefs()) {
304-
auto [startOp, endOp] = getOpRange(it);
305-
// loop through the resultgroups
306-
for (auto &resultGroup : it.resultGroups) {
307-
auto def = resultGroup.definition;
308-
auto sm_range = def.loc;
309-
const char *start = sm_range.Start.getPointer();
310-
int len = sm_range.End.getPointer() - start;
311-
// Drop the % prefix, and put in new string with `loc("name")` format.
312-
auto name = StringRef(start + 1, len - 1);
313-
314-
// Add loc("{name}") to the end of the op
315-
std::string formattedStr = " loc(\"" + name.str() + "\")";
316-
StringRef namedLoc(formattedStr);
317-
formatState.insertText(endOp, namedLoc);
318-
}
458+
// Drop the '%' prefix and construct the `loc("name")` string
459+
auto name = llvm::StringRef(start + 1,
460+
len - 1); // Assumes the '%' is always present
461+
std::string formattedStr = " loc(\"" + name.str() + "\")";
462+
463+
return formattedStr;
319464
}
320465

321-
// Insert the NameLocs for the block arguments
322-
for (BlockDefinition &block : formatState.getBlockDefs()) {
323-
for (size_t i = 0; i < block.arguments.size(); ++i) {
324-
SMDefinition &arg = block.arguments[i];
325-
326-
// Find where to insert the NameLoc. Either before the next argument,
327-
// or at the end of the arg list
328-
const char *insertPointPtr;
329-
const char *arg_end = arg.loc.End.getPointer();
330-
SMDefinition *nextArg =
331-
(i + 1 < block.arguments.size()) ? &block.arguments[i + 1] : nullptr;
332-
if (nextArg) {
333-
const char *nextStart = nextArg->loc.Start.getPointer();
334-
insertPointPtr = findPrevComma(nextStart, arg_end);
466+
// To handle ops with multiple result groups, create a dummy "alias" op
467+
// so that we can each group its own NameLoc
468+
void insertAliasOp() {}
469+
470+
LogicalResult markNames(Formatter & formatState, raw_ostream & os) {
471+
// Get the operation definitions from the AsmParserState.
472+
for (OperationDefinition &it : formatState.getOpDefs()) {
473+
auto [startOp, endOp] = getOpRange(it);
474+
475+
if (it.resultGroups.size() == 1) {
476+
// Simple case, where we have only one result group for the op,
477+
// e.g., `%v = op` or `%v:2 = op`
478+
auto resultGroup = it.resultGroups[0];
479+
auto nameLoc = getNamedLoc(resultGroup);
480+
formatState.insertText(endOp, StringRef(nameLoc));
335481
} else {
336-
insertPointPtr = findNextCloseParenth(arg.loc.End.getPointer());
482+
// Complex case, where we have more than one result group, e.g.,
483+
// `%x, %y = op` or `%xs:2, %ys:3 = op`.
484+
// In this case we need insert some aliasing ops.
485+
for (auto &resultGroup : it.resultGroups) {
486+
auto nameLoc = getNamedLoc(resultGroup);
487+
// StringRef namedLoc(getNamedLoc(resultGroup));
488+
llvm::errs() << "Not implemented yet\n";
489+
return failure();
490+
}
337491
}
492+
}
338493

339-
// Drop the % prefix, and put in new string with `loc("name")` format.
340-
const char *start = arg.loc.Start.getPointer();
341-
const int len = arg_end - start;
342-
auto name = StringRef(start + 1, len - 1);
343-
std::string formattedStr = " loc(\"" + name.str() + "\")";
344-
StringRef namedLoc(formattedStr);
345-
formatState.insertText(SMLoc::getFromPointer(insertPointPtr), namedLoc);
494+
// Insert the NameLocs for the block arguments
495+
for (BlockDefinition &block : formatState.getBlockDefs()) {
496+
for (size_t i = 0; i < block.arguments.size(); ++i) {
497+
SMDefinition &arg = block.arguments[i];
498+
499+
// Find where to insert the NameLoc. Either before the next argument,
500+
// or at the end of the arg list
501+
const char *insertPointPtr;
502+
const char *arg_end = arg.loc.End.getPointer();
503+
SMDefinition *nextArg = (i + 1 < block.arguments.size())
504+
? &block.arguments[i + 1]
505+
: nullptr;
506+
if (nextArg) {
507+
const char *nextStart = nextArg->loc.Start.getPointer();
508+
insertPointPtr = findPrevComma(nextStart, arg_end);
509+
} else {
510+
insertPointPtr = findNextCloseParenth(arg.loc.End.getPointer());
511+
}
512+
513+
// Drop the % prefix, and put in new string with `loc("name")` format.
514+
const char *start = arg.loc.Start.getPointer();
515+
const int len = arg_end - start;
516+
auto name = StringRef(start + 1, len - 1);
517+
std::string formattedStr = " loc(\"" + name.str() + "\")";
518+
StringRef namedLoc(formattedStr);
519+
formatState.insertText(SMLoc::getFromPointer(insertPointPtr), namedLoc);
520+
}
346521
}
522+
return success();
347523
}
348-
}
349524
} // namespace mlir
350525

351526
int main(int argc, char **argv) {
@@ -370,7 +545,9 @@ int main(int argc, char **argv) {
370545
auto f = Formatter::init(inputFilename, outputFilename);
371546

372547
// Append the SSA names as NameLocs
373-
markNames(*f, llvm::outs());
548+
LogicalResult result = markNames(*f, llvm::outs());
549+
if (!succeeded(result))
550+
return mlir::asMainReturnCode(mlir::failure());
374551

375552
if (nameLocOnly) {
376553
// Return the original buffer with NameLocs appended to ops

0 commit comments

Comments
 (0)