Skip to content

Commit d5e7655

Browse files
authored
Merge pull request #406 from Routhleck/master
Optimize GaussianProb
2 parents 532dafb + b5da3e0 commit d5e7655

File tree

4 files changed

+137
-23
lines changed

4 files changed

+137
-23
lines changed

brainpy/_src/connect/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
from brainpy import tools, math as bm
1010
from brainpy.errors import ConnectorError
1111

12-
import matplotlib.pyplot as plt
13-
import seaborn as sns
1412
import textwrap
1513

1614
__all__ = [
@@ -729,6 +727,12 @@ def coo2csc(coo, post_num, data=None):
729727

730728

731729
def visualizeMat(mat, description):
730+
try:
731+
import seaborn as sns
732+
import matplotlib.pyplot as plt
733+
except (ModuleNotFoundError, ImportError):
734+
print('Please install seaborn and matplotlib for this function')
735+
return
732736
sns.heatmap(mat, cmap='viridis')
733737
warpped_title = textwrap.fill(description, width=60)
734738
plt.title(warpped_title)

brainpy/_src/connect/random_conn.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# -*- coding: utf-8 -*-
2-
2+
from functools import partial
33
from typing import Optional
44

5-
import jax.numpy as jnp
5+
from jax import vmap, jit, numpy as jnp
66
import numpy as np
77

88
import brainpy.math as bm
@@ -327,6 +327,21 @@ def build_csr(self):
327327
selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums]))
328328
return selected_post_ids.astype(IDX_DTYPE), selected_pre_inptr.astype(IDX_DTYPE)
329329

330+
@jit
331+
@partial(vmap, in_axes=(0, None, None))
332+
def gaussian_prob_dist_cal1(i_value, post_values, sigma):
333+
dists = jnp.abs(i_value - post_values)
334+
exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2)
335+
return bm.asarray(exp_dists)
336+
337+
@jit
338+
@partial(vmap, in_axes=(0, None, None, None))
339+
def gaussian_prob_dist_cal2(i_value, post_values, value_sizes, sigma):
340+
dists = jnp.abs(i_value - post_values)
341+
dists = jnp.where(dists > (value_sizes / 2), value_sizes - dists, dists)
342+
exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2)
343+
return bm.asarray(exp_dists)
344+
330345

