Skip to content

Commit 978ea66

Browse files
committed
updates
1 parent 12615a3 commit 978ea66

File tree

3 files changed

+21
-304
lines changed

3 files changed

+21
-304
lines changed

.github/ISSUE_TEMPLATE/feature_request.md

Lines changed: 0 additions & 10 deletions
This file was deleted.

brainpy/dyn/synapses/abstract_models.py

Lines changed: 15 additions & 202 deletions
Original file line numberDiff line numberDiff line change
@@ -104,33 +104,8 @@ def __init__(
104104
self.check_post_attrs('refractory')
105105

106106
# connections and weights
107-
self.conn_type = conn_type
108-
if conn_type not in ['sparse', 'dense']:
109-
raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}')
110-
if self.conn is None:
111-
raise ValueError(f'Must provide "conn" when initialize the model {self.name}')
112-
if isinstance(self.conn, One2One):
113-
self.weights = init_param(weights, (self.pre.num,), allow_none=False)
114-
self.weight_type = 'heter' if bm.size(self.weights) != 1 else 'homo'
115-
elif isinstance(self.conn, All2All):
116-
self.weights = init_param(weights, (self.pre.num, self.post.num), allow_none=False)
117-
if bm.size(self.weights) != 1:
118-
self.weight_type = 'heter'
119-
bm.fill_diagonal(self.weights, 0.)
120-
else:
121-
self.weight_type = 'homo'
122-
else:
123-
if conn_type == 'sparse':
124-
self.pre2post = self.conn.require('pre2post')
125-
self.weights = init_param(weights, self.pre2post[1].shape, allow_none=False)
126-
self.weight_type = 'heter' if bm.size(self.weights) != 1 else 'homo'
127-
elif conn_type == 'dense':
128-
self.weights = init_param(weights, (self.pre.num, self.post.num), allow_none=False)
129-
self.weight_type = 'heter' if bm.size(self.weights) != 1 else 'homo'
130-
if self.weight_type == 'homo':
131-
self.conn_mat = self.conn.require('conn_mat')
132-
else:
133-
raise ValueError(f'Unknown connection type: {conn_type}')
107+
self.weights = weights
108+
self.pre2post = self.conn.require('pre2post')
134109

135110
# variables
136111
self.delay_step = self.register_delay(f"{self.pre.name}.spike",
@@ -144,33 +119,7 @@ def update(self, t, dt):
144119
# delays
145120
pre_spike = self.get_delay_data(f"{self.pre.name}.spike", delay_step=self.delay_step)
146121

147-
# post values
148-
assert self.weight_type in ['homo', 'heter']
149-
assert self.conn_type in ['sparse', 'dense']
150-
if isinstance(self.conn, All2All):
151-
pre_spike = pre_spike.astype(bm.float_)
152-
if self.weight_type == 'homo':
153-
post_vs = bm.sum(pre_spike)
154-
if not self.conn.include_self:
155-
post_vs = post_vs - pre_spike
156-
post_vs *= self.weights
157-
else:
158-
post_vs = pre_spike @ self.weights
159-
elif isinstance(self.conn, One2One):
160-
pre_spike = pre_spike.astype(bm.float_)
161-
post_vs = pre_spike * self.weights
162-
else:
163-
if self.conn_type == 'sparse':
164-
post_vs = bm.pre2post_event_sum(pre_spike,
165-
self.pre2post,
166-
self.post.num,
167-
self.weights)
168-
else:
169-
pre_spike = pre_spike.astype(bm.float_)
170-
if self.weight_type == 'homo':
171-
post_vs = self.weights * (pre_spike @ self.conn_mat)
172-
else:
173-
post_vs = pre_spike @ self.weights
122+
post_vs = bm.pre2post_event_sum(pre_spike, self.pre2post, self.post.num, self.weights)
174123

175124
# update outputs
176125
target = getattr(self.post, self.post_key)
@@ -299,33 +248,8 @@ def __init__(
299248
f'But we got {self.tau}')
300249

301250
# connections and weights
302-
self.conn_type = conn_type
303-
if conn_type not in ['sparse', 'dense']:
304-
raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}')
305-
if self.conn is None:
306-
raise ValueError(f'Must provide "conn" when initialize the model {self.name}')
307-
if isinstance(self.conn, One2One):
308-
self.g_max = init_param(g_max, (self.pre.num,), allow_none=False)
309-
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
310-
elif isinstance(self.conn, All2All):
311-
self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False)
312-
if bm.size(self.g_max) != 1:
313-
self.weight_type = 'heter'
314-
bm.fill_diagonal(self.g_max, 0.)
315-
else:
316-
self.weight_type = 'homo'
317-
else:
318-
if conn_type == 'sparse':
319-
self.pre2post = self.conn.require('pre2post')
320-
self.g_max = init_param(g_max, self.pre2post[1].shape, allow_none=False)
321-
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
322-
elif conn_type == 'dense':
323-
self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False)
324-
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
325-
if self.weight_type == 'homo':
326-
self.conn_mat = self.conn.require('conn_mat')
327-
else:
328-
raise ValueError(f'Unknown connection type: {conn_type}')
251+
self.pre2post = self.conn.require('pre2post')
252+
self.g_max = init_param(g_max, self.pre2post[1].shape, allow_none=False)
329253

