Skip to content

Commit 724db67

Browse files
committed
Optimize GaussianProb
Optimize GaussianProb with JAX and Create test_GaussianProb_opt.py to test its correction and optimization effect Below is the test report ============================= test session starts ============================= collecting ... collected 4 items test_GaussianProb_opt.py::test_gaussian_prob1 test_GaussianProb_opt.py::test_gaussian_prob2 test_GaussianProb_opt.py::test_gaussian_prob3 test_GaussianProb_opt.py::test_gaussian_prob4 ================== 4 passed, 1 warning in 134.70s (0:02:14) =================== PASSED [ 25%] time_optimized:0.040474891662597656 time_origin:0.16180109977722168 PASSED [ 50%] time_optimized:1.0694525241851807 time_origin:7.487141132354736 PASSED [ 75%] time_optimized:22.59926676750183 time_origin:52.102147579193115 PASSED [100%] time_optimized:2.2735044956207275 time_origin:21.08591890335083 Process finished with exit code 0
1 parent 91c1c3d commit 724db67

File tree

2 files changed

+131
-20
lines changed

2 files changed

+131
-20
lines changed

brainpy/_src/connect/random_conn.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# -*- coding: utf-8 -*-
2-
2+
from functools import partial
33
from typing import Optional
44

5-
import jax.numpy as jnp
5+
from jax import vmap, jit, numpy as jnp
66
import numpy as np
77

88
import brainpy.math as bm
@@ -327,6 +327,21 @@ def build_csr(self):
327327
selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums]))
328328
return selected_post_ids.astype(IDX_DTYPE), selected_pre_inptr.astype(IDX_DTYPE)
329329

330+
@jit
331+
@partial(vmap, in_axes=(0, None, None))
332+
def gaussian_prob_dist_cal1(i_value, post_values, sigma):
333+
dists = jnp.abs(i_value - post_values)
334+
exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2)
335+
return bm.asarray(exp_dists)
336+
337+
@jit
338+
@partial(vmap, in_axes=(0, None, None, None))
339+
def gaussian_prob_dist_cal2(i_value, post_values, value_sizes, sigma):
340+
dists = jnp.abs(i_value - post_values)
341+
dists = jnp.where(dists > (value_sizes / 2), value_sizes - dists, dists)
342+
exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2)
343+
return bm.asarray(exp_dists)
344+
330345

