99from matplotlib .patches import Circle
1010
1111from compas .datastructures import Network
12+ from compas .artists import NetworkArtist
13+ from compas .utilities .colors import is_color_light
1214from .artist import PlotterArtist
1315
1416Color = Tuple [float , float , float ]
1517
1618
17- class NetworkArtist (PlotterArtist ):
19+ class NetworkArtist (PlotterArtist , NetworkArtist ):
1820 """Artist for COMPAS network data structures.
1921
2022 Parameters
@@ -57,18 +59,19 @@ def __init__(self,
5759 network : Network ,
5860 nodes : Optional [List [int ]] = None ,
5961 edges : Optional [List [int ]] = None ,
60- nodecolor : Color = (1 , 1 , 1 ),
61- edgecolor : Color = (0 , 0 , 0 ),
62+ nodecolor : Color = (1.0 , 1.0 , 1.0 ),
63+ edgecolor : Color = (0.0 , 0.0 , 0. 0 ),
6264 edgewidth : float = 1.0 ,
6365 show_nodes : bool = True ,
6466 show_edges : bool = True ,
6567 nodesize : int = 5 ,
6668 sizepolicy : Literal ['relative' , 'absolute' ] = 'relative' ,
67- zorder : int = 2000 ,
69+ zorder : int = 1000 ,
6870 ** kwargs ):
6971
7072 super ().__init__ (network = network , ** kwargs )
7173
74+ self .sizepolicy = sizepolicy
7275 self .nodes = nodes
7376 self .edges = edges
7477 self .node_color = nodecolor
@@ -77,16 +80,15 @@ def __init__(self,
7780 self .edge_width = edgewidth
7881 self .show_nodes = show_nodes
7982 self .show_edges = show_edges
80- self .sizepolicy = sizepolicy
8183 self .zorder = zorder
8284
8385 @property
8486 def zorder_edges (self ):
85- return self .zorder
87+ return self .zorder + 10
8688
8789 @property
8890 def zorder_nodes (self ):
89- return self .zorder + 10
91+ return self .zorder + 20
9092
9193 @property
9294 def item (self ):
@@ -101,6 +103,22 @@ def item(self, item: Network):
101103 def data (self ) -> List [List [float ]]:
102104 return self .network .nodes_attributes ('xy' )
103105
106+ @property
107+ def node_size (self ):
108+ if not self ._node_size :
109+ factor = self .plotter .dpi if self .sizepolicy == 'absolute' else self .network .number_of_nodes ()
110+ size = self .default_nodesize / factor
111+ self ._node_size = {node : size for node in self .network .nodes ()}
112+ return self ._node_size
113+
114+ @node_size .setter
115+ def node_size (self , nodesize ):
116+ factor = self .plotter .dpi if self .sizepolicy == 'absolute' else self .network .number_of_nodes ()
117+ if isinstance (nodesize , dict ):
118+ self .node_size .update ({node : size / factor for node , size in nodesize .items ()})
119+ elif isinstance (nodesize , (int , float )):
120+ self ._node_size = {node : nodesize / factor for node in self .network .nodes ()}
121+
104122 # ==============================================================================
105123 # clear and draw
106124 # ==============================================================================
@@ -218,3 +236,91 @@ def draw_edges(self,
218236 )
219237 self .plotter .axes .add_collection (collection )
220238 self ._edgecollection = collection
239+
240+ def draw_nodelabels (self , text : Optional [Dict [int , str ]] = None ) -> None :
241+ """Draw a selection of node labels.
242+
243+ Parameters
244+ ----------
245+ text : dict of int to str, optional
246+ A node-label map.
247+ If not text dict is provided, the node identifiers are drawn.
248+
249+ Returns
250+ -------
251+ None
252+ """
253+ if self ._nodelabelcollection :
254+ for artist in self ._nodelabelcollection :
255+ artist .remove ()
256+
257+ if text :
258+ self .node_text = text
259+
260+ labels = []
261+ for node in self .nodes :
262+ bgcolor = self .node_color .get (node , self .default_nodecolor )
263+ color = (0 , 0 , 0 ) if is_color_light (bgcolor ) else (1 , 1 , 1 )
264+
265+ text = self .node_text .get (node , None )
266+ print (text )
267+ if text is None :
268+ continue
269+
270+ x , y = self .node_xyz [node ][:2 ]
271+ artist = self .plotter .axes .text (
272+ x , y ,
273+ f'{ text } ' ,
274+ fontsize = self .plotter .fontsize ,
275+ family = 'monospace' ,
276+ ha = 'center' , va = 'center' ,
277+ zorder = 10000 ,
278+ color = color
279+ )
280+ labels .append (artist )
281+
282+ self ._nodelabelcollection = labels
283+
284+ def draw_edgelabels (self , text : Optional [Dict [int , str ]] = None ) -> None :
285+ """Draw a selection of edge labels.
286+
287+ Parameters
288+ ----------
289+ text : dict of tuple of int to str
290+ An edge-label map.
291+
292+ Returns
293+ -------
294+ None
295+ """
296+ if self ._edgelabelcollection :
297+ for artist in self ._edgelabelcollection :
298+ artist .remove ()
299+
300+ if text :
301+ self .edge_text = text
302+
303+ labels = []
304+ for edge in self .edges :
305+ u , v = edge
306+ text = self .edge_text .get (edge , self .edge_text .get ((v , u ), None ))
307+ if text is None :
308+ continue
309+
310+ x0 , y0 = self .node_xyz [edge [0 ]][:2 ]
311+ x1 , y1 = self .node_xyz [edge [1 ]][:2 ]
312+ x = 0.5 * (x0 + x1 )
313+ y = 0.5 * (y0 + y1 )
314+
315+ artist = self .plotter .axes .text (
316+ x , y , f'{ text } ' ,
317+ fontsize = self .plotter .fontsize ,
318+ family = 'monospace' ,
319+ ha = 'center' , va = 'center' ,
320+ zorder = 10000 ,
321+ color = (0 , 0 , 0 ),
322+ bbox = dict (boxstyle = 'round, pad=0.3' , facecolor = (1 , 1 , 1 ), edgecolor = None , linewidth = 0 )
323+ )
324+ labels .append (artist )
325+
326+ self ._edgelabelcollection = labels
0 commit comments