55
66import jax .numpy as jnp
77import numpy as onp
8-
8+ from jax import config
99from brainpy import tools , math as bm
1010from brainpy .errors import ConnectorError
11+ from brainpy .tools .others import numba_jit , numba_range
1112
1213__all__ = [
1314 # the connection types
2425 'Connector' , 'TwoEndConnector' , 'OneEndConnector' ,
2526
2627 # methods
27- 'csr2csc' , 'csr2mat' , 'mat2csr' , 'ij2csr'
28+ 'csr2csc' , 'csr2mat' , 'mat2csr' , 'ij2csr' , 'ij2csr2'
2829]
2930
3031CONN_MAT = 'conn_mat'
4344 PRE2SYN , POST2SYN ,
4445 PRE_SLICE , POST_SLICE ]
4546
46- MAT_DTYPE = onp .bool_
47- IDX_DTYPE = onp .uint32
48-
47+ MAT_DTYPE = jnp .bool_
48+ IDX_DTYPE = jnp .uint32
4949
5050def set_default_dtype (mat_dtype = None , idx_dtype = None ):
5151 """Set the default dtype.
@@ -203,15 +203,15 @@ def _return_by_mat(self, structures, mat, all_data: dict):
203203
204204 require_other_structs = len ([s for s in structures if s != CONN_MAT ]) > 0
205205 if require_other_structs :
206- np = onp if isinstance (mat , onp .ndarray ) else bm
206+ np = jnp if isinstance (mat , jnp .ndarray ) else bm
207207 pre_ids , post_ids = np .where (mat > 0 )
208208 pre_ids = np .asarray (pre_ids , dtype = IDX_DTYPE )
209209 post_ids = np .asarray (post_ids , dtype = IDX_DTYPE )
210210 self ._return_by_ij (structures , ij = (pre_ids , post_ids ), all_data = all_data )
211211
212212 def _return_by_csr (self , structures , csr : tuple , all_data : dict ):
213213 indices , indptr = csr
214- np = onp if isinstance (indices , onp .ndarray ) else bm
214+ np = jnp if isinstance (indices , jnp .ndarray ) else bm
215215 assert self .pre_num == indptr .size - 1
216216
217217 if (CONN_MAT in structures ) and (CONN_MAT not in all_data ):
@@ -260,7 +260,10 @@ def _return_by_ij(self, structures, ij: tuple, all_data: dict):
260260 require_other_structs = len ([s for s in structures
261261 if s not in [CONN_MAT , PRE_IDS , POST_IDS ]]) > 0
262262 if require_other_structs :
263- csr = ij2csr (pre_ids , post_ids , self .pre_num )
263+ if config .read ('jax_platform_name' ) == "gpu" :
264+ csr = ij2csr (pre_ids , post_ids , self .pre_num )
265+ else :
266+ csr = ij2csr2 (pre_ids , post_ids , self .pre_num )
264267 self ._return_by_csr (structures , csr = csr , all_data = all_data )
265268
266269 def make_returns (self , structures , conn_data , csr = None , mat = None , ij = None ):
@@ -421,19 +424,19 @@ def _reset_conn(self, pre_size, post_size=None):
421424def csr2csc (csr , post_num , data = None ):
422425 """Convert csr to csc."""
423426 indices , indptr = csr
424- np = onp if isinstance (indices , onp .ndarray ) else bm
425- kind = 'quicksort' if isinstance (indices , onp .ndarray ) else 'stable'
427+ np = jnp if isinstance (indices , jnp .ndarray ) else bm
428+ # kind = 'quicksort' if isinstance(indices, jnp .ndarray) else 'stable'
426429
427430 pre_ids = np .repeat (np .arange (indptr .size - 1 ), np .diff (indptr ))
428431
429- sort_ids = np .argsort (indices , kind = kind ) # to maintain the original order of the elements with the same value
432+ sort_ids = np .argsort (indices ) # to maintain the original order of the elements with the same value
430433 if isinstance (sort_ids , bm .JaxArray ):
431434 sort_ids = sort_ids .value
432435 pre_ids_new = np .asarray (pre_ids [sort_ids ], dtype = IDX_DTYPE )
433436
434437 unique_post_ids , count = np .unique (indices , return_counts = True )
435438 post_count = np .zeros (post_num , dtype = IDX_DTYPE )
436- post_count [unique_post_ids ] = count
439+ post_count = post_count . at [unique_post_ids . value if isinstance ( unique_post_ids , bm . JaxArray ) else unique_post_ids ]. set ( count . value if isinstance ( count , bm . JaxArray ) else count )
437440
438441 indptr_new = post_count .cumsum ()
439442 indptr_new = np .insert (indptr_new , 0 , 0 )
@@ -448,14 +451,14 @@ def csr2csc(csr, post_num, data=None):
448451
449452def mat2csr (dense ):
450453 """convert a dense matrix to (indices, indptr)."""
451- np = onp if isinstance (dense , onp .ndarray ) else bm
454+ np = jnp if isinstance (dense , jnp .ndarray ) else bm
452455
453456 pre_ids , post_ids = np .where (dense > 0 )
454457 pre_num = dense .shape [0 ]
455458
456459 uni_idx , count = np .unique (pre_ids , return_counts = True )
457460 pre_count = np .zeros (pre_num , dtype = IDX_DTYPE )
458- pre_count [uni_idx ] = count
461+ pre_count = pre_count . at [uni_idx . value if isinstance ( uni_idx , bm . JaxArray ) else uni_idx ]. set ( count . value if isinstance ( count , bm . JaxArray ) else count )
459462 indptr = count .cumsum ()
460463 indptr = np .insert (indptr , 0 , 0 )
461464
@@ -465,38 +468,54 @@ def mat2csr(dense):
465468def csr2mat (csr , num_pre , num_post ):
466469 """convert (indices, indptr) to a dense matrix."""
467470 indices , indptr = csr
468- np = onp if isinstance (indices , onp .ndarray ) else bm
471+ np = jnp if isinstance (indices , jnp .ndarray ) else bm
469472
470473 d = np .zeros ((num_pre , num_post ), dtype = MAT_DTYPE ) # num_pre, num_post
471474 pre_ids = np .repeat (np .arange (indptr .size - 1 ), np .diff (indptr ))
472- d [pre_ids , indices ] = True
475+ d = d . at [pre_ids . value if isinstance ( pre_ids , bm . JaxArray ) else pre_ids , indices . value if isinstance ( indices , bm . JaxArray ) else indices ]. set ( True )
473476 return d
474477
475478
476479def ij2mat (ij , num_pre , num_post ):
477480 """convert (indices, indptr) to a dense matrix."""
478481 pre_ids , post_ids = ij
479- np = onp if isinstance (pre_ids , onp .ndarray ) else bm
482+ np = jnp if isinstance (pre_ids , jnp .ndarray ) else bm
480483
481484 d = np .zeros ((num_pre , num_post ), dtype = MAT_DTYPE ) # num_pre, num_post
482- d [pre_ids , post_ids ] = True
485+ d = d . at [pre_ids . value if isinstance ( pre_ids , bm . JaxArray ) else pre_ids , post_ids . value if isinstance ( post_ids , bm . JaxArray ) else post_ids ]. set ( True )
483486 return d
484487
485488
486489def ij2csr (pre_ids , post_ids , num_pre ):
487- """convert pre_ids, post_ids to (indices, indptr)."""
488- np = onp if isinstance (pre_ids , onp .ndarray ) else bm
489- kind = 'quicksort' if isinstance (pre_ids , onp .ndarray ) else 'stable'
490-
491- # sorting
492- sort_ids = np .argsort (pre_ids , kind = kind )
490+ """convert pre_ids, post_ids to (indices, indptr) when'jax_platform_name' = 'gpu'"""
491+ np = jnp if isinstance (pre_ids , jnp .ndarray ) else bm
492+ sort_ids = np .argsort (pre_ids )
493493 post_ids = post_ids [sort_ids .value if isinstance (sort_ids , bm .JaxArray ) else sort_ids ]
494-
495494 indices = post_ids
496495 unique_pre_ids , pre_count = np .unique (pre_ids , return_counts = True )
497- final_pre_count = np .zeros (num_pre , dtype = IDX_DTYPE )
498- final_pre_count [unique_pre_ids ] = pre_count
496+ final_pre_count = np .zeros (num_pre , dtype = jnp . uint32 )
497+ final_pre_count = final_pre_count . at [unique_pre_ids . value if isinstance ( unique_pre_ids , bm . JaxArray ) else unique_pre_ids ]. set ( pre_count . value if isinstance ( pre_count , bm . JaxArray ) else pre_count )
499498 indptr = final_pre_count .cumsum ()
500499 indptr = np .insert (indptr , 0 , 0 )
501500
502501 return np .asarray (indices , dtype = IDX_DTYPE ), np .asarray (indptr , dtype = IDX_DTYPE )
502+
503+ def ij2csr2 (pre_ids , post_ids , num_pre ):
504+ """convert pre_ids, post_ids to (indices, indptr). and use numba for sort function when'jax_platform_name' = 'cpu'"""
505+ np = jnp if isinstance (pre_ids , jnp .ndarray ) else bm
506+ post_ids = onp .asarray (post_ids )
507+ unique_pre_ids , pre_count = onp .unique (pre_ids , return_counts = True )
508+ final_pre_count = onp .zeros (num_pre , dtype = onp .uint32 )
509+ final_pre_count [unique_pre_ids ] = pre_count
510+ indptr = final_pre_count .cumsum ()
511+ indptr = onp .insert (indptr , 0 , 0 )
512+ @numba_jit (parallel = True , nogil = True )
513+ def single_sort (pre_ids ,post_ids ,indptr ):
514+ pre_tmp = indptr .copy ()
515+ indices = onp .zeros ((indptr [- 1 ],))
516+ for i in numba_range (indptr [- 1 ]):
517+ indices [pre_tmp [pre_ids [i ]]]= post_ids [i ]
518+ pre_tmp [pre_ids [i ]]+= 1
519+ return indices
520+ indices = single_sort (onp .asarray (pre_ids ),onp .asarray (post_ids ),indptr )
521+ return np .asarray (indices , dtype = IDX_DTYPE ), np .asarray (indptr , dtype = IDX_DTYPE )
0 commit comments