@@ -73,6 +73,157 @@ class CombinedOpDefIterator {
7373 BaseIterator fmtIter;
7474};
7575
76+ // Given the scopeLoc of an operation, extract src locations of the input and
77+ // output type
78+ std::pair<SmallVector<llvm::SMRange>, SmallVector<llvm::SMRange>>
79+ getOpTypeLoc (llvm::SMRange op_loc) {
80+ SmallVector<llvm::SMRange> inputTypeRanges;
81+ SmallVector<llvm::SMRange> outputTypeRanges;
82+
83+ // Extract the string from the range
84+ const char *startPtr = op_loc.Start .getPointer ();
85+ const char *endPtr = op_loc.End .getPointer ();
86+ StringRef opString (startPtr, endPtr - startPtr);
87+
88+ // Find the position of the last ':' in the string
89+ size_t colonPos = opString.rfind (' :' );
90+ if (colonPos == StringRef::npos) {
91+ // No ':' found, return empty vectors
92+ return {inputTypeRanges, outputTypeRanges};
93+ }
94+
95+ // Extract the type definition substring
96+ StringRef typeDefStr = opString.substr (colonPos + 1 ).trim ();
97+
98+ // Check if the type definition substring contains '->' (input -> output
99+ // types)
100+ size_t arrowPos = typeDefStr.find (" ->" );
101+
102+ if (arrowPos != StringRef::npos) {
103+ // Split into input and output type strings
104+ StringRef inputTypeStr = typeDefStr.substr (0 , arrowPos).trim ();
105+ StringRef outputTypeStr = typeDefStr.substr (arrowPos + 2 ).trim ();
106+
107+ // Parse input type ranges (if any)
108+ if (!inputTypeStr.empty () && inputTypeStr != " ()" ) {
109+ SmallVector<StringRef> inputTypeParts;
110+ inputTypeStr
111+ .drop_front () // Remove '('
112+ .drop_back () // Remove ')'
113+ .split (inputTypeParts, ' ,' );
114+
115+ for (const auto &typeStr : inputTypeParts) {
116+ const char *start = typeStr.trim ().data ();
117+ const char *end = start + typeStr.trim ().size ();
118+ inputTypeRanges.push_back (
119+ llvm::SMRange (llvm::SMLoc::getFromPointer (start),
120+ llvm::SMLoc::getFromPointer (end)));
121+ }
122+ }
123+
124+ // Parse output type ranges (if any)
125+ if (!outputTypeStr.empty () && outputTypeStr != " ()" ) {
126+ SmallVector<StringRef> outputTypeParts;
127+ outputTypeStr.split (outputTypeParts, ' ,' );
128+
129+ for (const auto &typeStr : outputTypeParts) {
130+ const char *start = typeStr.trim ().data ();
131+ const char *end = start + typeStr.trim ().size ();
132+ outputTypeRanges.push_back (
133+ llvm::SMRange (llvm::SMLoc::getFromPointer (start),
134+ llvm::SMLoc::getFromPointer (end)));
135+ }
136+ }
137+ } else {
138+ // Single type definition (no '->'), assume it's an output type
139+ SmallVector<StringRef> typeParts;
140+ typeDefStr.split (typeParts, ' ,' );
141+
142+ for (const auto &typeStr : typeParts) {
143+ const char *start = typeStr.trim ().data ();
144+ const char *end = start + typeStr.trim ().size ();
145+ outputTypeRanges.push_back (
146+ llvm::SMRange (llvm::SMLoc::getFromPointer (start),
147+ llvm::SMLoc::getFromPointer (end)));
148+ }
149+ }
150+
151+ return {inputTypeRanges, outputTypeRanges};
152+ }
153+
154+ llvm::SMRange getSMRangeFromString (const std::string &str) {
155+ const char *startPtr = str.data ();
156+ const char *endPtr = startPtr + str.size ();
157+ return llvm::SMRange (llvm::SMLoc::getFromPointer (startPtr),
158+ llvm::SMLoc::getFromPointer (endPtr));
159+ }
160+
161+ void replaceTypesInString (std::string &formattedStr,
162+ const SmallVector<llvm::SMRange> &inputTypes,
163+ const SmallVector<llvm::SMRange> &outputTypes) {
164+ // Get type locations from the formatted string
165+ llvm::SMRange formattedLoc = getSMRangeFromString (formattedStr);
166+ auto formattedTypes = getOpTypeLoc (formattedLoc);
167+
168+ // Ensure the number of types matches
169+ if (inputTypes.size () != formattedTypes.first .size () ||
170+ outputTypes.size () != formattedTypes.second .size ()) {
171+ llvm::errs () << " Error: Mismatched number of input/output types in "
172+ " replacement operation.\n " ;
173+ return ;
174+ }
175+
176+ // Perform input type replacements backwards to avoid index issues
177+ for (size_t i = inputTypes.size (); i-- > 0 ;) {
178+ const llvm::SMRange &formattedRange = formattedTypes.first [i];
179+ const llvm::SMRange &inputRange = inputTypes[i];
180+
181+ const char *formattedStart = formattedRange.Start .getPointer ();
182+ const char *formattedEnd = formattedRange.End .getPointer ();
183+
184+ const char *inputStart = inputRange.Start .getPointer ();
185+ const char *inputEnd = inputRange.End .getPointer ();
186+
187+ llvm::StringRef formattedType (formattedStart,
188+ formattedEnd - formattedStart);
189+ llvm::StringRef inputType (inputStart, inputEnd - inputStart);
190+
191+ // Replace in the formatted string
192+ size_t pos = formattedStr.find (formattedType.str ());
193+ if (pos != std::string::npos) {
194+ formattedStr.replace (pos, formattedType.size (), inputType.str ());
195+ } else {
196+ llvm::errs () << " Warning: Input type not found in formatted string: "
197+ << formattedType << " \n " ;
198+ }
199+ }
200+
201+ // Perform output type replacements backwards to avoid index issues
202+ for (size_t i = outputTypes.size (); i-- > 0 ;) {
203+ const llvm::SMRange &formattedRange = formattedTypes.second [i];
204+ const llvm::SMRange &outputRange = outputTypes[i];
205+
206+ const char *formattedStart = formattedRange.Start .getPointer ();
207+ const char *formattedEnd = formattedRange.End .getPointer ();
208+
209+ const char *outputStart = outputRange.Start .getPointer ();
210+ const char *outputEnd = outputRange.End .getPointer ();
211+
212+ llvm::StringRef formattedType (formattedStart,
213+ formattedEnd - formattedStart);
214+ llvm::StringRef outputType (outputStart, outputEnd - outputStart);
215+
216+ // Replace in the formatted string
217+ size_t pos = formattedStr.find (formattedType.str ());
218+ if (pos != std::string::npos) {
219+ formattedStr.replace (pos, formattedType.size (), outputType.str ());
220+ } else {
221+ llvm::errs () << " Warning: Output type not found in formatted string: "
222+ << formattedType << " \n " ;
223+ }
224+ }
225+ }
226+
76227// Function to find the character before the previous comma
77228const char *findPrevComma (const char *start, const char *stop_point) {
78229 if (!start) {
@@ -256,13 +407,11 @@ void Formatter::formatOps() {
256407 ParserConfig parseConfig (&context, /* verifyAfterParse=*/ true ,
257408 &fallbackResourceMap);
258409
259- // Write the rewriteBuffer to a stream, that we can then parse
260410 std::string bufferContent;
261411 llvm::raw_string_ostream stream (bufferContent);
262412 rewriteBuffer.write (stream);
263413 stream.flush ();
264414
265- // Print the bufferContent to llvm::outs() for debugging.
266415 fmtSourceMgr.AddNewSourceBuffer (
267416 llvm::MemoryBuffer::getMemBufferCopy (bufferContent), SMLoc ());
268417
@@ -285,67 +434,93 @@ void Formatter::formatOps() {
285434 continue ;
286435
287436 // Print the fmtDef op and store as a string.
288- // Replace the opDef with this formatted string.
289437 std::string formattedStr;
290438 llvm::raw_string_ostream stream (formattedStr);
291439 fmtDef.op ->print (stream);
292440
293- // Replacing the range:
441+ // Use the original type aliases
442+ auto orig_types = getOpTypeLoc (opDef.scopeLoc );
443+ replaceTypesInString (formattedStr, orig_types.first , orig_types.second );
444+
445+ // Replace the opDef with this formatted string.
294446 replaceRangeFmt ({startOp, endOp}, formattedStr);
447+
448+ // Write the updated buffer to llvm::outs()
449+ writeFmt (llvm::outs ());
295450 }
296451
297- // Write the updated buffer to llvm::outs()
298- writeFmt (llvm::outs ());
299- }
452+ std::string getNamedLoc (
453+ const OperationDefinition::ResultGroupDefinition &resultGroup) {
454+ auto sm_range = resultGroup.definition .loc ;
455+ const char *start = sm_range.Start .getPointer ();
456+ const int len = sm_range.End .getPointer () - start;
300457
301- void markNames (Formatter &formatState, raw_ostream &os) {
302- // Get the operation definitions from the AsmParserState.
303- for (OperationDefinition &it : formatState.getOpDefs ()) {
304- auto [startOp, endOp] = getOpRange (it);
305- // loop through the resultgroups
306- for (auto &resultGroup : it.resultGroups ) {
307- auto def = resultGroup.definition ;
308- auto sm_range = def.loc ;
309- const char *start = sm_range.Start .getPointer ();
310- int len = sm_range.End .getPointer () - start;
311- // Drop the % prefix, and put in new string with `loc("name")` format.
312- auto name = StringRef (start + 1 , len - 1 );
313-
314- // Add loc("{name}") to the end of the op
315- std::string formattedStr = " loc(\" " + name.str () + " \" )" ;
316- StringRef namedLoc (formattedStr);
317- formatState.insertText (endOp, namedLoc);
318- }
458+ // Drop the '%' prefix and construct the `loc("name")` string
459+ auto name = llvm::StringRef (start + 1 ,
460+ len - 1 ); // Assumes the '%' is always present
461+ std::string formattedStr = " loc(\" " + name.str () + " \" )" ;
462+
463+ return formattedStr;
319464 }
320465
321- // Insert the NameLocs for the block arguments
322- for (BlockDefinition &block : formatState.getBlockDefs ()) {
323- for (size_t i = 0 ; i < block.arguments .size (); ++i) {
324- SMDefinition &arg = block.arguments [i];
325-
326- // Find where to insert the NameLoc. Either before the next argument,
327- // or at the end of the arg list
328- const char *insertPointPtr;
329- const char *arg_end = arg.loc .End .getPointer ();
330- SMDefinition *nextArg =
331- (i + 1 < block.arguments .size ()) ? &block.arguments [i + 1 ] : nullptr ;
332- if (nextArg) {
333- const char *nextStart = nextArg->loc .Start .getPointer ();
334- insertPointPtr = findPrevComma (nextStart, arg_end);
466+ // To handle ops with multiple result groups, create a dummy "alias" op
467+ // so that we can each group its own NameLoc
468+ void insertAliasOp () {}
469+
470+ LogicalResult markNames (Formatter & formatState, raw_ostream & os) {
471+ // Get the operation definitions from the AsmParserState.
472+ for (OperationDefinition &it : formatState.getOpDefs ()) {
473+ auto [startOp, endOp] = getOpRange (it);
474+
475+ if (it.resultGroups .size () == 1 ) {
476+ // Simple case, where we have only one result group for the op,
477+ // e.g., `%v = op` or `%v:2 = op`
478+ auto resultGroup = it.resultGroups [0 ];
479+ auto nameLoc = getNamedLoc (resultGroup);
480+ formatState.insertText (endOp, StringRef (nameLoc));
335481 } else {
336- insertPointPtr = findNextCloseParenth (arg.loc .End .getPointer ());
482+ // Complex case, where we have more than one result group, e.g.,
483+ // `%x, %y = op` or `%xs:2, %ys:3 = op`.
484+ // In this case we need insert some aliasing ops.
485+ for (auto &resultGroup : it.resultGroups ) {
486+ auto nameLoc = getNamedLoc (resultGroup);
487+ // StringRef namedLoc(getNamedLoc(resultGroup));
488+ llvm::errs () << " Not implemented yet\n " ;
489+ return failure ();
490+ }
337491 }
492+ }
338493
339- // Drop the % prefix, and put in new string with `loc("name")` format.
340- const char *start = arg.loc .Start .getPointer ();
341- const int len = arg_end - start;
342- auto name = StringRef (start + 1 , len - 1 );
343- std::string formattedStr = " loc(\" " + name.str () + " \" )" ;
344- StringRef namedLoc (formattedStr);
345- formatState.insertText (SMLoc::getFromPointer (insertPointPtr), namedLoc);
494+ // Insert the NameLocs for the block arguments
495+ for (BlockDefinition &block : formatState.getBlockDefs ()) {
496+ for (size_t i = 0 ; i < block.arguments .size (); ++i) {
497+ SMDefinition &arg = block.arguments [i];
498+
499+ // Find where to insert the NameLoc. Either before the next argument,
500+ // or at the end of the arg list
501+ const char *insertPointPtr;
502+ const char *arg_end = arg.loc .End .getPointer ();
503+ SMDefinition *nextArg = (i + 1 < block.arguments .size ())
504+ ? &block.arguments [i + 1 ]
505+ : nullptr ;
506+ if (nextArg) {
507+ const char *nextStart = nextArg->loc .Start .getPointer ();
508+ insertPointPtr = findPrevComma (nextStart, arg_end);
509+ } else {
510+ insertPointPtr = findNextCloseParenth (arg.loc .End .getPointer ());
511+ }
512+
513+ // Drop the % prefix, and put in new string with `loc("name")` format.
514+ const char *start = arg.loc .Start .getPointer ();
515+ const int len = arg_end - start;
516+ auto name = StringRef (start + 1 , len - 1 );
517+ std::string formattedStr = " loc(\" " + name.str () + " \" )" ;
518+ StringRef namedLoc (formattedStr);
519+ formatState.insertText (SMLoc::getFromPointer (insertPointPtr), namedLoc);
520+ }
346521 }
522+ return success ();
347523 }
348- }
349524} // namespace mlir
350525
351526int main (int argc, char **argv) {
@@ -370,7 +545,9 @@ int main(int argc, char **argv) {
370545 auto f = Formatter::init (inputFilename, outputFilename);
371546
372547 // Append the SSA names as NameLocs
373- markNames (*f, llvm::outs ());
548+ LogicalResult result = markNames (*f, llvm::outs ());
549+ if (!succeeded (result))
550+ return mlir::asMainReturnCode (mlir::failure ());
374551
375552 if (nameLocOnly) {
376553 // Return the original buffer with NameLocs appended to ops
0 commit comments