Skip to content

Commit d495004

Browse files
authored
consolidate the concept of OO transformation (#309)
consolidate the concept of OO transformation
2 parents ffcbf54 + 4008f07 commit d495004

File tree

162 files changed

+5155
-8387
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

162 files changed

+5155
-8387
lines changed

.github/ISSUE_TEMPLATE/bug_report.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Please:
1212
```python
1313
import brainpy.math as bm
1414
print(bm.asarray([1, 2, 3]))
15-
# JaxArray([1, 2, 3], dtype=int32)
15+
# Array([1, 2, 3], dtype=int32)
1616
```
1717

1818
- [ ] If applicable, include full error messages/tracebacks.

.github/workflows/Linux_CI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ jobs:
2929
python -m pip install --upgrade pip
3030
python -m pip install flake8 pytest
3131
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
32+
pip uninstall brainpy -y
3233
python setup.py install
3334
- name: Lint with flake8
3435
run: |

.github/workflows/MacOS_CI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ jobs:
2929
python -m pip install --upgrade pip
3030
python -m pip install flake8 pytest
3131
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
32+
pip uninstall brainpy -y
3233
python setup.py install
3334
- name: Lint with flake8
3435
run: |

.github/workflows/Windows_CI.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@ jobs:
2929
python -m pip install --upgrade pip
3030
python -m pip install flake8 pytest
3131
python -m pip install numpy>=1.21.0
32-
python -m pip install "jaxlib==0.3.14" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
33-
python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.14.tar.gz
32+
python -m pip install "jaxlib==0.3.25" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
33+
python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.25.tar.gz
3434
python -m pip install -r requirements-dev.txt
3535
python -m pip install tqdm brainpylib
36+
pip uninstall brainpy -y
3637
python setup.py install
3738
- name: Lint with flake8
3839
run: |

LICENSE

Lines changed: 674 additions & 201 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
<p align="center">
2-
<img alt="Header image of BrainPy - brain dynamics programming in Python." src="https://github.com/PKU-NIP-Lab/BrainPy/blob/master/images/logo.png" width=80%>
2+
<img alt="Header image of BrainPy - brain dynamics programming in Python." src="https://github.com/brainpy/BrainPy/blob/master/images/logo.png" width=80%>
33
</p>
44

55

66

77
<p align="center">
88
<a href="https://pypi.org/project/brainpy/"><img alt="Supported Python Version" src="https://img.shields.io/pypi/pyversions/brainpy"></a>
9-
<a href="https://github.com/PKU-NIP-Lab/BrainPy"><img alt="LICENSE" src="https://anaconda.org/brainpy/brainpy/badges/license.svg"></a>
9+
<a href="https://github.com/brainpy/BrainPy"><img alt="LICENSE" src="https://anaconda.org/brainpy/brainpy/badges/license.svg"></a>
1010
<a href="https://brainpy.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation" src="https://readthedocs.org/projects/brainpy/badge/?version=latest"></a>
1111
<a href="https://badge.fury.io/py/brainpy"><img alt="PyPI version" src="https://badge.fury.io/py/brainpy.svg"></a>
12-
<a href="https://github.com/PKU-NIP-Lab/BrainPy"><img alt="Linux CI" src="https://github.com/PKU-NIP-Lab/BrainPy/actions/workflows/Linux_CI.yml/badge.svg"></a>
13-
<a href="https://github.com/PKU-NIP-Lab/BrainPy"><img alt="Windows CI" src="https://github.com/PKU-NIP-Lab/BrainPy/actions/workflows/Windows_CI.yml/badge.svg"></a>
14-
<a href="https://github.com/PKU-NIP-Lab/BrainPy"><img alt="MacOS CI" src="https://github.com/PKU-NIP-Lab/BrainPy/actions/workflows/MacOS_CI.yml/badge.svg"></a>
12+
<a href="https://github.com/brainpy/BrainPy"><img alt="Linux CI" src="https://github.com/brainpy/BrainPy/actions/workflows/Linux_CI.yml/badge.svg"></a>
13+
<a href="https://github.com/brainpy/BrainPy"><img alt="Windows CI" src="https://github.com/brainpy/BrainPy/actions/workflows/Windows_CI.yml/badge.svg"></a>
14+
<a href="https://github.com/brainpy/BrainPy"><img alt="MacOS CI" src="https://github.com/brainpy/BrainPy/actions/workflows/MacOS_CI.yml/badge.svg"></a>
1515
</p>
1616

1717

@@ -20,17 +20,17 @@
2020
BrainPy is a flexible, efficient, and extensible framework for computational neuroscience and brain-inspired computation based on the Just-In-Time (JIT) compilation (built on top of [JAX](https://github.com/google/jax), [Numba](https://github.com/numba/numba), and other JIT compilers). It provides an integrative ecosystem for brain dynamics programming, including brain dynamics **building**, **simulation**, **training**, **analysis**, etc.
2121

2222
- **Website (documentation and APIs)**: https://brainpy.readthedocs.io/en/latest
23-
- **Source**: https://github.com/PKU-NIP-Lab/BrainPy
24-
- **Bug reports**: https://github.com/PKU-NIP-Lab/BrainPy/issues
23+
- **Source**: https://github.com/brainpy/BrainPy
24+
- **Bug reports**: https://github.com/brainpy/BrainPy/issues
2525
- **Source on OpenI**: https://git.openi.org.cn/OpenI/BrainPy
2626

2727

2828

2929
## Ecosystem
3030

31-
- **[BrainPy](https://github.com/PKU-NIP-Lab/BrainPy)**: The solution for the general-purpose brain dynamics programming.
32-
- **[brainpylib](https://github.com/PKU-NIP-Lab/brainpylib)**: Efficient operators for the sparse and event-driven computation.
33-
- **[BrainPyExamples](https://github.com/PKU-NIP-Lab/BrainPyExamples)**: Comprehensive examples of BrainPy computation.
31+
- **[BrainPy](https://github.com/brainpy/BrainPy)**: The solution for the general-purpose brain dynamics programming.
32+
- **[brainpylib](https://github.com/brainpy/brainpylib)**: Efficient operators for the sparse and event-driven computation.
33+
- **[BrainPyExamples](https://github.com/brainpy/BrainPyExamples)**: Comprehensive examples of BrainPy computation.
3434
- **[brainpy-largescale](https://github.com/NH-NCL/brainpy-largescale)**: One solution for the large-scale brain modeling.
3535

3636

@@ -47,12 +47,7 @@ For detailed installation instructions, please refer to the documentation: [Quic
4747

4848

4949

50-
## License
51-
52-
[Apache License, Version 2.0](https://github.com/PKU-NIP-Lab/BrainPy/blob/master/LICENSE)
53-
54-
5550

5651
## Citing
5752

58-
If you are using BrainPy, please consider citing [the corresponding papers](https://brainpy.readthedocs.io/en/latest/tutorial_FAQs/citing_and_publication.html).
53+
If you are using ``brainpy``, please consider citing [the corresponding papers](https://brainpy.readthedocs.io/en/latest/tutorial_FAQs/citing_and_publication.html).

brainpy/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# -*- coding: utf-8 -*-
22

3-
__version__ = "2.2.4.0"
3+
__version__ = "2.3.0"
44

55

66
# fundamental modules
77
from . import errors, tools, check, modes
88

99
# "base" module
1010
from . import base
11-
from .base.base import Base
11+
from .base.base import BrainPyObject, Base
1212
from .base.collector import Collector, TensorCollector
1313

1414
# math foundation
@@ -72,7 +72,6 @@
7272
OfflineTrainer, RidgeTrainer,
7373
BPFF,
7474
BPTT,
75-
OnlineBPTT,
7675
)
7776

7877
# automatic dynamics analysis

brainpy/algorithms/offline.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from jax.lax import while_loop
77

88
import brainpy.math as bm
9-
from brainpy.base import Base
9+
from brainpy.base import BrainPyObject
1010
from brainpy.types import Array
1111
from .utils import (Sigmoid,
1212
Regularization, L1Regularization, L1L2Regularization, L2Regularization,
@@ -33,7 +33,7 @@
3333
name2func = dict()
3434

3535

36-
class OfflineAlgorithm(Base):
36+
class OfflineAlgorithm(BrainPyObject):
3737
"""Base class for offline training algorithm."""
3838

3939
def __init__(self, name=None):
@@ -46,16 +46,16 @@ def __call__(self, identifier, target, input, output):
4646
----------
4747
identifier: str
4848
The variable name.
49-
target: JaxArray, ndarray
49+
target: Array, ndarray
5050
The 2d target data with the shape of `(num_batch, num_output)`.
51-
input: JaxArray, ndarray
51+
input: Array, ndarray
5252
The 2d input data with the shape of `(num_batch, num_input)`.
53-
output: JaxArray, ndarray
53+
output: Array, ndarray
5454
The 2d output data with the shape of `(num_batch, num_output)`.
5555
5656
Returns
5757
-------
58-
weight: JaxArray
58+
weight: Array
5959
The weights after fit.
6060
"""
6161
return self.call(identifier, target, input, output)
@@ -68,21 +68,21 @@ def call(self, identifier, targets, inputs, outputs) -> Array:
6868
identifier: str
6969
The identifier.
7070
71-
inputs: JaxArray, jax.numpy.ndarray, numpy.ndarray
71+
inputs: Array, jax.numpy.ndarray, numpy.ndarray
7272
The 3d input data with the shape of `(num_batch, num_time, num_input)`,
7373
or, the 2d input data with the shape of `(num_time, num_input)`.
7474
75-
targets: JaxArray, jax.numpy.ndarray, numpy.ndarray
75+
targets: Array, jax.numpy.ndarray, numpy.ndarray
7676
The 3d target data with the shape of `(num_batch, num_time, num_output)`,
7777
or the 2d target data with the shape of `(num_time, num_output)`.
7878
79-
outputs: JaxArray, jax.numpy.ndarray, numpy.ndarray
79+
outputs: Array, jax.numpy.ndarray, numpy.ndarray
8080
The 3d output data with the shape of `(num_batch, num_time, num_output)`,
8181
or the 2d output data with the shape of `(num_time, num_output)`.
8282
8383
Returns
8484
-------
85-
weight: JaxArray
85+
weight: Array
8686
The weights after fit.
8787
"""
8888
raise NotImplementedError('Must implement the __call__ function by the subclass itself.')

brainpy/algorithms/online.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22

33
import brainpy.math as bm
4-
from brainpy.base import Base
4+
from brainpy.base import BrainPyObject
55
from jax import vmap
66
import jax.numpy as jnp
77

@@ -21,7 +21,7 @@
2121
name2func = dict()
2222

2323

24-
class OnlineAlgorithm(Base):
24+
class OnlineAlgorithm(BrainPyObject):
2525
"""Base class for online training algorithm."""
2626

2727
def __init__(self, name=None):
@@ -34,16 +34,16 @@ def __call__(self, identifier, target, input, output):
3434
----------
3535
identifier: str
3636
The variable name.
37-
target: JaxArray, ndarray
37+
target: Array, ndarray
3838
The 2d target data with the shape of `(num_batch, num_output)`.
39-
input: JaxArray, ndarray
39+
input: Array, ndarray
4040
The 2d input data with the shape of `(num_batch, num_input)`.
41-
output: JaxArray, ndarray
41+
output: Array, ndarray
4242
The 2d output data with the shape of `(num_batch, num_output)`.
4343
4444
Returns
4545
-------
46-
weight: JaxArray
46+
weight: Array
4747
The weights after fit.
4848
"""
4949
return self.call(identifier, target, input, output)
@@ -58,16 +58,16 @@ def call(self, identifier, target, input, output):
5858
----------
5959
identifier: str
6060
The variable name.
61-
target: JaxArray, ndarray
61+
target: Array, ndarray
6262
The 2d target data with the shape of `(num_batch, num_output)`.
63-
input: JaxArray, ndarray
63+
input: Array, ndarray
6464
The 2d input data with the shape of `(num_batch, num_input)`.
65-
output: JaxArray, ndarray
65+
output: Array, ndarray
6666
The 2d output data with the shape of `(num_batch, num_output)`.
6767
6868
Returns
6969
-------
70-
weight: JaxArray
70+
weight: Array
7171
The weights after fit.
7272
"""
7373
raise NotImplementedError('Must implement the call() function by the subclass itself.')

brainpy/analysis/highdim/slow_points.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -335,14 +335,14 @@ def find_fps_with_gd_method(
335335
# set up optimization
336336
num_candidate = self._check_candidates(candidates)
337337
if not (isinstance(candidates, (bm.ndarray, jnp.ndarray, np.ndarray)) or isinstance(candidates, dict)):
338-
raise ValueError('Candidates must be instance of JaxArray or dict of JaxArray.')
339-
fixed_points = tree_map(lambda a: bm.TrainVar(a), candidates, is_leaf=lambda x: isinstance(x, bm.JaxArray))
338+
raise ValueError('Candidates must be instance of Array or dict of Array.')
339+
fixed_points = tree_map(lambda a: bm.TrainVar(a), candidates, is_leaf=lambda x: isinstance(x, bm.Array))
340340
f_eval_loss = self._get_f_eval_loss()
341341

342342
def f_loss():
343343
return f_eval_loss(tree_map(lambda a: bm.as_device_array(a),
344344
fixed_points,
345-
is_leaf=lambda x: isinstance(x, bm.JaxArray))).mean()
345+
is_leaf=lambda x: isinstance(x, bm.Array))).mean()
346346

347347
grad_f = bm.grad(f_loss, grad_vars=fixed_points, return_value=True)
348348
optimizer.register_vars(fixed_points if isinstance(fixed_points, dict) else {'a': fixed_points})
@@ -355,7 +355,7 @@ def train(idx):
355355
return loss
356356

357357
def batch_train(start_i, n_batch):
358-
return bm.for_loop(train, dyn_vars, bm.arange(start_i, start_i + n_batch))
358+
return bm.for_loop(train, bm.arange(start_i, start_i + n_batch), dyn_vars=dyn_vars)
359359

360360
# Run the optimization
361361
if self.verbose:
@@ -387,10 +387,10 @@ def batch_train(start_i, n_batch):
387387
self._opt_losses = bm.concatenate(opt_losses)
388388
self._losses = f_eval_loss(tree_map(lambda a: bm.as_device_array(a),
389389
fixed_points,
390-
is_leaf=lambda x: isinstance(x, bm.JaxArray)))
390+
is_leaf=lambda x: isinstance(x, bm.Array)))
391391
self._fixed_points = tree_map(lambda a: bm.as_device_array(a),
392392
fixed_points,
393-
is_leaf=lambda x: isinstance(x, bm.JaxArray))
393+
is_leaf=lambda x: isinstance(x, bm.Array))
394394
self._selected_ids = jnp.arange(num_candidate)
395395

396396
if isinstance(self.target, DynamicalSystem):
@@ -428,7 +428,7 @@ def find_fps_with_opt_solver(
428428
# optimizing
429429
res = f_opt(tree_map(lambda a: bm.as_device_array(a),
430430
candidates,
431-
is_leaf=lambda a: isinstance(a, bm.JaxArray)))
431+
is_leaf=lambda a: isinstance(a, bm.Array)))
432432

433433
# results
434434
valid_ids = jnp.where(res.success)[0]
@@ -546,7 +546,7 @@ def compute_jacobians(
546546
547547
Parameters
548548
----------
549-
points: np.ndarray, bm.JaxArray, jax.ndarray
549+
points: np.ndarray, bm.Array, jax.ndarray
550550
The fixed points with the shape of (num_point, num_dim).
551551
stack_dict_var: bool
552552
Stack dictionary variables to calculate Jacobian matrix?
@@ -561,7 +561,7 @@ def compute_jacobians(
561561
"""
562562
# check data
563563
info = np.asarray([(l.ndim, l.shape[0])
564-
for l in tree_flatten(points, is_leaf=lambda a: isinstance(a, bm.JaxArray))[0]])
564+
for l in tree_flatten(points, is_leaf=lambda a: isinstance(a, bm.Array))[0]])
565565
ndim = np.unique(info[:, 0])
566566
if len(ndim) != 1: raise ValueError(f'Get multiple dimension of the evaluated points. {ndim}')
567567
if ndim[0] == 1:
@@ -606,7 +606,7 @@ def decompose_eigenvalues(matrices, sort_by='magnitude', do_compute_lefts=False)
606606
607607
Parameters
608608
----------
609-
matrices: np.ndarray, bm.JaxArray, jax.ndarray
609+
matrices: np.ndarray, bm.Array, jax.ndarray
610610
A 3D array with the shape of (num_matrices, dim, dim).
611611
sort_by: str
612612
The method of sorting.

0 commit comments

Comments
 (0)