Skip to content

Commit 5ef4f01

Browse files
authored
andrew's improvements
Implement support for saving/loading Groups in NeoMatlabIO.
2 parents 8e7e67c + c51b2ce commit 5ef4f01

File tree

3 files changed

+63
-44
lines changed

3 files changed

+63
-44
lines changed

neo/core/spiketrainlist.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ def __getitem__(self, i):
115115
else:
116116
return SpikeTrainList(items=items)
117117

118+
def __setitem__(self, i, value):
119+
if self._items is None:
120+
self._spiketrains_from_array()
121+
self._items[i] = value
122+
118123
def __str__(self):
119124
"""Return str(self)"""
120125
if self._items is None:

neo/io/neomatlabio.py

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def read_block(self, lazy=False):
227227
bl_struct = d['block']
228228
bl = self.create_ob_from_struct(
229229
bl_struct, 'Block')
230+
self._resolve_references(bl)
230231
bl.check_relationships()
231232
return bl
232233

@@ -242,38 +243,23 @@ def write_block(self, bl, **kargs):
242243
seg_struct = self.create_struct_from_obj(seg)
243244
bl_struct['segments'].append(seg_struct)
244245

245-
for anasig in seg.analogsignals:
246-
anasig_struct = self.create_struct_from_obj(anasig)
247-
seg_struct['analogsignals'].append(anasig_struct)
248-
249-
for irrsig in seg.irregularlysampledsignals:
250-
irrsig_struct = self.create_struct_from_obj(irrsig)
251-
seg_struct['irregularlysampledsignals'].append(irrsig_struct)
252-
253-
for ea in seg.events:
254-
ea_struct = self.create_struct_from_obj(ea)
255-
seg_struct['events'].append(ea_struct)
256-
257-
for ea in seg.epochs:
258-
ea_struct = self.create_struct_from_obj(ea)
259-
seg_struct['epochs'].append(ea_struct)
260-
261-
for sptr in seg.spiketrains:
262-
sptr_struct = self.create_struct_from_obj(sptr)
263-
seg_struct['spiketrains'].append(sptr_struct)
264-
265-
for image_sq in seg.imagesequences:
266-
image_sq_structure = self.create_struct_from_obj(image_sq)
267-
seg_struct['imagesequences'].append(image_sq_structure)
246+
for container_name in seg._child_containers:
247+
for child_obj in getattr(seg, container_name):
248+
child_struct = self.create_struct_from_obj(child_obj)
249+
seg_struct[container_name].append(child_struct)
268250

269251
for group in bl.groups:
270252
group_structure = self.create_struct_from_obj(group)
271253
bl_struct['groups'].append(group_structure)
272254

255+
for container_name in group._child_containers:
256+
for child_obj in getattr(group, container_name):
257+
group_structure[container_name].append(id(child_obj))
258+
273259
scipy.io.savemat(self.filename, {'block': bl_struct}, oned_as='row')
274260

275261
def create_struct_from_obj(self, ob):
276-
struct = {}
262+
struct = {"neo_id": id(ob)}
277263

278264
# relationship
279265
for childname in getattr(ob, '_child_containers', []):
@@ -290,11 +276,6 @@ def create_struct_from_obj(self, ob):
290276
for i, attr in enumerate(all_attrs):
291277
attrname, attrtype = attr[0], attr[1]
292278

293-
# ~ if attrname =='':
294-
# ~ struct['array'] = ob.magnitude
295-
# ~ struct['units'] = ob.dimensionality.string
296-
# ~ continue
297-
298279
if (hasattr(ob, '_quantity_attr') and
299280
ob._quantity_attr == attrname):
300281
struct[attrname] = ob.magnitude
@@ -320,13 +301,6 @@ def create_struct_from_obj(self, ob):
320301

321302
def create_ob_from_struct(self, struct, classname):
322303
cl = class_by_name[classname]
323-
# check if inherits Quantity
324-
# ~ is_quantity = False
325-
# ~ for attr in cl._necessary_attrs:
326-
# ~ if attr[0] == '' and attr[1] == pq.Quantity:
327-
# ~ is_quantity = True
328-
# ~ break
329-
# ~ is_quantiy = hasattr(cl, '_quantity_attr')
330304

