Skip to content

Commit 59bfc0d

Browse files
authored
reorganize operators in brainpy.math (#357)
reorganize operators in brainpy.math (#357)
2 parents c0c7910 + 1836e2b commit 59bfc0d

File tree

11 files changed

+137
-143
lines changed

11 files changed

+137
-143
lines changed

brainpy/_src/math/environment.py

Lines changed: 60 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -324,66 +324,6 @@ def clone(self):
324324
return self.__class__()
325325

326326

327-
def set(
328-
mode: modes.Mode = None,
329-
dt: float = None,
330-
x64: bool = None,
331-
complex_: type = None,
332-
float_: type = None,
333-
int_: type = None,
334-
bool_: type = None,
335-
):
336-
"""Set the default computation environment.
337-
338-
Parameters
339-
----------
340-
mode: Mode
341-
The computing mode.
342-
dt: float
343-
The numerical integration precision.
344-
x64: bool
345-
Enable x64 computation.
346-
complex_: type
347-
The complex data type.
348-
float_
349-
The floating data type.
350-
int_
351-
The integer data type.
352-
bool_
353-
The bool data type.
354-
"""
355-
if dt is not None:
356-
assert isinstance(dt, float), '"dt" must a float.'
357-
set_dt(dt)
358-
359-
if mode is not None:
360-
assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.'
361-
set_mode(mode)
362-
363-
if x64 is not None:
364-
assert isinstance(x64, bool), f'"x64" must be a bool.'
365-
set_x64(x64)
366-
367-
if float_ is not None:
368-
assert isinstance(float_, type), '"float_" must a float.'
369-
set_float(float_)
370-
371-
if int_ is not None:
372-
assert isinstance(int_, type), '"int_" must a type.'
373-
set_int(int_)
374-
375-
if bool_ is not None:
376-
assert isinstance(bool_, type), '"bool_" must a type.'
377-
set_bool(bool_)
378-
379-
if complex_ is not None:
380-
assert isinstance(complex_, type), '"complex_" must a type.'
381-
set_complex(complex_)
382-
383-
384-
set_environment = set
385-
386-
387327
class environment(_DecoratorContextManager):
388328
r"""Context-manager that sets a computing environment for brain dynamics computation.
389329
@@ -541,6 +481,66 @@ def __init__(
541481
mode=modes.BatchingMode(batch_size))
542482

543483

484+
def set(
485+
mode: modes.Mode = None,
486+
dt: float = None,
487+
x64: bool = None,
488+
complex_: type = None,
489+
float_: type = None,
490+
int_: type = None,
491+
bool_: type = None,
492+
):
493+
"""Set the default computation environment.
494+
495+
Parameters
496+
----------
497+
mode: Mode
498+
The computing mode.
499+
dt: float
500+
The numerical integration precision.
501+
x64: bool
502+
Enable x64 computation.
503+
complex_: type
504+
The complex data type.
505+
float_
506+
The floating data type.
507+
int_
508+
The integer data type.
509+
bool_
510+
The bool data type.
511+
"""
512+
if dt is not None:
513+
assert isinstance(dt, float), '"dt" must a float.'
514+
set_dt(dt)
515+
516+
if mode is not None:
517+
assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.'
518+
set_mode(mode)
519+
520+
if x64 is not None:
521+
assert isinstance(x64, bool), f'"x64" must be a bool.'
522+
set_x64(x64)
523+
524+
if float_ is not None:
525+
assert isinstance(float_, type), '"float_" must a float.'
526+
set_float(float_)
527+
528+
if int_ is not None:
529+
assert isinstance(int_, type), '"int_" must a type.'
530+
set_int(int_)
531+
532+
if bool_ is not None:
533+
assert isinstance(bool_, type), '"bool_" must a type.'
534+
set_bool(bool_)
535+
536+
if complex_ is not None:
537+
assert isinstance(complex_, type), '"complex_" must a type.'
538+
set_complex(complex_)
539+
540+
541+
set_environment = set
542+
543+
544544
def enable_x64():
545545
config.update("jax_enable_x64", True)
546546
set_int(jnp.int64)

brainpy/_src/math/ndarray.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class Array(object):
9090

9191
is_brainpy_array = True
9292
_need_check_context = True
93-
__slots__ = ("_value", "_transform_context")
93+
__slots__ = ("_value", )
9494

9595
def __init__(self, value, dtype=None):
9696
# array value
@@ -101,26 +101,6 @@ def __init__(self, value, dtype=None):
101101
if dtype is not None:
102102
value = jnp.asarray(value, dtype=dtype)
103103
self._value = value
104-
# jit mode
105-
self._transform_context = get_context()
106-
107-
def __check_context(self) -> None:
108-
# raise error when in-place updating a
109-
if self._need_check_context:
110-
if self._transform_context is None:
111-
if len(_jax_transformation_context_) > 0:
112-
raise MathError(f'Array created outside of the transformation functions '
113-
f'({_jax_transformation_context_[-1]}) cannot be updated. '
114-
f'You should mark it as a brainpy.math.Variable instead.')
115-
else:
116-
if len(_jax_transformation_context_) > 0:
117-
if self._transform_context != _jax_transformation_context_[-1]:
118-
raise MathError(f'Array context "{self._transform_context}" differs from the JAX '
119-
f'transformation context "{_jax_transformation_context_[-1]}"'
120-
'\n\n'
121-
'Array created in one transformation function '
122-
'cannot be updated another transformation function. '
123-
'You should mark it as a brainpy.math.Variable instead.')
124104

125105
@property
126106
def value(self):
@@ -1455,7 +1435,7 @@ class Variable(Array):
14551435
"""
14561436

14571437
_need_check_context = False
1458-
__slots__ = ('_value', '_batch_axis')
1438+
__slots__ = ('_value', '_batch_axis', '_env')
14591439

14601440
def __init__(
14611441
self,
@@ -1487,6 +1467,8 @@ def __init__(
14871467
raise MathError(f'This variables has {self.ndim} dimension, '
14881468
f'but the batch axis is set to be {batch_axis}.')
14891469

1470+
self._env = True
1471+
14901472
@property
14911473
def nobatch_shape(self) -> TupleType[int, ...]:
14921474
"""Shape without batch axis."""

brainpy/math/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# data structure
55
from .ndarray import *
66
from .delayvars import *
7-
from .arrayinterporate import *
7+
from .interoperability import *
88
from .datatypes import *
99
from .compat_numpy import *
1010
from .compat_tensorflow import *
@@ -15,7 +15,11 @@
1515
from . import activations
1616

1717
# operators
18-
from .operators import *
18+
from .event_ops import *
19+
from .jitconn_ops import *
20+
from .pre_syn_post import *
21+
from .sparse_ops import *
22+
from .op_register import *
1923
from . import surrogate
2024

2125
# Variable and Objects for object-oriented JAX transformations

brainpy/math/event_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
from brainpy._src.math.operators.event_ops import (
3+
event_csr_matvec,
4+
event_info
5+
)
6+
File renamed without changes.

brainpy/math/jitconn_ops.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
2+
from brainpy._src.math.operators.jitconn_ops import (
3+
event_matvec_prob_conn_homo_weight,
4+
event_matvec_prob_conn_uniform_weight,
5+
event_matvec_prob_conn_normal_weight,
6+
7+
matmat_prob_conn_homo_weight,
8+
matmat_prob_conn_uniform_weight,
9+
matmat_prob_conn_normal_weight,
10+
11+
matvec_prob_conn_homo_weight,
12+
matvec_prob_conn_uniform_weight,
13+
matvec_prob_conn_normal_weight,
14+
)

brainpy/math/op_register.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
from brainpy._src.math.operators.op_registers import (
5+
XLACustomOp,
6+
compile_cpu_signature_with_numba,
7+
)
8+
9+
10+

brainpy/math/operators.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

brainpy/math/pre_syn_post.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from brainpy._src.math.operators.pre_syn_post import (
2+
pre2post_sum,
3+
pre2post_prod,
4+
pre2post_max,
5+
pre2post_min,
6+
pre2post_mean,
7+
8+
pre2post_event_sum,
9+
pre2post_coo_event_sum,
10+
pre2post_event_prod,
11+
12+
pre2syn,
13+
14+
syn2post_sum, syn2post,
15+
syn2post_prod,
16+
syn2post_max,
17+
syn2post_min,
18+
syn2post_mean,
19+
syn2post_softmax,
20+
)

brainpy/math/sparse_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from brainpy._src.math.operators.sparse_ops import (
2+
cusparse_csr_matvec,
3+
cusparse_coo_matvec,
4+
csr_matvec,
5+
sparse_matmul,
6+
coo_to_csr,
7+
csr_to_coo,
8+
csr_to_dense
9+
)

0 commit comments

Comments
 (0)