Skip to content

Commit 532dafb

Browse files
authored
Implement function to visualize connection matrix (#405)
Implement function to visualize connection matrix
2 parents 98f8aef + 91c1c3d commit 532dafb

File tree

4 files changed

+195
-0
lines changed

4 files changed

+195
-0
lines changed

brainpy/_src/connect/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
from brainpy import tools, math as bm
1010
from brainpy.errors import ConnectorError
1111

12+
import matplotlib.pyplot as plt
13+
import seaborn as sns
14+
import textwrap
15+
1216
__all__ = [
1317
# the connection types
1418
'CONN_MAT',
@@ -27,6 +31,9 @@
2731
'mat2coo', 'mat2csc', 'mat2csr',
2832
'csr2csc', 'csr2mat', 'csr2coo',
2933
'coo2csr', 'coo2csc', 'coo2mat',
34+
35+
# visualize
36+
'visualizeMat',
3037
]
3138

3239
CONN_MAT = 'conn_mat'
@@ -719,3 +726,10 @@ def coo2csc(coo, post_num, data=None):
719726
else:
720727
data_new = data[sort_ids]
721728
return pre_ids_new, indptr_new, data_new
729+
730+
731+
def visualizeMat(mat, description):
732+
sns.heatmap(mat, cmap='viridis')
733+
warpped_title = textwrap.fill(description, width=60)
734+
plt.title(warpped_title)
735+
plt.show()
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import pytest
4+
5+
import unittest
6+
7+
import brainpy as bp
8+
9+
10+
def test_random_fix_pre1():
11+
for num in [0.4, 20]:
12+
conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
13+
mat1 = conn1.require(bp.connect.CONN_MAT)
14+
15+
conn2 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
16+
mat2 = conn2.require(bp.connect.CONN_MAT)
17+
18+
print()
19+
print(f'num = {num}')
20+
print('conn_mat 1\n', mat1)
21+
print(mat1.sum())
22+
print('conn_mat 2\n', mat2)
23+
print(mat2.sum())
24+
25+
assert bp.math.array_equal(mat1, mat2)
26+
bp.connect.visualizeMat(mat1, 'FixedPreNum: num=%s pre_size=(10, 15), post_size=(10, 20)' % num)
27+
28+
29+
def test_random_fix_pre2():
30+
for num in [0.5, 3]:
31+
conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=5, post_size=4)
32+
mat1 = conn1.require(bp.connect.CONN_MAT)
33+
print()
34+
print(mat1)
35+
36+
bp.connect.visualizeMat(mat1, 'FixedPreNum: num=%s pre_size=5, post_size=4' % num)
37+
38+
39+
def test_random_fix_pre3():
40+
with pytest.raises(bp.errors.ConnectorError):
41+
conn1 = bp.connect.FixedPreNum(num=6, seed=1234)(pre_size=3, post_size=4)
42+
conn1.require(bp.connect.CONN_MAT)
43+
44+
bp.connect.visualizeMat(conn1, 'FixedPreNum: num=6, pre_size=3, post_size=4')
45+
46+
47+
def test_random_fix_post1():
48+
for num in [0.4, 20]:
49+
conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
50+
mat1 = conn1.require(bp.connect.CONN_MAT)
51+
52+
conn2 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
53+
mat2 = conn2.require(bp.connect.CONN_MAT)
54+
55+
print()
56+
print('conn_mat 1\n', mat1)
57+
print('conn_mat 2\n', mat2)
58+
59+
assert bp.math.array_equal(mat1, mat2)
60+
bp.connect.visualizeMat(mat1, 'FixedPostNum: num=%s pre_size=(10, 15), post_size=(10, 20)' % num)
61+
62+
63+
def test_random_fix_post2():
64+
for num in [0.5, 3]:
65+
conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=5, post_size=4)
66+
mat1 = conn1.require(bp.connect.CONN_MAT)
67+
print(mat1)
68+
bp.connect.visualizeMat(mat1, 'FixedPostNum: num=%s pre_size=5, post_size=4' % num)
69+
70+
71+
def test_random_fix_post3():
72+
with pytest.raises(bp.errors.ConnectorError):
73+
conn1 = bp.connect.FixedPostNum(num=6, seed=1234)(pre_size=3, post_size=4)
74+
conn1.require(bp.connect.CONN_MAT)
75+
bp.connect.visualizeMat(conn1, 'FixedPostNum: num=6, pre_size=3, post_size=4')
76+
77+
78+
def test_gaussian_prob1():
79+
conn = bp.connect.GaussianProb(sigma=1., include_self=False)(pre_size=100)
80+
mat = conn.require(bp.connect.CONN_MAT)
81+
82+
print()
83+
print('conn_mat', mat)
84+
bp.connect.visualizeMat(mat, 'GaussianProb: sigma=1., include_self=False, pre_size=100')
85+
86+
87+
def test_gaussian_prob2():
88+
conn = bp.connect.GaussianProb(sigma=4)(pre_size=(50, 50))
89+
mat = conn.require(bp.connect.CONN_MAT)
90+
91+
print()
92+
print('conn_mat', mat)
93+
bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, pre_size=(50, 50)')
94+
95+
96+
def test_gaussian_prob3():
97+
conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(50, 50))
98+
mat = conn.require(bp.connect.CONN_MAT)
99+
100+
print()
101+
print('conn_mat', mat)
102+
bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, periodic_boundary=True, pre_size=(50, 50)')
103+
104+
105+
def test_gaussian_prob4():
106+
conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(10, 10, 10))
107+
conn.require(bp.connect.CONN_MAT,
108+
bp.connect.PRE_IDS, bp.connect.POST_IDS,
109+
bp.connect.PRE2POST, bp.connect.POST_IDS)
110+
mat = conn.require(bp.connect.CONN_MAT)
111+
bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, periodic_boundary=True, pre_size=(10, 10, 10)')
112+
113+
114+
def test_SmallWorld1():
115+
conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False)
116+
conn(pre_size=10, post_size=10)
117+
118+
mat = conn.require(bp.connect.CONN_MAT)
119+
120+
print('conn_mat', mat)
121+
bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, include_self=False, pre_size=10, post_size=10')
122+
123+
124+
def test_SmallWorld3():
125+
conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=True)
126+
conn(pre_size=20, post_size=20)
127+
128+
mat = conn.require(bp.connect.CONN_MAT)
129+
130+
print('conn_mat', mat)
131+
132+
bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, include_self=True, pre_size=20, post_size=20')
133+
134+
135+
def test_SmallWorld2():
136+
conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5)
137+
conn(pre_size=(100,), post_size=(100,))
138+
mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT,
139+
bp.connect.PRE_IDS, bp.connect.POST_IDS,
140+
bp.connect.PRE2POST, bp.connect.POST_IDS)
141+
print()
142+
print('conn_mat', mat)
143+
bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, pre_size=(100,), post_size=(100,)')
144+
145+
146+
def test_ScaleFreeBA():
147+
conn = bp.connect.ScaleFreeBA(m=2)
148+
for size in [100, (10, 20), (2, 10, 20)]:
149+
conn(pre_size=size, post_size=size)
150+
mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT,
151+
bp.connect.PRE_IDS, bp.connect.POST_IDS,
152+
bp.connect.PRE2POST, bp.connect.POST_IDS)
153+
print()
154+
print('conn_mat', mat)
155+
bp.connect.visualizeMat(mat, 'ScaleFreeBA: m=2, pre_size=%s, post_size=%s' % (size, size))
156+
157+
158+
def test_ScaleFreeBADual():
159+
conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4)
160+
for size in [100, (10, 20), (2, 10, 20)]:
161+
conn(pre_size=size, post_size=size)
162+
mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT,
163+
bp.connect.PRE_IDS, bp.connect.POST_IDS,
164+
bp.connect.PRE2POST, bp.connect.POST_IDS)
165+
print()
166+
print('conn_mat', mat)
167+
bp.connect.visualizeMat(mat, 'ScaleFreeBADual: m1=2, m2=3, p=0.4, pre_size=%s, post_size=%s' % (size, size))
168+
169+
170+
def test_PowerLaw():
171+
conn = bp.connect.PowerLaw(m=3, p=0.4)
172+
for size in [100, (10, 20), (2, 10, 20)]:
173+
conn(pre_size=size, post_size=size)
174+
mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT,
175+
bp.connect.PRE_IDS, bp.connect.POST_IDS,
176+
bp.connect.PRE2POST, bp.connect.POST_IDS)
177+
print()
178+
print('conn_mat', mat)
179+
bp.connect.visualizeMat(mat, 'PowerLaw: m=3, p=0.4, pre_size=%s, post_size=%s' % (size, size))

brainpy/connect.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
coo2csr as coo2csr,
1414
coo2csc as coo2csc,
1515
coo2mat as coo2mat,
16+
visualizeMat as visualizeMat,
1617

1718
CONN_MAT,
1819
PRE_IDS, POST_IDS,

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ scipy>=1.1.0
88
brainpylib
99
h5py
1010
pathos
11+
seaborn
1112

1213
# test requirements
1314
pytest

0 commit comments

Comments
 (0)