331305
# ~ if is_quantity:
332306
if hasattr(cl, '_quantity_attr'):
@@ -374,20 +348,27 @@ def create_ob_from_struct(self, struct, classname):
374348
# check children
375349
if attrname in getattr(ob, '_child_containers', []):
376350
child_struct = getattr(struct, attrname)
351+
child_class_name = classname_lower_to_upper[attrname[:-1]]
377352
try:
378353
# try must only surround len() or other errors are captured
379354
child_len = len(child_struct)
380355
except TypeError:
381356
# strange scipy.io behavior: if len is 1 there is no len()
382-
child = self.create_ob_from_struct(
383-
child_struct,
384-
classname_lower_to_upper[attrname[:-1]])
357+
if classname == "Group":
358+
child = _Ref(child_struct, child_class_name)
359+
else:
360+
child = self.create_ob_from_struct(
361+
child_struct,
362+
child_class_name)
385363
getattr(ob, attrname.lower()).append(child)
386364
else:
387365
for c in range(child_len):
388-
child = self.create_ob_from_struct(
389-
child_struct[c],
390-
classname_lower_to_upper[attrname[:-1]])
366+
if classname == "Group":
367+
child = _Ref(child_struct[c], child_class_name)
368+
else:
369+
child = self.create_ob_from_struct(
370+
child_struct[c],
371+
child_class_name)
391372
getattr(ob, attrname.lower()).append(child)
392373
continue
393374

@@ -432,4 +413,31 @@ def create_ob_from_struct(self, struct, classname):
432413

433414
setattr(ob, attrname, item)
434415

416+
neo_id = getattr(struct, "neo_id", None)
417+
if neo_id:
418+
setattr(ob, "_id", neo_id)
435419
return ob
420+
421+
def _resolve_references(self, bl):
422+
if bl.groups:
423+
obj_lookup = {}
424+
for ob in bl.children_recur:
425+
if hasattr(ob, "_id"):
426+
obj_lookup[ob._id] = ob
427+
for grp in bl.groups:
428+
for container_name in grp._child_containers:
429+
container = getattr(grp, container_name)
430+
for i, ref in enumerate(container):
431+
assert isinstance(ref, _Ref)
432+
container[i] = obj_lookup[ref.identifier]
433+
434+
435+
class _Ref:
436+
437+
def __init__(self, identifier, target_class_name):
438+
self.identifier = identifier
439+
self.target_cls = class_by_name[target_class_name]
440+
441+
@property
442+
def proxy_for(self):
443+
return self.target_cls

neo/test/iotest/test_neomatlabio.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from neo.core.analogsignal import AnalogSignal
1010
from neo.core.irregularlysampledsignal import IrregularlySampledSignal
11-
from neo import Block, Segment, SpikeTrain, ImageSequence
11+
from neo import Block, Segment, SpikeTrain, ImageSequence, Group
1212
from neo.test.iotest.common_io_test import BaseTestIO
1313
from neo.io.neomatlabio import NeoMatlabIO
1414

@@ -26,7 +26,7 @@ class TestNeoMatlabIO(BaseTestIO, unittest.TestCase):
2626
files_to_download = []
2727

2828
def test_write_read_single_spike(self):
29-
block1 = Block()
29+
block1 = Block(name="test_neomatlabio")
3030
seg = Segment('segment1')
3131
spiketrain1 = SpikeTrain([1] * pq.s, t_stop=10 * pq.s, sampling_rate=1 * pq.Hz)
3232
spiketrain1.annotate(yep='yop')
@@ -43,6 +43,8 @@ def test_write_read_single_spike(self):
4343
seg.irregularlysampledsignals.append(irrsig1)
4444
seg.imagesequences.append(image_sequence)
4545

46+
group1 = Group([spiketrain1, sig1])
47+
block1.groups.append(group1)
4648

4749
# write block
4850
filename = self.get_local_path('matlabiotestfile.mat')
@@ -72,6 +74,10 @@ def test_write_read_single_spike(self):
7274
assert 'yep' in spiketrain2.annotations
7375
assert spiketrain2.annotations['yep'] == 'yop'
7476

77+
# test group retrieval
78+
group2 = block2.groups[0]
79+
assert_array_equal(group1.analogsignals[0], group2.analogsignals[0])
80+
7581

7682
if __name__ == "__main__":
7783
unittest.main()

0 commit comments

Comments
 (0)