@@ -570,7 +570,7 @@ def __init__(
570570 sharding : Optional [Sharding ] = None ,
571571 mode : Optional [bm .Mode ] = None ,
572572 name : Optional [str ] = None ,
573- method : str = 'cusparse' ,
573+ method : str = None ,
574574 transpose : bool = True ,
575575 ):
576576 super ().__init__ (name = name , mode = mode , conn = conn , weight = weight , sharding = sharding , transpose = transpose )
@@ -580,8 +580,7 @@ def update(self, x):
580580 if x .ndim == 1 :
581581 return bm .sparse .csrmv (self .weight , self .indices , self .indptr , x ,
582582 shape = (self .conn .pre_num , self .conn .post_num ),
583- transpose = self .transpose ,
584- method = self .method )
583+ method = self .method , transpose = self .transpose )
585584 elif x .ndim > 1 :
586585 shapes = x .shape [:- 1 ]
587586 x = bm .flatten (x , end_dim = - 2 )
@@ -593,9 +592,7 @@ def update(self, x):
593592 def _batch_csrmv (self , x ):
594593 return bm .sparse .csrmv (self .weight , self .indices , self .indptr , x ,
595594 shape = (self .conn .pre_num , self .conn .post_num ),
596- transpose = self .transpose ,
597- method = self .method )
598-
595+ method = self .method , transpose = self .transpose )
599596
600597class EventCSRLinear (_CSRLayer ):
601598 r"""Synaptic matrix multiplication with event CSR sparse computation.
@@ -646,7 +643,6 @@ def _batch_csrmv(self, x):
646643 shape = (self .conn .pre_num , self .conn .post_num ),
647644 transpose = self .transpose )
648645
649-
650646@numba .njit (nogil = True , fastmath = True , parallel = False )
651647def _cpu_csr_on_pre_update (w , indices , indptr , spike , trace , w_min , w_max , out_w ):
652648 out_w [:] = w
@@ -659,7 +655,6 @@ def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w
659655 # out_w[k] = np.clip(out_w[k] + trace[j], w_min, w_max)
660656 out_w [k ] = np .minimum (np .maximum (out_w [k ] + trace [j ], w_min ), w_max )
661657
662-
663658csr_on_pre_update_prim = bm .XLACustomOp (_cpu_csr_on_pre_update )
664659
665660
@@ -671,7 +666,6 @@ def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None):
671666 return csr_on_pre_update_prim (w , indices , indptr , spike , trace , w_min , w_max ,
672667 outs = [jax .ShapeDtypeStruct (w .shape , w .dtype )])[0 ]
673668
674-
675669@numba .njit (nogil = True , fastmath = True , parallel = False )
676670def _cpu_csc_on_pre_update (w , post_ids , indptr , w_ids , spike , trace , w_min , w_max , out_w ):
677671 out_w [:] = w
@@ -697,6 +691,7 @@ def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_m
697691 outs = [jax .ShapeDtypeStruct (w .shape , w .dtype )])[0 ]
698692
699693
694+
700695class CSCLinear (Layer ):
701696 r"""Synaptic matrix multiplication with CSC sparse computation.
702697
@@ -1080,7 +1075,7 @@ def __init__(
10801075 mode : Optional [bm .Mode ] = None ,
10811076 name : Optional [str ] = None ,
10821077 transpose : bool = False ,
1083- atomic : bool = False ,
1078+ atomic : bool = True ,
10841079 ):
10851080 super ().__init__ (name = name , mode = mode )
10861081
@@ -1161,7 +1156,7 @@ def __init__(
11611156 mode : Optional [bm .Mode ] = None ,
11621157 name : Optional [str ] = None ,
11631158 transpose : bool = False ,
1164- atomic : bool = False ,
1159+ atomic : bool = True ,
11651160 ):
11661161 super ().__init__ (name = name , mode = mode )
11671162
@@ -1239,7 +1234,7 @@ def __init__(
12391234 seed : Optional [int ] = None ,
12401235 sharding : Optional [Sharding ] = None ,
12411236 transpose : bool = False ,
1242- atomic : bool = False ,
1237+ atomic : bool = True ,
12431238 mode : Optional [bm .Mode ] = None ,
12441239 name : Optional [str ] = None ,
12451240 ):
0 commit comments