Skip to content

Commit 1c2df42

Browse files
committed
Merge branch 'master' into updates
2 parents 1df9c76 + 9bb11a8 commit 1c2df42

File tree

6 files changed

+492
-364
lines changed

6 files changed

+492
-364
lines changed

brainpy/connect/custom_conn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, conn_mat):
2525
self.pre_num, self.post_num = conn_mat.shape
2626
self.pre_size, self.post_size = (self.pre_num,), (self.post_num,)
2727

28-
self.conn_mat = np.asarray(conn_mat, dtype=MAT_DTYPE)
28+
self.conn_mat = np.asarray(conn_mat).astype(MAT_DTYPE)
2929

3030
def __call__(self, pre_size, post_size):
3131
assert self.pre_num == tools.size2num(pre_size)
@@ -47,8 +47,8 @@ def __init__(self, i, j):
4747
assert i.size == j.size
4848

4949
# initialize the class via "pre_ids" and "post_ids"
50-
self.pre_ids = np.asarray(i, dtype=IDX_DTYPE)
51-
self.post_ids = np.asarray(j, dtype=IDX_DTYPE)
50+
self.pre_ids = np.asarray(i).astype(IDX_DTYPE)
51+
self.post_ids = np.asarray(j).astype(IDX_DTYPE)
5252

5353
def __call__(self, pre_size, post_size):
5454
super(IJConn, self).__call__(pre_size, post_size)
@@ -80,7 +80,7 @@ def __init__(self, csr_mat):
8080
f'Please run "pip install scipy" to install scipy.')
8181

8282
assert isinstance(csr_mat, csr_matrix)
83-
csr_mat.data = np.asarray(csr_mat.data, dtype=MAT_DTYPE)
83+
csr_mat.data = np.asarray(csr_mat.data).astype(MAT_DTYPE)
8484
self.csr_mat = csr_mat
8585
self.pre_num, self.post_num = csr_mat.shape
8686

brainpy/dyn/base.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -483,16 +483,19 @@ def register_delay(
483483
def get_delay(
484484
self,
485485
name: str,
486-
delay_step: Union[int, bm.JaxArray, bm.ndarray]
486+
delay_step: Union[int, bm.JaxArray, bm.ndarray],
487+
indices=None,
487488
):
488-
"""Get delay data according to the delay times.
489+
"""Get delay data according to the provided delay steps.
489490
490491
Parameters
491492
----------
492493
name: str
493494
The delay variable name.
494495
delay_step: int, JaxArray, ndarray
495496
The delay length.
497+
indices: optional, JaxArray, ndarray
498+
The indices of the delay.
496499
497500
Returns
498501
-------
@@ -501,14 +504,18 @@ def get_delay(
501504
"""
502505
if name in self.global_delay_vars:
503506
if isinstance(delay_step, int):
504-
return self.global_delay_vars[name](delay_step)
507+
return self.global_delay_vars[name](delay_step, indices)
505508
else:
506-
return self.global_delay_vars[name](delay_step, jnp.arange(delay_step.size))
509+
if indices is None:
510+
indices = jnp.arange(delay_step.size)
511+
return self.global_delay_vars[name](delay_step, indices)
507512
elif name in self.local_delay_vars:
508513
if isinstance(delay_step, int):
509514
return self.local_delay_vars[name](delay_step)
510515
else:
511-
return self.local_delay_vars[name](delay_step, jnp.arange(delay_step.size))
516+
if indices is None:
517+
indices = jnp.arange(delay_step.size)
518+
return self.local_delay_vars[name](delay_step, indices)
512519
else:
513520
raise ValueError(f'{name} is not defined in delay variables.')
514521

0 commit comments

Comments
 (0)