Skip to content

Commit 3973f3c

Browse files
committed
enable x64 setting in an environment
1 parent 0c53428 commit 3973f3c

File tree

1 file changed

+41
-4
lines changed

1 file changed

+41
-4
lines changed

brainpy/math/environment.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import warnings
1010
from typing import Any, Callable, TypeVar, cast
1111

12-
from jax import dtypes, config, numpy as jnp, devices
12+
from jax import config, numpy as jnp, devices
1313
from jax.lib import xla_bridge
1414

1515
from . import modes
@@ -329,6 +329,7 @@ def clone(self):
329329
def set_environment(
330330
mode: modes.Mode = None,
331331
dt: float = None,
332+
x64: bool = None,
332333
complex_: type = None,
333334
float_: type = None,
334335
int_: type = None,
@@ -342,6 +343,8 @@ def set_environment(
342343
The computing mode.
343344
dt: float
344345
The numerical integration precision.
346+
x64: bool
347+
Enable x64 computation.
345348
complex_: type
346349
The complex data type.
347350
float_
@@ -359,6 +362,10 @@ def set_environment(
359362
assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.'
360363
set_mode(mode)
361364

365+
if x64 is not None:
366+
assert isinstance(x64, bool), f'"x64" must be a bool.'
367+
set_x64(x64)
368+
362369
if float_ is not None:
363370
assert isinstance(float_, type), '"float_" must a float.'
364371
set_float(float_)
@@ -402,8 +409,9 @@ class environment(_DecoratorContextManager):
402409

403410
def __init__(
404411
self,
405-
dt: float = None,
406412
mode: modes.Mode = None,
413+
dt: float = None,
414+
x64: bool = None,
407415
complex_: type = None,
408416
float_: type = None,
409417
int_: type = None,
@@ -412,6 +420,7 @@ def __init__(
412420
super().__init__()
413421
self.old_dt = get_dt()
414422
self.old_mode = get_mode()
423+
self.old_x64 = config.read("jax_enable_x64")
415424
self.old_int = get_int()
416425
self.old_bool = get_bool()
417426
self.old_float = get_float()
@@ -421,6 +430,8 @@ def __init__(
421430
assert isinstance(dt, float), '"dt" must a float.'
422431
if mode is not None:
423432
assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.'
433+
if x64 is not None:
434+
assert isinstance(x64, bool), f'"x64" must be a bool.'
424435
if float_ is not None:
425436
assert isinstance(float_, type), '"float_" must a float.'
426437
if int_ is not None:
@@ -431,6 +442,7 @@ def __init__(
431442
assert isinstance(complex_, type), '"complex_" must a type.'
432443
self.dt = dt
433444
self.mode = mode
445+
self.x64 = x64
434446
self.complex_ = complex_
435447
self.float_ = float_
436448
self.int_ = int_
@@ -439,6 +451,7 @@ def __init__(
439451
def __enter__(self) -> 'environment':
440452
if self.dt is not None: set_dt(self.dt)
441453
if self.mode is not None: set_mode(self.mode)
454+
if self.x64 is not None: set_x64(self.x64)
442455
if self.float_ is not None: set_float(self.float_)
443456
if self.int_ is not None: set_int(self.int_)
444457
if self.complex_ is not None: set_complex(self.complex_)
@@ -448,6 +461,7 @@ def __enter__(self) -> 'environment':
448461
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
449462
if self.dt is not None: set_dt(self.old_dt)
450463
if self.mode is not None: set_mode(self.old_mode)
464+
if self.x64 is not None: set_x64(self.old_x64)
451465
if self.int_ is not None: set_int(self.old_int)
452466
if self.float_ is not None: set_float(self.old_float)
453467
if self.complex_ is not None: set_complex(self.old_complex)
@@ -456,6 +470,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
456470
def clone(self):
457471
return self.__class__(dt=self.dt,
458472
mode=self.mode,
473+
x64=self.x64,
459474
bool_=self.bool_,
460475
complex_=self.complex_,
461476
float_=self.float_,
@@ -468,6 +483,7 @@ class training_environment(environment):
468483
This is a short-cut context setting for an environment with the training mode.
469484
It is equivalent to::
470485
486+
>>> import brainpy.math as bm
471487
>>> with bm.environment(mode=bm.training_mode):
472488
>>> pass
473489
@@ -476,11 +492,17 @@ class training_environment(environment):
476492

477493
def __init__(self,
478494
dt: float = None,
495+
x64: bool = None,
479496
complex_: type = None,
480497
float_: type = None,
481498
int_: type = None,
482499
bool_: type = None):
483-
super().__init__(dt=dt, complex_=complex_, float_=float_, int_=int_, bool_=bool_,
500+
super().__init__(dt=dt,
501+
x64=x64,
502+
complex_=complex_,
503+
float_=float_,
504+
int_=int_,
505+
bool_=bool_,
484506
mode=modes.TrainingMode())
485507

486508

@@ -490,6 +512,7 @@ class batching_environment(environment):
490512
This is a short-cut context setting for an environment with the batching mode.
491513
It is equivalent to::
492514
515+
>>> import brainpy.math as bm
493516
>>> with bm.environment(mode=bm.batching_mode):
494517
>>> pass
495518
@@ -498,11 +521,17 @@ class batching_environment(environment):
498521

499522
def __init__(self,
500523
dt: float = None,
524+
x64: bool = None,
501525
complex_: type = None,
502526
float_: type = None,
503527
int_: type = None,
504528
bool_: type = None):
505-
super().__init__(dt=dt, complex_=complex_, float_=float_, int_=int_, bool_=bool_,
529+
super().__init__(dt=dt,
530+
x64=x64,
531+
complex_=complex_,
532+
float_=float_,
533+
int_=int_,
534+
bool_=bool_,
506535
mode=modes.BatchingMode())
507536

508537

@@ -520,6 +549,14 @@ def disable_x64():
520549
set_complex(jnp.complex64)
521550

522551

552+
def set_x64(enable: bool):
553+
assert isinstance(enable, bool)
554+
if enable:
555+
enable_x64()
556+
else:
557+
disable_x64()
558+
559+
523560
def set_platform(platform: str):
524561
"""
525562
Changes platform to CPU, GPU, or TPU. This utility only takes

0 commit comments

Comments
 (0)