Skip to content

Commit a52de70

Browse files
committed
[compatibility] more operators in pytorch and tensorflow
1 parent 33db354 commit a52de70

File tree

7 files changed

+148
-22
lines changed

7 files changed

+148
-22
lines changed

brainpy/_src/math/_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,29 @@ def _compatible_with_brainpy_array(fun: Callable):
3939
@functools.wraps(fun)
4040
def new_fun(*args, **kwargs):
4141
args = tree_map(_as_jax_array_, args, is_leaf=_is_leaf)
42+
out = None
4243
if len(kwargs):
44+
# compatible with PyTorch syntax
45+
if 'dim' in kwargs:
46+
kwargs['axis'] = kwargs.pop('dim')
47+
# compatible with PyTorch syntax
48+
if 'keepdim' in kwargs:
49+
kwargs['keep_dims'] = kwargs.pop('keepdim')
50+
# compatible with TensorFlow syntax
51+
if 'keepdims' in kwargs:
52+
kwargs['keep_dims'] = kwargs.pop('keepdims')
53+
# compatible with NumPy/PyTorch syntax
54+
if 'out' in kwargs:
55+
out = kwargs.get('out')
56+
if not isinstance(out, Array):
57+
raise TypeError(f'"out" must be an instance of brainpy Array. While we got {type(out)}')
58+
# format
4359
kwargs = tree_map(_as_jax_array_, kwargs, is_leaf=_is_leaf)
4460
r = fun(*args, **kwargs)
45-
return tree_map(_return, r)
61+
if out is None:
62+
return tree_map(_return, r)
63+
else:
64+
out.value = r
4665

4766
new_fun.__doc__ = getattr(fun, "__doc__", None)
4867

brainpy/_src/math/compat_numpy.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
# -*- coding: utf-8 -*-
22

3+
from typing import (Union, Any, Protocol)
4+
35
import jax.numpy as jnp
46
import numpy as np
5-
from jax.tree_util import tree_map
67
from jax.tree_util import tree_flatten, tree_unflatten
8+
from jax.tree_util import tree_map
79

810
from ._utils import _compatible_with_brainpy_array, _as_jax_array_
911
from .arrayinterporate import *
1012
from .ndarray import Array
1113

14+
15+
class SupportsDType(Protocol):
16+
@property
17+
def dtype(self) -> np.dtype: ...
18+
19+
20+
DTypeLike = Union[Any, str, np.dtype, SupportsDType]
21+
1222
__all__ = [
1323
'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu',
1424
'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like',
@@ -99,10 +109,39 @@
99109

100110
]
101111

102-
103112
_min = min
104113
_max = max
105114

115+
# def concatenate(arrays: Union[np.ndarray, Array, Sequence[Array]],
116+
# axis: Optional[int] = None,
117+
# dim: Optional[int] = None,
118+
# dtype: Optional[DTypeLike] = None) -> Array:
119+
# """Join a sequence of arrays along an existing axis.
120+
#
121+
#
122+
# Parameters
123+
# ----------
124+
# a1, a2, ... : sequence of array_like
125+
# The arrays must have the same shape, except in the dimension
126+
# corresponding to `axis` (the first, by default).
127+
# axis : int, optional
128+
# The axis along which the arrays will be joined. If axis is None,
129+
# arrays are flattened before use. Default is 0.
130+
# dtype : str or dtype
131+
# If provided, the destination array will have this dtype. Cannot be
132+
# provided together with `out`.
133+
#
134+
# Returns
135+
# -------
136+
# res : ndarray
137+
# The concatenated array.
138+
# """
139+
# axis = one_of(0, axis, dim, ['axis', 'dim'])
140+
# r = jnp.concatenate(tree_map(_as_jax_array_, arrays, is_leaf=_is_leaf),
141+
# axis=axis,
142+
# dtype=dtype)
143+
# return _return(r)
144+
106145

