@@ -1335,17 +1335,24 @@ def forward(self,
13351335
13361336 def plot_node_graph (self ,
13371337 fig_size : tuple = (10 , 10 ),
1338- node_size : int = 2000 ,
1338+ node_size : int = 1000 ,
13391339 arrow_size : int = 20 ,
1340- layout = 'shell_layout' ):
1340+ layout = 'shell_layout' ,
1341+ show = True ,
1342+ legends = None ,
1343+ ax = None ):
13411344 """Plot the node graph based on NetworkX package
13421345
13431346 Parameters
13441347 ----------
13451348 fig_size: tuple, default to (10, 10)
13461349 The size of the figure
1347- node_size: int, default to 2000
1348- The size of the node
1350+
1351+ .. deprecated:: 2.1.9
1352+ Please use ``ax`` variable.
1353+
1354+ node_size: int
1355+ The size of the node. default to 1000
13491356 arrow_size:int, default to 20
13501357 The size of the arrow
13511358 layout: str
@@ -1412,11 +1419,15 @@ def plot_node_graph(self,
14121419 raise UnsupportedError (f'Only support layouts: { SUPPORTED_LAYOUTS } ' )
14131420 layout = getattr (nx , layout )(G )
14141421
1415- plt .figure (figsize = fig_size )
1422+ if ax is None :
1423+ from brainpy .visualization .figures import get_figure
1424+ fig , gs = get_figure (1 , 1 , fig_size [1 ], fig_size [0 ])
1425+ ax = fig .add_subplot (gs [0 , 0 ])
14161426 nx .draw_networkx_nodes (G , pos = layout ,
14171427 nodelist = nodes_trainable ,
14181428 node_color = trainable_color ,
1419- node_size = node_size )
1429+ node_size = node_size ,
1430+ ax = ax )
14201431 nx .draw_networkx_nodes (G , pos = layout ,
14211432 nodelist = nodes_untrainable ,
14221433 node_color = untrainable_color ,
@@ -1449,12 +1460,10 @@ def plot_node_graph(self,
14491460 proxie = []
14501461 labels = []
14511462 if len (nodes_trainable ):
1452- proxie .append (Line2D ([], [], color = 'white' , marker = 'o' ,
1453- markerfacecolor = trainable_color ))
1463+ proxie .append (Line2D ([], [], color = 'white' , marker = 'o' , markerfacecolor = trainable_color ))
14541464 labels .append ('Trainable' )
14551465 if len (nodes_untrainable ):
1456- proxie .append (Line2D ([], [], color = 'white' , marker = 'o' ,
1457- markerfacecolor = untrainable_color ))
1466+ proxie .append (Line2D ([], [], color = 'white' , marker = 'o' , markerfacecolor = untrainable_color ))
14581467 labels .append ('Nontrainable' )
14591468 if len (ff_edges ):
14601469 proxie .append (Line2D ([], [], color = ff_color , linewidth = 2 ))
@@ -1466,9 +1475,11 @@ def plot_node_graph(self,
14661475 proxie .append (Line2D ([], [], color = rec_color , linewidth = 2 ))
14671476 labels .append ('Recurrent' )
14681477
1469- plt .legend (proxie , labels , scatterpoints = 1 , markerscale = 2 , loc = 'best' )
1470- plt .tight_layout ()
1471- plt .show ()
1478+ legends = dict () if legends is None else legends
1479+ ax .legend (proxie , labels , scatterpoints = 1 , markerscale = 2 , loc = 'best' , ** legends )
1480+ if show :
1481+ plt .tight_layout ()
1482+ plt .show ()
14721483
14731484
14741485class FrozenNetwork (Network ):
0 commit comments