Skip to content

Commit 4e6f7ac

Browse files
committed
improve node graph visualization
1 parent 9b339b9 commit 4e6f7ac

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

brainpy/math/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
from .autograd import *
4747
from .controls import *
4848
from .jit import *
49-
# from .parallels import *
5049

5150
# settings
5251
from . import setting

brainpy/nn/base.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,17 +1335,24 @@ def forward(self,
13351335

13361336
def plot_node_graph(self,
13371337
fig_size: tuple = (10, 10),
1338-
node_size: int = 2000,
1338+
node_size: int = 1000,
13391339
arrow_size: int = 20,
1340-
layout='shell_layout'):
1340+
layout='shell_layout',
1341+
show=True,
1342+
legends=None,
1343+
ax=None):
13411344
"""Plot the node graph based on NetworkX package
13421345
13431346
Parameters
13441347
----------
13451348
fig_size: tuple, default to (10, 10)
13461349
The size of the figure
1347-
node_size: int, default to 2000
1348-
The size of the node
1350+
1351+
.. deprecated:: 2.1.9
1352+
Please use ``ax`` variable.
1353+
1354+
node_size: int
1355+
The size of the node. default to 1000
13491356
arrow_size:int, default to 20
13501357
The size of the arrow
13511358
layout: str
@@ -1412,11 +1419,15 @@ def plot_node_graph(self,
14121419
raise UnsupportedError(f'Only support layouts: {SUPPORTED_LAYOUTS}')
14131420
layout = getattr(nx, layout)(G)
14141421

1415-
plt.figure(figsize=fig_size)
1422+
if ax is None:
1423+
from brainpy.visualization.figures import get_figure
1424+
fig, gs = get_figure(1, 1, fig_size[1], fig_size[0])
1425+
ax = fig.add_subplot(gs[0, 0])
14161426
nx.draw_networkx_nodes(G, pos=layout,
14171427
nodelist=nodes_trainable,
14181428
node_color=trainable_color,
1419-
node_size=node_size)
1429+
node_size=node_size,
1430+
ax=ax)
14201431
nx.draw_networkx_nodes(G, pos=layout,
14211432
nodelist=nodes_untrainable,
14221433
node_color=untrainable_color,
@@ -1449,12 +1460,10 @@ def plot_node_graph(self,
14491460
proxie = []
14501461
labels = []
14511462
if len(nodes_trainable):
1452-
proxie.append(Line2D([], [], color='white', marker='o',
1453-
markerfacecolor=trainable_color))
1463+
proxie.append(Line2D([], [], color='white', marker='o', markerfacecolor=trainable_color))
14541464
labels.append('Trainable')
14551465
if len(nodes_untrainable):
1456-
proxie.append(Line2D([], [], color='white', marker='o',
1457-
markerfacecolor=untrainable_color))
1466+
proxie.append(Line2D([], [], color='white', marker='o', markerfacecolor=untrainable_color))
14581467
labels.append('Nontrainable')
14591468
if len(ff_edges):
14601469
proxie.append(Line2D([], [], color=ff_color, linewidth=2))
@@ -1466,9 +1475,11 @@ def plot_node_graph(self,
14661475
proxie.append(Line2D([], [], color=rec_color, linewidth=2))
14671476
labels.append('Recurrent')
14681477

1469-
plt.legend(proxie, labels, scatterpoints=1, markerscale=2, loc='best')
1470-
plt.tight_layout()
1471-
plt.show()
1478+
legends = dict() if legends is None else legends
1479+
ax.legend(proxie, labels, scatterpoints=1, markerscale=2, loc='best', **legends)
1480+
if show:
1481+
plt.tight_layout()
1482+
plt.show()
14721483

14731484

14741485
class FrozenNetwork(Network):

0 commit comments

Comments
 (0)