Skip to content

Commit d80b715

Browse files
authored
optimizes the connect time when using gpu (#293)
optimizes the connect time when using gpu (#293)
2 parents b9bce46 + cb1018c commit d80b715

File tree

2 files changed

+114
-72
lines changed

2 files changed

+114
-72
lines changed

brainpy/connect/base.py

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66
import jax.numpy as jnp
77
import numpy as onp
8-
8+
from jax import config
99
from brainpy import tools, math as bm
1010
from brainpy.errors import ConnectorError
11+
from brainpy.tools.others import numba_jit, numba_range
1112

1213
__all__ = [
1314
# the connection types
@@ -24,7 +25,7 @@
2425
'Connector', 'TwoEndConnector', 'OneEndConnector',
2526

2627
# methods
27-
'csr2csc', 'csr2mat', 'mat2csr', 'ij2csr'
28+
'csr2csc', 'csr2mat', 'mat2csr', 'ij2csr', 'ij2csr2'
2829
]
2930

3031
CONN_MAT = 'conn_mat'
@@ -43,9 +44,8 @@
4344
PRE2SYN, POST2SYN,
4445
PRE_SLICE, POST_SLICE]
4546

46-
MAT_DTYPE = onp.bool_
47-
IDX_DTYPE = onp.uint32
48-
47+
MAT_DTYPE = jnp.bool_
48+
IDX_DTYPE = jnp.uint32
4949

5050
def set_default_dtype(mat_dtype=None, idx_dtype=None):
5151
"""Set the default dtype.
@@ -203,15 +203,15 @@ def _return_by_mat(self, structures, mat, all_data: dict):
203203

204204
require_other_structs = len([s for s in structures if s != CONN_MAT]) > 0
205205
if require_other_structs:
206-
np = onp if isinstance(mat, onp.ndarray) else bm
206+
np = jnp if isinstance(mat, jnp.ndarray) else bm
207207
pre_ids, post_ids = np.where(mat > 0)
208208
pre_ids = np.asarray(pre_ids, dtype=IDX_DTYPE)
209209
post_ids = np.asarray(post_ids, dtype=IDX_DTYPE)
210210
self._return_by_ij(structures, ij=(pre_ids, post_ids), all_data=all_data)
211211

212212
def _return_by_csr(self, structures, csr: tuple, all_data: dict):
213213
indices, indptr = csr
214-
np = onp if isinstance(indices, onp.ndarray) else bm
214+
np = jnp if isinstance(indices, jnp.ndarray) else bm
215215
assert self.pre_num == indptr.size - 1
216216

217217
if (CONN_MAT in structures) and (CONN_MAT not in all_data):
@@ -260,7 +260,10 @@ def _return_by_ij(self, structures, ij: tuple, all_data: dict):
260260
require_other_structs = len([s for s in structures
261261
if s not in [CONN_MAT, PRE_IDS, POST_IDS]]) > 0
262262
if require_other_structs:
263-
csr = ij2csr(pre_ids, post_ids, self.pre_num)
263+
if config.read('jax_platform_name') == "gpu":
264+
csr = ij2csr(pre_ids, post_ids, self.pre_num)
265+
else:
266+
csr = ij2csr2(pre_ids, post_ids, self.pre_num)
264267
self._return_by_csr(structures, csr=csr, all_data=all_data)
265268

266269
def make_returns(self, structures, conn_data, csr=None, mat=None, ij=None):
@@ -421,19 +424,19 @@ def _reset_conn(self, pre_size, post_size=None):
421424
def csr2csc(csr, post_num, data=None):
422425
"""Convert csr to csc."""
423426
indices, indptr = csr
424-
np = onp if isinstance(indices, onp.ndarray) else bm
425-
kind = 'quicksort' if isinstance(indices, onp.ndarray) else 'stable'
427+
np = jnp if isinstance(indices, jnp.ndarray) else bm
428+
# kind = 'quicksort' if isinstance(indices, jnp.ndarray) else 'stable'
426429

427430
pre_ids = np.repeat(np.arange(indptr.size - 1), np.diff(indptr))
428431

429-
sort_ids = np.argsort(indices, kind=kind) # to maintain the original order of the elements with the same value
432+
sort_ids = np.argsort(indices) # to maintain the original order of the elements with the same value
430433
if isinstance(sort_ids, bm.JaxArray):
431434
sort_ids = sort_ids.value
432435
pre_ids_new = np.asarray(pre_ids[sort_ids], dtype=IDX_DTYPE)
433436

434437
unique_post_ids, count = np.unique(indices, return_counts=True)
435438
post_count = np.zeros(post_num, dtype=IDX_DTYPE)
436-
post_count[unique_post_ids] = count
439+
post_count = post_count.at[unique_post_ids.value if isinstance(unique_post_ids, bm.JaxArray) else unique_post_ids].set(count.value if isinstance(count, bm.JaxArray) else count)
437440

438441
indptr_new = post_count.cumsum()
439442
indptr_new = np.insert(indptr_new, 0, 0)
@@ -448,14 +451,14 @@ def csr2csc(csr, post_num, data=None):
448451

449452
def mat2csr(dense):
450453
"""convert a dense matrix to (indices, indptr)."""
451-
np = onp if isinstance(dense, onp.ndarray) else bm
454+
np = jnp if isinstance(dense, jnp.ndarray) else bm
452455

453456
pre_ids, post_ids = np.where(dense > 0)
454457
pre_num = dense.shape[0]
455458

456459
uni_idx, count = np.unique(pre_ids, return_counts=True)
457460
pre_count = np.zeros(pre_num, dtype=IDX_DTYPE)
458-
pre_count[uni_idx] = count
461+
pre_count = pre_count.at[uni_idx.value if isinstance(uni_idx, bm.JaxArray) else uni_idx].set(count.value if isinstance(count, bm.JaxArray) else count)
459462
indptr = count.cumsum()
460463
indptr = np.insert(indptr, 0, 0)
461464

@@ -465,38 +468,54 @@ def mat2csr(dense):
465468
def csr2mat(csr, num_pre, num_post):
466469
"""convert (indices, indptr) to a dense matrix."""
467470
indices, indptr = csr
468-
np = onp if isinstance(indices, onp.ndarray) else bm
471+
np = jnp if isinstance(indices, jnp.ndarray) else bm
469472

470473
d = np.zeros((num_pre, num_post), dtype=MAT_DTYPE) # num_pre, num_post
471474
pre_ids = np.repeat(np.arange(indptr.size - 1), np.diff(indptr))
472-
d[pre_ids, indices] = True
475+
d = d.at[pre_ids.value if isinstance(pre_ids, bm.JaxArray) else pre_ids, indices.value if isinstance(indices, bm.JaxArray) else indices].set(True)
473476
return d
474477

475478

476479
def ij2mat(ij, num_pre, num_post):
477480
"""convert (indices, indptr) to a dense matrix."""
478481
pre_ids, post_ids = ij
479-
np = onp if isinstance(pre_ids, onp.ndarray) else bm
482+
np = jnp if isinstance(pre_ids, jnp.ndarray) else bm
480483

481484
d = np.zeros((num_pre, num_post), dtype=MAT_DTYPE) # num_pre, num_post
482-
d[pre_ids, post_ids] = True
485+
d = d.at[pre_ids.value if isinstance(pre_ids, bm.JaxArray) else pre_ids, post_ids.value if isinstance(post_ids, bm.JaxArray) else post_ids].set(True)
483486
return d
484487

485488

486489
def ij2csr(pre_ids, post_ids, num_pre):
487-
"""convert pre_ids, post_ids to (indices, indptr)."""
488-
np = onp if isinstance(pre_ids, onp.ndarray) else bm
489-
kind = 'quicksort' if isinstance(pre_ids, onp.ndarray) else 'stable'
490-
491-
# sorting
492-
sort_ids = np.argsort(pre_ids, kind=kind)
490+
"""convert pre_ids, post_ids to (indices, indptr) when'jax_platform_name' = 'gpu'"""
491+
np = jnp if isinstance(pre_ids, jnp.ndarray) else bm
492+
sort_ids = np.argsort(pre_ids)
493493
post_ids = post_ids[sort_ids.value if isinstance(sort_ids, bm.JaxArray) else sort_ids]
494-
495494
indices = post_ids
496495
unique_pre_ids, pre_count = np.unique(pre_ids, return_counts=True)
497-
final_pre_count = np.zeros(num_pre, dtype=IDX_DTYPE)
498-
final_pre_count[unique_pre_ids] = pre_count
496+
final_pre_count = np.zeros(num_pre, dtype=jnp.uint32)
497+
final_pre_count = final_pre_count.at[unique_pre_ids.value if isinstance(unique_pre_ids, bm.JaxArray) else unique_pre_ids].set(pre_count.value if isinstance(pre_count, bm.JaxArray) else pre_count)
499498
indptr = final_pre_count.cumsum()
500499
indptr = np.insert(indptr, 0, 0)
501500

502501
return np.asarray(indices, dtype=IDX_DTYPE), np.asarray(indptr, dtype=IDX_DTYPE)
502+
503+
def ij2csr2(pre_ids, post_ids, num_pre):
504+
"""convert pre_ids, post_ids to (indices, indptr). and use numba for sort function when'jax_platform_name' = 'cpu'"""
505+
np = jnp if isinstance(pre_ids, jnp.ndarray) else bm
506+
post_ids = onp.asarray(post_ids)
507+
unique_pre_ids, pre_count = onp.unique(pre_ids, return_counts=True)
508+
final_pre_count = onp.zeros(num_pre, dtype=onp.uint32)
509+
final_pre_count[unique_pre_ids] = pre_count
510+
indptr = final_pre_count.cumsum()
511+
indptr = onp.insert(indptr, 0, 0)
512+
@numba_jit (parallel=True, nogil=True)
513+
def single_sort(pre_ids,post_ids,indptr):
514+
pre_tmp = indptr.copy()
515+
indices= onp.zeros((indptr[-1],))
516+
for i in numba_range(indptr[-1]):
517+
indices[pre_tmp[pre_ids[i]]]=post_ids[i]
518+
pre_tmp[pre_ids[i]]+=1
519+
return indices
520+
indices = single_sort(onp.asarray(pre_ids),onp.asarray(post_ids),indptr)
521+
return np.asarray(indices, dtype=IDX_DTYPE), np.asarray(indptr, dtype=IDX_DTYPE)

0 commit comments

Comments
 (0)