Skip to content

Commit 78cbc1d

Browse files
committed
fix bugs
1 parent 16c9780 commit 78cbc1d

File tree

6 files changed

+116
-59
lines changed

6 files changed

+116
-59
lines changed

brainpy/dyn/runners.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def __init__(
346346
self.i0 = bm.Variable(bm.asarray([1], dtype=bm.int_))
347347
self.t0 = bm.Variable(bm.asarray([t0], dtype=bm.float_))
348348
if data_first_axis is None:
349-
data_first_axis = 'B' if isinstance(self.target, bm.BatchingMode) else 'T'
349+
data_first_axis = 'B' if isinstance(self.target.mode, bm.BatchingMode) else 'T'
350350
assert data_first_axis in ['B', 'T']
351351
self.data_first_axis = data_first_axis
352352

@@ -380,8 +380,8 @@ def __repr__(self):
380380

381381
def reset_state(self):
382382
"""Reset state of the ``DSRunner``."""
383-
self.i0[0] = 0
384-
self.t0[0] = self._t0
383+
self.i0.value = bm.zeros_like(self.i0)
384+
self.t0.value = bm.ones_like(self.t0) * self._t0
385385

386386
def predict(
387387
self,
@@ -635,8 +635,10 @@ def _get_f_predict(self, shared_args: Dict = None):
635635

636636
shared_kwargs_str = serialize_kwargs(shared_args)
637637
if shared_kwargs_str not in self._f_predict_compiled:
638-
dyn_vars = self.vars().unique()
639-
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
638+
dyn_vars = self.target.vars()
639+
dyn_vars.update(self._dyn_vars)
640+
dyn_vars.update(self.vars(level=0))
641+
dyn_vars = dyn_vars.unique()
640642

641643
def run_func(all_inputs):
642644
with jax.disable_jit(not self.jit['predict']):

brainpy/integrators/runner.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,14 @@ def __init__(
168168
self.variables[k][:] = inits[k]
169169

170170
# format string monitors
171-
monitors = self._format_seq_monitors(monitors)
172-
monitors = {k: (self.variables[k], i) for k, i in monitors}
171+
if isinstance(monitors, (tuple, list)):
172+
monitors = self._format_seq_monitors(monitors)
173+
monitors = {k: (self.variables[k], i) for k, i in monitors}
174+
elif isinstance(monitors, dict):
175+
monitors = self._format_dict_monitors(monitors)
176+
monitors = {k: ((self.variables[i], i) if isinstance(i, str) else i) for k, i in monitors.items()}
177+
else:
178+
raise ValueError
173179

174180
# initialize super class
175181
super(IntegratorRunner, self).__init__(target=target,
@@ -218,12 +224,6 @@ def __init__(
218224
else:
219225
self._dyn_args = dict()
220226

221-
# monitors
222-
for k in self.mon.var_names:
223-
if k not in self.target.variables:
224-
raise MonitorError(f'Variable "{k}" to monitor is not defined '
225-
f'in the integrator {self.target}.')
226-
227227
# start simulation time and index
228228
self.start_t = bm.Variable(bm.zeros(1))
229229
self.idx = bm.Variable(bm.zeros(1, dtype=bm.int_))

brainpy/math/environment.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
import os
77
import re
88
import sys
9-
import warnings
109
from typing import Any, Callable, TypeVar, cast
1110

1211
from jax import config, numpy as jnp, devices
1312
from jax.lib import xla_bridge
1413

14+
from brainpy import errors
1515
from . import modes
1616

1717
bm = None
@@ -63,8 +63,8 @@ def ditype():
6363
.. deprecated:: 2.3.1
6464
Use `brainpy.math.int_` instead.
6565
"""
66-
warnings.warn('\nGet default integer data type through `ditype()` has been deprecated. \n'
67-
'Use `brainpy.math.int_` instead.')
66+
# raise errors.NoLongerSupportError('\nGet default integer data type through `ditype()` has been deprecated. \n'
67+
# 'Use `brainpy.math.int_` instead.')
6868
global bm
6969
if bm is None: from brainpy import math as bm
7070
return bm.int_
@@ -77,8 +77,8 @@ def dftype():
7777
Use `brainpy.math.float_` instead.
7878
"""
7979

80-
warnings.warn('\nGet default floating data type through `dftype()` has been deprecated. \n'
81-
'Use `brainpy.math.float_` instead.')
80+
# raise errors.NoLongerSupportError('\nGet default floating data type through `dftype()` has been deprecated. \n'
81+
# 'Use `brainpy.math.float_` instead.')
8282
global bm
8383
if bm is None: from brainpy import math as bm
8484
return bm.float_

brainpy/math/object_transform/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
"""
1717

1818
from . import (
19+
base_object,
20+
base_transform,
21+
collector,
1922
autograd,
2023
controls,
2124
jit,

brainpy/running/runner.py

Lines changed: 77 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -99,36 +99,12 @@ def __init__(
9999
# format string monitors
100100
monitors = self._format_seq_monitors(monitors)
101101
# get monitor targets
102-
monitors = self._find_monitor_targets(monitors)
102+
monitors = self._find_seq_monitor_targets(monitors)
103103
elif isinstance(monitors, dict):
104-
_monitors = dict()
105-
for key, val in monitors.items():
106-
if not isinstance(key, str):
107-
raise MonitorError('Expect the key of the dict "monitors" must be a string. But got '
108-
f'{type(key)}: {key}')
109-
if isinstance(val, bm.Variable):
110-
val = (val, None)
111-
if isinstance(val, (tuple, list)):
112-
if not isinstance(val[0], bm.Variable):
113-
raise MonitorError('Expect the format of (variable, index) in the monitor setting. '
114-
f'But we got {val}')
115-
if len(val) == 1:
116-
_monitors[key] = (val[0], None)
117-
elif len(val) == 2:
118-
if isinstance(val[1], (int, np.integer)):
119-
idx = bm.array([val[1]])
120-
else:
121-
idx = None if val[1] is None else bm.asarray(val[1])
122-
_monitors[key] = (val[0], idx)
123-
else:
124-
raise MonitorError('Expect the format of (variable, index) in the monitor setting. '
125-
f'But we got {val}')
126-
elif callable(val):
127-
_monitors[key] = val
128-
else:
129-
raise MonitorError('The value of dict monitor expect a sequence with (variable, index) '
130-
f'or a callable function. But we got {val}')
131-
monitors = _monitors
104+
# format string monitors
105+
monitors = self._format_dict_monitors(monitors)
106+
# get monitor targets
107+
monitors = self._find_dict_monitor_targets(monitors)
132108
else:
133109
raise MonitorError(f'We only supports a format of list/tuple/dict of '
134110
f'"vars", while we got {type(monitors)}.')
@@ -160,7 +136,7 @@ def __init__(
160136

161137
def _format_seq_monitors(self, monitors):
162138
if not isinstance(monitors, (tuple, list)):
163-
raise TypeError(f'Must be a sequence, but we got {type(monitors)}')
139+
raise TypeError(f'Must be a tuple/list, but we got {type(monitors)}')
164140
_monitors = []
165141
for mon in monitors:
166142
if isinstance(mon, str):
@@ -183,7 +159,40 @@ def _format_seq_monitors(self, monitors):
183159
raise MonitorError(f'We do not support monitor with {type(mon)}: {mon}')
184160
return _monitors
185161

186-
def _find_monitor_targets(self, _monitors):
162+
def _format_dict_monitors(self, monitors):
163+
if not isinstance(monitors, dict):
164+
raise TypeError(f'Must be a dict, but we got {type(monitors)}')
165+
_monitors = dict()
166+
for key, val in monitors.items():
167+
if not isinstance(key, str):
168+
raise MonitorError('Expect the key of the dict "monitors" must be a string. But got '
169+
f'{type(key)}: {key}')
170+
if isinstance(val, (bm.Variable, str)):
171+
val = (val, None)
172+
173+
if isinstance(val, (tuple, list)):
174+
if not isinstance(val[0], (bm.Variable, str)):
175+
raise MonitorError('Expect the format of (variable, index) in the monitor setting. '
176+
f'But we got {val}')
177+
if len(val) == 1:
178+
_monitors[key] = (val[0], None)
179+
elif len(val) == 2:
180+
if isinstance(val[1], (int, np.integer)):
181+
idx = bm.array([val[1]])
182+
else:
183+
idx = None if val[1] is None else bm.asarray(val[1])
184+
_monitors[key] = (val[0], idx)
185+
else:
186+
raise MonitorError('Expect the format of (variable, index) in the monitor setting. '
187+
f'But we got {val}')
188+
elif callable(val):
189+
_monitors[key] = val
190+
else:
191+
raise MonitorError('The value of dict monitor expect a sequence with (variable, index) '
192+
f'or a callable function. But we got {val}')
193+
return _monitors
194+
195+
def _find_seq_monitor_targets(self, _monitors):
187196
if not isinstance(_monitors, (tuple, list)):
188197
raise TypeError(f'Must be a sequence, but we got {type(_monitors)}')
189198
# get monitor targets
@@ -214,6 +223,43 @@ def _find_monitor_targets(self, _monitors):
214223
monitors[key] = (getattr(master, splits[-1]), index)
215224
return monitors
216225

226+
def _find_dict_monitor_targets(self, _monitors):
227+
if not isinstance(_monitors, dict):
228+
raise TypeError(f'Must be a dict, but we got {type(_monitors)}')
229+
# get monitor targets
230+
monitors = {}
231+
name2node = None
232+
for _key, _mon in _monitors.items():
233+
if isinstance(_mon, str):
234+
if name2node is None:
235+
name2node = {node.name: node for node in list(self.target.nodes(level=-1).unique().values())}
236+
237+
key, index = _mon[0], _mon[1]
238+
splits = key.split('.')
239+
if len(splits) == 1:
240+
if not hasattr(self.target, splits[0]):
241+
raise RunningError(f'{self.target} does not has variable {key}.')
242+
monitors[key] = (getattr(self.target, splits[-1]), index)
243+
else:
244+
if not hasattr(self.target, splits[0]):
245+
if splits[0] not in name2node:
246+
raise MonitorError(f'Cannot find target {key} in monitor of {self.target}, please check.')
247+
else:
248+
master = name2node[splits[0]]
249+
assert len(splits) == 2
250+
monitors[key] = (getattr(master, splits[-1]), index)
251+
else:
252+
master = self.target
253+
for s in splits[:-1]:
254+
try:
255+
master = getattr(master, s)
256+
except KeyError:
257+
raise MonitorError(f'Cannot find {key} in {master}, please check.')
258+
monitors[key] = (getattr(master, splits[-1]), index)
259+
else:
260+
monitors[_key] = _mon
261+
return monitors
262+
217263
def __del__(self):
218264
if hasattr(self, 'mon'):
219265
for key in tuple(self.mon.keys()):

brainpy/train/back_propagation.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def __init__(
112112
lr = optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975)
113113
optimizer = optim.Adam(lr=lr)
114114
self.optimizer: optim.Optimizer = optimizer
115-
self.optimizer.register_vars(self.target.vars(level=-1, include_self=True).subset(bm.TrainVar).unique())
115+
if len(self.optimizer.vars_to_train) == 0:
116+
self.optimizer.register_vars(self.target.vars(level=-1, include_self=True).subset(bm.TrainVar).unique())
116117

117118
# loss function
118119
self.loss_has_aux = loss_has_aux
@@ -146,7 +147,7 @@ def __repr__(self):
146147
f'{prefix}loss={self._loss_func}, \n\t'
147148
f'{prefix}optimizer={self.optimizer})')
148149

149-
def get_hist_metric(self, phase='fit', metric='loss', which='detailed'):
150+
def get_hist_metric(self, phase='fit', metric='loss', which='report'):
150151
"""Get history losses."""
151152
assert phase in [c.FIT_PHASE, c.TEST_PHASE, c.TRAIN_PHASE, c.PREDICT_PHASE]
152153
assert which in ['report', 'detailed']
@@ -332,7 +333,7 @@ def fit(
332333
self.target.reset_state(self._get_input_batch_size(x))
333334
self.reset_state()
334335

335-
# training
336+
# testing
336337
res = self._get_f_loss(shared_args)(x, y)
337338

338339
# loss
@@ -406,7 +407,7 @@ def _get_f_loss(self, shared_args=None, jit=True) -> Callable:
406407
if self.jit[c.LOSS_PHASE] and jit:
407408
dyn_vars = self.target.vars()
408409
dyn_vars.update(self._dyn_vars)
409-
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
410+
dyn_vars.update(self.vars(level=0))
410411
self._f_loss_compiled[shared_args_str] = bm.jit(self._f_loss_compiled[shared_args_str],
411412
dyn_vars=dyn_vars.unique())
412413
return self._f_loss_compiled[shared_args_str]
@@ -437,18 +438,23 @@ def _get_f_train(self, shared_args=None) -> Callable:
437438

438439
shared_args_str = serialize_kwargs(shared_args)
439440
if shared_args_str not in self._f_fit_compiled:
440-
self._f_fit_compiled[shared_args_str] = partial(self._step_func_train, shared_args)
441+
self._f_fit_compiled[shared_args_str] = partial(self._step_func_fit, shared_args)
441442
if self.jit[c.FIT_PHASE]:
442-
dyn_vars = self.vars().unique()
443-
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
443+
dyn_vars = self.target.vars()
444+
dyn_vars.update(self.optimizer.vars())
445+
if isinstance(self._loss_func, bm.BrainPyObject):
446+
dyn_vars.update(self._loss_func)
447+
dyn_vars.update(self._dyn_vars)
448+
dyn_vars.update(self.vars(level=0))
449+
dyn_vars = dyn_vars.unique()
444450
self._f_fit_compiled[shared_args_str] = bm.jit(self._f_fit_compiled[shared_args_str],
445451
dyn_vars=dyn_vars)
446452
return self._f_fit_compiled[shared_args_str]
447453

448454
def _step_func_loss(self, shared_args, inputs, targets):
449455
raise NotImplementedError
450456

451-
def _step_func_train(self, shared_args, inputs, targets):
457+
def _step_func_fit(self, shared_args, inputs, targets):
452458
raise NotImplementedError
453459

454460

@@ -508,7 +514,7 @@ def _step_func_loss(self, shared_args, inputs, targets):
508514
predicts = (outs, mons) if len(mons) > 0 else outs
509515
return self._loss_func(predicts, targets)
510516

511-
def _step_func_train(self, shared_args, inputs, targets):
517+
def _step_func_fit(self, shared_args, inputs, targets):
512518
res = self._get_f_grad(shared_args)(inputs, targets)
513519
self.optimizer.update(res[0])
514520
return res[1:]
@@ -529,7 +535,7 @@ def _step_func_loss(self, shared_args, inputs, targets):
529535
loss = self._loss_func(outs, targets)
530536
return loss
531537

532-
def _step_func_train(self, shared_args, inputs, targets):
538+
def _step_func_fit(self, shared_args, inputs, targets):
533539
res = self._get_f_grad(shared_args)(inputs, targets)
534540
self.optimizer.update(res[0])
535541
return res[1:]

0 commit comments

Comments
 (0)