Skip to content

Commit dd5e7ad

Browse files
authored
upgrade connection apis (#299)
upgrade connection apis
2 parents 0ca7fb9 + 48f4e36 commit dd5e7ad

File tree

8 files changed

+240
-141
lines changed

8 files changed

+240
-141
lines changed

brainpy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
__version__ = "2.2.3.6"
3+
__version__ = "2.2.4.0"
44

55
try:
66
import jaxlib

brainpy/connect/base.py

Lines changed: 133 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
'Connector', 'TwoEndConnector', 'OneEndConnector',
2626

2727
# methods
28-
'csr2csc', 'csr2mat', 'mat2csr', 'ij2csr', 'ij2csr2'
28+
'csr2csc', 'csr2mat', 'mat2csr', 'ij2csr'
2929
]
3030

3131
CONN_MAT = 'conn_mat'
@@ -47,6 +47,7 @@
4747
MAT_DTYPE = jnp.bool_
4848
IDX_DTYPE = jnp.uint32
4949

50+
5051
def set_default_dtype(mat_dtype=None, idx_dtype=None):
5152
"""Set the default dtype.
5253
@@ -99,7 +100,7 @@ class TwoEndConnector(Connector):
99100
100101
1. Implementing ``build_conn(self)`` function, which returns one of
101102
the connection data ``csr`` (CSR sparse data, a tuple of <post_ids, inptr>),
102-
``ij`` (COO sparse data, a tuple of <pre_ids, post_ids>), and ``mat``
103+
``ij`` (COO sparse data, a tuple of <pre_ids, post_ids>), or ``mat``
103104
(a binary connection matrix). For instance,
104105
105106
.. code-block:: python
@@ -185,7 +186,7 @@ def is_version2_style(self):
185186
else:
186187
return True
187188

188-
def check(self, structures: Union[Tuple, List, str]):
189+
def _check(self, structures: Union[Tuple, List, str]):
189190
# check synaptic structures
190191
if isinstance(structures, str):
191192
structures = [structures]
@@ -203,15 +204,15 @@ def _return_by_mat(self, structures, mat, all_data: dict):
203204

204205
require_other_structs = len([s for s in structures if s != CONN_MAT]) > 0
205206
if require_other_structs:
206-
np = jnp if isinstance(mat, jnp.ndarray) else bm
207+
np = onp if isinstance(mat, onp.ndarray) else bm
207208
pre_ids, post_ids = np.where(mat > 0)
208209
pre_ids = np.asarray(pre_ids, dtype=IDX_DTYPE)
209210
post_ids = np.asarray(post_ids, dtype=IDX_DTYPE)
210211
self._return_by_ij(structures, ij=(pre_ids, post_ids), all_data=all_data)
211212

212213
def _return_by_csr(self, structures, csr: tuple, all_data: dict):
213214
indices, indptr = csr
214-
np = jnp if isinstance(indices, jnp.ndarray) else bm
215+
np = onp if isinstance(indices, onp.ndarray) else bm
215216
assert self.pre_num == indptr.size - 1
216217

217218
if (CONN_MAT in structures) and (CONN_MAT not in all_data):
@@ -260,15 +261,15 @@ def _return_by_ij(self, structures, ij: tuple, all_data: dict):
260261
require_other_structs = len([s for s in structures
261262
if s not in [CONN_MAT, PRE_IDS, POST_IDS]]) > 0
262263
if require_other_structs:
263-
if config.read('jax_platform_name') == "gpu":
264-
csr = ij2csr(pre_ids, post_ids, self.pre_num)
265-
else:
266-
csr = ij2csr2(pre_ids, post_ids, self.pre_num)
264+
csr = ij2csr(pre_ids, post_ids, self.pre_num)
267265
self._return_by_csr(structures, csr=csr, all_data=all_data)
268266

269-
def make_returns(self, structures, conn_data, csr=None, mat=None, ij=None):
267+
def _make_returns(self, structures, conn_data):
270268
"""Make the desired synaptic structures and return them.
271269
"""
270+
csr = None
271+
mat = None
272+
ij = None
272273
if isinstance(conn_data, dict):
273274
csr = conn_data.get('csr', None)
274275
mat = conn_data.get('mat', None)
@@ -320,28 +321,41 @@ def make_returns(self, structures, conn_data, csr=None, mat=None, ij=None):
320321
else:
321322
return tuple([all_data[n] for n in structures])
322323

323-
@tools.not_customized
324-
def build_conn(self):
325-
"""build connections with certain data type.
324+
def require(self, *structures):
325+
"""Require all the connection data needed.
326326
327-
Returns
328-
-------
329-
A tuple with two elements: connection type (str) and connection data.
330-
example: return 'csr', (ind, indptr)
331-
Or a dict with three elements: csr, mat and ij.
332-
example: return dict(csr=(ind, indptr), mat=None, ij=None)
327+
Examples
328+
--------
329+
330+
>>> import brainpy as bp
331+
>>> conn = bp.connect.FixedProb(0.1)
332+
>>> mat = conn.require(10, 20, 'conn_mat')
333+
>>> mat.shape
334+
(10, 20)
333335
"""
334-
pass
335336