107146
def fill_diagonal(a, val, inplace=True):
108147
if a.ndim < 2:
@@ -112,13 +151,14 @@ def fill_diagonal(a, val, inplace=True):
112151
'it requires a brainpy Array. If you want to disable '
113152
'inplace updating, use ``fill_diagonal(inplace=False)``.')
114153
val = val.value if isinstance(val, Array) else val
115-
i, j = jnp.diag_indices(min(a.shape[-2:]))
154+
i, j = jnp.diag_indices(_min(a.shape[-2:]))
116155
r = as_jax(a).at[..., i, j].set(val)
117156
if inplace:
118157
a.value = r
119158
else:
120159
return r
121160

161+
122162
def zeros(shape, dtype=None):
123163
return Array(jnp.zeros(shape, dtype=dtype))
124164

@@ -191,6 +231,7 @@ def logspace(*args, **kwargs):
191231
kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()}
192232
return Array(jnp.logspace(*args, **kwargs))
193233

234+
194235
def asanyarray(a, dtype=None, order=None):
195236
return asarray(a, dtype=dtype, order=order)
196237

@@ -612,7 +653,7 @@ def common_type(*arrays):
612653
p = array_precision.get(t, None)
613654
if p is None:
614655
raise TypeError("can't get common type for non-numeric array")
615-
precision = max(precision, p)
656+
precision = _max(precision, p)
616657
if is_complex:
617658
return array_type[1][precision]
618659
else:

brainpy/_src/math/compat_pytorch.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,25 @@
55
import numpy as np
66

77
from .ndarray import Array, _as_jax_array_
8+
from .compat_numpy import (
9+
concatenate,
10+
)
811

912
__all__ = [
13+
'Tensor',
1014
'flatten',
15+
'cat',
16+
17+
# data types
18+
'bfloat16', 'half', 'float', 'double', 'cfloat', 'cdouble', 'short', 'int', 'long', 'bool'
1119
]
1220

1321

