Skip to content

Commit f36dd10

Browse files
authored
Merge pull request #421 from brainpy/updates
Recent updates
2 parents bee7382 + d9a737b commit f36dd10

File tree

26 files changed

+1000
-401
lines changed

26 files changed

+1000
-401
lines changed

brainpy/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,14 @@
5858
DynamicalSystem as DynamicalSystem,
5959
DynSysGroup as DynSysGroup, # collectors
6060
Sequential as Sequential,
61-
Network as Network,
6261
Dynamic as Dynamic, # category
6362
Projection as Projection,
6463
)
6564
DynamicalSystemNS = DynamicalSystem
65+
Network = DynSysGroup
6666
# delays
6767
from brainpy._src.delay import (
68-
VariableDelay as VariableDelay,
68+
VariDelay as VariDelay,
6969
)
7070

7171
# building blocks

brainpy/_src/delay.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121

2222
__all__ = [
2323
'Delay',
24-
'VariableDelay',
24+
'VariDelay',
2525
'DataDelay',
26+
'DelayAccess',
2627
]
2728

2829

@@ -431,8 +432,8 @@ def _check_target_sharding(sharding, ndim, mode: bm.Mode):
431432
return sharding
432433

433434

434-
class VariableDelay(Delay):
435-
"""Delay variable which has a fixed delay length.
435+
class VariDelay(Delay):
436+
"""Generate Delays for the given :py:class:`~.Variable` instance.
436437
437438
The data in this delay variable is arranged as::
438439
@@ -517,8 +518,8 @@ def __init__(
517518

518519
# other info
519520
if entries is not None:
520-
for entry, value in entries.items():
521-
self.register_entry(entry, value)
521+
for entry, delay_time in entries.items():
522+
self.register_entry(entry, delay_time)
522523

523524
def register_entry(
524525
self,
@@ -572,11 +573,17 @@ def at(self, entry: str, *indices) -> bm.Array:
572573
raise KeyError(f'Does not find delay entry "{entry}".')
573574
delay_step = self._registered_entries[entry]
574575
if delay_step is None or delay_step == 0.:
575-
return self.target.value
576+
if len(indices):
577+
return self.target[indices]
578+
else:
579+
return self.target.value
576580
else:
577581
assert self.data is not None
578582
if delay_step == 0:
579-
return self.target.value
583+
if len(indices):
584+
return self.target[indices]
585+
else:
586+
return self.target.value
580587
else:
581588
return self.retrieve(delay_step, *indices)
582589

@@ -683,16 +690,15 @@ def _init_data(self, length: int, batch_size: int = None):
683690
self.data[:] = self._init((length,) + self.target.shape, dtype=self.target.dtype)
684691

685692

686-
class DataDelay(VariableDelay):
687-
693+
class DataDelay(VariDelay):
688694
not_desc_params = ('time', 'entries')
689695

690696
def __init__(
691697
self,
692698

693699
# delay target
694-
target: bm.Variable,
695-
target_init: Callable,
700+
data: bm.Variable,
701+
data_init: Union[Callable, bm.Array, jax.Array],
696702

697703
# delay time
698704
time: Optional[Union[int, float]] = None,
@@ -710,8 +716,8 @@ def __init__(
710716
name: Optional[str] = None,
711717
mode: Optional[bm.Mode] = None,
712718
):
713-
self.target_init = target_init
714-
super().__init__(target=target,
719+
self.target_init = data_init
720+
super().__init__(target=data,
715721
time=time,
716722
init=init,
717723
entries=entries,
@@ -736,3 +742,20 @@ def update(
736742
super().update(latest_value)
737743

738744

745+
class DelayAccess(DynamicalSystem):
746+
def __init__(
747+
self,
748+
delay: Delay,
749+
time: Union[None, int, float],
750+
*indices
751+
):
752+
super().__init__(mode=delay.mode)
753+
self.delay = delay
754+
assert isinstance(delay, Delay)
755+
delay.register_entry(self.name, time)
756+
self.indices = indices
757+
758+
def update(self):
759+
return self.delay.at(self.name, *self.indices)
760+
761+

brainpy/_src/dnn/linear.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Dict, Optional, Union, Callable
55

66
import jax
7+
import numpy as np
78
import jax.numpy as jnp
89

910
from brainpy import math as bm
@@ -63,8 +64,8 @@ def __init__(
6364
num_out: int,
6465
W_initializer: Union[Initializer, Callable, ArrayType] = XavierNormal(),
6566
b_initializer: Optional[Union[Initializer, Callable, ArrayType]] = ZeroInit(),
66-
mode: bm.Mode = None,
67-
name: str = None,
67+
mode: Optional[bm.Mode] = None,
68+
name: Optional[str] = None,
6869
):
6970
super(Dense, self).__init__(mode=mode, name=name)
7071

@@ -642,7 +643,7 @@ def __init__(
642643
num_out: int,
643644
prob: float,
644645
weight: float,
645-
seed: int,
646+
seed: Optional[int] = None,
646647
sharding: Optional[Sharding] = None,
647648
mode: Optional[bm.Mode] = None,
648649
name: Optional[str] = None,
@@ -654,7 +655,7 @@ def __init__(
654655
self.prob = prob
655656
self.sharding = sharding
656657
self.transpose = transpose
657-
self.seed = seed
658+
self.seed = np.random.randint(0, 100000) if seed is None else seed
658659
self.atomic = atomic
659660
self.num_in = num_in
660661
self.num_out = num_out
@@ -723,7 +724,7 @@ def __init__(
723724
prob: float,
724725
w_low: float,
725726
w_high: float,
726-
seed: int,
727+
seed: Optional[int] = None,
727728
sharding: Optional[Sharding] = None,
728729
mode: Optional[bm.Mode] = None,
729730
name: Optional[str] = None,
@@ -735,7 +736,7 @@ def __init__(
735736
self.prob = prob
736737
self.sharding = sharding
737738
self.transpose = transpose
738-
self.seed = seed
739+
self.seed = np.random.randint(0, 100000) if seed is None else seed
739740
self.atomic = atomic
740741
self.num_in = num_in
741742
self.num_out = num_out
@@ -803,7 +804,7 @@ def __init__(
803804
prob: float,
804805
w_mu: float,
805806
w_sigma: float,
806-
seed: int,
807+
seed: Optional[int] = None,
807808
sharding: Optional[Sharding] = None,
808809
transpose: bool = False,
809810
atomic: bool = False,
@@ -815,7 +816,7 @@ def __init__(
815816
self.prob = prob
816817
self.sharding = sharding
817818
self.transpose = transpose
818-
self.seed = seed
819+
self.seed = np.random.randint(0, 100000) if seed is None else seed
819820
self.atomic = atomic
820821
self.num_in = num_in
821822
self.num_out = num_out
@@ -881,7 +882,7 @@ def __init__(
881882
num_out: int,
882883
prob: float,
883884
weight: float,
884-
seed: int,
885+
seed: Optional[int] = None,
885886
sharding: Optional[Sharding] = None,
886887
mode: Optional[bm.Mode] = None,
887888
name: Optional[str] = None,
@@ -893,7 +894,7 @@ def __init__(
893894
self.prob = prob
894895
self.sharding = sharding
895896
self.transpose = transpose
896-
self.seed = seed
897+
self.seed = np.random.randint(0, 1000000) if seed is None else seed
897898
self.atomic = atomic
898899
self.num_in = num_in
899900
self.num_out = num_out
@@ -962,7 +963,7 @@ def __init__(
962963
prob: float,
963964
w_low: float,
964965
w_high: float,
965-
seed: int,
966+
seed: Optional[int] = None,
966967
sharding: Optional[Sharding] = None,
967968
mode: Optional[bm.Mode] = None,
968969
name: Optional[str] = None,
@@ -974,7 +975,7 @@ def __init__(
974975
self.prob = prob
975976
self.sharding = sharding
976977
self.transpose = transpose
977-
self.seed = seed
978+
self.seed = np.random.randint(0, 100000) if seed is None else seed
978979
self.atomic = atomic
979980
self.num_in = num_in
980981
self.num_out = num_out
@@ -1042,7 +1043,7 @@ def __init__(
10421043
prob: float,
10431044
w_mu: float,
10441045
w_sigma: float,
1045-
seed: int,
1046+
seed: Optional[int] = None,
10461047
sharding: Optional[Sharding] = None,
10471048
transpose: bool = False,
10481049
atomic: bool = False,
@@ -1054,7 +1055,7 @@ def __init__(
10541055
self.prob = prob
10551056
self.sharding = sharding
10561057
self.transpose = transpose
1057-
self.seed = seed
1058+
self.seed = np.random.randint(0, 100000) if seed is None else seed
10581059
self.atomic = atomic
10591060
self.num_in = num_in
10601061
self.num_out = num_out

brainpy/_src/dyn/channels/potassium_compatible.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99

1010
import brainpy.math as bm
1111
from brainpy._src.context import share
12-
from brainpy._src.dyn.channels.leaky import LeakyChannel
12+
from brainpy._src.dyn.channels.base import IonChannel
1313
from brainpy._src.dyn.neurons.hh import HHTypedNeuron
1414
from brainpy._src.initialize import Initializer, parameter, variable
1515
from brainpy._src.integrators import odeint, JointEq
1616
from brainpy.types import ArrayType
17-
from .potassium import PotassiumChannel
1817

1918
__all__ = [
2019
'IKDR_Ba2002',
@@ -29,7 +28,7 @@
2928
]
3029

3130

32-
class _IK_p4_markov(PotassiumChannel):
31+
class _IK_p4_markov(IonChannel):
3332
r"""The delayed rectifier potassium channel of :math:`p^4`
3433
current which described with first-order Markov chain.
3534
@@ -339,7 +338,7 @@ def f_p_beta(self, V):
339338
return 0.125 * bm.exp(-(V - self.V_sh + 20) / 80)
340339

341340

342-
class _IKA_p4q_ss(PotassiumChannel):
341+
class _IKA_p4q_ss(IonChannel):
343342
r"""The rapidly inactivating Potassium channel of :math:`p^4q`
344343
current which described with steady-state format.
345344
@@ -634,7 +633,7 @@ def f_q_tau(self, V):
634633
19.)
635634

636635

637-
class _IKK2_pq_ss(PotassiumChannel):
636+
class _IKK2_pq_ss(IonChannel):
638637
r"""The slowly inactivating Potassium channel of :math:`pq`
639638
current which described with steady-state format.
640639
@@ -921,7 +920,7 @@ def f_q_tau(self, V):
921920
8.9)
922921

923922

924-
class IKNI_Ya1989(PotassiumChannel):
923+
class IKNI_Ya1989(IonChannel):
925924
r"""A slow non-inactivating K+ current described by Yamada et al. (1989) [1]_.
926925
927926
This slow potassium current can effectively account for spike-frequency adaptation.
@@ -1019,7 +1018,7 @@ def f_p_tau(self, V):
10191018
return self.tau_max / (3.3 * bm.exp(temp / 20.) + bm.exp(-temp / 20.))
10201019

10211020

1022-
class IKL(LeakyChannel):
1021+
class IKL(IonChannel):
10231022
"""The potassium leak channel current.
10241023
10251024
Parameters
@@ -1031,6 +1030,8 @@ class IKL(LeakyChannel):
10311030
The reversal potential.
10321031
"""
10331032

1033+
master_type = HHTypedNeuron
1034+
10341035
def __init__(
10351036
self,
10361037
size: Union[int, Sequence[int]],

brainpy/_src/dyn/channels/sodium_compatible.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from brainpy._src.initialize import Initializer, parameter, variable
1414
from brainpy._src.integrators import odeint, JointEq
1515
from brainpy.types import ArrayType
16-
from .sodium import SodiumChannel
16+
from .base import IonChannel
1717

1818
__all__ = [
1919
'INa_Ba2002',
@@ -22,7 +22,7 @@
2222
]
2323

2424

25-
class _INa_p3q_markov(SodiumChannel):
25+
class _INa_p3q_markov(IonChannel):
2626
r"""The sodium current model of :math:`p^3q` current which described with first-order Markov chain.
2727
2828
The general model can be used to model the dynamics with:
@@ -64,7 +64,7 @@ def __init__(
6464
name: str = None,
6565
mode: bm.Mode = None,
6666
):
67-
super(_INa_p3q_markov, self).__init__(size=size,
67+
super().__init__(size=size,
6868
keep_size=keep_size,
6969
name=name,
7070
mode=mode)
@@ -173,7 +173,7 @@ def __init__(
173173
name: str = None,
174174
mode: bm.Mode = None,
175175
):
176-
super(INa_Ba2002, self).__init__(size,
176+
super().__init__(size,
177177
keep_size=keep_size,
178178
name=name,
179179
method=method,
@@ -260,7 +260,7 @@ def __init__(
260260
name: str = None,
261261
mode: bm.Mode = None,
262262
):
263-
super(INa_TM1991, self).__init__(size,
263+
super().__init__(size,
264264
keep_size=keep_size,
265265
name=name,
266266
method=method,
@@ -347,7 +347,7 @@ def __init__(
347347
name: str = None,
348348
mode: bm.Mode = None,
349349
):
350-
super(INa_HH1952, self).__init__(size,
350+
super().__init__(size,
351351
keep_size=keep_size,
352352
name=name,
353353
method=method,

brainpy/_src/dyn/ions/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,14 @@ def update(self, V):
166166
for node in self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values():
167167
node.update(V, self.C, self.E)
168168

169-
def current(self, V, C=None, E=None):
169+
def current(self, V, C=None, E=None, external: bool = False):
170170
"""Generate ion channel current.
171171
172172
Args:
173173
V: The membrane potential.
174174
C: The ion concentration.
175175
E: The reversal potential.
176+
external: Include the external current.
176177
177178
Returns:
178179
Current.
@@ -186,8 +187,9 @@ def current(self, V, C=None, E=None):
186187
if len(nodes) > 0:
187188
for node in nodes:
188189
current = current + node.current(V, C, E)
189-
for key, node in self.external.items():
190-
current = current + node(V, C, E)
190+
if external:
191+
for key, node in self.external.items():
192+
current = current + node(V, C, E)
191193
return current
192194

193195
def reset_state(self, V, batch_size=None):

0 commit comments

Comments
 (0)