Skip to content

Commit 646057a

Browse files
Update mpl_draw() to fix multigraph plots (#1204)
* Update #1: Fixing mpl_draw() for multigraphs * Update matplotlib.py for formatting * Update rustworkx/visualization/matplotlib.py Co-authored-by: Ivan Carvalho <[email protected]> * Update matplotlib.py to remove the loop * Add releasenotes * Reformat connectionstyle string in rustworkx/visualization/matplotlib.py Co-authored-by: Ivan Carvalho <[email protected]> * Fixes #774 * Optimize edge search by using sets --------- Co-authored-by: Ivan Carvalho <[email protected]>
1 parent c7a7d53 commit 646057a

File tree

2 files changed

+91
-46
lines changed

2 files changed

+91
-46
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
---
2+
fixes:
3+
- |
4+
Fixed the plots of multigraphs using :func:`.mpl_draw`. Previously, parallel edges of
5+
multigraphs were plotted on top of each other, with overlapping arrows and labels.
6+
The radius of parallel edges of the multigraph was fixed to be `0.25` for
7+
`connectionstyle` supporting this argument in :func:`.draw_edges`. The edge lables
8+
were offset to `0.25` in :func:`.draw_edge_labels` to align with their respective
9+
edges. This fix can be tested using the following code:
10+
11+
.. jupyter-execute::
12+
13+
import rustworkx
14+
from rustworkx.visualization import mpl_draw
15+
16+
graph = rustworkx.PyDiGraph()
17+
graph.add_node('A')
18+
graph.add_node('B')
19+
graph.add_node('C')
20+
21+
graph.add_edge(1, 0, 2)
22+
graph.add_edge(0, 1, 3)
23+
graph.add_edge(1, 2, 4)
24+
25+
mpl_draw(graph, with_labels=True, labels=str, edge_labels=str, alpha=0.5)
26+
27+
- |
28+
Refer to `#774 <https://github.com/Qiskit/rustworkx/issues/774>` for more
29+
details.

rustworkx/visualization/matplotlib.py

Lines changed: 62 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,9 @@ def draw_edges(
636636
edge_color = "k"
637637

638638
# set edge positions
639-
edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edge_list])
639+
edge_pos = set()
640+
for e in edge_list:
641+
edge_pos.add((tuple(pos[e[0]]), tuple(pos[e[1]])))
640642

641643
# Check if edge_color is an array of floats and map to edge_cmap.
642644
# This is the only case handled differently from matplotlib
@@ -670,58 +672,17 @@ def to_marker_edge(marker_size, marker):
670672
arrow_collection = []
671673
mutation_scale = arrow_size # scale factor of arrow head
672674

673-
# compute view
674-
mirustworkx = np.amin(np.ravel(edge_pos[:, :, 0]))
675-
maxx = np.amax(np.ravel(edge_pos[:, :, 0]))
676-
miny = np.amin(np.ravel(edge_pos[:, :, 1]))
677-
maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
678-
w = maxx - mirustworkx
679-
h = maxy - miny
680-
681675
base_connectionstyle = mpl.patches.ConnectionStyle(connectionstyle)
682676

683677
# Fallback for self-loop scale. Left outside of _connectionstyle so it is
684678
# only computed once
685679
max_nodesize = np.array(node_size).max()
686680

687-
def _connectionstyle(posA, posB, *args, **kwargs):
688-
# check if we need to do a self-loop
689-
if np.all(posA == posB):
690-
# Self-loops are scaled by view extent, except in cases the extent
691-
# is 0, e.g. for a single node. In this case, fall back to scaling
692-
# by the maximum node size
693-
selfloop_ht = 0.005 * max_nodesize if h == 0 else h
694-
# this is called with _screen space_ values so covert back
695-
# to data space
696-
data_loc = ax.transData.inverted().transform(posA)
697-
v_shift = 0.1 * selfloop_ht
698-
h_shift = v_shift * 0.5
699-
# put the top of the loop first so arrow is not hidden by node
700-
path = [
701-
# 1
702-
data_loc + np.asarray([0, v_shift]),
703-
# 4 4 4
704-
data_loc + np.asarray([h_shift, v_shift]),
705-
data_loc + np.asarray([h_shift, 0]),
706-
data_loc,
707-
# 4 4 4
708-
data_loc + np.asarray([-h_shift, 0]),
709-
data_loc + np.asarray([-h_shift, v_shift]),
710-
data_loc + np.asarray([0, v_shift]),
711-
]
712-
713-
ret = mpl.path.Path(ax.transData.transform(path), [1, 4, 4, 4, 4, 4, 4])
714-
# if not, fall back to the user specified behavior
715-
else:
716-
ret = base_connectionstyle(posA, posB, *args, **kwargs)
717-
718-
return ret
719-
720681
# FancyArrowPatch doesn't handle color strings
721682
arrow_colors = mpl.colors.colorConverter.to_rgba_array(edge_color, alpha)
722-
for i, (src, dst) in enumerate(edge_pos):
723-
x1, y1 = src
724-
x2, y2 = dst
683+
for i, edge in enumerate(edge_pos):
684+
x1, y1 = edge[0][0], edge[0][1]
685+
x2, y2 = edge[1][0], edge[1][1]
725686
shrink_source = 0 # space from source to tail
726687
shrink_target = 0 # space from head to target
727688
if np.iterable(node_size): # many node sizes
@@ -754,6 +715,12 @@ def _connectionstyle(posA, posB, *args, **kwargs):
754715
else:
755716
line_width = width
756717

