Skip to content

Commit a2c399b

Browse files
committed
customizing operator using numba and hand-written ops
1 parent 07533d9 commit a2c399b

File tree

5 files changed

+127
-32
lines changed

5 files changed

+127
-32
lines changed

brainpy/connect/regular_conn.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,21 @@ def build_coo(self):
3838
if self.pre_num != self.post_num:
3939
raise ConnectorError(f'One2One connection must be defined in two groups with the '
4040
f'same size, but {self.pre_num} != {self.post_num}.')
41-
return bm.arange(self.pre_num, dtype=IDX_DTYPE), bm.arange(self.post_num, dtype=IDX_DTYPE),
41+
return np.arange(self.pre_num, dtype=IDX_DTYPE), np.arange(self.post_num, dtype=IDX_DTYPE),
4242

4343
def build_csr(self):
4444
if self.pre_num != self.post_num:
4545
raise ConnectorError(f'One2One connection must be defined in two groups with the '
4646
f'same size, but {self.pre_num} != {self.post_num}.')
47-
ind = bm.arange(self.pre_num)
47+
ind = np.arange(self.pre_num)
4848
indptr = np.arange(self.pre_num + 1)
49-
return bm.asarray(ind, dtype=IDX_DTYPE), bm.arange(indptr, dtype=IDX_DTYPE),
49+
return np.asarray(ind, dtype=IDX_DTYPE), np.arange(indptr, dtype=IDX_DTYPE),
5050

5151
def build_mat(self, pre_size=None, post_size=None):
5252
if self.pre_num != self.post_num:
5353
raise ConnectorError(f'One2One connection must be defined in two groups with the '
5454
f'same size, but {self.pre_num} != {self.post_num}.')
55-
return bm.fill_diagonal(bm.zeros((self.pre_num, self.post_num), dtype=MAT_DTYPE), True)
55+
return np.fill_diagonal(np.zeros((self.pre_num, self.post_num), dtype=MAT_DTYPE), True)
5656

5757

5858
one2one = One2One()
@@ -72,9 +72,9 @@ def __repr__(self):
7272
return f'{self.__class__.__name__}(include_self={self.include_self})'
7373

7474
def build_mat(self):
75-
mat = bm.ones((self.pre_num, self.post_num), dtype=MAT_DTYPE)
75+
mat = np.ones((self.pre_num, self.post_num), dtype=MAT_DTYPE)
7676
if not self.include_self:
77-
bm.fill_diagonal(mat, False)
77+
np.fill_diagonal(mat, False)
7878
return mat
7979

8080

extensions/brainpylib/custom_op/cpu.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
import ctypes
44

5-
import numpy as np
5+
from jax import dtypes
66
from jax.abstract_arrays import ShapedArray
77
from jax.lib import xla_client
8-
from jax import dtypes
98
from numba import types, carray, cfunc
109

1110
_lambda_no = 0
@@ -17,8 +16,14 @@
1716
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
1817

1918

20-
def _compile_cpu_signature(func, input_dtypes, input_shapes,
21-
output_dtypes, output_shapes, debug=False):
19+
def _compile_cpu_signature(
20+
func,
21+
input_dtypes,
22+
input_shapes,
23+
output_dtypes,
24+
output_shapes,
25+
debug=False
26+
):
2227
code_scope = dict(
2328
func_to_call=func,
2429
input_shapes=input_shapes,
@@ -53,7 +58,7 @@ def xla_cpu_custom_call_target(output_ptrs, input_ptrs):
5358

5459
new_f = code_scope['xla_cpu_custom_call_target']
5560
xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr),
56-
types.CPointer(types.voidptr)))(new_f)
61+
types.CPointer(types.voidptr)))(new_f)
5762
target_name = xla_c_rule.native_name.encode("ascii")
5863
capsule = ctypes.pythonapi.PyCapsule_New(
5964
xla_c_rule.address, # A CFFI pointer to a function
@@ -83,8 +88,10 @@ def func_cpu_translation(func, abs_eval_fn, c, *inputs, **info):
8388
for arg in zip(output_dtypes, output_shapes, output_layouts)]
8489
xla_output_shape = xla_client.Shape.tuple_shape(xla_output_shapes)
8590
target_name = _compile_cpu_signature(func,
86-
input_dtypes, input_dimensions,
87-
output_dtypes, output_shapes)
91+
input_dtypes,
92+
input_dimensions,
93+
output_dtypes,
94+
output_shapes)
8895