330254
# variables
331255
self.g = bm.Variable(bm.zeros(self.post.num))
@@ -344,33 +268,10 @@ def update(self, t, dt):
344268
pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
345269

346270
# post values
347-
assert self.weight_type in ['homo', 'heter']
348-
assert self.conn_type in ['sparse', 'dense']
349-
if isinstance(self.conn, All2All):
350-
pre_spike = pre_spike.astype(bm.float_)
351-
if self.weight_type == 'homo':
352-
post_vs = bm.sum(pre_spike)
353-
if not self.conn.include_self:
354-
post_vs = post_vs - pre_spike
355-
post_vs = self.g_max * post_vs
356-
else:
357-
post_vs = pre_spike @ self.g_max
358-
elif isinstance(self.conn, One2One):
359-
pre_spike = pre_spike.astype(bm.float_)
360-
post_vs = pre_spike * self.g_max
361-
else:
362-
if self.conn_type == 'sparse':
363-
post_vs = bm.pre2post_event_sum(pre_spike,
364-
self.pre2post,
365-
self.post.num,
366-
self.g_max)
367-
else:
368-
pre_spike = pre_spike.astype(bm.float_)
369-
if self.weight_type == 'homo':
370-
post_vs = self.g_max * (pre_spike @ self.conn_mat)
371-
else:
372-
post_vs = pre_spike @ self.g_max
373-
271+
post_vs = bm.pre2post_event_sum(pre_spike,
272+
self.pre2post,
273+
self.post.num,
274+
self.g_max)
374275
# updates
375276
self.g.value = self.integral(self.g.value, t, dt=dt) + post_vs
376277
self.post.input += self.output(self.g)
@@ -619,33 +520,8 @@ def __init__(
619520
f'But we got {self.tau_decay}')
620521

621522
# connections
622-
self.conn_type = conn_type
623-
if conn_type not in ['sparse', 'dense']:
624-
raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}')
625-
if self.conn is None:
626-
raise ValueError(f'Must provide "conn" when initialize the model {self.name}')
627-
if isinstance(self.conn, One2One):
628-
self.g_max = init_param(g_max, (self.pre.num,), allow_none=False)
629-
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
630-
elif isinstance(self.conn, All2All):
631-
self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False)
632-
if bm.size(self.g_max) != 1:
633-
self.weight_type = 'heter'
634-
bm.fill_diagonal(self.g_max, 0.)
635-
else:
636-
self.weight_type = 'homo'
637-
else:
638-
if conn_type == 'sparse':
639-
self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids')
640-
self.g_max = init_param(g_max, self.post_ids.shape, allow_none=False)
641-
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
642-
elif conn_type == 'dense':
643-
self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False)
644-
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
645-
if self.weight_type == 'homo':
646-
self.conn_mat = self.conn.require('conn_mat')
647-
else:
648-
raise ValueError(f'Unknown connection type: {conn_type}')
523+
self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids')
524+
self.g_max = init_param(g_max, self.post_ids.shape, allow_none=False)
649525

650526
# variables
651527
self.h = bm.Variable(bm.zeros(self.pre.num))
@@ -674,26 +550,7 @@ def update(self, t, dt):
674550
self.h += pre_spike
675551

