@@ -228,20 +228,22 @@ def _return_by_ij(self, structures, ij: tuple, all_data: dict):
228228 csr = ij2csr (pre_ids , post_ids , self .pre_num )
229229 self ._return_by_csr (structures , csr = csr , all_data = all_data )
230230
231- def make_returns (self , structures , conn_type , conn_data ):
231+ def make_returns (self , structures , conn_data , csr = None , mat = None , ij = None ):
232232 """Make the desired synaptic structures and return them.
233233 """
234- csr = None
235- mat = None
236- ij = None
237- if conn_type == 'csr' :
238- csr = conn_data
239- elif conn_type == 'mat' :
240- mat = conn_data
241- elif conn_type == 'ij' :
242- ij = conn_data
243- else :
244- raise ConnectorError (f'conn_type must be one of "csr", "mat" or "ij", but we got "{ conn_type } " instead.' )
234+ if isinstance (conn_data , dict ):
235+ csr = conn_data ['csr' ]
236+ mat = conn_data ['mat' ]
237+ ij = conn_data ['ij' ]
238+ elif isinstance (conn_data , tuple ):
239+ if conn_data [0 ] == 'csr' :
240+ csr = conn_data [1 ]
241+ elif conn_data [0 ] == 'mat' :
242+ mat = conn_data [1 ]
243+ elif conn_data [0 ] == 'ij' :
244+ ij = conn_data [1 ]
245+ else :
246+ raise ConnectorError (f'Must provide one of "csr", "mat" or "ij". Got "{ conn_data [0 ]} " instead.' )
245247
246248 # checking
247249 all_data = dict ()
@@ -281,12 +283,21 @@ def make_returns(self, structures, conn_type, conn_data):
281283 return tuple ([all_data [n ] for n in structures ])
282284
283285 def build_conn (self ):
286+ """build connections with certain data type.
287+
288+ Returns
289+ -------
290+ A tuple with two elements: connection type (str) and connection data.
291+ example: return 'csr', (ind, indptr)
292+ Or a dict with three elements: csr, mat and ij.
293+ example: return dict(csr=(ind, indptr), mat=None, ij=None)
294+ """
284295 raise NotImplementedError
285296
286297 def require (self , * structures ):
287298 self .check (structures )
288- conn_type , conn_data = self .build_conn ()
289- return self .make_returns (structures , conn_type , conn_data )
299+ conn_data = self .build_conn ()
300+ return self .make_returns (structures , conn_data )
290301
291302 def requires (self , * structures ):
292303 return self .require (* structures )
0 commit comments