336-
def require(self, *structures):
337+
if len(structures) > 0:
338+
pre_size = None
339+
post_size = None
340+
if not isinstance(structures[0], str):
341+
pre_size = structures[0]
342+
structures = structures[1:]
343+
if len(structures) > 0:
344+
if not isinstance(structures[0], str):
345+
post_size = structures[0]
346+
structures = structures[1:]
347+
if pre_size is not None:
348+
self.__call__(pre_size, post_size)
349+
else:
350+
return tuple()
351+
337352
try:
338353
assert self.pre_num is not None and self.post_num is not None
339354
except AssertionError:
340355
raise ConnectorError(f'self.pre_num or self.post_num is not defined. '
341-
f'Please use self.__call__() '
342-
f'before requiring connection data.')
356+
f'Please use "self.require(pre_size, post_size, DATA1, DATA2, ...)" ')
343357

344-
self.check(structures)
358+
self._check(structures)
345359
if self.is_version2_style:
346360
if len(structures) == 1:
347361
if PRE2POST in structures and not hasattr(self.build_csr, 'not_customized'):
@@ -368,21 +382,74 @@ def require(self, *structures):
368382

369383
else:
370384
conn_data = self.build_conn()
371-
return self.make_returns(structures, conn_data)
385+
return self._make_returns(structures, conn_data)
372386

373387
def requires(self, *structures):
388+
"""Require all the connection data needed."""
374389
return self.require(*structures)
375390

391+
@tools.not_customized
392+
def build_conn(self):
393+
"""build connections with certain data type.
394+
395+
If users want to customize their connections, please provide one
396+
of the following functions:
397+
398+
- ``build_mat()``: build a matrix binary connection matrix.
399+
- ``build_csr()``: build a csr sparse connection data.
400+
- ``build_coo()``: build a coo sparse connection data.
401+
- ``build_conn()``: deprecated.
402+
403+
Returns
404+
-------
405+
conn: tuple, dict
406+
A tuple with two elements: connection type (str) and connection data.
407+
For example: ``return 'csr', (ind, indptr)``
408+
Or a dict with three elements: csr, mat and ij. For example:
409+
``return dict(csr=(ind, indptr), mat=None, ij=None)``
410+
"""
411+
pass
412+
376413
@tools.not_customized
377414
def build_mat(self):
415+
"""Build a binary matrix connection data.
416+
417+
418+
If users want to customize their connections, please provide one
419+
of the following functions:
420+
421+
- ``build_mat()``: build a matrix binary connection matrix.
422+
- ``build_csr()``: build a csr sparse connection data.
423+
- ``build_coo()``: build a coo sparse connection data.
424+
- ``build_conn()``: deprecated.
425+
426+
Returns
427+
-------
428+
conn: Array
429+
A binary matrix with the shape ``(num_pre, num_post)``.
430+
"""
378431
pass
379432

380433
@tools.not_customized
381434
def build_csr(self):
435+
"""Build a csr sparse connection data.
436+
437+
Returns
438+
-------
439+
conn: tuple
440+
A tuple denoting the ``(indices, indptr)``.
441+
"""
382442
pass
383443

384444
@tools.not_customized
385445
def build_coo(self):
446+
"""Build a coo sparse connection data.
447+
448+
Returns
449+
-------
450+
conn: tuple
451+
A tuple denoting the ``(pre_ids, post_ids)``.
452+
"""
386453
pass
387454

