Skip to content

Commit 237607d

Browse files
committed
[test] update csrmv tests
1 parent a25fa78 commit 237607d

File tree

7 files changed

+148
-158
lines changed

7 files changed

+148
-158
lines changed

brainpy/_src/dyn/neurons/input_groups.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -141,30 +141,28 @@ def __init__(
141141
# variables
142142
self.reset_state(self.mode)
143143

144-
# functions
145-
def cond_fun(t):
146-
i = self.i.value
147-
return bm.logical_and(i < self.num_times, t >= self.times[i])
148-
149-
def body_fun(t):
150-
i = self.i.value
151-
if isinstance(self.mode, bm.BatchingMode):
152-
self.spike[:, self.indices[i]] = True
153-
else:
154-
self.spike[self.indices[i]] = True
155-
self.i += 1
156-
157-
self._run = bm.make_while(cond_fun, body_fun, dyn_vars=self.vars())
158-
159144
def reset_state(self, batch_size=None):
160145
self.i = bm.Variable(bm.asarray(0))
161146
self.spike = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size)
162147

163148
def update(self):
164149
self.spike.value = bm.zeros_like(self.spike)
165-
self._run(share.load('t'))
150+
bm.while_loop(self._cond_fun, self._body_fun, share.load('t'))
166151
return self.spike.value
167152

153+
# functions
154+
def _cond_fun(self, t):
155+
i = self.i.value
156+
return bm.logical_and(i < self.num_times, t >= self.times[i])
157+
158+
def _body_fun(self, t):
159+
i = self.i.value
160+
if isinstance(self.mode, bm.BatchingMode):
161+
self.spike[:, self.indices[i]] = True
162+
else:
163+
self.spike[self.indices[i]] = True
164+
self.i += 1
165+
168166

169167
class PoissonGroup(NeuGroupNS):
170168
"""Poisson Neuron Group.

brainpy/_src/dyn/rates/populations.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1041,7 +1041,6 @@ def reset_state(self, batch_size=None):
10411041
self.Ii.value = variable(bm.zeros, batch_size, self.varshape)
10421042

10431043
def update(self, x1=None, x2=None):
1044-
t = share.load('t')
10451044
dt = share.load('dt')
10461045

10471046
# input

brainpy/_src/math/sparse/_bsr_mm.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,20 @@
11
# -*- coding: utf-8 -*-
22

3-
import warnings
43
from functools import partial
5-
from typing import Union, Tuple
64

75
import jax.lax
86
import numba
97
import numpy as np
10-
from jax import core, numpy as jnp, dtypes, default_backend, random
11-
from jax.interpreters import ad, mlir, xla
12-
from jax.lib import xla_client
8+
from jax import numpy as jnp
139
from jax.core import Primitive, ShapedArray
14-
from jaxlib import gpu_sparse
10+
from jax.interpreters import ad, xla
11+
from jax.lib import xla_client
1512

16-
from brainpy._src.math.op_registers import (compile_cpu_signature_with_numba,
17-
register_general_batching)
18-
from brainpy._src.math.sparse._utils import csr_to_coo
1913
from brainpy._src.math.interoperability import as_jax
14+
from brainpy._src.math.op_registers import (compile_cpu_signature_with_numba,
15+
register_general_batching)
2016
from brainpy.errors import GPUOperatorNotFound
2117

22-
import brainpylib as bl
23-
2418
try:
2519
from brainpylib import gpu_ops
2620
except ImportError:

brainpy/_src/math/sparse/_bsr_mv.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11

2-
from typing import Union, Tuple
32
from functools import partial
3+
from typing import Union, Tuple
4+
45
import numba
56
import numpy as np
6-
from jax import numpy as jnp, ensure_compile_time_eval
7+
from jax import numpy as jnp
78
from jax.core import ShapedArray, Primitive
8-
from jax.lib import xla_client
99
from jax.interpreters import ad, xla
10+
from jax.lib import xla_client
1011

11-
from brainpy.errors import GPUOperatorNotFound
12-
from brainpy._src.math.op_registers import (compile_cpu_signature_with_numba,register_op_with_numba,
13-
register_general_batching)
1412
from brainpy._src.math.interoperability import as_jax
13+
from brainpy._src.math.op_registers import (compile_cpu_signature_with_numba,
14+
register_general_batching)
1515
from brainpy._src.math.sparse._utils import csr_to_coo
16-
16+
from brainpy.errors import GPUOperatorNotFound
1717

1818
try:
1919
from brainpylib import gpu_ops

0 commit comments

Comments
 (0)