718+
# radius of edges
719+
if tuple(reversed(edge)) in edge_pos:
720+
rad = 0.25
721+
else:
722+
rad = 0.0
723+
757724
arrow = mpl.patches.FancyArrowPatch(
758725
(x1, y1),
759726
(x2, y2),
@@ -763,14 +730,57 @@ def _connectionstyle(posA, posB, *args, **kwargs):
763730
mutation_scale=mutation_scale,
764731
color=arrow_color,
765732
linewidth=line_width,
766-
connectionstyle=_connectionstyle,
733+
connectionstyle=connectionstyle + f", rad = {rad}",
767734
linestyle=style,
768735
zorder=1,
769736
) # arrows go behind nodes
770737

771738
arrow_collection.append(arrow)
772739
ax.add_patch(arrow)
773740

741+
edge_pos = np.asarray(tuple(edge_pos))
742+
743+
# compute view
744+
mirustworkx = np.amin(np.ravel(edge_pos[:, :, 0]))
745+
maxx = np.amax(np.ravel(edge_pos[:, :, 0]))
746+
miny = np.amin(np.ravel(edge_pos[:, :, 1]))
747+
maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
748+
w = maxx - mirustworkx
749+
h = maxy - miny
750+
751+
def _connectionstyle(posA, posB, *args, **kwargs):
752+
# check if we need to do a self-loop
753+
if np.all(posA == posB):
754+
# Self-loops are scaled by view extent, except in cases the extent
755+
# is 0, e.g. for a single node. In this case, fall back to scaling
756+
# by the maximum node size
757+
selfloop_ht = 0.005 * max_nodesize if h == 0 else h
758+
# this is called with _screen space_ values so covert back
759+
# to data space
760+
data_loc = ax.transData.inverted().transform(posA)
761+
v_shift = 0.1 * selfloop_ht
762+
h_shift = v_shift * 0.5
763+
# put the top of the loop first so arrow is not hidden by node
764+
path = [
765+
# 1
766+
data_loc + np.asarray([0, v_shift]),
767+
# 4 4 4
768+
data_loc + np.asarray([h_shift, v_shift]),
769+
data_loc + np.asarray([h_shift, 0]),
770+
data_loc,
771+
# 4 4 4
772+
data_loc + np.asarray([-h_shift, 0]),
773+
data_loc + np.asarray([-h_shift, v_shift]),
774+
data_loc + np.asarray([0, v_shift]),
775+
]
776+
777+
ret = mpl.path.Path(ax.transData.transform(path), [1, 4, 4, 4, 4, 4, 4])
778+
# if not, fall back to the user specified behavior
779+
else:
780+
ret = base_connectionstyle(posA, posB, *args, **kwargs)
781+
782+
return ret
783+
774784
# update view
775785
padx, pady = 0.05 * w, 0.05 * h
776786
corners = (mirustworkx - padx, miny - pady), (maxx + padx, maxy + pady)
@@ -1001,6 +1011,12 @@ def draw_edge_labels(
10011011
x1 * label_pos + x2 * (1.0 - label_pos),
10021012
y1 * label_pos + y2 * (1.0 - label_pos),
10031013
)
1014+
if (n2, n1) in labels.keys(): # loop
1015+
dy = np.abs(y2 - y1)
1016+
if n2 > n1:
1017+
y -= 0.25 * dy
1018+
else:
1019+
y += 0.25 * dy
10041020

10051021
if rotate:
10061022
# in degrees

0 commit comments

Comments
 (0)