|
1 | 1 | # -*- coding: utf-8 -*- |
2 | | - |
| 2 | +from functools import partial |
3 | 3 | from typing import Optional |
4 | 4 |
|
5 | | -import jax.numpy as jnp |
| 5 | +from jax import vmap, jit, numpy as jnp |
6 | 6 | import numpy as np |
7 | 7 |
|
8 | 8 | import brainpy.math as bm |
@@ -327,6 +327,21 @@ def build_csr(self): |
327 | 327 | selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums])) |
328 | 328 | return selected_post_ids.astype(IDX_DTYPE), selected_pre_inptr.astype(IDX_DTYPE) |
329 | 329 |
|
| 330 | +@jit |
| 331 | +@partial(vmap, in_axes=(0, None, None)) |
| 332 | +def gaussian_prob_dist_cal1(i_value, post_values, sigma): |
| 333 | + dists = jnp.abs(i_value - post_values) |
| 334 | + exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) |
| 335 | + return bm.asarray(exp_dists) |
| 336 | + |
| 337 | +@jit |
| 338 | +@partial(vmap, in_axes=(0, None, None, None)) |
| 339 | +def gaussian_prob_dist_cal2(i_value, post_values, value_sizes, sigma): |
| 340 | + dists = jnp.abs(i_value - post_values) |
| 341 | + dists = jnp.where(dists > (value_sizes / 2), value_sizes - dists, dists) |
| 342 | + exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) |
| 343 | + return bm.asarray(exp_dists) |
| 344 | + |
330 | 345 |
|
331 | 346 | class GaussianProb(OneEndConnector): |
332 | 347 | r"""Builds a Gaussian connectivity pattern within a population of neurons, |
@@ -392,7 +407,8 @@ def __repr__(self): |
392 | 407 | f'include_self={self.include_self}, ' |
393 | 408 | f'seed={self.seed})') |
394 | 409 |
|
395 | | - def build_mat(self, pre_size=None, post_size=None): |
| 410 | + def build_mat(self, isOptimized=True): |
| 411 | + self.rng = np.random.RandomState(self.seed) |
396 | 412 | # value range to encode |
397 | 413 | if self.encoding_values is None: |
398 | 414 | value_ranges = tuple([(0, s) for s in self.pre_size]) |
@@ -426,24 +442,45 @@ def build_mat(self, pre_size=None, post_size=None): |
426 | 442 | value_sizes = np.expand_dims(value_sizes, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) |
427 | 443 |
|
428 | 444 | # probability of connections |
429 | | - prob_mat = [] |
430 | | - for i in range(self.pre_num): |
431 | | - # values for node i |
432 | | - i_coordinate = tuple() |
433 | | - for s in self.pre_size[:-1]: |
434 | | - i, pos = divmod(i, s) |
435 | | - i_coordinate += (pos,) |
436 | | - i_coordinate += (i,) |
437 | | - i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) |
438 | | - if i_value.ndim < post_values.ndim: |
439 | | - i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) |
440 | | - # distances |
441 | | - dists = np.abs(i_value - post_values) |
| 445 | + if isOptimized: |
| 446 | + i_value_list = np.zeros(shape=(self.pre_num, len(self.pre_size), 1)) |
| 447 | + for i in range(self.pre_num): |
| 448 | + list_index = i |
| 449 | + # values for node i |
| 450 | + i_coordinate = tuple() |
| 451 | + for s in self.pre_size[:-1]: |
| 452 | + i, pos = divmod(i, s) |
| 453 | + i_coordinate += (pos,) |
| 454 | + i_coordinate += (i,) |
| 455 | + i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) |
| 456 | + if i_value.ndim < post_values.ndim: |
| 457 | + i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) |
| 458 | + i_value_list[list_index] = i_value |
| 459 | + |
442 | 460 | if self.periodic_boundary: |
443 | | - dists = np.where(dists > value_sizes / 2, value_sizes - dists, dists) |
444 | | - exp_dists = np.exp(-(np.linalg.norm(dists, axis=0) / self.sigma) ** 2 / 2) |
445 | | - prob_mat.append(exp_dists) |
446 | | - prob_mat = np.stack(prob_mat) |
| 461 | + prob_mat = gaussian_prob_dist_cal2(i_value_list, post_values, value_sizes, self.sigma) |
| 462 | + else: |
| 463 | + prob_mat = gaussian_prob_dist_cal1(i_value_list, post_values, self.sigma) |
| 464 | + else: |
| 465 | + prob_mat = [] |
| 466 | + for i in range(self.pre_num): |
| 467 | + # values for node i |
| 468 | + i_coordinate = tuple() |
| 469 | + for s in self.pre_size[:-1]: |
| 470 | + i, pos = divmod(i, s) |
| 471 | + i_coordinate += (pos,) |
| 472 | + i_coordinate += (i,) |
| 473 | + i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) |
| 474 | + if i_value.ndim < post_values.ndim: |
| 475 | + i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) |
| 476 | + # distances |
| 477 | + dists = np.abs(i_value - post_values) |
| 478 | + if self.periodic_boundary: |
| 479 | + dists = np.where(dists > value_sizes / 2, value_sizes - dists, dists) |
| 480 | + exp_dists = np.exp(-(np.linalg.norm(dists, axis=0) / self.sigma) ** 2 / 2) |
| 481 | + prob_mat.append(exp_dists) |
| 482 | + prob_mat = np.stack(prob_mat) |
| 483 | + |
447 | 484 | if self.normalize: |
448 | 485 | prob_mat /= prob_mat.max() |
449 | 486 |
|
|
0 commit comments