Skip to content

Commit 96ec5a9

Browse files
authored
Fix DAG.to_dot when reducers have multiple outputs (#3150)
1 parent a8cb40e commit 96ec5a9

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

mars/core/graph/core.pyx

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,28 +339,38 @@ cdef class DirectedGraph:
339339
continue
340340
for input_chunk in (op.inputs or []):
341341
if input_chunk.key not in visited:
342-
sio.write(f'"Chunk:{input_chunk.key[:trunc_key]}" {chunk_style}\n')
342+
sio.write(f'"Chunk:{self._gen_chunk_key(input_chunk, trunc_key)}" {chunk_style}\n')
343343
visited.add(input_chunk.key)
344344
if op.key not in visited:
345345
sio.write(f'"{op_name}:{op.key[:trunc_key]}" {operand_style}\n')
346346
visited.add(op.key)
347-
sio.write(f'"Chunk:{input_chunk.key[:trunc_key]}" -> "{op_name}:{op.key[:5]}"\n')
347+
sio.write(f'"Chunk:{self._gen_chunk_key(input_chunk, trunc_key)}" -> '
348+
f'"{op_name}:{op.key[:trunc_key]}"\n')
348349

349350
for output_chunk in (op.outputs or []):
350351
if output_chunk.key not in visited:
351352
tmp_chunk_style = chunk_style
352353
if result_chunk_keys and output_chunk.key in result_chunk_keys:
353354
tmp_chunk_style = '[shape=box,style=filled,fillcolor=cadetblue1]'
354-
sio.write(f'"Chunk:{output_chunk.key[:trunc_key]}" {tmp_chunk_style}\n')
355+
sio.write(f'"Chunk:{self._gen_chunk_key(output_chunk, trunc_key)}" {tmp_chunk_style}\n')
355356
visited.add(output_chunk.key)
356357
if op.key not in visited:
357358
sio.write(f'"{op_name}:{op.key[:trunc_key]}" {operand_style}\n')
358359
visited.add(op.key)
359-
sio.write(f'"{op_name}:{op.key[:trunc_key]}" -> "Chunk:{output_chunk.key[:5]}"\n')
360+
sio.write(f'"{op_name}:{op.key[:trunc_key]}" -> '
361+
f'"Chunk:{self._gen_chunk_key(output_chunk, trunc_key)}"\n')
360362

361363
sio.write('}')
362364
return sio.getvalue()
363365

366+
@classmethod
367+
def _gen_chunk_key(cls, chunk, trunc_key):
368+
if "_" in chunk.key:
369+
key, index = chunk.key.split("_", 1)
370+
return "_".join([key[:trunc_key], index])
371+
else: # pragma: no cover
372+
return chunk.key[:trunc_key]
373+
364374
def _repr_svg_(self): # pragma: no cover
365375
from graphviz import Source
366376
return Source(self.to_dot())._repr_svg_()

0 commit comments

Comments
 (0)