331346
class GaussianProb(OneEndConnector):
332347
r"""Builds a Gaussian connectivity pattern within a population of neurons,
@@ -392,7 +407,8 @@ def __repr__(self):
392407
f'include_self={self.include_self}, '
393408
f'seed={self.seed})')
394409

395-
def build_mat(self, pre_size=None, post_size=None):
410+
def build_mat(self, isOptimized=True):
411+
self.rng = np.random.RandomState(self.seed)
396412
# value range to encode
397413
if self.encoding_values is None:
398414
value_ranges = tuple([(0, s) for s in self.pre_size])
@@ -426,24 +442,45 @@ def build_mat(self, pre_size=None, post_size=None):
426442
value_sizes = np.expand_dims(value_sizes, axis=tuple([i + 1 for i in range(post_values.ndim - 1)]))
427443

428444
# probability of connections
429-
prob_mat = []
430-
for i in range(self.pre_num):
431-
# values for node i
432-
i_coordinate = tuple()
433-
for s in self.pre_size[:-1]:
434-
i, pos = divmod(i, s)
435-
i_coordinate += (pos,)
436-
i_coordinate += (i,)
437-
i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)])
438-
if i_value.ndim < post_values.ndim:
439-
i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)]))
440-
# distances
441-
dists = np.abs(i_value - post_values)
445+
if isOptimized:
446+
i_value_list = np.zeros(shape=(self.pre_num, len(self.pre_size), 1))
447+
for i in range(self.pre_num):
448+
list_index = i
449+
# values for node i
450+
i_coordinate = tuple()
451+
for s in self.pre_size[:-1]:
452+
i, pos = divmod(i, s)
453+
i_coordinate += (pos,)
454+
i_coordinate += (i,)
455+
i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)])
456+
if i_value.ndim < post_values.ndim:
457+
i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)]))
458+
i_value_list[list_index] = i_value
459+
442460
if self.periodic_boundary:
443-
dists = np.where(dists > value_sizes / 2, value_sizes - dists, dists)
444-
exp_dists = np.exp(-(np.linalg.norm(dists, axis=0) / self.sigma) ** 2 / 2)
445-
prob_mat.append(exp_dists)
446-
prob_mat = np.stack(prob_mat)
461+
prob_mat = gaussian_prob_dist_cal2(i_value_list, post_values, value_sizes, self.sigma)
462+
else:
463+
prob_mat = gaussian_prob_dist_cal1(i_value_list, post_values, self.sigma)
464+
else:
465+
prob_mat = []
466+
for i in range(self.pre_num):
467+
# values for node i
468+
i_coordinate = tuple()
469+
for s in self.pre_size[:-1]:
470+
i, pos = divmod(i, s)
471+
i_coordinate += (pos,)
472+
i_coordinate += (i,)
473+
i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)])
474+
if i_value.ndim < post_values.ndim:
475+
i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)]))
476+
# distances
477+
dists = np.abs(i_value - post_values)
478+
if self.periodic_boundary:
479+
dists = np.where(dists > value_sizes / 2, value_sizes - dists, dists)
480+
exp_dists = np.exp(-(np.linalg.norm(dists, axis=0) / self.sigma) ** 2 / 2)
481+
prob_mat.append(exp_dists)
482+
prob_mat = np.stack(prob_mat)
483+
447484
if self.normalize:
448485
prob_mat /= prob_mat.max()
449486

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import pytest
4+
5+
import unittest
6+
7+
import brainpy as bp
8+
9+
from time import time
10+
11+
12+
def test_gaussian_prob1():
13+
conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123)(pre_size=100)
14+
15+
mat = conn.build_mat(isOptimized=True)
16+
time0 = time()
17+
mat1 = conn.build_mat(isOptimized=True)
18+
time_optimized = time() - time0
19+
20+
time0 = time()
21+
mat2 = conn.build_mat(isOptimized=False)
22+
time_origin = time() - time0
23+
24+
assert bp.math.array_equal(mat1, mat2)
25+
print()
26+
print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
27+
28+
29+
def test_gaussian_prob2():
30+
conn = bp.connect.GaussianProb(sigma=4, seed=123)(pre_size=(10, 10))
31+
mat = conn.build_mat(isOptimized=True)
32+
time0 = time()
33+
mat1 = conn.build_mat(isOptimized=True)
34+
time_optimized = time() - time0
35+
36+
time0 = time()
37+
mat2 = conn.build_mat(isOptimized=False)
38+
time_origin = time() - time0
39+
40+
assert bp.math.array_equal(mat1, mat2)
41+
print()
42+
print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
43+
44+
45+
def test_gaussian_prob3():
46+
conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True, seed=123)(pre_size=(10, 10))
47+
mat = conn.build_mat(isOptimized=True)
48+
time0 = time()
49+
mat1 = conn.build_mat(isOptimized=True)
50+
time_optimized = time() - time0
51+
52+
time0 = time()
53+
mat2 = conn.build_mat(isOptimized=False)
54+
time_origin = time() - time0
55+
56+
assert bp.math.array_equal(mat1, mat2)
57+
print()
58+
print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
59+
60+
61+
def test_gaussian_prob4():
62+
conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True, seed=123)(pre_size=(10, 10, 10))
63+
mat = conn.build_mat(isOptimized=True)
64+
time0 = time()
65+
mat1 = conn.build_mat(isOptimized=True)
66+
time_optimized = time() - time0
67+
68+
time0 = time()
69+
mat2 = conn.build_mat(isOptimized=False)
70+
time_origin = time() - time0
71+
72+
assert bp.math.array_equal(mat1, mat2)
73+
print()
74+
print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')

requirements-dev.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ scipy>=1.1.0
88
brainpylib
99
h5py
1010
pathos
11-
seaborn
1211

1312
# test requirements
1413
pytest

0 commit comments

Comments
 (0)