Skip to content

Commit 85556c5

Browse files
committed
feat: compatible with dict and tuple return types
1 parent 1f973df commit 85556c5

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

brainpy/connect/base.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -228,20 +228,22 @@ def _return_by_ij(self, structures, ij: tuple, all_data: dict):
228228
csr = ij2csr(pre_ids, post_ids, self.pre_num)
229229
self._return_by_csr(structures, csr=csr, all_data=all_data)
230230

231-
def make_returns(self, structures, conn_type, conn_data):
231+
def make_returns(self, structures, conn_data, csr=None, mat=None, ij=None):
232232
"""Make the desired synaptic structures and return them.
233233
"""
234-
csr = None
235-
mat = None
236-
ij = None
237-
if conn_type == 'csr':
238-
csr = conn_data
239-
elif conn_type == 'mat':
240-
mat = conn_data
241-
elif conn_type == 'ij':
242-
ij = conn_data
243-
else:
244-
raise ConnectorError(f'conn_type must be one of "csr", "mat" or "ij", but we got "{conn_type}" instead.')
234+
if isinstance(conn_data, dict):
235+
csr = conn_data['csr']
236+
mat = conn_data['mat']
237+
ij = conn_data['ij']
238+
elif isinstance(conn_data, tuple):
239+
if conn_data[0] == 'csr':
240+
csr = conn_data[1]
241+
elif conn_data[0] == 'mat':
242+
mat = conn_data[1]
243+
elif conn_data[0] == 'ij':
244+
ij = conn_data[1]
245+
else:
246+
raise ConnectorError(f'Must provide one of "csr", "mat" or "ij". Got "{conn_data[0]}" instead.')
245247

246248
# checking
247249
all_data = dict()
@@ -281,12 +283,21 @@ def make_returns(self, structures, conn_type, conn_data):
281283
return tuple([all_data[n] for n in structures])
282284

283285
def build_conn(self):
286+
"""build connections with certain data type.
287+
288+
Returns
289+
-------
290+
A tuple with two elements: connection type (str) and connection data.
291+
example: return 'csr', (ind, indptr)
292+
Or a dict with three elements: csr, mat and ij.
293+
example: return dict(csr=(ind, indptr), mat=None, ij=None)
294+
"""
284295
raise NotImplementedError
285296

286297
def require(self, *structures):
287298
self.check(structures)
288-
conn_type, conn_data = self.build_conn()
289-
return self.make_returns(structures, conn_type, conn_data)
299+
conn_data = self.build_conn()
300+
return self.make_returns(structures, conn_data)
290301

291302
def requires(self, *structures):
292303
return self.require(*structures)

brainpy/connect/regular_conn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def build_conn(self):
3939
ind = np.arange(self.pre_num)
4040
indptr = np.arange(self.pre_num + 1)
4141

42-
return 'csr', (ind, indptr)
42+
return dict(csr=(ind, indptr), mat=None, ij=None)
4343

4444

4545
one2one = One2One()
@@ -60,7 +60,7 @@ def build_conn(self):
6060
if not self.include_self:
6161
np.fill_diagonal(mat, False)
6262

63-
return 'mat', mat
63+
return dict(csr=None, mat=mat, ij=None)
6464

6565

6666
all2all = All2All(include_self=True)

0 commit comments

Comments
 (0)