22+
23+
Tensor = Array
24+
cat = concatenate
25+
26+
1427
def flatten(input: Union[jax.Array, Array],
1528
start_dim: Optional[int] = None,
1629
end_dim: Optional[int] = None) -> jax.Array:
@@ -56,3 +69,22 @@ def flatten(input: Union[jax.Array, Array],
5669
new_shape = shape[:start_dim] + (np.prod(shape[start_dim: end_dim], dtype=int), ) + shape[end_dim:]
5770
return jnp.reshape(input, new_shape)
5871

72+
# data types
73+
bfloat16 = jnp.bfloat16
74+
half = jnp.float16
75+
float = jnp.float32
76+
double = jnp.float64
77+
cfloat = jnp.complex64
78+
cdouble = jnp.complex128
79+
short = jnp.int16
80+
int = jnp.int32
81+
long = jnp.int64
82+
bool = jnp.bool_
83+
# missing types #
84+
# chalf = np.complex32
85+
# quint8 = jnp.quint8
86+
# qint8 = jnp.qint8
87+
# qint32 = jnp.qint32
88+
# quint4x2 = jnp.quint4x2
89+
90+

brainpy/_src/math/compat_tensorflow.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,31 @@
22
import jax.ops
33

44
from .ndarray import _return, _as_jax_array_
5-
from .compat_numpy import prod, min, sum, all, any, mean, std, var
5+
from .compat_numpy import (
6+
prod, min, sum, all, any, mean, std, var, concatenate, clip
7+
)
68

79
__all__ = [
8-
'reduce_sum', 'reduce_max', 'reduce_min', 'reduce_mean', 'reduce_all',
9-
'reduce_any', 'reduce_logsumexp', 'reduce_prod', 'reduce_std', 'reduce_variance',
10-
'reduce_euclidean_norm',
11-
'unsorted_segment_sqrt_n', 'segment_mean', 'unsorted_segment_sum',
12-
'unsorted_segment_prod', 'unsorted_segment_max', 'unsorted_segment_min',
13-
'unsorted_segment_mean',
10+
'concat',
11+
'reduce_sum', 'reduce_max', 'reduce_min', 'reduce_mean', 'reduce_all', 'reduce_any',
12+
'reduce_logsumexp', 'reduce_prod', 'reduce_std', 'reduce_variance', 'reduce_euclidean_norm',
13+
'unsorted_segment_sqrt_n', 'segment_mean', 'unsorted_segment_sum', 'unsorted_segment_prod',
14+
'unsorted_segment_max', 'unsorted_segment_min', 'unsorted_segment_mean',
15+
'clip_by_value',
1416
]
1517

1618

19+
reduce_prod = prod
20+
reduce_sum = sum
21+
reduce_all = all
22+
reduce_any = any
23+
reduce_min = min
24+
reduce_mean = mean
25+
reduce_std = std
26+
reduce_variance = var
27+
concat = concatenate
28+
clip_by_value = clip
29+
1730
def reduce_logsumexp(input_tensor, axis=None, keep_dims=False):
1831
"""Computes log(sum(exp(elements across dimensions of a tensor))).
1932
@@ -95,15 +108,6 @@ def reduce_max(input_tensor, axis=None, keep_dims=False):
95108
return _return(jnp.max(_as_jax_array_(input_tensor), axis=axis, keep_dims=keep_dims))
96109

97110

98-
reduce_prod = prod
99-
reduce_sum = sum
100-
reduce_all = all
101-
reduce_any = any
102-
reduce_min = min
103-
reduce_mean = mean
104-
reduce_std = std
105-
reduce_variance = var
106-
107111

108112

109113
def segment_mean(data, segment_ids):

brainpy/_src/tools/others.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
import collections.abc
44
import _thread as thread
55
import threading
6-
from typing import Optional, Tuple, Callable, Union, Sequence, TypeVar
6+
from typing import Optional, Tuple, Callable, Union, Sequence, TypeVar, Any
77

88
import numpy as np
99
from jax import lax
1010
from jax.experimental import host_callback
1111
from tqdm.auto import tqdm
1212

1313
__all__ = [
14+
'one_of',
1415
'replicate',
1516
'not_customized',
1617
'to_size',
@@ -20,6 +21,20 @@
2021
]
2122

2223

24+
def one_of(default: Any, *choices, names: Sequence[str] =None):
25+
names = [f'arg{i}' for i in range(len(choices))] if names is None else names
26+
res = default
27+
has_chosen = False
28+
for c in choices:
29+
if c is not None:
30+
if has_chosen:
31+
raise ValueError(f'Provide one of {names}, but we got {list(zip(choices, names))}')
32+
else:
33+
has_chosen = True
34+
res = c
35+
return res
36+
37+
2338
T = TypeVar('T')
2439

2540

brainpy/math/compat_pytorch.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,17 @@
11

22
from brainpy._src.math.compat_pytorch import (
3+
Tensor as Tensor,
34
flatten as flatten,
5+
cat as cat,
6+
7+
bfloat16 as bfloat16,
8+
half as half,
9+
float as float,
10+
double as double,
11+
cfloat as cfloat,
12+
cdouble as cdouble,
13+
short as short,
14+
int as int,
15+
long as long,
16+
bool as bool,
417
)

brainpy/math/compat_tensorflow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
from brainpy._src.math.compat_tensorflow import (
3+
concat as concat,
34
reduce_sum as reduce_sum,
45
reduce_max as reduce_max,
56
reduce_min as reduce_min,
@@ -18,5 +19,6 @@
1819
unsorted_segment_max as unsorted_segment_max,
1920
unsorted_segment_min as unsorted_segment_min,
2021
unsorted_segment_mean as unsorted_segment_mean,
22+
clip_by_value as clip_by_value,
2123
)
2224

0 commit comments

Comments
 (0)