8996
return xla_client.ops.CustomCallWithLayout(
9097
c,

extensions/brainpylib/custom_op/gpu.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def find_path_of_symbol_in_library(symbol):
6464
except Exception:
6565
numba_cffi_loaded = False
6666

67-
6867
if numba_cffi_loaded:
6968
# functions needed
7069
ffi = FFI()
@@ -86,7 +85,6 @@ def find_path_of_symbol_in_library(symbol):
8685
memcpyDeviceToHost = types.int32(2)
8786
memcpyDeviceToDevice = types.int32(3)
8887

89-
9088
_lambda_no = 0
9189
ctypes.pythonapi.PyCapsule_New.argtypes = [
9290
ctypes.c_void_p, # void* pointer
@@ -96,8 +94,14 @@ def find_path_of_symbol_in_library(symbol):
9694
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
9795

9896

99-
def _compile_gpu_signature(func, input_dtypes, input_shapes,
100-
output_dtypes, output_shapes):
97+
def _compile_gpu_signature(
98+
func,
99+
input_dtypes,
100+
input_shapes,
101+
output_dtypes,
102+
output_shapes,
103+
debug=False
104+
):
101105
input_byte_size = tuple(
102106
np.prod(shape) * dtype.itemsize
103107
for (shape, dtype) in zip(input_shapes, input_dtypes)
@@ -159,7 +163,7 @@ def xla_gpu_custom_call_target(stream, inout_gpu_ptrs, opaque, opaque_len):
159163
args_out="\n ".join(args_out),
160164
cuMemcpyAsync_in="\n ".join(cuMemcpyAsync_in),
161165
cuMemcpyAsync_out="\n ".join(cuMemcpyAsync_out))
162-
# print(code_string)
166+
if debug: print(code_string)
163167
exec(compile(code_string.strip(), '', 'exec'), code_scope)
164168

165169
new_f = code_scope['xla_gpu_custom_call_target']
@@ -197,8 +201,10 @@ def func_gpu_translation(func, abs_eval_fn, c, *inputs, **info):
197201
for arg in zip(output_dtypes, output_shapes, output_layouts)]
198202
xla_output_shape = xla_client.Shape.tuple_shape(xla_output_shapes)
199203
target_name = _compile_gpu_signature(func,
200-
input_dtypes, input_dimensions,
201-
output_dtypes, output_shapes)
204+
input_dtypes,
205+
input_dimensions,
206+
output_dtypes,
207+
output_shapes)
202208

203209
return xla_client.ops.CustomCallWithLayout(
204210
c,
@@ -207,6 +213,3 @@ def func_gpu_translation(func, abs_eval_fn, c, *inputs, **info):
207213
operand_shapes_with_layout=input_shapes,
208214
shape_with_layout=xla_output_shape,
209215
)
210-
211-
212-

extensions/brainpylib/custom_op/tests/a.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
12
import brainpy.math as bm
23
import brainpy as bp
34
from jax.abstract_arrays import ShapedArray
5+
import numba
46

7+
bm.set_platform('cpu')
58

69
def try1():
710
def abs_eval(events, indices, indptr, *, weight, post_num):
@@ -11,18 +14,20 @@ def con_compute(outs, ins):
1114
post_val, = outs
1215
post_val.fill(0)
1316
events, indices, indptr, weight, _ = ins
14-
weight = weight[()]
17+
# weight = weight[()]
18+
weight = weight
19+
print(weight)
1520
for i in range(events.size):
1621
if events[i]:
17-
for j in range(indptr[i], indptr[i + 1]):
18-
index = indices[j]
19-
post_val[index] += weight
22+
for j in numba.prange(indptr[i], indptr[i + 1]):
23+
post_val[indices[j]] += weight
2024

2125
event_sum = bm.XLACustomOp(eval_shape=abs_eval, con_compute=con_compute)
2226

