Skip to content

Commit 359dcbe

Browse files
committed
update brainpylib
1 parent 10d25bb commit 359dcbe

File tree

8 files changed

+120
-71
lines changed

8 files changed

+120
-71
lines changed

brainpy/dyn/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -756,8 +756,8 @@ def __init__(
756756

757757
def __repr__(self):
758758
names = self.__class__.__name__
759-
return (f'{names}(name={self.name}, mode={self.mode}, '
760-
f'{" " * len(names)} pre={self.pre}, '
759+
return (f'{names}(name={self.name}, mode={self.mode}, \n'
760+
f'{" " * len(names)} pre={self.pre}, \n'
761761
f'{" " * len(names)} post={self.post})')
762762

763763
def check_pre_attrs(self, *attrs):

brainpy/dyn/layers/dropout.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@ class Dropout(DynamicalSystem):
1515
In training, to compensate for the fraction of input values dropped (`rate`),
1616
all surviving values are multiplied by `1 / (1 - rate)`.
1717
18-
The parameter `shared_axes` allows to specify a list of axes on which
19-
the mask will be shared: we will use size 1 on those axes for dropout mask
20-
and broadcast it. Sharing reduces randomness, but can save memory.
21-
2218
This layer is active only during training (`mode='train'`). In other
2319
circumstances it is a no-op.
2420

extensions/brainpylib/atomic_sum.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def _atomic_sum_translation(c, values, pre_ids, post_ids, *, post_num, platform=
115115
shape_with_layout=x_shape(np.dtype(values_dtype), (post_num,), (0,)),
116116
)
117117
elif platform == 'gpu':
118-
if gpu_ops is None: raise ValueError('Cannot find compiled gpu wheels.')
118+
if gpu_ops is None:
119+
raise ValueError('Cannot find compiled gpu wheels.')
119120

120121
opaque = gpu_ops.build_atomic_sum_descriptor(conn_size, post_num)
121122
if values_dim[0] != 1:

extensions/brainpylib/event_sum.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@
77

88
from functools import partial
99

10+
from typing import Union, Tuple
1011
import jax.numpy as jnp
1112
import numpy as np
12-
from jax import core
13+
from jax import core, dtypes
1314
from jax.abstract_arrays import ShapedArray
1415
from jax.interpreters import xla, batching
1516
from jax.lax import scan
1617
from jax.lib import xla_client
1718

19+
from .utils import GPUOperatorNotFound
20+
1821
try:
1922
from . import gpu_ops
2023
except ImportError:
@@ -26,7 +29,10 @@
2629
_event_sum_prim = core.Primitive("event_sum")
2730

2831

29-
def event_sum(events, pre2post, post_num, values):
32+
def event_sum(events: jnp.ndarray,
33+
pre2post: Tuple[jnp.ndarray, jnp.ndarray],
34+
post_num: int,
35+
values: Union[float, jnp.ndarray]):
3036
# events
3137
if events.dtype != jnp.bool_:
3238
raise ValueError(f'"events" must be a vector of bool, while we got {events.dtype}')
@@ -39,17 +45,16 @@ def event_sum(events, pre2post, post_num, values):
3945
if indices.dtype != indptr.dtype:
4046
raise ValueError(f"The dtype of pre2post[0] must be equal to that of pre2post[1], "
4147
f"while we got {(indices.dtype, indptr.dtype)}")
42-
if indices.dtype not in [jnp.uint32, jnp.uint64]:
43-
raise ValueError(f'The dtype of pre2post must be uint32 or uint64, while we got {indices.dtype}')
48+
if indices.dtype not in [jnp.uint32, jnp.uint64, jnp.int32, jnp.int64]:
49+
raise ValueError(f'The dtype of pre2post must be integer, while we got {indices.dtype}')
4450

4551
# output value
46-
values = jnp.asarray([values])
47-
if values.dtype not in [jnp.float32, jnp.float64]:
48-
raise ValueError(f'The dtype of "values" must be float32 or float64, while we got {values.dtype}.')
49-
if values.size not in [1, indices.size]:
52+
dtype = values.dtype if isinstance(values, jnp.ndarray) else dtypes.canonicalize_dtype(type(values))
53+
if dtype not in [jnp.float32, jnp.float64]:
54+
raise ValueError(f'The dtype of "values" must be float32 or float64, while we got {dtype}.')
55+
if np.size(values) not in [1, indices.size]:
5056
raise ValueError(f'The size of "values" must be 1 (a scalar) or len(pre2post[0]) (a vector), '
51-
f'while we got {values.size} != 1 != {indices.size}')
52-
values = values.flatten()
57+
f'while we got {np.size(values)} != 1 != {indices.size}')
5358
# bind operator
5459
return _event_sum_prim.bind(events, indices, indptr, values, post_num=post_num)
5560

@@ -58,34 +63,27 @@ def _event_sum_abstract(events, indices, indptr, values, *, post_num):
5863
return ShapedArray(dtype=values.dtype, shape=(post_num,))
5964

6065

61-
_event_sum_prim.def_abstract_eval(_event_sum_abstract)
62-
_event_sum_prim.def_impl(partial(xla.apply_primitive, _event_sum_prim))
63-
64-
6566
def _event_sum_translation(c, events, indices, indptr, values, *, post_num, platform="cpu"):
66-
# The pre/post shape
67+
# The shape of pre/post
6768
pre_size = np.array(c.get_shape(events).dimensions()[0], dtype=np.uint32)
6869
_pre_shape = x_shape(np.dtype(np.uint32), (), ())
6970
_post_shape = x_shape(np.dtype(np.uint32), (), ())
7071

7172
# The indices shape
7273
indices_shape = c.get_shape(indices)
7374
Itype = indices_shape.element_type()
74-
assert Itype in [np.uint32, np.uint64]
7575

7676
# The value shape
7777
values_shape = c.get_shape(values)
7878
Ftype = values_shape.element_type()
79-
assert Ftype in [np.float32, np.float64]
8079
values_dim = values_shape.dimensions()
8180

8281
# We dispatch a different call depending on the dtype
83-
f_type = b'_f32' if Ftype == np.float32 else b'_f64'
84-
i_type = b'_i32' if Itype == np.uint32 else b'_i64'
82+
f_type = b'_f32' if Ftype in np.float32 else b'_f64'
83+
i_type = b'_i32' if Itype in [np.uint32, np.int32] else b'_i64'
8584

86-
# And then the following is what changes between the GPU and CPU
8785
if platform == "cpu":
88-
v_type = b'_event_sum_homo' if values_dim[0] == 1 else b'_event_sum_heter'
86+
v_type = b'_event_sum_homo' if len(values_dim) == 0 else b'_event_sum_heter'
8987
return x_ops.CustomCallWithLayout(
9088
c,
9189
platform.encode() + v_type + f_type + i_type,
@@ -103,9 +101,12 @@ def _event_sum_translation(c, events, indices, indptr, values, *, post_num, plat
103101
c.get_shape(values)),
104102
shape_with_layout=x_shape(np.dtype(Ftype), (post_num,), (0,)),
105103
)
104+
105+
# GPU platform
106106
elif platform == 'gpu':
107107
if gpu_ops is None:
108-
raise ValueError('Cannot find compiled gpu wheels.')
108+
raise GPUOperatorNotFound('event_sum')
109+
109110
v_type = b'_event_sum_homo' if values_dim[0] == 1 else b'_event_sum_heter'
110111
opaque = gpu_ops.build_event_sum_descriptor(pre_size, post_num)
111112
return x_ops.CustomCallWithLayout(
@@ -127,11 +128,7 @@ def _event_sum_translation(c, events, indices, indptr, values, *, post_num, plat
127128
raise ValueError("Unsupported platform, we only support 'cpu' or 'gpu'")
128129

129130

130-
xla.backend_specific_translations["cpu"][_event_sum_prim] = partial(_event_sum_translation, platform="cpu")
131-
xla.backend_specific_translations["gpu"][_event_sum_prim] = partial(_event_sum_translation, platform="gpu")
132-
133-
134-
def _event_sum_batch(args, axes):
131+
def _event_sum_batch(args, axes, *, post_num):
135132
batch_axes, batch_args, non_batch_args = [], {}, {}
136133
for ax_i, ax in enumerate(axes):
137134
if ax is None:
@@ -143,19 +140,22 @@ def _event_sum_batch(args, axes):
143140
def f(_, x):
144141
pars = tuple([(x[f'ax{i}'] if i in batch_axes else non_batch_args[f'ax{i}'])
145142
for i in range(len(axes))])
146-
return 0, _event_sum_prim.bind(*pars)
143+
return 0, _event_sum_prim.bind(*pars, post_num=post_num)
144+
147145
_, outs = scan(f, 0, batch_args)
148146
return outs, 0
149147

150148

149+
_event_sum_prim.def_abstract_eval(_event_sum_abstract)
150+
_event_sum_prim.def_impl(partial(xla.apply_primitive, _event_sum_prim))
151151
batching.primitive_batchers[_event_sum_prim] = _event_sum_batch
152-
152+
xla.backend_specific_translations["cpu"][_event_sum_prim] = partial(_event_sum_translation, platform="cpu")
153+
xla.backend_specific_translations["gpu"][_event_sum_prim] = partial(_event_sum_translation, platform="gpu")
153154

154155
# ---------------------------
155156
# event sum kernel 2
156157
# ---------------------------
157158

158-
159159
_event_sum2_prim = core.Primitive("event_sum2")
160160

161161

extensions/brainpylib/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
__all__ = [
5+
'GPUOperatorNotFound',
6+
]
7+
8+
9+
class GPUOperatorNotFound(Exception):
10+
def __init__(self, name):
11+
super(GPUOperatorNotFound, self).__init__(f'''
12+
GPU operator for "{name}" does not found.
13+
14+
Please compile brainpylib GPU operators with the guidance in the following link:
15+
16+
https://brainpy.readthedocs.io/en/latest/tutorial_advanced/compile_brainpylib.html
17+
''')
18+

extensions/lib/event_sum_cpu.cc

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,44 +4,82 @@ namespace brainpy_lib {
44
namespace{
55
template <typename F, typename I>
66
void cpu_event_sum_homo(void *out, const void **in) {
7-
// Parse the inputs
87
const std::uint32_t pre_size = *reinterpret_cast<const std::uint32_t *>(in[0]);
98
const std::uint32_t post_size = *reinterpret_cast<const std::uint32_t *>(in[1]);
109
const bool *events = reinterpret_cast<const bool *>(in[2]);
1110
const I *indices = reinterpret_cast<const I *>(in[3]);
1211
const I *indptr = reinterpret_cast<const I *>(in[4]);
13-
const F *values = reinterpret_cast<const F *>(in[5]);
14-
const F value = values[0];
12+
const F weight = *reinterpret_cast<const F *>(in[5]);
13+
F *result = reinterpret_cast<F *>(out);
1514

16-
// The output
15+
// algorithm
16+
memset(&result[0], 0, sizeof(F) * post_size);
17+
for (std::uint32_t i=0; i<pre_size; ++i) {
18+
if (events[i]){
19+
for (I j=indptr[i]; j<indptr[i+1]; ++j) {
20+
result[indices[j]] += weight;
21+
}
22+
}
23+
}
24+
}
25+
26+
// TODO:: batch version of "event_sum_homo" CPU operator
27+
template <typename F, typename I>
28+
void cpu_event_sum_batch_homo(void *out, const void **in) {
29+
const std::uint32_t pre_size = *reinterpret_cast<const std::uint32_t *>(in[0]);
30+
const std::uint32_t post_size = *reinterpret_cast<const std::uint32_t *>(in[1]);
31+
const bool *events = reinterpret_cast<const bool *>(in[2]);
32+
const I *indices = reinterpret_cast<const I *>(in[3]);
33+
const I *indptr = reinterpret_cast<const I *>(in[4]);
34+
const F weight = *reinterpret_cast<const F *>(in[5]);
1735
F *result = reinterpret_cast<F *>(out);
1836

1937
// algorithm
20-
memset(&result[0], 0, sizeof(result[0]) * post_size);
38+
memset(&result[0], 0, sizeof(F) * post_size);
2139
for (std::uint32_t i=0; i<pre_size; ++i) {
2240
if (events[i]){
2341
for (I j=indptr[i]; j<indptr[i+1]; ++j) {
24-
result[indices[j]] += value;
42+
result[indices[j]] += weight;
2543
}
2644
}
2745
}
2846
}
2947

3048
template <typename F, typename I>
3149
void cpu_event_sum_heter(void *out, const void **in) {
32-
// Parse the inputs
3350
const std::uint32_t pre_size = *reinterpret_cast<const std::uint32_t *>(in[0]);
3451
const std::uint32_t post_size = *reinterpret_cast<const std::uint32_t *>(in[1]);
3552
const bool *events = reinterpret_cast<const bool *>(in[2]);
3653
const I *indices = reinterpret_cast<const I *>(in[3]);
3754
const I *indptr = reinterpret_cast<const I *>(in[4]);
3855
const F *values = reinterpret_cast<const F *>(in[5]);
56+
F *result = reinterpret_cast<F *>(out);
57+
58+
// algorithm
59+
memset(&result[0], 0, sizeof(F) * post_size);
60+
for (std::uint32_t i = 0; i < pre_size; ++i) {
61+
if (events[i]){
62+
for (I j = indptr[i]; j < indptr[i+1]; ++j) {
63+
result[indices[j]] += values[j];
64+
}
65+
}
66+
}
67+
}
68+
3969

40-
// The output
70+
// TODO:: batch version of "event_sum_heter" CPU operator
71+
template <typename F, typename I>
72+
void cpu_event_sum_batch_heter(void *out, const void **in) {
73+
const std::uint32_t pre_size = *reinterpret_cast<const std::uint32_t *>(in[0]);
74+
const std::uint32_t post_size = *reinterpret_cast<const std::uint32_t *>(in[1]);
75+
const bool *events = reinterpret_cast<const bool *>(in[2]);
76+
const I *indices = reinterpret_cast<const I *>(in[3]);
77+
const I *indptr = reinterpret_cast<const I *>(in[4]);
78+
const F *values = reinterpret_cast<const F *>(in[5]);
4179
F *result = reinterpret_cast<F *>(out);
4280

4381
// algorithm
44-
memset(&result[0], 0, sizeof(result[0]) * post_size);
82+
memset(&result[0], 0, sizeof(F) * post_size);
4583
for (std::uint32_t i = 0; i < pre_size; ++i) {
4684
if (events[i]){
4785
for (I j = indptr[i]; j < indptr[i+1]; ++j) {
@@ -50,6 +88,8 @@ namespace{
5088
}
5189
}
5290
}
91+
92+
5393
}
5494

5595
void cpu_event_sum_homo_f32_i32(void *out, const void **in){cpu_event_sum_homo<float, std::uint32_t>(out, in);}

extensions/lib/event_sum_gpu.cu

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -458,8 +458,7 @@ namespace brainpy_lib {
458458
if (threadIdx.x < num_event) {
459459
const unsigned int pre_i = (r * 32) + threadIdx.x;
460460
shared_events[threadIdx.x] = events[pre_i];
461-
if (shared_events[threadIdx.x])
462-
{
461+
if (shared_events[threadIdx.x]) {
463462
shPreStartID[threadIdx.x] = indptr[pre_i];
464463
shRowLength[threadIdx.x] = indptr[pre_i + 1] - shPreStartID[threadIdx.x];
465464
}
@@ -532,8 +531,7 @@ namespace brainpy_lib {
532531
if (threadIdx.x < num_event) {
533532
const unsigned int pre_i = (r * 32) + threadIdx.x;
534533
shared_events[threadIdx.x] = events[pre_i];
535-
if (shared_events[threadIdx.x])
536-
{
534+
if (shared_events[threadIdx.x]) {
537535
shPreStartID[threadIdx.x] = indptr[pre_i];
538536
shRowLength[threadIdx.x] = indptr[pre_i + 1] - shPreStartID[threadIdx.x];
539537
}
@@ -553,7 +551,6 @@ namespace brainpy_lib {
553551
}
554552

555553

556-
557554
template<typename F, typename I>
558555
inline void gpu_event_sum4_heter(cudaStream_t stream,
559556
void **buffers,
@@ -578,17 +575,16 @@ namespace brainpy_lib {
578575
cudaMemset(result, 0, sizeof(F) * post_size);
579576
event_sum4_heter_kernel<F, I><<<grid_dim, block_dim,
580577
/*dynamic_shared_mem_bytes=*/0, stream>>>(max_post_conn,
581-
pre_size,
582-
events,
583-
indices,
584-
indptr,
585-
values,
586-
result);
578+
pre_size,
579+
events,
580+
indices,
581+
indptr,
582+
values,
583+
result);
587584
ThrowIfError(cudaGetLastError());
588585
}
589586

590587

591-
592588
} // namespace
593589

594590

@@ -758,24 +754,15 @@ namespace brainpy_lib {
758754
}
759755

760756
// heterogeneous event sum 3
761-
void gpu_event_sum3_heter_f32_i32(cudaStream_t stream,
762-
void **buffers,
763-
const char *opaque,
764-
std::size_t opaque_len) {
757+
void gpu_event_sum3_heter_f32_i32(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
765758
gpu_event_sum3_heter<float, std::uint32_t>(stream, buffers, opaque, opaque_len);
766759
}
767760

768-
void gpu_event_sum3_heter_f32_i64(cudaStream_t stream,
769-
void **buffers,
770-
const char *opaque,
771-
std::size_t opaque_len) {
761+
void gpu_event_sum3_heter_f32_i64(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
772762
gpu_event_sum3_heter<float, std::uint64_t>(stream, buffers, opaque, opaque_len);
773763
}
774764

775-
void gpu_event_sum3_heter_f64_i32(cudaStream_t stream,
776-
void **buffers,
777-
const char *opaque,
778-
std::size_t opaque_len) {
765+
void gpu_event_sum3_heter_f64_i32(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
779766
gpu_event_sum3_heter<double, std::uint32_t>(stream, buffers, opaque, opaque_len);
780767
}
781768

0 commit comments

Comments
 (0)