Skip to content

Commit 7e8dd81

Browse files
[math] taichi operators as default customized operators (#598)
* [dnn] Add dnn.linear taichi implmentation * [math] Remove multiple results of event csrmv and csrmv * [dnn] Fix bugs * [dnn] Update jitconn event atomic=True * [dnn] Replace brainpylib opeartors with taichi customized operators * Update linear.py * Update test_linear.py * [dnn, math] Fix bugs * [math] Fix bugs * Update linear.py * Refactor operators * [math] Fix bugs * [dnn] Fix bugs * [math] Fix bugs * [math] Fix jitconn matvec bugs * Update linear.py * [math] Update operators * [math] Update pytests * [math] Fix pytest bugs * Update test_csrmv.py * Update test_matvec.py * Update test_event_matvec.py * Update test_event_csrmv.py * [math] Update pytests * [math] Fix test case bugs * [math] Add more tolerance for jitconn operators * format the code --------- Co-authored-by: Chaoming Wang <[email protected]>
1 parent 8c57f66 commit 7e8dd81

31 files changed

+5368
-5390
lines changed

brainpy/_src/dnn/linear.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ def __init__(
570570
sharding: Optional[Sharding] = None,
571571
mode: Optional[bm.Mode] = None,
572572
name: Optional[str] = None,
573-
method: str = 'cusparse',
573+
method: str = None,
574574
transpose: bool = True,
575575
):
576576
super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose)
@@ -580,8 +580,7 @@ def update(self, x):
580580
if x.ndim == 1:
581581
return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x,
582582
shape=(self.conn.pre_num, self.conn.post_num),
583-
transpose=self.transpose,
584-
method=self.method)
583+
method=self.method, transpose=self.transpose)
585584
elif x.ndim > 1:
586585
shapes = x.shape[:-1]
587586
x = bm.flatten(x, end_dim=-2)
@@ -593,9 +592,7 @@ def update(self, x):
593592
def _batch_csrmv(self, x):
594593
return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x,
595594
shape=(self.conn.pre_num, self.conn.post_num),
596-
transpose=self.transpose,
597-
method=self.method)
598-
595+
method=self.method, transpose=self.transpose)
599596

600597
class EventCSRLinear(_CSRLayer):
601598
r"""Synaptic matrix multiplication with event CSR sparse computation.
@@ -646,7 +643,6 @@ def _batch_csrmv(self, x):
646643
shape=(self.conn.pre_num, self.conn.post_num),
647644
transpose=self.transpose)
648645

649-
650646
@numba.njit(nogil=True, fastmath=True, parallel=False)
651647
def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w):
652648
out_w[:] = w
@@ -659,7 +655,6 @@ def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w
659655
# out_w[k] = np.clip(out_w[k] + trace[j], w_min, w_max)
660656
out_w[k] = np.minimum(np.maximum(out_w[k] + trace[j], w_min), w_max)
661657

662-
663658
csr_on_pre_update_prim = bm.XLACustomOp(_cpu_csr_on_pre_update)
664659

665660

@@ -671,7 +666,6 @@ def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None):
671666
return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max,
672667
outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]
673668

674-
675669
@numba.njit(nogil=True, fastmath=True, parallel=False)
676670
def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w):
677671
out_w[:] = w
@@ -697,6 +691,7 @@ def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_m
697691
outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]
698692

699693

694+
700695
class CSCLinear(Layer):
701696
r"""Synaptic matrix multiplication with CSC sparse computation.
702697
@@ -1080,7 +1075,7 @@ def __init__(
10801075
mode: Optional[bm.Mode] = None,
10811076
name: Optional[str] = None,
10821077
transpose: bool = False,
1083-
atomic: bool = False,
1078+
atomic: bool = True,
10841079
):
10851080
super().__init__(name=name, mode=mode)
10861081

@@ -1161,7 +1156,7 @@ def __init__(
11611156
mode: Optional[bm.Mode] = None,
11621157
name: Optional[str] = None,
11631158
transpose: bool = False,
1164-
atomic: bool = False,
1159+
atomic: bool = True,
11651160
):
11661161
super().__init__(name=name, mode=mode)
11671162

@@ -1239,7 +1234,7 @@ def __init__(
12391234
seed: Optional[int] = None,
12401235
sharding: Optional[Sharding] = None,
12411236
transpose: bool = False,
1242-
atomic: bool = False,
1237+
atomic: bool = True,
12431238
mode: Optional[bm.Mode] = None,
12441239
name: Optional[str] = None,
12451240
):

brainpy/_src/dnn/tests/test_linear.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,5 @@ def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
213213
self.assertTrue(y2.shape == shape + (200,))
214214
bm.clear_buffer_memory()
215215

216-
217216
if __name__ == '__main__':
218217
absltest.main()
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11

22
from ._info_collection import *
33
from ._csr_matvec import *
4-
from ._csr_matvec_taichi import *
54

0 commit comments

Comments
 (0)