388455

@@ -424,7 +491,7 @@ def _reset_conn(self, pre_size, post_size=None):
424491
def csr2csc(csr, post_num, data=None):
425492
"""Convert csr to csc."""
426493
indices, indptr = csr
427-
np = jnp if isinstance(indices, jnp.ndarray) else bm
494+
np = onp if isinstance(indices, onp.ndarray) else bm
428495
# kind = 'quicksort' if isinstance(indices, jnp.ndarray) else 'stable'
429496

430497
pre_ids = np.repeat(np.arange(indptr.size - 1), np.diff(indptr))
@@ -436,7 +503,7 @@ def csr2csc(csr, post_num, data=None):
436503

437504
unique_post_ids, count = np.unique(indices, return_counts=True)
438505
post_count = np.zeros(post_num, dtype=IDX_DTYPE)
439-
post_count = post_count.at[unique_post_ids.value if isinstance(unique_post_ids, bm.JaxArray) else unique_post_ids].set(count.value if isinstance(count, bm.JaxArray) else count)
506+
post_count[unique_post_ids] = count
440507

441508
indptr_new = post_count.cumsum()
442509
indptr_new = np.insert(indptr_new, 0, 0)
@@ -451,14 +518,14 @@ def csr2csc(csr, post_num, data=None):
451518

452519
def mat2csr(dense):
453520
"""convert a dense matrix to (indices, indptr)."""
454-
np = jnp if isinstance(dense, jnp.ndarray) else bm
521+
np = onp if isinstance(dense, onp.ndarray) else bm
455522

456523
pre_ids, post_ids = np.where(dense > 0)
457524
pre_num = dense.shape[0]
458525

459526
uni_idx, count = np.unique(pre_ids, return_counts=True)
460527
pre_count = np.zeros(pre_num, dtype=IDX_DTYPE)
461-
pre_count = pre_count.at[uni_idx.value if isinstance(uni_idx, bm.JaxArray) else uni_idx].set(count.value if isinstance(count, bm.JaxArray) else count)
528+
pre_count[uni_idx] = count
462529
indptr = count.cumsum()
463530
indptr = np.insert(indptr, 0, 0)
464531

@@ -468,54 +535,65 @@ def mat2csr(dense):
468535
def csr2mat(csr, num_pre, num_post):
469536
"""convert (indices, indptr) to a dense matrix."""
470537
indices, indptr = csr
471-
np = jnp if isinstance(indices, jnp.ndarray) else bm
472-
538+
np = onp if isinstance(indices, onp.ndarray) else bm
473539
d = np.zeros((num_pre, num_post), dtype=MAT_DTYPE) # num_pre, num_post
474540
pre_ids = np.repeat(np.arange(indptr.size - 1), np.diff(indptr))
475-
d = d.at[pre_ids.value if isinstance(pre_ids, bm.JaxArray) else pre_ids, indices.value if isinstance(indices, bm.JaxArray) else indices].set(True)
541+
d[pre_ids, indices] = True
476542
return d
477543

478544

479545
def ij2mat(ij, num_pre, num_post):
480546
"""convert (indices, indptr) to a dense matrix."""
481547
pre_ids, post_ids = ij
482-
np = jnp if isinstance(pre_ids, jnp.ndarray) else bm
483-
548+
np = onp if isinstance(pre_ids, onp.ndarray) else bm
484549
d = np.zeros((num_pre, num_post), dtype=MAT_DTYPE) # num_pre, num_post
485-
d = d.at[pre_ids.value if isinstance(pre_ids, bm.JaxArray) else pre_ids, post_ids.value if isinstance(post_ids, bm.JaxArray) else post_ids].set(True)
550+
d[pre_ids, post_ids] = True
486551
return d
487552

488553

