Skip to content

Commit 4b967ea

Browse files
committed
move variables into brainpy.math.object_transform
1 parent 31a97b9 commit 4b967ea

File tree

29 files changed

+799
-693
lines changed

29 files changed

+799
-693
lines changed

brainpy/__init__.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,9 @@
132132

133133
# deprecated
134134
from brainpy._src.math.object_transform.base import (Base as Base,
135-
DynVarCollector,
135+
ArrayCollector,
136136
Collector as Collector, )
137-
globals()['ArrayCollector'] = DynVarCollector
138-
globals()['TensorCollector'] = DynVarCollector
137+
globals()['TensorCollector'] = ArrayCollector
139138

140139
train.__dict__['DSTrainer'] = DSTrainer
141140
train.__dict__['BPTT'] = BPTT
@@ -150,8 +149,8 @@
150149
base.base.__dict__['BrainPyObject'] = BrainPyObject
151150
base.base.__dict__['Base'] = Base
152151
base.collector.__dict__['Collector'] = Collector
153-
base.collector.__dict__['ArrayCollector'] = DynVarCollector
154-
base.collector.__dict__['TensorCollector'] = DynVarCollector
152+
base.collector.__dict__['ArrayCollector'] = ArrayCollector
153+
base.collector.__dict__['TensorCollector'] = ArrayCollector
155154
base.function.__dict__['FunAsObject'] = math.FunAsObject
156155
base.function.__dict__['Function'] = math.FunAsObject
157156
base.io.__dict__['save_as_h5'] = checkpoints.io.save_as_h5
@@ -166,8 +165,8 @@
166165
base.__dict__['BrainPyObject'] = BrainPyObject
167166
base.__dict__['Base'] = Base
168167
base.__dict__['Collector'] = Collector
169-
base.__dict__['ArrayCollector'] = DynVarCollector
170-
base.__dict__['TensorCollector'] = DynVarCollector
168+
base.__dict__['ArrayCollector'] = ArrayCollector
169+
base.__dict__['TensorCollector'] = ArrayCollector
171170
base.__dict__['FunAsObject'] = math.FunAsObject
172171
base.__dict__['Function'] = math.FunAsObject
173172
base.__dict__['save_as_h5'] = checkpoints.io.save_as_h5

brainpy/_src/analysis/highdim/slow_points.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,11 @@ def __init__(
132132

133133
# update function
134134
if target_vars is None:
135-
self.target_vars = bm.DynVarCollector()
135+
self.target_vars = bm.ArrayCollector()
136136
else:
137137
if not isinstance(target_vars, dict):
138138
raise TypeError(f'"target_vars" must be a dict but we got {type(target_vars)}')
139-
self.target_vars = bm.DynVarCollector(target_vars)
139+
self.target_vars = bm.ArrayCollector(target_vars)
140140
excluded_vars = () if excluded_vars is None else excluded_vars
141141
if isinstance(excluded_vars, dict):
142142
excluded_vars = tuple(excluded_vars.values())

brainpy/_src/analysis/utils/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22

33

4-
from brainpy._src.math.ndarray import Variable
4+
from brainpy._src.math.object_transform import Variable
55
from brainpy._src.math.environment import get_float
66
from brainpy._src.math.interoperability import as_jax
77
from brainpy._src.dyn.base import DynamicalSystem

brainpy/_src/checkpoints/io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from brainpy import errors
1010
import brainpy.math as bm
11-
from brainpy._src.math.object_transform.base import BrainPyObject, DynVarCollector
11+
from brainpy._src.math.object_transform.base import BrainPyObject, ArrayCollector
1212

1313

1414
logger = logging.getLogger('brainpy.brainpy_object.io')
@@ -120,7 +120,7 @@ def _load(
120120

121121

122122
def _unique_and_duplicate(collector: dict):
123-
gather = DynVarCollector()
123+
gather = ArrayCollector()
124124
id2name = dict()
125125
duplicates = ([], [])
126126
for k, v in collector.items():

brainpy/_src/dyn/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from brainpy._src.connect import TwoEndConnector, MatConn, IJConn, One2One, All2All
1515
from brainpy._src.initialize import Initializer, parameter, variable, Uniform, noise as init_noise
1616
from brainpy._src.integrators import odeint, sdeint
17-
from brainpy._src.math.ndarray import Variable, VariableView
17+
from brainpy._src.math.object_transform.variables import Variable, VariableView
1818
from brainpy._src.math.object_transform.base import BrainPyObject, Collector
1919
from brainpy.errors import NoImplementationError, UnsupportedError
2020
from brainpy.types import ArrayType, Shape

brainpy/_src/integrators/fde/Caputo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,16 +329,16 @@ def __init__(
329329

330330
# initial values
331331
inits = check_inits(inits, self.variables)
332-
self.inits = bm.VarDict({v: bm.asarray(inits[v]) for v in self.variables})
332+
self.inits = bm.VarDict({v: bm.Variable(inits[v]) for v in self.variables})
333333

334334
# coefficients
335335
ranges = bm.asarray([bm.arange(1, num_memory + 2) for _ in self.variables]).T
336336
coef = bm.diff(bm.power(ranges, 1 - self.alpha), axis=0)
337337
self.coef = bm.flip(coef, axis=0)
338338

339339
# used to save the difference of two adjacent states
340-
self.diff_states = bm.VarDict({v + "_diff": bm.zeros((num_memory,) + self.inits[v].shape,
341-
dtype=self.inits[v].dtype)
340+
self.diff_states = bm.VarDict({v + "_diff": bm.Variable(bm.zeros((num_memory,) + self.inits[v].shape,
341+
dtype=self.inits[v].dtype))
342342
for v in self.variables})
343343
self.idx = bm.Variable(bm.asarray([self.num_memory - 1]))
344344

brainpy/_src/math/delayvars.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
from brainpy.errors import UnsupportedError
1414
from .interoperability import as_jax
1515
from .compat_numpy import vstack, broadcast_to
16-
from .environment import get_dt, get_float, get_int
17-
from .ndarray import ndarray, Variable, Array
16+
from .environment import get_dt, get_float
17+
from .ndarray import ndarray, Array
1818
from .object_transform.base import BrainPyObject
19+
from .object_transform.variables import Variable
1920

2021
__all__ = [
2122
'AbstractDelay',

brainpy/_src/math/interoperability.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import jax.numpy as jnp
44
import numpy as np
55

6-
from .ndarray import Array, Variable
6+
from .ndarray import Array
7+
78

89
__all__ = [
910
'as_device_array', 'as_jax', 'as_ndarray', 'as_numpy', 'as_variable',
@@ -90,4 +91,5 @@ def as_variable(tensor, dtype=None):
9091
Array interpretation of `tensor`. No copy is performed if the input
9192
is already an ndarray with matching dtype.
9293
"""
94+
from .object_transform.variables import Variable
9395
return Variable(tensor, dtype=dtype)

0 commit comments

Comments
 (0)