Skip to content

Commit d9a737b

Browse files
committed
update brainpy package
1 parent af476ba commit d9a737b

File tree

14 files changed

+163
-232
lines changed

14 files changed

+163
-232
lines changed

brainpy/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,14 @@
5858
DynamicalSystem as DynamicalSystem,
5959
DynSysGroup as DynSysGroup, # collectors
6060
Sequential as Sequential,
61-
Network as Network,
6261
Dynamic as Dynamic, # category
6362
Projection as Projection,
6463
)
6564
DynamicalSystemNS = DynamicalSystem
65+
Network = DynSysGroup
6666
# delays
6767
from brainpy._src.delay import (
68-
VariableDelay as VariableDelay,
68+
VariDelay as VariDelay,
6969
)
7070

7171
# building blocks

brainpy/_src/delay.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121

2222
__all__ = [
2323
'Delay',
24-
'VariableDelay',
24+
'VariDelay',
2525
'DataDelay',
26+
'DelayAccess',
2627
]
2728

2829

@@ -431,8 +432,8 @@ def _check_target_sharding(sharding, ndim, mode: bm.Mode):
431432
return sharding
432433

433434

434-
class VariableDelay(Delay):
435-
"""Delay variable which has a fixed delay length.
435+
class VariDelay(Delay):
436+
"""Generate Delays for the given :py:class:`~.Variable` instance.
436437
437438
The data in this delay variable is arranged as::
438439
@@ -517,8 +518,8 @@ def __init__(
517518

518519
# other info
519520
if entries is not None:
520-
for entry, value in entries.items():
521-
self.register_entry(entry, value)
521+
for entry, delay_time in entries.items():
522+
self.register_entry(entry, delay_time)
522523

523524
def register_entry(
524525
self,
@@ -572,11 +573,17 @@ def at(self, entry: str, *indices) -> bm.Array:
572573
raise KeyError(f'Does not find delay entry "{entry}".')
573574
delay_step = self._registered_entries[entry]
574575
if delay_step is None or delay_step == 0.:
575-
return self.target.value
576+
if len(indices):
577+
return self.target[indices]
578+
else:
579+
return self.target.value
576580
else:
577581
assert self.data is not None
578582
if delay_step == 0:
579-
return self.target.value
583+
if len(indices):
584+
return self.target[indices]
585+
else:
586+
return self.target.value
580587
else:
581588
return self.retrieve(delay_step, *indices)
582589

@@ -683,16 +690,15 @@ def _init_data(self, length: int, batch_size: int = None):
683690
self.data[:] = self._init((length,) + self.target.shape, dtype=self.target.dtype)
684691

685692

686-
class DataDelay(VariableDelay):
687-
693+
class DataDelay(VariDelay):
688694
not_desc_params = ('time', 'entries')
689695

690696
def __init__(
691697
self,
692698

693699
# delay target
694-
target: bm.Variable,
695-
target_init: Callable,
700+
data: bm.Variable,
701+
data_init: Union[Callable, bm.Array, jax.Array],
696702

697703
# delay time
698704
time: Optional[Union[int, float]] = None,
@@ -710,8 +716,8 @@ def __init__(
710716
name: Optional[str] = None,
711717
mode: Optional[bm.Mode] = None,
712718
):
713-
self.target_init = target_init
714-
super().__init__(target=target,
719+
self.target_init = data_init
720+
super().__init__(target=data,
715721
time=time,
716722
init=init,
717723
entries=entries,
@@ -736,3 +742,20 @@ def update(
736742
super().update(latest_value)
737743

738744

745+
class DelayAccess(DynamicalSystem):
746+
def __init__(
747+
self,
748+
delay: Delay,
749+
time: Union[None, int, float],
750+
*indices
751+
):
752+
super().__init__(mode=delay.mode)
753+
self.delay = delay
754+
assert isinstance(delay, Delay)
755+
delay.register_entry(self.name, time)
756+
self.indices = indices
757+
758+
def update(self):
759+
return self.delay.at(self.name, *self.indices)
760+
761+

brainpy/_src/dnn/linear.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Dict, Optional, Union, Callable
55

66
import jax
7+
import numpy as np
78
import jax.numpy as jnp
89

910
from brainpy import math as bm
@@ -63,8 +64,8 @@ def __init__(
6364
num_out: int,
6465
W_initializer: Union[Initializer, Callable, ArrayType] = XavierNormal(),
6566
b_initializer: Optional[Union[Initializer, Callable, ArrayType]] = ZeroInit(),
66-
mode: bm.Mode = None,
67-
name: str = None,
67+
mode: Optional[bm.Mode] = None,
68+
name: Optional[str] = None,
6869
):
6970
super(Dense, self).__init__(mode=mode, name=name)
7071

@@ -642,7 +643,7 @@ def __init__(
642643
num_out: int,
643644
prob: float,
644645
weight: float,
645-
seed: int,
646+
seed: Optional[int] = None,
646647
sharding: Optional[Sharding] = None,
647648
mode: Optional[bm.Mode] = None,
648649
name: Optional[str] = None,
@@ -654,7 +655,7 @@ def __init__(
654655
self.prob = prob
655656
self.sharding = sharding
656657
self.transpose = transpose
657-
self.seed = seed
658+
self.seed = np.random.randint(0, 100000) if seed is None else seed
658659
self.atomic = atomic
659660
self.num_in = num_in
660661
self.num_out = num_out
@@ -723,7 +724,7 @@ def __init__(
723724
prob: float,
724725
w_low: float,
725726
w_high: float,
726-
seed: int,
727+
seed: Optional[int] = None,
727728
sharding: Optional[Sharding] = None,
728729
mode: Optional[bm.Mode] = None,
729730
name: Optional[str] = None,
@@ -735,7 +736,7 @@ def __init__(
735736
self.prob = prob
736737
self.sharding = sharding
737738
self.transpose = transpose
738-
self.seed = seed
739+
self.seed = np.random.randint(0, 100000) if seed is None else seed
739740
self.atomic = atomic
740741
self.num_in = num_in
741742
self.num_out = num_out
@@ -803,7 +804,7 @@ def __init__(
803804
prob: float,
804805
w_mu: float,
805806
w_sigma: float,
806-
seed: int,
807+
seed: Optional[int] = None,
807808
sharding: Optional[Sharding] = None,
808809
transpose: bool = False,
809810
atomic: bool = False,
@@ -815,7 +816,7 @@ def __init__(
815816
self.prob = prob
816817
self.sharding = sharding
817818
self.transpose = transpose
818-
self.seed = seed
819+
self.seed = np.random.randint(0, 100000) if seed is None else seed
819820
self.atomic = atomic
820821
self.num_in = num_in
821822
self.num_out = num_out
@@ -881,7 +882,7 @@ def __init__(
881882
num_out: int,
882883
prob: float,
883884
weight: float,
884-
seed: int,
885+
seed: Optional[int] = None,
885886
sharding: Optional[Sharding] = None,
886887
mode: Optional[bm.Mode] = None,
887888
name: Optional[str] = None,
@@ -893,7 +894,7 @@ def __init__(
893894
self.prob = prob
894895
self.sharding = sharding
895896
self.transpose = transpose
896-
self.seed = seed
897+
self.seed = np.random.randint(0, 1000000) if seed is None else seed
897898
self.atomic = atomic
898899
self.num_in = num_in
899900
self.num_out = num_out
@@ -962,7 +963,7 @@ def __init__(
962963
prob: float,
963964
w_low: float,
964965
w_high: float,
965-
seed: int,
966+
seed: Optional[int] = None,
966967
sharding: Optional[Sharding] = None,
967968
mode: Optional[bm.Mode] = None,
968969
name: Optional[str] = None,
@@ -974,7 +975,7 @@ def __init__(
974975
self.prob = prob
975976
self.sharding = sharding
976977
self.transpose = transpose
977-
self.seed = seed
978+
self.seed = np.random.randint(0, 100000) if seed is None else seed
978979
self.atomic = atomic
979980
self.num_in = num_in
980981
self.num_out = num_out
@@ -1042,7 +1043,7 @@ def __init__(
10421043
prob: float,
10431044
w_mu: float,
10441045
w_sigma: float,
1045-
seed: int,
1046+
seed: Optional[int] = None,
10461047
sharding: Optional[Sharding] = None,
10471048
transpose: bool = False,
10481049
atomic: bool = False,
@@ -1054,7 +1055,7 @@ def __init__(
10541055
self.prob = prob
10551056
self.sharding = sharding
10561057
self.transpose = transpose
1057-
self.seed = seed
1058+
self.seed = np.random.randint(0, 100000) if seed is None else seed
10581059
self.atomic = atomic
10591060
self.num_in = num_in
10601061
self.num_out = num_out

brainpy/_src/dyn/others/input.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ def __init__(
4040
mode: Optional[bm.Mode] = None,
4141
name: Optional[str] = None,
4242
):
43-
super(InputGroup, self).__init__(name=name,
44-
sharding=sharding,
45-
size=size,
46-
keep_size=keep_size,
47-
mode=mode)
43+
super().__init__(name=name,
44+
sharding=sharding,
45+
size=size,
46+
keep_size=keep_size,
47+
mode=mode)
4848

4949
def update(self, x):
5050
return x
@@ -74,11 +74,11 @@ def __init__(
7474
mode: Optional[bm.Mode] = None,
7575
name: Optional[str] = None,
7676
):
77-
super(OutputGroup, self).__init__(name=name,
78-
sharding=sharding,
79-
size=size,
80-
keep_size=keep_size,
81-
mode=mode)
77+
super().__init__(name=name,
78+
sharding=sharding,
79+
size=size,
80+
keep_size=keep_size,
81+
mode=mode)
8282

8383
def update(self, x):
8484
return x
@@ -130,11 +130,11 @@ def __init__(
130130
mode: Optional[bm.Mode] = None,
131131
need_sort: bool = True,
132132
):
133-
super(SpikeTimeGroup, self).__init__(size=size,
134-
sharding=sharding,
135-
name=name,
136-
keep_size=keep_size,
137-
mode=mode)
133+
super().__init__(size=size,
134+
sharding=sharding,
135+
name=name,
136+
keep_size=keep_size,
137+
mode=mode)
138138

139139
# parameters
140140
if keep_size:
@@ -202,11 +202,11 @@ def __init__(
202202
mode: Optional[bm.Mode] = None,
203203
seed=None,
204204
):
205-
super(PoissonGroup, self).__init__(size=size,
206-
sharding=sharding,
207-
name=name,
208-
keep_size=keep_size,
209-
mode=mode)
205+
super().__init__(size=size,
206+
sharding=sharding,
207+
name=name,
208+
keep_size=keep_size,
209+
mode=mode)
210210

211211
if seed is not None:
212212
warnings.warn('')

brainpy/_src/dyn/synapses/abstract_models.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,8 @@ def add_current(self, inp):
334334
self.g_decay += inp
335335

336336
def return_info(self):
337-
return ReturnInfo(self.varshape, self.sharding, self.mode, bm.zeros)
337+
return ReturnInfo(self.varshape, self.sharding, self.mode,
338+
lambda shape: self.coeff * (self.g_decay - self.g_rise))
338339

339340

340341
DualExponV2.__doc__ = DualExponV2.__doc__ % (pneu_doc,)
@@ -677,22 +678,21 @@ def update(self, pre_spike):
677678
t = share.load('t')
678679
dt = share.load('dt')
679680
u, x = self.integral(self.u.value, self.x.value, t, dt)
680-
if pre_spike.dtype == jax.numpy.bool_:
681-
u = bm.where(pre_spike, u + self.U * (1 - self.u), u)
682-
x = bm.where(pre_spike, x - u * self.x, x)
683-
else:
684-
u = pre_spike * (u + self.U * (1 - self.u)) + (1 - pre_spike) * u
685-
x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x
681+
# if pre_spike.dtype == jax.numpy.bool_:
682+
# u = bm.where(pre_spike, u + self.U * (1 - self.u), u)
683+
# x = bm.where(pre_spike, x - u * self.x, x)
684+
# else:
685+
# u = pre_spike * (u + self.U * (1 - self.u)) + (1 - pre_spike) * u
686+
# x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x
687+
u = pre_spike * self.U * (1 - self.u) + u
688+
x = pre_spike * -u * self.x + x
686689
self.x.value = x
687690
self.u.value = u
688691
return u * x
689692

690693
def return_info(self):
691-
return ReturnInfo(size=self.varshape,
692-
batch_or_mode=self.mode,
693-
axis_names=self.sharding,
694-
init=Constant(self.U))
694+
return ReturnInfo(self.varshape, self.sharding, self.mode,
695+
lambda shape: self.u * self.x)
695696

696697

697698
STP.__doc__ = STP.__doc__ % (pneu_doc,)
698-

brainpy/_src/math/object_transform/jit.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,13 +405,14 @@ def _make_jit_fun(
405405

406406
@wraps(fun)
407407
def call_fun(self, *args, **kwargs):
408-
fun2 = partial(fun, self)
409408
if jax.config.jax_disable_jit:
410-
return fun2(*args, **kwargs)
409+
return fun(self, *args, **kwargs)
411410

412411
hash_v = hash(fun) + hash(self)
413412
cache = get_stack_cache(hash_v) # TODO: better cache mechanism
414413
if cache is None:
414+
fun2 = partial(fun, self)
415+
415416
with jax.ensure_compile_time_eval():
416417
if len(static_argnums) or len(static_argnames):
417418
fun3, args_, kwargs_ = _partial_fun(fun2, args, kwargs, static_argnums, static_argnames)

brainpy/_src/math/object_transform/variables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def dict_data(self) -> dict:
7171
"""Get all data in the collected variables with a python dict structure."""
7272
new_dict = dict()
7373
for id_, elem in tuple(self.items()):
74-
new_dict[id_] = elem.value if isinstance(elem, Array) else elem
74+
new_dict[id_] = elem._value if isinstance(elem, Array) else elem
7575
return new_dict
7676

7777
def list_data(self) -> list:

0 commit comments

Comments
 (0)