489554
def ij2csr(pre_ids, post_ids, num_pre):
490555
"""convert pre_ids, post_ids to (indices, indptr) when'jax_platform_name' = 'gpu'"""
491-
np = jnp if isinstance(pre_ids, jnp.ndarray) else bm
492-
sort_ids = np.argsort(pre_ids)
493-
post_ids = post_ids[sort_ids.value if isinstance(sort_ids, bm.JaxArray) else sort_ids]
556+
if isinstance(pre_ids, onp.ndarray):
557+
return _cpu_ij2csr(pre_ids, post_ids, num_pre)
558+
elif isinstance(pre_ids, (jnp.ndarray, bm.ndarray)):
559+
if pre_ids.device().platform == 'cpu':
560+
return _cpu_ij2csr(pre_ids, post_ids, num_pre)
561+
else:
562+
return _gpu_ij2csr(pre_ids, post_ids, num_pre)
563+
else:
564+
raise TypeError
565+
566+
567+
def _gpu_ij2csr(pre_ids, post_ids, num_pre):
568+
"""convert pre_ids, post_ids to (indices, indptr) when'jax_platform_name' = 'gpu'"""
569+
sort_ids = bm.argsort(pre_ids)
570+
post_ids = post_ids[sort_ids]
494571
indices = post_ids
495-
unique_pre_ids, pre_count = np.unique(pre_ids, return_counts=True)
496-
final_pre_count = np.zeros(num_pre, dtype=jnp.uint32)
497-
final_pre_count = final_pre_count.at[unique_pre_ids.value if isinstance(unique_pre_ids, bm.JaxArray) else unique_pre_ids].set(pre_count.value if isinstance(pre_count, bm.JaxArray) else pre_count)
572+
unique_pre_ids, pre_count = bm.unique(pre_ids, return_counts=True)
573+
final_pre_count = bm.zeros(num_pre, dtype=jnp.uint32)
574+
final_pre_count[unique_pre_ids] = pre_count
498575
indptr = final_pre_count.cumsum()
499-
indptr = np.insert(indptr, 0, 0)
576+
indptr = bm.insert(indptr, 0, 0)
577+
return bm.asarray(indices, dtype=IDX_DTYPE), bm.asarray(indptr, dtype=IDX_DTYPE)
500578

501-
return np.asarray(indices, dtype=IDX_DTYPE), np.asarray(indptr, dtype=IDX_DTYPE)
502579

503-
def ij2csr2(pre_ids, post_ids, num_pre):
580+
def _cpu_ij2csr(pre_ids, post_ids, num_pre):
504581
"""convert pre_ids, post_ids to (indices, indptr). and use numba for sort function when'jax_platform_name' = 'cpu'"""
505-
np = jnp if isinstance(pre_ids, jnp.ndarray) else bm
506-
post_ids = onp.asarray(post_ids)
507-
unique_pre_ids, pre_count = onp.unique(pre_ids, return_counts=True)
508-
final_pre_count = onp.zeros(num_pre, dtype=onp.uint32)
582+
np = onp if isinstance(pre_ids, onp.ndarray) else bm
583+
unique_pre_ids, pre_count = np.unique(pre_ids, return_counts=True)
584+
final_pre_count = np.zeros(num_pre, dtype=np.uint32)
509585
final_pre_count[unique_pre_ids] = pre_count
510586
indptr = final_pre_count.cumsum()
511-
indptr = onp.insert(indptr, 0, 0)
512-
@numba_jit (parallel=True, nogil=True)
513-
def single_sort(pre_ids,post_ids,indptr):
587+
indptr = np.insert(indptr, 0, 0)
588+
589+
@numba_jit(parallel=True, nogil=True)
590+
def single_sort(pre_ids, post_ids, indptr):
514591
pre_tmp = indptr.copy()
515-
indices= onp.zeros((indptr[-1],))
592+
indices = onp.zeros((indptr[-1],))
516593
for i in numba_range(indptr[-1]):
517-
indices[pre_tmp[pre_ids[i]]]=post_ids[i]
518-
pre_tmp[pre_ids[i]]+=1
594+
indices[pre_tmp[pre_ids[i]]] = post_ids[i]
595+
pre_tmp[pre_ids[i]] += 1
519596
return indices
520-
indices = single_sort(onp.asarray(pre_ids),onp.asarray(post_ids),indptr)
597+
598+
indices = single_sort(bm.as_numpy(pre_ids), bm.as_numpy(post_ids), bm.as_numpy(indptr))
521599
return np.asarray(indices, dtype=IDX_DTYPE), np.asarray(indptr, dtype=IDX_DTYPE)

0 commit comments

Comments
 (0)