1414#include " mlir/IR/Operation.h"
1515#include " mlir/Pass/Pass.h"
1616#include " mlir/Support/IndentedOstream.h"
17+ #include " llvm/ADT/STLExtras.h"
1718#include " llvm/Support/Format.h"
1819#include " llvm/Support/GraphWriter.h"
1920#include < map>
@@ -29,7 +30,7 @@ using namespace mlir;
2930
3031static const StringRef kLineStyleControlFlow = " dashed" ;
3132static const StringRef kLineStyleDataFlow = " solid" ;
32- static const StringRef kShapeNode = " ellipse " ;
33+ static const StringRef kShapeNode = " Mrecord " ;
3334static const StringRef kShapeNone = " plain" ;
3435
3536// / Return the size limits for eliding large attributes.
@@ -49,16 +50,25 @@ static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
4950 return buf;
5051}
5152
52- // / Escape special characters such as '\n' and quotation marks.
53- static std::string escapeString (std::string str) {
54- return strFromOs ([&](raw_ostream &os) { os.write_escaped (str); });
55- }
56-
5753// / Put quotation marks around a given string.
5854static std::string quoteString (const std::string &str) {
5955 return " \" " + str + " \" " ;
6056}
6157
58+ // / For Graphviz record nodes:
59+ // / " Braces, vertical bars and angle brackets must be escaped with a backslash
60+ // / character if you wish them to appear as a literal character "
61+ std::string escapeLabelString (const std::string &str) {
62+ std::string buf;
63+ llvm::raw_string_ostream os (buf);
64+ for (char c : str) {
65+ if (llvm::is_contained ({' {' , ' |' , ' <' , ' }' , ' >' , ' \n ' , ' "' }, c))
66+ os << ' \\ ' ;
67+ os << c;
68+ }
69+ return buf;
70+ }
71+
6272using AttributeMap = std::map<std::string, std::string>;
6373
6474namespace {
@@ -79,6 +89,12 @@ struct Node {
7989 std::optional<int > clusterId;
8090};
8191
92+ struct DataFlowEdge {
93+ Value value;
94+ Node node;
95+ std::string port;
96+ };
97+
8298// / This pass generates a Graphviz dataflow visualization of an MLIR operation.
8399// / Note: See https://www.graphviz.org/doc/info/lang.html for more information
84100// / about the Graphviz DOT language.
@@ -107,7 +123,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
107123private:
108124 // / Generate a color mapping that will color every operation with the same
109125 // / name the same way. It'll interpolate the hue in the HSV color-space,
110- // / attempting to keep the contrast suitable for black text.
126+ // / using muted colors that provide good contrast for black text.
111127 template <typename T>
112128 void initColorMapping (T &irEntity) {
113129 backgroundColors.clear ();
@@ -120,17 +136,19 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
120136 });
121137 for (auto indexedOps : llvm::enumerate (ops)) {
122138 double hue = ((double )indexedOps.index ()) / ops.size ();
139+ // Use lower saturation (0.3) and higher value (0.95) for better
140+ // readability
123141 backgroundColors[indexedOps.value ()->getName ()].second =
124- std::to_string (hue) + " 1.0 1.0 " ;
142+ std::to_string (hue) + " 0.3 0.95 " ;
125143 }
126144 }
127145
128146 // / Emit all edges. This function should be called after all nodes have been
129147 // / emitted.
130148 void emitAllEdgeStmts () {
131149 if (printDataFlowEdges) {
132- for (const auto &[value, node, label] : dataFlowEdges) {
133- emitEdgeStmt (valueToNode[value], node, label , kLineStyleDataFlow );
150+ for (const auto &e : dataFlowEdges) {
151+ emitEdgeStmt (valueToNode[e. value ], e. node , e. port , kLineStyleDataFlow );
134152 }
135153 }
136154
@@ -147,8 +165,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
147165 os.indent ();
148166 // Emit invisible anchor node from/to which arrows can be drawn.
149167 Node anchorNode = emitNodeStmt (" " , kShapeNone );
150- os << attrStmt (" label" , quoteString (escapeString (std::move (label))))
151- << " ;\n " ;
168+ os << attrStmt (" label" , quoteString (label)) << " ;\n " ;
152169 builder ();
153170 os.unindent ();
154171 os << " }\n " ;
@@ -176,16 +193,17 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
176193
177194 // Always emit splat attributes.
178195 if (isa<SplatElementsAttr>(attr)) {
179- attr.print (os);
196+ os << escapeLabelString (
197+ strFromOs ([&](raw_ostream &os) { attr.print (os); }));
180198 return ;
181199 }
182200
183201 // Elide "big" elements attributes.
184202 auto elements = dyn_cast<ElementsAttr>(attr);
185203 if (elements && elements.getNumElements () > largeAttrLimit) {
186204 os << std::string (elements.getShapedType ().getRank (), ' [' ) << " ..."
187- << std::string (elements.getShapedType ().getRank (), ' ]' ) << " : "
188- << elements.getType ();
205+ << std::string (elements.getShapedType ().getRank (), ' ]' ) << " : " ;
206+ emitMlirType (os, elements.getType () );
189207 return ;
190208 }
191209
@@ -199,27 +217,43 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
199217 std::string buf;
200218 llvm::raw_string_ostream ss (buf);
201219 attr.print (ss);
202- os << truncateString (buf);
220+ os << escapeLabelString (truncateString (buf));
221+ }
222+
223+ // Print a truncated and escaped MLIR type to `os`.
224+ void emitMlirType (raw_ostream &os, Type type) {
225+ std::string buf;
226+ llvm::raw_string_ostream ss (buf);
227+ type.print (ss);
228+ os << escapeLabelString (truncateString (buf));
229+ }
230+
231+ // Print a truncated and escaped MLIR operand to `os`.
232+ void emitMlirOperand (raw_ostream &os, Value operand) {
233+ operand.printAsOperand (os, OpPrintingFlags ());
203234 }
204235
205236 // / Append an edge to the list of edges.
206237 // / Note: Edges are written to the output stream via `emitAllEdgeStmts`.
207- void emitEdgeStmt (Node n1, Node n2, std::string label , StringRef style) {
238+ void emitEdgeStmt (Node n1, Node n2, std::string port , StringRef style) {
208239 AttributeMap attrs;
209240 attrs[" style" ] = style.str ();
210- // Do not label edges that start/end at a cluster boundary. Such edges are
211- // clipped at the boundary, but labels are not. This can lead to labels
212- // floating around without any edge next to them.
213- if (!n1.clusterId && !n2.clusterId )
214- attrs[" label" ] = quoteString (escapeString (std::move (label)));
215241 // Use `ltail` and `lhead` to draw edges between clusters.
216242 if (n1.clusterId )
217243 attrs[" ltail" ] = " cluster_" + std::to_string (*n1.clusterId );
218244 if (n2.clusterId )
219245 attrs[" lhead" ] = " cluster_" + std::to_string (*n2.clusterId );
220246
221247 edges.push_back (strFromOs ([&](raw_ostream &os) {
222- os << llvm::format (" v%i -> v%i " , n1.id , n2.id );
248+ os << " v" << n1.id ;
249+ if (!port.empty () && !n1.clusterId )
250+ // Attach edge to south compass point of the result
251+ os << " :res" << port << " :s" ;
252+ os << " -> " ;
253+ os << " v" << n2.id ;
254+ if (!port.empty () && !n2.clusterId )
255+ // Attach edge to north compass point of the operand
256+ os << " :arg" << port << " :n" ;
223257 emitAttrList (os, attrs);
224258 }));
225259 }
@@ -240,20 +274,30 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
240274 StringRef background = " " ) {
241275 int nodeId = ++counter;
242276 AttributeMap attrs;
243- attrs[" label" ] = quoteString (escapeString ( std::move ( label)) );
277+ attrs[" label" ] = quoteString (label);
244278 attrs[" shape" ] = shape.str ();
245279 if (!background.empty ()) {
246280 attrs[" style" ] = " filled" ;
247- attrs[" fillcolor" ] = ( " \" " + background + " \" " ) .str ();
281+ attrs[" fillcolor" ] = quoteString ( background.str () );
248282 }
249283 os << llvm::format (" v%i " , nodeId);
250284 emitAttrList (os, attrs);
251285 os << " ;\n " ;
252286 return Node (nodeId);
253287 }
254288
255- // / Generate a label for an operation.
256- std::string getLabel (Operation *op) {
289+ std::string getValuePortName (Value operand) {
290+ // Print value as an operand and omit the leading '%' character.
291+ auto str = strFromOs ([&](raw_ostream &os) {
292+ operand.printAsOperand (os, OpPrintingFlags ());
293+ });
294+ // Replace % and # with _
295+ std::replace (str.begin (), str.end (), ' %' , ' _' );
296+ std::replace (str.begin (), str.end (), ' #' , ' _' );
297+ return str;
298+ }
299+
300+ std::string getClusterLabel (Operation *op) {
257301 return strFromOs ([&](raw_ostream &os) {
258302 // Print operation name and type.
259303 os << op->getName ();
@@ -267,18 +311,73 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
267311
268312 // Print attributes.
269313 if (printAttrs) {
270- os << " \n " ;
314+ os << " \\ l" ;
315+ for (const NamedAttribute &attr : op->getAttrs ()) {
316+ os << escapeLabelString (attr.getName ().getValue ().str ()) << " : " ;
317+ emitMlirAttr (os, attr.getValue ());
318+ os << " \\ l" ;
319+ }
320+ }
321+ });
322+ }
323+
324+ // / Generate a label for an operation.
325+ std::string getRecordLabel (Operation *op) {
326+ return strFromOs ([&](raw_ostream &os) {
327+ os << " {" ;
328+
329+ // Print operation inputs.
330+ if (op->getNumOperands () > 0 ) {
331+ os << " {" ;
332+ auto operandToPort = [&](Value operand) {
333+ os << " <arg" << getValuePortName (operand) << " > " ;
334+ emitMlirOperand (os, operand);
335+ };
336+ interleave (op->getOperands (), os, operandToPort, " |" );
337+ os << " }|" ;
338+ }
339+ // Print operation name and type.
340+ os << op->getName () << " \\ l" ;
341+
342+ // Print attributes.
343+ if (printAttrs && !op->getAttrs ().empty ()) {
344+ // Extra line break to separate attributes from the operation name.
345+ os << " \\ l" ;
271346 for (const NamedAttribute &attr : op->getAttrs ()) {
272- os << ' \n ' << attr.getName ().getValue () << " : " ;
347+ os << attr.getName ().getValue () << " : " ;
273348 emitMlirAttr (os, attr.getValue ());
349+ os << " \\ l" ;
274350 }
275351 }
352+
353+ if (op->getNumResults () > 0 ) {
354+ os << " |{" ;
355+ auto resultToPort = [&](Value result) {
356+ os << " <res" << getValuePortName (result) << " > " ;
357+ emitMlirOperand (os, result);
358+ if (printResultTypes) {
359+ os << " " ;
360+ emitMlirType (os, result.getType ());
361+ }
362+ };
363+ interleave (op->getResults (), os, resultToPort, " |" );
364+ os << " }" ;
365+ }
366+
367+ os << " }" ;
276368 });
277369 }
278370
279371 // / Generate a label for a block argument.
280372 std::string getLabel (BlockArgument arg) {
281- return " arg" + std::to_string (arg.getArgNumber ());
373+ return strFromOs ([&](raw_ostream &os) {
374+ os << " <res" << getValuePortName (arg) << " > " ;
375+ arg.printAsOperand (os, OpPrintingFlags ());
376+ if (printResultTypes) {
377+ os << " " ;
378+ emitMlirType (os, arg.getType ());
379+ }
380+ });
282381 }
283382
284383 // / Process a block. Emit a cluster and one node per block argument and
@@ -287,14 +386,12 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
287386 emitClusterStmt ([&]() {
288387 for (BlockArgument &blockArg : block.getArguments ())
289388 valueToNode[blockArg] = emitNodeStmt (getLabel (blockArg));
290-
291389 // Emit a node for each operation.
292390 std::optional<Node> prevNode;
293391 for (Operation &op : block) {
294392 Node nextNode = processOperation (&op);
295393 if (printControlFlowEdges && prevNode)
296- emitEdgeStmt (*prevNode, nextNode, /* label=*/ " " ,
297- kLineStyleControlFlow );
394+ emitEdgeStmt (*prevNode, nextNode, /* port=*/ " " , kLineStyleControlFlow );
298395 prevNode = nextNode;
299396 }
300397 });
@@ -311,18 +408,19 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
311408 for (Region ®ion : op->getRegions ())
312409 processRegion (region);
313410 },
314- getLabel (op));
411+ getClusterLabel (op));
315412 } else {
316- node = emitNodeStmt (getLabel (op), kShapeNode ,
413+ node = emitNodeStmt (getRecordLabel (op), kShapeNode ,
317414 backgroundColors[op->getName ()].second );
318415 }
319416
320417 // Insert data flow edges originating from each operand.
321418 if (printDataFlowEdges) {
322419 unsigned numOperands = op->getNumOperands ();
323- for (unsigned i = 0 ; i < numOperands; i++)
324- dataFlowEdges.push_back ({op->getOperand (i), node,
325- numOperands == 1 ? " " : std::to_string (i)});
420+ for (unsigned i = 0 ; i < numOperands; i++) {
421+ auto operand = op->getOperand (i);
422+ dataFlowEdges.push_back ({operand, node, getValuePortName (operand)});
423+ }
326424 }
327425
328426 for (Value result : op->getResults ())
@@ -352,7 +450,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
352450 // / Mapping of SSA values to Graphviz nodes/clusters.
353451 DenseMap<Value, Node> valueToNode;
354452 // / Output for data flow edges is delayed until the end to handle cycles
355- std::vector<std::tuple<Value, Node, std::string> > dataFlowEdges;
453+ std::vector<DataFlowEdge > dataFlowEdges;
356454 // / Counter for generating unique node/subgraph identifiers.
357455 int counter = 0 ;
358456
0 commit comments