|
6 | 6 | from jax import vmap |
7 | 7 |
|
8 | 8 | import brainpy.math as bm |
9 | | -from brainpy.dyn.base import DynamicalSystem |
| 9 | +from brainpy.dyn.base import SynConn, SynOut |
| 10 | +from brainpy.dyn.synouts import CUBA |
10 | 11 | from brainpy.initialize import Initializer |
| 12 | +from brainpy.dyn.neurons.input_groups import InputGroup, OutputGroup |
11 | 13 | from brainpy.modes import Mode, TrainingMode, normal |
12 | 14 | from brainpy.tools.checking import check_sequence |
13 | 15 | from brainpy.types import Tensor |
|
19 | 21 | ] |
20 | 22 |
|
21 | 23 |
|
22 | | -class DelayCoupling(DynamicalSystem): |
| 24 | +class DelayCoupling(SynConn): |
23 | 25 | """Delay coupling. |
24 | 26 |
|
25 | 27 | Parameters |
@@ -49,7 +51,10 @@ def __init__( |
49 | 51 | name: str = None, |
50 | 52 | mode: Mode = normal, |
51 | 53 | ): |
52 | | - super(DelayCoupling, self).__init__(name=name, mode=mode) |
| 54 | + super(DelayCoupling, self).__init__(name=name, |
| 55 | + mode=mode, |
| 56 | + pre=InputGroup(1), |
| 57 | + post=OutputGroup(1)) |
53 | 58 |
|
54 | 59 | # delay variable |
55 | 60 | if not isinstance(delay_var, bm.Variable): |
@@ -201,8 +206,8 @@ def update(self, tdi): |
201 | 206 | indices = (slice(None, None, None), bm.arange(self.coupling_var1.size),) |
202 | 207 | else: |
203 | 208 | indices = (bm.arange(self.coupling_var1.size),) |
204 | | - f = vmap(lambda i: delay_var(self.delay_steps[:, i], *indices)) # (..., pre.num) |
205 | | - delays = f(bm.arange(self.coupling_var2.size).value) # (..., post.num, pre.num) |
| 209 | + f = vmap(lambda steps: delay_var(steps, *indices), in_axes=1) # (..., pre.num) |
| 210 | + delays = f(self.delay_steps) # (..., post.num, pre.num) |
206 | 211 | diffusive = (bm.moveaxis(delays, axis - 1, axis) - |
207 | 212 | bm.expand_dims(self.coupling_var2, axis=axis - 1)) # (..., pre.num, post.num) |
208 | 213 | diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1) |
@@ -284,8 +289,8 @@ def update(self, tdi): |
284 | 289 | indices = (slice(None, None, None), bm.arange(self.coupling_var.size),) |
285 | 290 | else: |
286 | 291 | indices = (bm.arange(self.coupling_var.size),) |
287 | | - f = vmap(lambda i: delay_var(self.delay_steps[:, i], *indices)) # (.., pre.num,) |
288 | | - delays = f(bm.arange(self.coupling_var.size).value) # (..., post.num, pre.num) |
| 292 | + f = vmap(lambda steps: delay_var(steps, *indices), in_axes=1) # (.., pre.num,) |
| 293 | + delays = f(self.delay_steps) # (..., post.num, pre.num) |
289 | 294 | additive = (self.conn_mat * bm.moveaxis(delays, axis - 1, axis)).sum(axis=axis - 1) |
290 | 295 | elif self.delay_type == 'int': |
291 | 296 | delayed_var = delay_var(self.delay_steps) # (..., pre.num) |
|
0 commit comments