|
1 | 1 | # -*- coding: utf-8 -*- |
2 | 2 |
|
3 | | -from typing import Union, Callable, Optional, Tuple, Sequence, Dict |
| 3 | +from typing import Union, Callable, Optional, Dict |
4 | 4 |
|
5 | 5 | import jax |
6 | | -import jax.numpy as jnp |
7 | | -import numpy as np |
8 | | -from jax.lax import stop_gradient |
9 | 6 |
|
10 | | -from brainpy import check, math as bm |
11 | | -from brainpy._src.math.object_transform.base import Collector |
| 7 | +from brainpy import math as bm |
12 | 8 | from brainpy._src.dyn.base import DynamicalSystem, not_pass_shargs |
13 | | -from brainpy.check import is_integer, jit_error_checking |
| 9 | +from brainpy._src.math.delayvars import DelayVariable, ROTATE_UPDATE, CONCAT_UPDATE |
14 | 10 |
|
15 | | -ROTATE_UPDATE = 'rotation' |
16 | | -CONCAT_UPDATE = 'concat' |
17 | 11 |
|
| 12 | +class Delay(DynamicalSystem, DelayVariable): |
| 13 | + """Delay for dynamical systems which has a fixed delay length. |
18 | 14 |
|
19 | | -class Delay(DynamicalSystem): |
20 | | - """Delay variable which has a fixed delay length. |
21 | | -
|
22 | | - The data in this delay variable is arranged as:: |
23 | | -
|
24 | | - delay = 0 [ data |
25 | | - delay = 1 data |
26 | | - delay = 2 data |
27 | | - ... .... |
28 | | - ... .... |
29 | | - delay = length-1 data |
30 | | - delay = length data ] |
31 | | -
|
32 | | - Parameters |
33 | | - ---------- |
34 | | - target: Variable |
35 | | - The initial delay data. |
36 | | - length: int |
37 | | - The delay data length. |
38 | | - initial_delay_data: Any |
39 | | - The delay data. It can be a Python number, like float, int, boolean values. |
40 | | - It can also be arrays. Or a callable function or instance of ``Connector``. |
41 | | - Note that ``initial_delay_data`` should be arranged as the following way:: |
42 | | -
|
43 | | - delay = 1 [ data |
44 | | - delay = 2 data |
45 | | - ... .... |
46 | | - ... .... |
47 | | - delay = length-1 data |
48 | | - delay = length data ] |
49 | | - method: str |
50 | | - The method used for updating delay. |
51 | | -
|
| 15 | + Detailed docstring please see :py:class:`~.DelayVariable`. |
52 | 16 | """ |
53 | 17 |
|
54 | | - data: Optional[bm.Variable] |
55 | | - idx: Optional[bm.Variable] |
56 | | - length: int |
57 | | - |
58 | 18 | def __init__( |
59 | 19 | self, |
60 | 20 | target: bm.Variable, |
61 | 21 | length: int = 0, |
62 | | - initial_delay_data: Union[float, int, bool, bm.Array, jax.Array, Callable] = None, |
| 22 | + before_t0: Union[float, int, bool, bm.Array, jax.Array, Callable] = None, |
63 | 23 | entries: Optional[Dict] = None, |
| 24 | + method: str = ROTATE_UPDATE, |
64 | 25 | mode: bm.Mode = None, |
65 | 26 | name: str = None, |
66 | | - method: str = None, |
67 | 27 | ): |
68 | | - super().__init__(mode=mode, name=name) |
69 | | - |
70 | | - # delay updating method |
| 28 | + DynamicalSystem.__init__(self, mode=mode) |
71 | 29 | if method is None: |
72 | 30 | if self.mode.is_a(bm.NonBatchingMode): |
73 | 31 | method = ROTATE_UPDATE |
74 | | - else: |
| 32 | + elif self.mode.is_parent_of(bm.TrainingMode): |
75 | 33 | method = CONCAT_UPDATE |
76 | | - assert method in [ROTATE_UPDATE, CONCAT_UPDATE] |
77 | | - self.method = method |
78 | | - |
79 | | - # target |
80 | | - self.target = target |
81 | | - if not isinstance(target, bm.Variable): |
82 | | - raise ValueError(f'Must be an instance of brainpy.math.Variable. But we got {type(target)}') |
83 | | - |
84 | | - # delay length |
85 | | - self.length = is_integer(length, allow_none=False, min_bound=0) |
86 | | - |
87 | | - # delay data |
88 | | - if initial_delay_data is not None: |
89 | | - assert isinstance(initial_delay_data, (int, float, bool, bm.Array, jax.Array, Callable)) |
90 | | - self._initial_delay_data = initial_delay_data |
91 | | - if length > 0: |
92 | | - self._init_data(length) |
93 | | - else: |
94 | | - self.data = None |
95 | | - |
96 | | - # time variables |
97 | | - if self.method == ROTATE_UPDATE: |
98 | | - self.idx = bm.Variable(stop_gradient(jnp.asarray(0, dtype=jnp.int32))) |
99 | | - |
100 | | - # other info |
101 | | - self._access_to_step = dict() |
102 | | - for entry, value in entries.items(): |
103 | | - self.register_entry(entry, value) |
104 | | - |
105 | | - def register_entry( |
106 | | - self, |
107 | | - entry: str, |
108 | | - delay_time: Optional[Union[float, bm.Array, Callable]] = None, |
109 | | - delay_step: Optional[Union[int, bm.Array, Callable]] = None, |
110 | | - ) -> 'Delay': |
111 | | - """Register an entry to access the data. |
112 | | -
|
113 | | - Args: |
114 | | - entry (str): The entry to access the delay data. |
115 | | - delay_step: The delay step of the entry (must be an integer, denoting the delay step). |
116 | | - delay_time: The delay time of the entry (can be a float). |
117 | | -
|
118 | | - Returns: |
119 | | - Return the self. |
120 | | - """ |
121 | | - if entry in self._access_to_step: |
122 | | - raise KeyError(f'Entry {entry} has been registered.') |
123 | | - |
124 | | - if delay_time is not None: |
125 | | - if delay_step is not None: |
126 | | - raise ValueError('Provide either "delay_time" or "delay_step". Both you have given both.') |
127 | | - if callable(delay_time): |
128 | | - delay_time = bm.as_jax(delay_time(self.delay_target_shape)) |
129 | | - delay_step = jnp.asarray(delay_time / bm.get_dt(), dtype=bm.get_int()) |
130 | | - elif isinstance(delay_time, float): |
131 | | - delay_step = int(delay_time / bm.get_dt()) |
132 | 34 | else: |
133 | | - delay_step = jnp.asarray(bm.as_jax(delay_time) / bm.get_dt(), dtype=bm.get_int()) |
134 | | - |
135 | | - # delay steps |
136 | | - if delay_step is None: |
137 | | - delay_type = 'none' |
138 | | - elif isinstance(delay_step, int): |
139 | | - delay_type = 'homo' |
140 | | - elif isinstance(delay_step, (bm.Array, jax.Array, np.ndarray)): |
141 | | - if delay_step.size == 1 and delay_step.ndim == 0: |
142 | | - delay_type = 'homo' |
143 | | - else: |
144 | | - delay_type = 'heter' |
145 | | - delay_step = bm.Array(delay_step) |
146 | | - elif callable(delay_step): |
147 | | - delay_step = delay_step(self.delay_target_shape) |
148 | | - delay_type = 'heter' |
149 | | - else: |
150 | | - raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support ' |
151 | | - f'integer, array of integers, callable function, brainpy.init.Initializer.') |
152 | | - if delay_type == 'heter': |
153 | | - if delay_step.dtype not in [jnp.int32, jnp.int64]: |
154 | | - raise ValueError('Only support delay steps of int32, int64. If your ' |
155 | | - 'provide delay time length, please divide the "dt" ' |
156 | | - 'then provide us the number of delay steps.') |
157 | | - if self.delay_target_shape[0] != delay_step.shape[0]: |
158 | | - raise ValueError(f'Shape is mismatched: {self.delay_target_shape[0]} != {delay_step.shape[0]}') |
159 | | - if delay_type == 'heter': |
160 | | - max_delay_step = int(max(delay_step)) |
161 | | - elif delay_type == 'homo': |
162 | | - max_delay_step = delay_step |
163 | | - else: |
164 | | - max_delay_step = None |
165 | | - |
166 | | - # delay variable |
167 | | - if max_delay_step is not None: |
168 | | - if self.length < max_delay_step: |
169 | | - self._init_data(max_delay_step) |
170 | | - self.length = max_delay_step |
171 | | - self._access_to_step[entry] = delay_step |
172 | | - return self |
173 | | - |
174 | | - def at_entry(self, entry: str, *indices) -> bm.Array: |
175 | | - """Get the data at the given entry. |
176 | | -
|
177 | | - Args: |
178 | | - entry (str): The entry to access the data. |
179 | | - *indices: |
180 | | -
|
181 | | - Returns: |
182 | | - The data. |
183 | | - """ |
184 | | - assert isinstance(entry, str) |
185 | | - if entry not in self._access_to_step: |
186 | | - raise KeyError(f'Does not find delay entry "{entry}".') |
187 | | - delay_step = self._access_to_step[entry] |
188 | | - if delay_step is None: |
189 | | - return self.target.value |
190 | | - else: |
191 | | - if self.data is None: |
192 | | - return self.target.value |
193 | | - else: |
194 | | - if isinstance(delay_step, slice): |
195 | | - return self.retrieve(delay_step, *indices) |
196 | | - elif np.ndim(delay_step) == 0: |
197 | | - return self.retrieve(delay_step, *indices) |
198 | | - else: |
199 | | - if len(indices) == 0 and len(delay_step) == self.target.shape[0]: |
200 | | - indices = (jnp.arange(delay_step.size),) |
201 | | - return self.retrieve(delay_step, *indices) |
202 | | - |
203 | | - @property |
204 | | - def delay_target_shape(self): |
205 | | - """The data shape of the delay target.""" |
206 | | - return self.target.shape |
207 | | - |
208 | | - def __repr__(self): |
209 | | - name = self.__class__.__name__ |
210 | | - return (f'{name}(num_delay_step={self.length}, ' |
211 | | - f'delay_target_shape={self.delay_target_shape}, ' |
212 | | - f'update_method={self.method})') |
213 | | - |
214 | | - def _check_delay(self, delay_len): |
215 | | - raise ValueError(f'The request delay length should be less than the ' |
216 | | - f'maximum delay {self.length}. ' |
217 | | - f'But we got {delay_len}') |
218 | | - |
219 | | - def retrieve(self, delay_step, *indices): |
220 | | - """Retrieve the delay data according to the delay length. |
221 | | -
|
222 | | - Parameters |
223 | | - ---------- |
224 | | - delay_step: int, ArrayType |
225 | | - The delay length used to retrieve the data. |
226 | | - """ |
227 | | - assert delay_step is not None |
228 | | - if check.is_checking(): |
229 | | - jit_error_checking(jnp.any(delay_step > self.length), self._check_delay, delay_step) |
230 | | - |
231 | | - if self.method == ROTATE_UPDATE: |
232 | | - delay_idx = (self.idx.value + delay_step) % (self.length + 1) |
233 | | - delay_idx = stop_gradient(delay_idx) |
234 | | - |
235 | | - elif self.method == CONCAT_UPDATE: |
236 | | - delay_idx = delay_step |
237 | | - |
238 | | - else: |
239 | | - raise ValueError(f'Unknown updating method "{self.method}"') |
240 | | - |
241 | | - # the delay index |
242 | | - if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): |
243 | | - raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}') |
244 | | - indices = (delay_idx,) + tuple(indices) |
245 | | - |
246 | | - # the delay data |
247 | | - return self.data[indices] |
| 35 | + method = ROTATE_UPDATE |
| 36 | + DelayVariable.__init__(self, |
| 37 | + target=target, |
| 38 | + length=length, |
| 39 | + before_t0=before_t0, |
| 40 | + entries=entries, |
| 41 | + method=method, |
| 42 | + name=name) |
248 | 43 |
|
249 | 44 | @not_pass_shargs |
250 | | - def update(self, latest_value: Optional[Union[bm.Array, jax.Array]] = None) -> None: |
251 | | - """Update delay variable with the new data. |
252 | | - """ |
253 | | - if self.data is not None: |
254 | | - # get the latest target value |
255 | | - if latest_value is None: |
256 | | - latest_value = self.target.value |
257 | | - |
258 | | - # update the delay data at the rotation index |
259 | | - if self.method == ROTATE_UPDATE: |
260 | | - self.idx.value = stop_gradient(bm.as_jax((self.idx - 1) % (self.length + 1))) |
261 | | - self.data[self.idx.value] = latest_value |
262 | | - |
263 | | - # update the delay data at the first position |
264 | | - elif self.method == CONCAT_UPDATE: |
265 | | - if self.length >= 2: |
266 | | - self.data.value = bm.vstack([latest_value, self.data[1:]]) |
267 | | - else: |
268 | | - self.data[0] = latest_value |
269 | | - |
270 | | - def reset_state(self, batch_size: int = None): |
271 | | - """Reset the delay data. |
272 | | - """ |
273 | | - # initialize delay data |
274 | | - if self.data is not None: |
275 | | - self._init_data(self.length, batch_size) |
276 | | - |
277 | | - # time variables |
278 | | - if self.method == ROTATE_UPDATE: |
279 | | - self.idx.value = stop_gradient(jnp.asarray(0, dtype=jnp.int32)) |
280 | | - |
281 | | - def _init_data(self, length, batch_size: int = None): |
282 | | - if batch_size is not None: |
283 | | - if self.target.batch_size != batch_size: |
284 | | - raise ValueError(f'The batch sizes of delay variable and target variable differ ' |
285 | | - f'({self.target.batch_size} != {batch_size}). ' |
286 | | - 'Please reset the target variable first, because delay data ' |
287 | | - 'depends on the target variable. ') |
| 45 | + def update(self, *args, **kwargs): |
| 46 | + return DelayVariable.update(self, *args, **kwargs) |
288 | 47 |
|
289 | | - if self.target.batch_axis is None: |
290 | | - batch_axis = None |
291 | | - else: |
292 | | - batch_axis = self.target.batch_axis + 1 |
293 | | - self.data = bm.Variable(jnp.zeros((length + 1,) + self.target.shape, dtype=self.target.dtype), |
294 | | - batch_axis=batch_axis) |
295 | | - # update delay data |
296 | | - self.data[0] = self.target.value |
297 | | - if isinstance(self._initial_delay_data, (bm.Array, jax.Array, float, int, bool)): |
298 | | - self.data[1:] = self._initial_delay_data |
299 | | - elif callable(self._initial_delay_data): |
300 | | - self.data[1:] = self._initial_delay_data((length,) + self.target.shape, dtype=self.target.dtype) |
0 commit comments