@@ -636,7 +636,9 @@ def draw_edges(
636636 edge_color = "k"
637637
638638 # set edge positions
639- edge_pos = np .asarray ([(pos [e [0 ]], pos [e [1 ]]) for e in edge_list ])
639+ edge_pos = set ()
640+ for e in edge_list :
641+ edge_pos .add ((tuple (pos [e [0 ]]), tuple (pos [e [1 ]])))
640642
641643 # Check if edge_color is an array of floats and map to edge_cmap.
642644 # This is the only case handled differently from matplotlib
@@ -670,58 +672,17 @@ def to_marker_edge(marker_size, marker):
670672 arrow_collection = []
671673 mutation_scale = arrow_size # scale factor of arrow head
672674
673- # compute view
674- mirustworkx = np .amin (np .ravel (edge_pos [:, :, 0 ]))
675- maxx = np .amax (np .ravel (edge_pos [:, :, 0 ]))
676- miny = np .amin (np .ravel (edge_pos [:, :, 1 ]))
677- maxy = np .amax (np .ravel (edge_pos [:, :, 1 ]))
678- w = maxx - mirustworkx
679- h = maxy - miny
680-
681675 base_connectionstyle = mpl .patches .ConnectionStyle (connectionstyle )
682676
683677 # Fallback for self-loop scale. Left outside of _connectionstyle so it is
684678 # only computed once
685679 max_nodesize = np .array (node_size ).max ()
686680
687- def _connectionstyle (posA , posB , * args , ** kwargs ):
688- # check if we need to do a self-loop
689- if np .all (posA == posB ):
690- # Self-loops are scaled by view extent, except in cases the extent
691- # is 0, e.g. for a single node. In this case, fall back to scaling
692- # by the maximum node size
693- selfloop_ht = 0.005 * max_nodesize if h == 0 else h
694- # this is called with _screen space_ values so covert back
695- # to data space
696- data_loc = ax .transData .inverted ().transform (posA )
697- v_shift = 0.1 * selfloop_ht
698- h_shift = v_shift * 0.5
699- # put the top of the loop first so arrow is not hidden by node
700- path = [
701- # 1
702- data_loc + np .asarray ([0 , v_shift ]),
703- # 4 4 4
704- data_loc + np .asarray ([h_shift , v_shift ]),
705- data_loc + np .asarray ([h_shift , 0 ]),
706- data_loc ,
707- # 4 4 4
708- data_loc + np .asarray ([- h_shift , 0 ]),
709- data_loc + np .asarray ([- h_shift , v_shift ]),
710- data_loc + np .asarray ([0 , v_shift ]),
711- ]
712-
713- ret = mpl .path .Path (ax .transData .transform (path ), [1 , 4 , 4 , 4 , 4 , 4 , 4 ])
714- # if not, fall back to the user specified behavior
715- else :
716- ret = base_connectionstyle (posA , posB , * args , ** kwargs )
717-
718- return ret
719-
720681 # FancyArrowPatch doesn't handle color strings
721682 arrow_colors = mpl .colors .colorConverter .to_rgba_array (edge_color , alpha )
722- for i , ( src , dst ) in enumerate (edge_pos ):
723- x1 , y1 = src
724- x2 , y2 = dst
683+ for i , edge in enumerate (edge_pos ):
684+ x1 , y1 = edge [ 0 ][ 0 ], edge [ 0 ][ 1 ]
685+ x2 , y2 = edge [ 1 ][ 0 ], edge [ 1 ][ 1 ]
725686 shrink_source = 0 # space from source to tail
726687 shrink_target = 0 # space from head to target
727688 if np .iterable (node_size ): # many node sizes
@@ -754,6 +715,12 @@ def _connectionstyle(posA, posB, *args, **kwargs):
754715 else :
755716 line_width = width
756717
718+ # radius of edges
719+ if tuple (reversed (edge )) in edge_pos :
720+ rad = 0.25
721+ else :
722+ rad = 0.0
723+
757724 arrow = mpl .patches .FancyArrowPatch (
758725 (x1 , y1 ),
759726 (x2 , y2 ),
@@ -763,14 +730,57 @@ def _connectionstyle(posA, posB, *args, **kwargs):
763730 mutation_scale = mutation_scale ,
764731 color = arrow_color ,
765732 linewidth = line_width ,
766- connectionstyle = _connectionstyle ,
733+ connectionstyle = connectionstyle + f", rad = { rad } " ,
767734 linestyle = style ,
768735 zorder = 1 ,
769736 ) # arrows go behind nodes
770737
771738 arrow_collection .append (arrow )
772739 ax .add_patch (arrow )
773740
741+ edge_pos = np .asarray (tuple (edge_pos ))
742+
743+ # compute view
744+ mirustworkx = np .amin (np .ravel (edge_pos [:, :, 0 ]))
745+ maxx = np .amax (np .ravel (edge_pos [:, :, 0 ]))
746+ miny = np .amin (np .ravel (edge_pos [:, :, 1 ]))
747+ maxy = np .amax (np .ravel (edge_pos [:, :, 1 ]))
748+ w = maxx - mirustworkx
749+ h = maxy - miny
750+
751+ def _connectionstyle (posA , posB , * args , ** kwargs ):
752+ # check if we need to do a self-loop
753+ if np .all (posA == posB ):
754+ # Self-loops are scaled by view extent, except in cases the extent
755+ # is 0, e.g. for a single node. In this case, fall back to scaling
756+ # by the maximum node size
757+ selfloop_ht = 0.005 * max_nodesize if h == 0 else h
758+ # this is called with _screen space_ values so covert back
759+ # to data space
760+ data_loc = ax .transData .inverted ().transform (posA )
761+ v_shift = 0.1 * selfloop_ht
762+ h_shift = v_shift * 0.5
763+ # put the top of the loop first so arrow is not hidden by node
764+ path = [
765+ # 1
766+ data_loc + np .asarray ([0 , v_shift ]),
767+ # 4 4 4
768+ data_loc + np .asarray ([h_shift , v_shift ]),
769+ data_loc + np .asarray ([h_shift , 0 ]),
770+ data_loc ,
771+ # 4 4 4
772+ data_loc + np .asarray ([- h_shift , 0 ]),
773+ data_loc + np .asarray ([- h_shift , v_shift ]),
774+ data_loc + np .asarray ([0 , v_shift ]),
775+ ]
776+
777+ ret = mpl .path .Path (ax .transData .transform (path ), [1 , 4 , 4 , 4 , 4 , 4 , 4 ])
778+ # if not, fall back to the user specified behavior
779+ else :
780+ ret = base_connectionstyle (posA , posB , * args , ** kwargs )
781+
782+ return ret
783+
774784 # update view
775785 padx , pady = 0.05 * w , 0.05 * h
776786 corners = (mirustworkx - padx , miny - pady ), (maxx + padx , maxy + pady )
@@ -1001,6 +1011,12 @@ def draw_edge_labels(
10011011 x1 * label_pos + x2 * (1.0 - label_pos ),
10021012 y1 * label_pos + y2 * (1.0 - label_pos ),
10031013 )
1014+ if (n2 , n1 ) in labels .keys (): # loop
1015+ dy = np .abs (y2 - y1 )
1016+ if n2 > n1 :
1017+ y -= 0.25 * dy
1018+ else :
1019+ y += 0.25 * dy
10041020
10051021 if rotate :
10061022 # in degrees
0 commit comments