676552
# post-synaptic values
677-
assert self.weight_type in ['homo', 'heter']
678-
assert self.conn_type in ['sparse', 'dense']
679-
if isinstance(self.conn, All2All):
680-
if self.weight_type == 'homo':
681-
post_vs = bm.sum(self.g)
682-
if not self.conn.include_self:
683-
post_vs = post_vs - self.g
684-
post_vs = self.g_max * post_vs
685-
else:
686-
post_vs = self.g @ self.g_max
687-
elif isinstance(self.conn, One2One):
688-
post_vs = self.g_max * self.g
689-
else:
690-
if self.conn_type == 'sparse':
691-
post_vs = bm.pre2post_sum(self.g, self.post.num, self.post_ids, self.pre_ids)
692-
else:
693-
if self.weight_type == 'homo':
694-
post_vs = (self.g_max * self.g) @ self.conn_mat
695-
else:
696-
post_vs = self.g @ self.g_max
553+
post_vs = bm.pre2post_sum(self.g, self.post.num, self.post_ids, self.pre_ids)
697554

698555
# output
699556
self.post.input += self.output(post_vs)
@@ -1199,33 +1056,8 @@ def __init__(
11991056
raise ValueError(f'"tau_rise" must be a scalar or a tensor with size of 1. But we got {tau_rise}')
12001057

12011058
# connections and weights
1202-
self.conn_type = conn_type
1203-
if conn_type not in ['sparse', 'dense']:
1204-
raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}')
1205-
if self.conn is None:
1206-
raise ValueError(f'Must provide "conn" when initialize the model {self.name}')
1207-
if isinstance(self.conn, One2One):
1208-
self.g_max = init_param(g_max, (self.pre.num,), allow_none=False)
1209-
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
1210-
elif isinstance(self.conn, All2All):
1211-
self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False)
1212-
if bm.size(self.g_max) != 1:
1213-
self.weight_type = 'heter'
1214-
bm.fill_diagonal(self.g_max, 0.)
1215-
else:
1216-
self.weight_type = 'homo'
1217-
else:
1218-
if conn_type == 'sparse':
1219-
self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids')
1220-
self.g_max = init_param(g_max, self.post_ids.shape, allow_none=False)
1221-
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
1222-
elif conn_type == 'dense':
1223-
self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False)
1224-
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
1225-
if self.weight_type == 'homo':
1226-
self.conn_mat = self.conn.require('conn_mat')
1227-
else:
1228-
raise ValueError(f'Unknown connection type: {conn_type}')
1059+
self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids')
1060+
self.g_max = init_param(g_max, self.post_ids.shape, allow_none=False)
12291061

12301062
# variables
12311063
self.g = bm.Variable(bm.zeros(self.pre.num, dtype=bm.float_))
@@ -1254,26 +1086,7 @@ def update(self, t, dt):
12541086
self.x += delayed_pre_spike
12551087

12561088
# post-synaptic value
1257-
assert self.weight_type in ['homo', 'heter']
1258-
assert self.conn_type in ['sparse', 'dense']
1259-
if isinstance(self.conn, All2All):
1260-
if self.weight_type == 'homo':
1261-
post_g = bm.sum(self.g)
1262-
if not self.conn.include_self:
1263-
post_g = post_g - self.g
1264-
post_g = post_g * self.g_max
1265-
else:
1266-
post_g = self.g @ self.g_max
1267-
elif isinstance(self.conn, One2One):
1268-
post_g = self.g_max * self.g
1269-
else:
1270-
if self.conn_type == 'sparse':
1271-
post_g = bm.pre2post_sum(self.g, self.post.num, self.post_ids, self.pre_ids)
1272-
else:
1273-
if self.weight_type == 'homo':
1274-
post_g = (self.g_max * self.g) @ self.conn_mat
1275-
else:
1276-
post_g = self.g @ self.g_max
1089+
post_g = bm.pre2post_sum(self.g, self.post.num, self.post_ids, self.pre_ids)
12771090

12781091
# output
12791092
g_inf = 1 + self.cc_Mg / self.beta * bm.exp(-self.alpha * self.post.V)

0 commit comments

Comments
 (0)