2525 'Connector' , 'TwoEndConnector' , 'OneEndConnector' ,
2626
2727 # methods
28- 'csr2csc' , 'csr2mat' , 'mat2csr' , 'ij2csr' , 'ij2csr2'
28+ 'csr2csc' , 'csr2mat' , 'mat2csr' , 'ij2csr'
2929]
3030
3131CONN_MAT = 'conn_mat'
4747MAT_DTYPE = jnp .bool_
4848IDX_DTYPE = jnp .uint32
4949
50+
5051def set_default_dtype (mat_dtype = None , idx_dtype = None ):
5152 """Set the default dtype.
5253
@@ -99,7 +100,7 @@ class TwoEndConnector(Connector):
99100
100101 1. Implementing ``build_conn(self)`` function, which returns one of
101102 the connection data ``csr`` (CSR sparse data, a tuple of <post_ids, inptr>),
102- ``ij`` (COO sparse data, a tuple of <pre_ids, post_ids>), and ``mat``
103+ ``ij`` (COO sparse data, a tuple of <pre_ids, post_ids>), or ``mat``
103104 (a binary connection matrix). For instance,
104105
105106 .. code-block:: python
@@ -185,7 +186,7 @@ def is_version2_style(self):
185186 else :
186187 return True
187188
188- def check (self , structures : Union [Tuple , List , str ]):
189+ def _check (self , structures : Union [Tuple , List , str ]):
189190 # check synaptic structures
190191 if isinstance (structures , str ):
191192 structures = [structures ]
@@ -203,15 +204,15 @@ def _return_by_mat(self, structures, mat, all_data: dict):
203204
204205 require_other_structs = len ([s for s in structures if s != CONN_MAT ]) > 0
205206 if require_other_structs :
206- np = jnp if isinstance (mat , jnp .ndarray ) else bm
207+ np = onp if isinstance (mat , onp .ndarray ) else bm
207208 pre_ids , post_ids = np .where (mat > 0 )
208209 pre_ids = np .asarray (pre_ids , dtype = IDX_DTYPE )
209210 post_ids = np .asarray (post_ids , dtype = IDX_DTYPE )
210211 self ._return_by_ij (structures , ij = (pre_ids , post_ids ), all_data = all_data )
211212
212213 def _return_by_csr (self , structures , csr : tuple , all_data : dict ):
213214 indices , indptr = csr
214- np = jnp if isinstance (indices , jnp .ndarray ) else bm
215+ np = onp if isinstance (indices , onp .ndarray ) else bm
215216 assert self .pre_num == indptr .size - 1
216217
217218 if (CONN_MAT in structures ) and (CONN_MAT not in all_data ):
@@ -260,15 +261,15 @@ def _return_by_ij(self, structures, ij: tuple, all_data: dict):
260261 require_other_structs = len ([s for s in structures
261262 if s not in [CONN_MAT , PRE_IDS , POST_IDS ]]) > 0
262263 if require_other_structs :
263- if config .read ('jax_platform_name' ) == "gpu" :
264- csr = ij2csr (pre_ids , post_ids , self .pre_num )
265- else :
266- csr = ij2csr2 (pre_ids , post_ids , self .pre_num )
264+ csr = ij2csr (pre_ids , post_ids , self .pre_num )
267265 self ._return_by_csr (structures , csr = csr , all_data = all_data )
268266
269- def make_returns (self , structures , conn_data , csr = None , mat = None , ij = None ):
267+ def _make_returns (self , structures , conn_data ):
270268 """Make the desired synaptic structures and return them.
271269 """
270+ csr = None
271+ mat = None
272+ ij = None
272273 if isinstance (conn_data , dict ):
273274 csr = conn_data .get ('csr' , None )
274275 mat = conn_data .get ('mat' , None )
@@ -320,28 +321,41 @@ def make_returns(self, structures, conn_data, csr=None, mat=None, ij=None):
320321 else :
321322 return tuple ([all_data [n ] for n in structures ])
322323
323- @tools .not_customized
324- def build_conn (self ):
325- """build connections with certain data type.
324+ def require (self , * structures ):
325+ """Require all the connection data needed.
326326
327- Returns
328- -------
329- A tuple with two elements: connection type (str) and connection data.
330- example: return 'csr', (ind, indptr)
331- Or a dict with three elements: csr, mat and ij.
332- example: return dict(csr=(ind, indptr), mat=None, ij=None)
327+ Examples
328+ --------
329+
330+ >>> import brainpy as bp
331+ >>> conn = bp.connect.FixedProb(0.1)
332+ >>> mat = conn.require(10, 20, 'conn_mat')
333+ >>> mat.shape
334+ (10, 20)
333335 """
334- pass
335336
336- def require (self , * structures ):
337+ if len (structures ) > 0 :
338+ pre_size = None
339+ post_size = None
340+ if not isinstance (structures [0 ], str ):
341+ pre_size = structures [0 ]
342+ structures = structures [1 :]
343+ if len (structures ) > 0 :
344+ if not isinstance (structures [0 ], str ):
345+ post_size = structures [0 ]
346+ structures = structures [1 :]
347+ if pre_size is not None :
348+ self .__call__ (pre_size , post_size )
349+ else :
350+ return tuple ()
351+
337352 try :
338353 assert self .pre_num is not None and self .post_num is not None
339354 except AssertionError :
340355 raise ConnectorError (f'self.pre_num or self.post_num is not defined. '
341- f'Please use self.__call__() '
342- f'before requiring connection data.' )
356+ f'Please use "self.require(pre_size, post_size, DATA1, DATA2, ...)" ' )
343357
344- self .check (structures )
358+ self ._check (structures )
345359 if self .is_version2_style :
346360 if len (structures ) == 1 :
347361 if PRE2POST in structures and not hasattr (self .build_csr , 'not_customized' ):
@@ -368,21 +382,74 @@ def require(self, *structures):
368382
369383 else :
370384 conn_data = self .build_conn ()
371- return self .make_returns (structures , conn_data )
385+ return self ._make_returns (structures , conn_data )
372386
373387 def requires (self , * structures ):
388+ """Require all the connection data needed."""
374389 return self .require (* structures )
375390
391+ @tools .not_customized
392+ def build_conn (self ):
393+ """build connections with certain data type.
394+
395+ If users want to customize their connections, please provide one
396+ of the following functions:
397+
398+ - ``build_mat()``: build a matrix binary connection matrix.
399+ - ``build_csr()``: build a csr sparse connection data.
400+ - ``build_coo()``: build a coo sparse connection data.
401+ - ``build_conn()``: deprecated.
402+
403+ Returns
404+ -------
405+ conn: tuple, dict
406+ A tuple with two elements: connection type (str) and connection data.
407+ For example: ``return 'csr', (ind, indptr)``
408+ Or a dict with three elements: csr, mat and ij. For example:
409+ ``return dict(csr=(ind, indptr), mat=None, ij=None)``
410+ """
411+ pass
412+
376413 @tools .not_customized
377414 def build_mat (self ):
415+ """Build a binary matrix connection data.
416+
417+
418+ If users want to customize their connections, please provide one
419+ of the following functions:
420+
421+ - ``build_mat()``: build a matrix binary connection matrix.
422+ - ``build_csr()``: build a csr sparse connection data.
423+ - ``build_coo()``: build a coo sparse connection data.
424+ - ``build_conn()``: deprecated.
425+
426+ Returns
427+ -------
428+ conn: Array
429+ A binary matrix with the shape ``(num_pre, num_post)``.
430+ """
378431 pass
379432
380433 @tools .not_customized
381434 def build_csr (self ):
435+ """Build a csr sparse connection data.
436+
437+ Returns
438+ -------
439+ conn: tuple
440+ A tuple denoting the ``(indices, indptr)``.
441+ """
382442 pass
383443
384444 @tools .not_customized
385445 def build_coo (self ):
446+ """Build a coo sparse connection data.
447+
448+ Returns
449+ -------
450+ conn: tuple
451+ A tuple denoting the ``(pre_ids, post_ids)``.
452+ """
386453 pass
387454
388455
@@ -424,7 +491,7 @@ def _reset_conn(self, pre_size, post_size=None):
424491def csr2csc (csr , post_num , data = None ):
425492 """Convert csr to csc."""
426493 indices , indptr = csr
427- np = jnp if isinstance (indices , jnp .ndarray ) else bm
494+ np = onp if isinstance (indices , onp .ndarray ) else bm
428495 # kind = 'quicksort' if isinstance(indices, jnp.ndarray) else 'stable'
429496
430497 pre_ids = np .repeat (np .arange (indptr .size - 1 ), np .diff (indptr ))
@@ -436,7 +503,7 @@ def csr2csc(csr, post_num, data=None):
436503
437504 unique_post_ids , count = np .unique (indices , return_counts = True )
438505 post_count = np .zeros (post_num , dtype = IDX_DTYPE )
439- post_count = post_count . at [unique_post_ids . value if isinstance ( unique_post_ids , bm . JaxArray ) else unique_post_ids ]. set ( count . value if isinstance ( count , bm . JaxArray ) else count )
506+ post_count [unique_post_ids ] = count
440507
441508 indptr_new = post_count .cumsum ()
442509 indptr_new = np .insert (indptr_new , 0 , 0 )
@@ -451,14 +518,14 @@ def csr2csc(csr, post_num, data=None):
451518
452519def mat2csr (dense ):
453520 """convert a dense matrix to (indices, indptr)."""
454- np = jnp if isinstance (dense , jnp .ndarray ) else bm
521+ np = onp if isinstance (dense , onp .ndarray ) else bm
455522
456523 pre_ids , post_ids = np .where (dense > 0 )
457524 pre_num = dense .shape [0 ]
458525
459526 uni_idx , count = np .unique (pre_ids , return_counts = True )
460527 pre_count = np .zeros (pre_num , dtype = IDX_DTYPE )
461- pre_count = pre_count . at [uni_idx . value if isinstance ( uni_idx , bm . JaxArray ) else uni_idx ]. set ( count . value if isinstance ( count , bm . JaxArray ) else count )
528+ pre_count [uni_idx ] = count
462529 indptr = count .cumsum ()
463530 indptr = np .insert (indptr , 0 , 0 )
464531
@@ -468,54 +535,65 @@ def mat2csr(dense):
468535def csr2mat (csr , num_pre , num_post ):
469536 """convert (indices, indptr) to a dense matrix."""
470537 indices , indptr = csr
471- np = jnp if isinstance (indices , jnp .ndarray ) else bm
472-
538+ np = onp if isinstance (indices , onp .ndarray ) else bm
473539 d = np .zeros ((num_pre , num_post ), dtype = MAT_DTYPE ) # num_pre, num_post
474540 pre_ids = np .repeat (np .arange (indptr .size - 1 ), np .diff (indptr ))
475- d = d . at [pre_ids . value if isinstance ( pre_ids , bm . JaxArray ) else pre_ids , indices . value if isinstance ( indices , bm . JaxArray ) else indices ]. set ( True )
541+ d [pre_ids , indices ] = True
476542 return d
477543
478544
479545def ij2mat (ij , num_pre , num_post ):
480546 """convert (indices, indptr) to a dense matrix."""
481547 pre_ids , post_ids = ij
482- np = jnp if isinstance (pre_ids , jnp .ndarray ) else bm
483-
548+ np = onp if isinstance (pre_ids , onp .ndarray ) else bm
484549 d = np .zeros ((num_pre , num_post ), dtype = MAT_DTYPE ) # num_pre, num_post
485- d = d . at [pre_ids . value if isinstance ( pre_ids , bm . JaxArray ) else pre_ids , post_ids . value if isinstance ( post_ids , bm . JaxArray ) else post_ids ]. set ( True )
550+ d [pre_ids , post_ids ] = True
486551 return d
487552
488553
489554def ij2csr (pre_ids , post_ids , num_pre ):
490555 """convert pre_ids, post_ids to (indices, indptr) when'jax_platform_name' = 'gpu'"""
491- np = jnp if isinstance (pre_ids , jnp .ndarray ) else bm
492- sort_ids = np .argsort (pre_ids )
493- post_ids = post_ids [sort_ids .value if isinstance (sort_ids , bm .JaxArray ) else sort_ids ]
556+ if isinstance (pre_ids , onp .ndarray ):
557+ return _cpu_ij2csr (pre_ids , post_ids , num_pre )
558+ elif isinstance (pre_ids , (jnp .ndarray , bm .ndarray )):
559+ if pre_ids .device ().platform == 'cpu' :
560+ return _cpu_ij2csr (pre_ids , post_ids , num_pre )
561+ else :
562+ return _gpu_ij2csr (pre_ids , post_ids , num_pre )
563+ else :
564+ raise TypeError
565+
566+
567+ def _gpu_ij2csr (pre_ids , post_ids , num_pre ):
568+ """convert pre_ids, post_ids to (indices, indptr) when'jax_platform_name' = 'gpu'"""
569+ sort_ids = bm .argsort (pre_ids )
570+ post_ids = post_ids [sort_ids ]
494571 indices = post_ids
495- unique_pre_ids , pre_count = np .unique (pre_ids , return_counts = True )
496- final_pre_count = np .zeros (num_pre , dtype = jnp .uint32 )
497- final_pre_count = final_pre_count . at [unique_pre_ids . value if isinstance ( unique_pre_ids , bm . JaxArray ) else unique_pre_ids ]. set ( pre_count . value if isinstance ( pre_count , bm . JaxArray ) else pre_count )
572+ unique_pre_ids , pre_count = bm .unique (pre_ids , return_counts = True )
573+ final_pre_count = bm .zeros (num_pre , dtype = jnp .uint32 )
574+ final_pre_count [unique_pre_ids ] = pre_count
498575 indptr = final_pre_count .cumsum ()
499- indptr = np .insert (indptr , 0 , 0 )
576+ indptr = bm .insert (indptr , 0 , 0 )
577+ return bm .asarray (indices , dtype = IDX_DTYPE ), bm .asarray (indptr , dtype = IDX_DTYPE )
500578
501- return np .asarray (indices , dtype = IDX_DTYPE ), np .asarray (indptr , dtype = IDX_DTYPE )
502579
503- def ij2csr2 (pre_ids , post_ids , num_pre ):
580+ def _cpu_ij2csr (pre_ids , post_ids , num_pre ):
504581 """convert pre_ids, post_ids to (indices, indptr). and use numba for sort function when'jax_platform_name' = 'cpu'"""
505- np = jnp if isinstance (pre_ids , jnp .ndarray ) else bm
506- post_ids = onp .asarray (post_ids )
507- unique_pre_ids , pre_count = onp .unique (pre_ids , return_counts = True )
508- final_pre_count = onp .zeros (num_pre , dtype = onp .uint32 )
582+ np = onp if isinstance (pre_ids , onp .ndarray ) else bm
583+ unique_pre_ids , pre_count = np .unique (pre_ids , return_counts = True )
584+ final_pre_count = np .zeros (num_pre , dtype = np .uint32 )
509585 final_pre_count [unique_pre_ids ] = pre_count
510586 indptr = final_pre_count .cumsum ()
511- indptr = onp .insert (indptr , 0 , 0 )
512- @numba_jit (parallel = True , nogil = True )
513- def single_sort (pre_ids ,post_ids ,indptr ):
587+ indptr = np .insert (indptr , 0 , 0 )
588+
589+ @numba_jit (parallel = True , nogil = True )
590+ def single_sort (pre_ids , post_ids , indptr ):
514591 pre_tmp = indptr .copy ()
515- indices = onp .zeros ((indptr [- 1 ],))
592+ indices = onp .zeros ((indptr [- 1 ],))
516593 for i in numba_range (indptr [- 1 ]):
517- indices [pre_tmp [pre_ids [i ]]]= post_ids [i ]
518- pre_tmp [pre_ids [i ]]+= 1
594+ indices [pre_tmp [pre_ids [i ]]] = post_ids [i ]
595+ pre_tmp [pre_ids [i ]] += 1
519596 return indices
520- indices = single_sort (onp .asarray (pre_ids ),onp .asarray (post_ids ),indptr )
597+
598+ indices = single_sort (bm .as_numpy (pre_ids ), bm .as_numpy (post_ids ), bm .as_numpy (indptr ))
521599 return np .asarray (indices , dtype = IDX_DTYPE ), np .asarray (indptr , dtype = IDX_DTYPE )
0 commit comments