Skip to content

Commit 3ebeb2a

Browse files
authored
Unify autograd transformations and upgrade others (#314)
Unify autograd transformations and upgrade others
2 parents 33cff62 + d37c113 commit 3ebeb2a

File tree

16 files changed

+1834
-1075
lines changed

16 files changed

+1834
-1075
lines changed

brainpy/base/function.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# -*- coding: utf-8 -*-
22

3-
from typing import Callable, Sequence, Dict, Union
3+
from typing import Callable, Sequence, Dict, Union, TypeVar
44

55
from brainpy.base.base import BrainPyObject
6-
from brainpy.types import ArrayType
6+
7+
8+
Variable = TypeVar('Variable')
9+
710

811
__all__ = [
912
'FunAsObject',
@@ -28,7 +31,7 @@ class FunAsObject(BrainPyObject):
2831
def __init__(self,
2932
f: Callable,
3033
child_objs: Union[BrainPyObject, Sequence[BrainPyObject], Dict[dict, BrainPyObject]] = None,
31-
dyn_vars: Union[ArrayType, Sequence[ArrayType], Dict[dict, ArrayType]] = None,
34+
dyn_vars: Union[Variable, Sequence[Variable], Dict[dict, Variable]] = None,
3235
name: str = None):
3336
super(FunAsObject, self).__init__(name=name)
3437
self._f = f

brainpy/checkpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1304,7 +1304,7 @@ def load(
13041304
gda_manager: Optional[Any] = None,
13051305
allow_partial_mpa_restoration: bool = False,
13061306
) -> PyTree:
1307-
"""Load last or best checkpoint from the given checkpoint path.
1307+
"""Load last or best checkpoint from the given checkpoint path.
13081308
13091309
Sorts the checkpoint files naturally, returning the highest-valued
13101310
file, e.g.:

brainpy/math/environment.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import warnings
1010
from typing import Any, Callable, TypeVar, cast
1111

12-
from jax import dtypes, config, numpy as jnp, devices
12+
from jax import config, numpy as jnp, devices
1313
from jax.lib import xla_bridge
1414

1515
from . import modes
@@ -329,6 +329,7 @@ def clone(self):
329329
def set_environment(
330330
mode: modes.Mode = None,
331331
dt: float = None,
332+
x64: bool = None,
332333
complex_: type = None,
333334
float_: type = None,
334335
int_: type = None,
@@ -342,6 +343,8 @@ def set_environment(
342343
The computing mode.
343344
dt: float
344345
The numerical integration precision.
346+
x64: bool
347+
Enable x64 computation.
345348
complex_: type
346349
The complex data type.
347350
float_
@@ -359,6 +362,10 @@ def set_environment(
359362
assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.'
360363
set_mode(mode)
361364

365+
if x64 is not None:
366+
assert isinstance(x64, bool), f'"x64" must be a bool.'
367+
set_x64(x64)
368+
362369
if float_ is not None:
363370
assert isinstance(float_, type), '"float_" must a float.'
364371
set_float(float_)
@@ -402,8 +409,9 @@ class environment(_DecoratorContextManager):
402409

403410
def __init__(
404411
self,
405-
dt: float = None,
406412
mode: modes.Mode = None,
413+
dt: float = None,
414+
x64: bool = None,
407415
complex_: type = None,
408416
float_: type = None,
409417
int_: type = None,
@@ -412,6 +420,7 @@ def __init__(
412420
super().__init__()
413421
self.old_dt = get_dt()
414422
self.old_mode = get_mode()
423+
self.old_x64 = config.read("jax_enable_x64")
415424
self.old_int = get_int()
416425
self.old_bool = get_bool()
417426
self.old_float = get_float()
@@ -421,6 +430,8 @@ def __init__(
421430
assert isinstance(dt, float), '"dt" must a float.'
422431
if mode is not None:
423432
assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.'
433+
if x64 is not None:
434+
assert isinstance(x64, bool), f'"x64" must be a bool.'
424435
if float_ is not None:
425436
assert isinstance(float_, type), '"float_" must a float.'
426437
if int_ is not None:
@@ -431,6 +442,7 @@ def __init__(
431442
assert isinstance(complex_, type), '"complex_" must a type.'
432443
self.dt = dt
433444
self.mode = mode
445+
self.x64 = x64
434446
self.complex_ = complex_
435447
self.float_ = float_
436448
self.int_ = int_
@@ -439,6 +451,7 @@ def __init__(
439451
def __enter__(self) -> 'environment':
440452
if self.dt is not None: set_dt(self.dt)
441453
if self.mode is not None: set_mode(self.mode)
454+
if self.x64 is not None: set_x64(self.x64)
442455
if self.float_ is not None: set_float(self.float_)
443456
if self.int_ is not None: set_int(self.int_)
444457
if self.complex_ is not None: set_complex(self.complex_)
@@ -448,6 +461,7 @@ def __enter__(self) -> 'environment':
448461
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
449462
if self.dt is not None: set_dt(self.old_dt)
450463
if self.mode is not None: set_mode(self.old_mode)
464+
if self.x64 is not None: set_x64(self.old_x64)
451465
if self.int_ is not None: set_int(self.old_int)
452466
if self.float_ is not None: set_float(self.old_float)
453467
if self.complex_ is not None: set_complex(self.old_complex)
@@ -456,6 +470,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
456470
def clone(self):
457471
return self.__class__(dt=self.dt,
458472
mode=self.mode,
473+
x64=self.x64,
459474
bool_=self.bool_,
460475
complex_=self.complex_,
461476
float_=self.float_,
@@ -468,6 +483,7 @@ class training_environment(environment):
468483
This is a short-cut context setting for an environment with the training mode.
469484
It is equivalent to::
470485
486+
>>> import brainpy.math as bm
471487
>>> with bm.environment(mode=bm.training_mode):
472488
>>> pass
473489
@@ -476,11 +492,17 @@ class training_environment(environment):
476492

477493
def __init__(self,
478494
dt: float = None,
495+
x64: bool = None,
479496
complex_: type = None,
480497
float_: type = None,
481498
int_: type = None,
482499
bool_: type = None):
483-
super().__init__(dt=dt, complex_=complex_, float_=float_, int_=int_, bool_=bool_,
500+
super().__init__(dt=dt,
501+
x64=x64,
502+
complex_=complex_,
503+
float_=float_,
504+
int_=int_,
505+
bool_=bool_,
484506
mode=modes.TrainingMode())
485507

486508

@@ -490,6 +512,7 @@ class batching_environment(environment):
490512
This is a short-cut context setting for an environment with the batching mode.
491513
It is equivalent to::
492514
515+
>>> import brainpy.math as bm
493516
>>> with bm.environment(mode=bm.batching_mode):
494517
>>> pass
495518
@@ -498,11 +521,17 @@ class batching_environment(environment):
498521

499522
def __init__(self,
500523
dt: float = None,
524+
x64: bool = None,
501525
complex_: type = None,
502526
float_: type = None,
503527
int_: type = None,
504528
bool_: type = None):
505-
super().__init__(dt=dt, complex_=complex_, float_=float_, int_=int_, bool_=bool_,
529+
super().__init__(dt=dt,
530+
x64=x64,
531+
complex_=complex_,
532+
float_=float_,
533+
int_=int_,
534+
bool_=bool_,
506535
mode=modes.BatchingMode())
507536

508537

@@ -520,6 +549,14 @@ def disable_x64():
520549
set_complex(jnp.complex64)
521550

522551

552+
def set_x64(enable: bool):
553+
assert isinstance(enable, bool)
554+
if enable:
555+
enable_x64()
556+
else:
557+
disable_x64()
558+
559+
523560
def set_platform(platform: str):
524561
"""
525562
Changes platform to CPU, GPU, or TPU. This utility only takes

brainpy/math/fft.py

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# -*- coding: utf-8 -*-
22

3+
from typing import Optional
34
import jax.numpy.fft
45

5-
from brainpy.math.ndarray import Array
6-
from brainpy.math.numpy_ops import _remove_brainpy_array
6+
from brainpy.math.numpy_ops import _as_jax_array_
77

88
__all__ = [
99
"fft", "fft2", "fftfreq", "fftn", "fftshift", "hfft",
@@ -12,89 +12,95 @@
1212
]
1313

1414

15-
def fft(a, n=None, axis=-1, norm=None):
16-
a = _remove_brainpy_array(a)
17-
return Array(jax.numpy.fft.fft(a=a, n=n, axis=axis, norm=norm))
15+
def fft(a,
16+
n: Optional[int] = None,
17+
axis: int = -1,
18+
norm: Optional[str] = None):
19+
a = _as_jax_array_(a)
20+
return jax.numpy.fft.fft(a=a, n=n, axis=axis, norm=norm)
1821

1922

2023
def fft2(a, s=None, axes=(-2, -1), norm=None):
21-
a = _remove_brainpy_array(a)
22-
return Array(jax.numpy.fft.fft2(a=a, s=s, axes=axes, norm=norm))
24+
a = _as_jax_array_(a)
25+
return jax.numpy.fft.fft2(a=a, s=s, axes=axes, norm=norm)
2326

2427

2528
def fftfreq(n, d=1.0):
26-
return Array(jax.numpy.fft.fftfreq(n=n, d=d))
29+
return jax.numpy.fft.fftfreq(n=n, d=d)
2730

2831

2932
def fftn(a, s=None, axes=None, norm=None):
30-
a = _remove_brainpy_array(a)
31-
return Array(jax.numpy.fft.fftn(a=a, s=s, axes=axes, norm=norm))
33+
a = _as_jax_array_(a)
34+
return jax.numpy.fft.fftn(a=a, s=s, axes=axes, norm=norm)
3235

3336

3437
def fftshift(x, axes=None):
35-
x = _remove_brainpy_array(x)
36-
return Array(jax.numpy.fft.fftshift(x=x, axes=axes))
38+
x = _as_jax_array_(x)
39+
return jax.numpy.fft.fftshift(x=x, axes=axes)
3740

3841

3942
def hfft(a, n=None, axis=-1, norm=None):
40-
a = _remove_brainpy_array(a)
41-
return Array(jax.numpy.fft.hfft(a=a, n=n, axis=axis, norm=norm))
43+
a = _as_jax_array_(a)
44+
return jax.numpy.fft.hfft(a=a, n=n, axis=axis, norm=norm)
4245

4346

44-
def ifft(a, n=None, axis=-1, norm=None):
45-
a = _remove_brainpy_array(a)
46-
return Array(jax.numpy.fft.ifft(a=a, n=n, axis=axis, norm=norm))
47+
def ifft(a,
48+
n: Optional[int] = None,
49+
axis: int = -1,
50+
norm: Optional[str] = None):
51+
a = _as_jax_array_(a)
52+
return jax.numpy.fft.ifft(a=a, n=n, axis=axis, norm=norm)
4753

4854

4955
def ifft2(a, s=None, axes=(-2, -1), norm=None):
50-
a = _remove_brainpy_array(a)
51-
return Array(jax.numpy.fft.ifft2(a=a, s=s, axes=axes, norm=norm))
56+
a = _as_jax_array_(a)
57+
return jax.numpy.fft.ifft2(a=a, s=s, axes=axes, norm=norm)
5258

5359

5460
def ifftn(a, s=None, axes=None, norm=None):
55-
a = _remove_brainpy_array(a)
56-
return Array(jax.numpy.fft.ifftn(a=a, s=s, axes=axes, norm=norm))
61+
a = _as_jax_array_(a)
62+
return jax.numpy.fft.ifftn(a=a, s=s, axes=axes, norm=norm)
5763

5864

5965
def ifftshift(x, axes=None):
60-
x = _remove_brainpy_array(x)
61-
return Array(jax.numpy.fft.ifftshift(x=x, axes=axes))
66+
x = _as_jax_array_(x)
67+
return jax.numpy.fft.ifftshift(x=x, axes=axes)
6268

6369

6470
def ihfft(a, n=None, axis=-1, norm=None):
65-
a = _remove_brainpy_array(a)
66-
return Array(jax.numpy.fft.ihfft(a=a, n=n, axis=axis, norm=norm))
71+
a = _as_jax_array_(a)
72+
return jax.numpy.fft.ihfft(a=a, n=n, axis=axis, norm=norm)
6773

6874

6975
def irfft(a, n=None, axis=-1, norm=None):
70-
a = _remove_brainpy_array(a)
71-
return Array(jax.numpy.fft.irfft(a=a, n=n, axis=axis, norm=norm))
76+
a = _as_jax_array_(a)
77+
return jax.numpy.fft.irfft(a=a, n=n, axis=axis, norm=norm)
7278

7379

7480
def irfft2(a, s=None, axes=(-2, -1), norm=None):
75-
a = _remove_brainpy_array(a)
76-
return Array(jax.numpy.fft.irfft2(a=a, s=s, axes=axes, norm=norm))
81+
a = _as_jax_array_(a)
82+
return jax.numpy.fft.irfft2(a=a, s=s, axes=axes, norm=norm)
7783

7884

7985
def irfftn(a, s=None, axes=None, norm=None):
80-
a = _remove_brainpy_array(a)
81-
return Array(jax.numpy.fft.irfftn(a=a, s=s, axes=axes, norm=norm))
86+
a = _as_jax_array_(a)
87+
return jax.numpy.fft.irfftn(a=a, s=s, axes=axes, norm=norm)
8288

8389

8490
def rfft(a, n=None, axis=-1, norm=None):
85-
a = _remove_brainpy_array(a)
86-
return Array(jax.numpy.fft.rfft(a=a, n=n, axis=axis, norm=norm))
91+
a = _as_jax_array_(a)
92+
return jax.numpy.fft.rfft(a=a, n=n, axis=axis, norm=norm)
8793

8894

8995
def rfft2(a, s=None, axes=(-2, -1), norm=None):
90-
a = _remove_brainpy_array(a)
91-
return Array(jax.numpy.fft.rfft2(a=a, s=s, axes=axes, norm=norm))
96+
a = _as_jax_array_(a)
97+
return jax.numpy.fft.rfft2(a=a, s=s, axes=axes, norm=norm)
9298

9399

94100
def rfftfreq(n, d=1.0):
95-
return Array(jax.numpy.fft.rfftfreq(n=n, d=d))
101+
return jax.numpy.fft.rfftfreq(n=n, d=d)
96102

97103

98104
def rfftn(a, s=None, axes=None, norm=None):
99-
a = _remove_brainpy_array(a)
100-
return Array(jax.numpy.fft.rfftn(a=a, s=s, axes=axes, norm=norm))
105+
a = _as_jax_array_(a)
106+
return jax.numpy.fft.rfftn(a=a, s=s, axes=axes, norm=norm)

0 commit comments

Comments
 (0)