11# -*- coding: utf-8 -*-
22
3+ from typing import (Union , Any , Protocol )
4+
35import jax .numpy as jnp
46import numpy as np
5- from jax .tree_util import tree_map
67from jax .tree_util import tree_flatten , tree_unflatten
8+ from jax .tree_util import tree_map
79
810from ._utils import _compatible_with_brainpy_array , _as_jax_array_
911from .arrayinterporate import *
1012from .ndarray import Array
1113
14+
15+ class SupportsDType (Protocol ):
16+ @property
17+ def dtype (self ) -> np .dtype : ...
18+
19+
20+ DTypeLike = Union [Any , str , np .dtype , SupportsDType ]
21+
1222__all__ = [
1323 'full' , 'full_like' , 'eye' , 'identity' , 'diag' , 'tri' , 'tril' , 'triu' ,
1424 'empty' , 'empty_like' , 'ones' , 'ones_like' , 'zeros' , 'zeros_like' ,
99109
100110]
101111
102-
103112_min = min
104113_max = max
105114
115+ # def concatenate(arrays: Union[np.ndarray, Array, Sequence[Array]],
116+ # axis: Optional[int] = None,
117+ # dim: Optional[int] = None,
118+ # dtype: Optional[DTypeLike] = None) -> Array:
119+ # """Join a sequence of arrays along an existing axis.
120+ #
121+ #
122+ # Parameters
123+ # ----------
124+ # a1, a2, ... : sequence of array_like
125+ # The arrays must have the same shape, except in the dimension
126+ # corresponding to `axis` (the first, by default).
127+ # axis : int, optional
128+ # The axis along which the arrays will be joined. If axis is None,
129+ # arrays are flattened before use. Default is 0.
130+ # dtype : str or dtype
131+ # If provided, the destination array will have this dtype. Cannot be
132+ # provided together with `out`.
133+ #
134+ # Returns
135+ # -------
136+ # res : ndarray
137+ # The concatenated array.
138+ # """
139+ # axis = one_of(0, axis, dim, ['axis', 'dim'])
140+ # r = jnp.concatenate(tree_map(_as_jax_array_, arrays, is_leaf=_is_leaf),
141+ # axis=axis,
142+ # dtype=dtype)
143+ # return _return(r)
144+
106145
107146def fill_diagonal (a , val , inplace = True ):
108147 if a .ndim < 2 :
@@ -112,13 +151,14 @@ def fill_diagonal(a, val, inplace=True):
112151 'it requires a brainpy Array. If you want to disable '
113152 'inplace updating, use ``fill_diagonal(inplace=False)``.' )
114153 val = val .value if isinstance (val , Array ) else val
115- i , j = jnp .diag_indices (min (a .shape [- 2 :]))
154+ i , j = jnp .diag_indices (_min (a .shape [- 2 :]))
116155 r = as_jax (a ).at [..., i , j ].set (val )
117156 if inplace :
118157 a .value = r
119158 else :
120159 return r
121160
161+
122162def zeros (shape , dtype = None ):
123163 return Array (jnp .zeros (shape , dtype = dtype ))
124164
@@ -191,6 +231,7 @@ def logspace(*args, **kwargs):
191231 kwargs = {k : _as_jax_array_ (v ) for k , v in kwargs .items ()}
192232 return Array (jnp .logspace (* args , ** kwargs ))
193233
234+
194235def asanyarray (a , dtype = None , order = None ):
195236 return asarray (a , dtype = dtype , order = order )
196237
@@ -612,7 +653,7 @@ def common_type(*arrays):
612653 p = array_precision .get (t , None )
613654 if p is None :
614655 raise TypeError ("can't get common type for non-numeric array" )
615- precision = max (precision , p )
656+ precision = _max (precision , p )
616657 if is_complex :
617658 return array_type [1 ][precision ]
618659 else :
0 commit comments