99import warnings
1010from typing import Any , Callable , TypeVar , cast
1111
12- from jax import dtypes , config , numpy as jnp , devices
12+ from jax import config , numpy as jnp , devices
1313from jax .lib import xla_bridge
1414
1515from . import modes
@@ -329,6 +329,7 @@ def clone(self):
329329def set_environment (
330330 mode : modes .Mode = None ,
331331 dt : float = None ,
332+ x64 : bool = None ,
332333 complex_ : type = None ,
333334 float_ : type = None ,
334335 int_ : type = None ,
@@ -342,6 +343,8 @@ def set_environment(
342343 The computing mode.
343344 dt: float
344345 The numerical integration precision.
346+ x64: bool
347+ Enable x64 computation.
345348 complex_: type
346349 The complex data type.
347350 float_
@@ -359,6 +362,10 @@ def set_environment(
359362 assert isinstance (mode , modes .Mode ), f'"mode" must a { modes .Mode } .'
360363 set_mode (mode )
361364
365+ if x64 is not None :
366+ assert isinstance (x64 , bool ), f'"x64" must be a bool.'
367+ set_x64 (x64 )
368+
362369 if float_ is not None :
363370 assert isinstance (float_ , type ), '"float_" must a float.'
364371 set_float (float_ )
@@ -402,8 +409,9 @@ class environment(_DecoratorContextManager):
402409
403410 def __init__ (
404411 self ,
405- dt : float = None ,
406412 mode : modes .Mode = None ,
413+ dt : float = None ,
414+ x64 : bool = None ,
407415 complex_ : type = None ,
408416 float_ : type = None ,
409417 int_ : type = None ,
@@ -412,6 +420,7 @@ def __init__(
412420 super ().__init__ ()
413421 self .old_dt = get_dt ()
414422 self .old_mode = get_mode ()
423+ self .old_x64 = config .read ("jax_enable_x64" )
415424 self .old_int = get_int ()
416425 self .old_bool = get_bool ()
417426 self .old_float = get_float ()
@@ -421,6 +430,8 @@ def __init__(
421430 assert isinstance (dt , float ), '"dt" must a float.'
422431 if mode is not None :
423432 assert isinstance (mode , modes .Mode ), f'"mode" must a { modes .Mode } .'
433+ if x64 is not None :
434+ assert isinstance (x64 , bool ), f'"x64" must be a bool.'
424435 if float_ is not None :
425436 assert isinstance (float_ , type ), '"float_" must a float.'
426437 if int_ is not None :
@@ -431,6 +442,7 @@ def __init__(
431442 assert isinstance (complex_ , type ), '"complex_" must a type.'
432443 self .dt = dt
433444 self .mode = mode
445+ self .x64 = x64
434446 self .complex_ = complex_
435447 self .float_ = float_
436448 self .int_ = int_
@@ -439,6 +451,7 @@ def __init__(
439451 def __enter__ (self ) -> 'environment' :
440452 if self .dt is not None : set_dt (self .dt )
441453 if self .mode is not None : set_mode (self .mode )
454+ if self .x64 is not None : set_x64 (self .x64 )
442455 if self .float_ is not None : set_float (self .float_ )
443456 if self .int_ is not None : set_int (self .int_ )
444457 if self .complex_ is not None : set_complex (self .complex_ )
@@ -448,6 +461,7 @@ def __enter__(self) -> 'environment':
448461 def __exit__ (self , exc_type : Any , exc_value : Any , traceback : Any ) -> None :
449462 if self .dt is not None : set_dt (self .old_dt )
450463 if self .mode is not None : set_mode (self .old_mode )
464+ if self .x64 is not None : set_x64 (self .old_x64 )
451465 if self .int_ is not None : set_int (self .old_int )
452466 if self .float_ is not None : set_float (self .old_float )
453467 if self .complex_ is not None : set_complex (self .old_complex )
@@ -456,6 +470,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
456470 def clone (self ):
457471 return self .__class__ (dt = self .dt ,
458472 mode = self .mode ,
473+ x64 = self .x64 ,
459474 bool_ = self .bool_ ,
460475 complex_ = self .complex_ ,
461476 float_ = self .float_ ,
@@ -468,6 +483,7 @@ class training_environment(environment):
468483 This is a short-cut context setting for an environment with the training mode.
469484 It is equivalent to::
470485
486+ >>> import brainpy.math as bm
471487 >>> with bm.environment(mode=bm.training_mode):
472488 >>> pass
473489
@@ -476,11 +492,17 @@ class training_environment(environment):
476492
477493 def __init__ (self ,
478494 dt : float = None ,
495+ x64 : bool = None ,
479496 complex_ : type = None ,
480497 float_ : type = None ,
481498 int_ : type = None ,
482499 bool_ : type = None ):
483- super ().__init__ (dt = dt , complex_ = complex_ , float_ = float_ , int_ = int_ , bool_ = bool_ ,
500+ super ().__init__ (dt = dt ,
501+ x64 = x64 ,
502+ complex_ = complex_ ,
503+ float_ = float_ ,
504+ int_ = int_ ,
505+ bool_ = bool_ ,
484506 mode = modes .TrainingMode ())
485507
486508
@@ -490,6 +512,7 @@ class batching_environment(environment):
490512 This is a short-cut context setting for an environment with the batching mode.
491513 It is equivalent to::
492514
515+ >>> import brainpy.math as bm
493516 >>> with bm.environment(mode=bm.batching_mode):
494517 >>> pass
495518
@@ -498,11 +521,17 @@ class batching_environment(environment):
498521
499522 def __init__ (self ,
500523 dt : float = None ,
524+ x64 : bool = None ,
501525 complex_ : type = None ,
502526 float_ : type = None ,
503527 int_ : type = None ,
504528 bool_ : type = None ):
505- super ().__init__ (dt = dt , complex_ = complex_ , float_ = float_ , int_ = int_ , bool_ = bool_ ,
529+ super ().__init__ (dt = dt ,
530+ x64 = x64 ,
531+ complex_ = complex_ ,
532+ float_ = float_ ,
533+ int_ = int_ ,
534+ bool_ = bool_ ,
506535 mode = modes .BatchingMode ())
507536
508537
@@ -520,6 +549,14 @@ def disable_x64():
520549 set_complex (jnp .complex64 )
521550
522551
552+ def set_x64 (enable : bool ):
553+ assert isinstance (enable , bool )
554+ if enable :
555+ enable_x64 ()
556+ else :
557+ disable_x64 ()
558+
559+
523560def set_platform (platform : str ):
524561 """
525562 Changes platform to CPU, GPU, or TPU. This utility only takes
0 commit comments