331346
class GaussianProb(OneEndConnector):
332347
r"""Builds a Gaussian connectivity pattern within a population of neurons,
@@ -392,7 +407,8 @@ def __repr__(self):
392407
f'include_self={self.include_self}, '
393408
f'seed={self.seed})')
394409

395-
def build_mat(self, pre_size=None, post_size=None):
410+
def build_mat(self, isOptimized=True):
411+
self.rng = np.random.RandomState(self.seed)
396412
# value range to encode
397413
if self.encoding_values is None:
398414
value_ranges = tuple([(0, s) for s in self.pre_size])
@@ -426,24 +442,45 @@ def build_mat(self, pre_size=None, post_size=None):
426442
value_sizes = np.expand_dims(value_sizes, axis=tuple([i + 1 for i in range(post_values.ndim - 1)]))
427443

428444
# probability of connections
429-
prob_mat = []
430-
for i in range(self.pre_num):
431-
# values for node i
432-
i_coordinate = tuple()
433-
for s in self.pre_size[:-1]:
434-
i, pos = divmod(i, s)
435-
i_coordinate += (pos,)
436-
i_coordinate += (i,)
437-
i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)])
438-
if i_value.ndim < post_values.ndim:
439-
i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)]))
440-
# distances
441-
dists = np.abs(i_value - post_values)
445+
if isOptimized:
446+
i_value_list = np.zeros(shape=(self.pre_num, len(self.pre_size), 1))
447+
for i in range(self.pre_num):
448+
list_index = i
449+
# values for node i
450+
i_coordinate = tuple()
451+
for s in self.pre_size[:-1]:
452+
i, pos = divmod(i, s)
453+
i_coordinate += (pos,)
454+
i_coordinate += (i,)
455+
i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)])
456+
if i_value.ndim < post_values.ndim:
457+
i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)]))
458+
i_value_list[list_index] = i_value
459+
442460
if self.periodic_boundary:
443-
dists = np.where(dists > value_sizes / 2, value_sizes - dists, dists)
444-
exp_dists = np.exp(-(np.linalg.norm(dists, axis=0) / self.sigma) ** 2 / 2)
445-
prob_mat.append(exp_dists)
446-
prob_mat = np.stack(prob_mat)
461+
prob_mat = gaussian_prob_dist_cal2(i_value_list, post_values, value_sizes, self.sigma)
462+
else:
463+
prob_mat = gaussian_prob_dist_cal1(i_value_list, post_values, self.sigma)
464+
else:
465+
prob_mat = []
466+
for i in range(self.pre_num):
467+
# values for node i
468+
i_coordinate = tuple()
469+
for s in self.pre_size[:-1]:
470+
i, pos = divmod(i, s)
471+
i_coordinate += (pos,)
472+
i_coordinate += (i,)
473+
i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)])
474+
if i_value.ndim < post_values.ndim:
475+
i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)]))
476+
# distances
477+
dists = np.abs(i_value - post_values)
478+
if self.periodic_boundary:
479+
dists = np.where(dists > value_sizes / 2, value_sizes - dists, dists)
480+
exp_dists = np.exp(-(np.linalg.norm(dists, axis=0) / self.sigma) ** 2 / 2)
481+
prob_mat.append(exp_dists)
482+
prob_mat = np.stack(prob_mat)
483+
447484
if self.normalize:
448485
prob_mat /= prob_mat.max()
449486

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import pytest
4+
5+
import unittest
6+
7+
import brainpy as bp
8+
9+
from time import time
10+
11+
12+
def test_gaussian_prob1():
13+
conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123)(pre_size=2000)
14+
15+
mat = conn.build_mat(isOptimized=True)
16+
time0 = time()
17+
mat1 = conn.build_mat(isOptimized=True)
18+
time_optimized = time() - time0
19+
20+
time0 = time()
21+
mat2 = conn.build_mat(isOptimized=False)
22+
time_origin = time() - time0
23+
24+
assert bp.math.array_equal(mat1, mat2)
25+
print()
26+
print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
27+
28+
29+
def test_gaussian_prob2():
30+
conn = bp.connect.GaussianProb(sigma=4, seed=123)(pre_size=(100, 100))
31+
mat = conn.build_mat(isOptimized=True)
32+
time0 = time()
33+
mat1 = conn.build_mat(isOptimized=True)
34+
time_optimized = time() - time0
35+
36+
time0 = time()
37+
mat2 = conn.build_mat(isOptimized=False)
38+
time_origin = time() - time0
39+
40+
assert bp.math.array_equal(mat1, mat2)
41+
print()
42+
print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
43+
44+
45+
def test_gaussian_prob3():
46+
conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True, seed=123)(pre_size=(200, 200))
47+
mat = conn.build_mat(isOptimized=True)
48+
time0 = time()
49+
mat1 = conn.build_mat(isOptimized=True)
50+
time_optimized = time() - time0
51+
52+
time0 = time()
53+
mat2 = conn.build_mat(isOptimized=False)
54+
time_origin = time() - time0
55+
56+
assert bp.math.array_equal(mat1, mat2)
57+
print()
58+
print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
59+
60+
61+
def test_gaussian_prob4():
62+
conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True, seed=123)(pre_size=(25, 25, 25))
63+
mat = conn.build_mat(isOptimized=True)
64+
time0 = time()
65+
mat1 = conn.build_mat(isOptimized=True)
66+
time_optimized = time() - time0
67+
68+
time0 = time()
69+
mat2 = conn.build_mat(isOptimized=False)
70+
time_origin = time() - time0
71+
72+
assert bp.math.array_equal(mat1, mat2)
73+
print()
74+
print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')

0 commit comments

Comments
 (0)