23-
events = bm.random.rand(10) < 0.2
24-
indices, indptr = bp.conn.FixedProb(0.1)(10, 20).require('pre2post')
25-
print(bm.jit(event_sum, static_argnames=('weight', 'post_num'))(events, indices, indptr, weight=1., post_num=20))
27+
events = bm.random.RandomState(123).rand(10) < 0.2
28+
indices, indptr = bp.conn.FixedProb(0.1, seed=123)(10, 20).require('pre2post')
29+
# print(bm.jit(, static_argnames=('weight', 'post_num'))(events, indices, indptr, weight=1., post_num=20))
30+
print(event_sum(events, indices, indptr, weight=1., post_num=20))
2631

2732

2833
def try2():
@@ -41,8 +46,8 @@ def con_compute(outs, ins):
4146

4247
event_sum = bm.XLACustomOp(eval_shape=abs_eval, con_compute=con_compute)
4348

44-
events = bm.random.rand(10) < 0.2
45-
indices, indptr = bp.conn.FixedProb(0.1)(10, 20).require('pre2post')
49+
events = bm.random.RandomState(123).rand(10) < 0.2
50+
indices, indptr = bp.conn.FixedProb(0.1, seed=123)(10, 20).require('pre2post')
4651
print(bm.jit(event_sum)(events, indices, indptr, bm.zeros(20), 1.))
4752

4853

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import brainpy.math as bm
2+
import brainpy as bp
3+
from jax.abstract_arrays import ShapedArray
4+
5+
bm.set_platform('cpu')
6+
7+
8+
def abs_eval(events, indices, indptr, *, weight, post_num):
9+
return ShapedArray((post_num,), bm.float32)
10+
11+
12+
def con_compute(outs, ins):
13+
post_val, = outs
14+
post_val.fill(0)
15+
events, indices, indptr, weight, _ = ins
16+
weight = weight[()]
17+
for i in range(events.size):
18+
if events[i]:
19+
for j in range(indptr[i], indptr[i + 1]):
20+
index = indices[j]
21+
post_val[index] += weight
22+
23+
24+
event_sum = bm.XLACustomOp(eval_shape=abs_eval, cpu_func=con_compute, apply_cpu_func_to_gpu=True)
25+
26+
27+
class ExponentialV2(bp.dyn.TwoEndConn):
28+
"""Exponential synapse model using customized operator written in C++."""
29+
30+
def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.):
31+
super(ExponentialV2, self).__init__(pre=pre, post=post, conn=conn)
32+
self.check_pre_attrs('spike')
33+
self.check_post_attrs('input', 'V')
34+
35+
# parameters
36+
self.E = E
37+
self.tau = tau
38+
self.delay = delay
39+
self.g_max = g_max
40+
self.pre2post = self.conn.require('pre2post')
41+
42+
# variables
43+
self.g = bm.Variable(bm.zeros(self.post.num))
44+
45+
# function
46+
self.integral = bp.odeint(lambda g, t: -g / self.tau, method='exp_auto')
47+
48+
def update(self, tdi):
49+
self.g.value = self.integral(self.g, tdi.t, tdi.dt)
50+
self.g += event_sum(self.pre.spike,
51+
self.pre2post[0],
52+
self.pre2post[1],
53+
weight=self.g_max,
54+
post_num=self.post.num)
55+
self.post.input += self.g * (self.E - self.post.V)
56+
57+
58+
class EINet(bp.dyn.Network):
59+
def __init__(self, scale):
60+
# neurons
61+
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
62+
V_initializer=bp.init.Normal(-55., 2.))
63+
E = bp.neurons.LIF(int(3200 * scale), **pars, method='exp_auto')
64+
I = bp.neurons.LIF(int(800 * scale), **pars, method='exp_auto')
65+
66+
# synapses
67+
E2E = ExponentialV2(E, E, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.)
68+
E2I = ExponentialV2(E, I, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.)
69+
I2E = ExponentialV2(I, E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.)
70+
I2I = ExponentialV2(I, I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.)
71+
72+
super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I)
73+
74+
75+
net2 = EINet(scale=10.)
76+
runner2 = bp.dyn.DSRunner(net2, inputs=[('E.input', 20.), ('I.input', 20.)])
77+
t, _ = runner2.predict(100., eval_time=True)
78+
print(t)
79+
80+

0 commit comments

Comments
 (0)