Skip to content

Commit c51b2ce

Browse files
committed
Implement support for saving/loading Groups in NeoMatlabIO.
This is done by saving references to objects inside groups, on the assumption that those objects are already stored somewhere within segments. This assumption may not always be true, so probably we should only save references for objects that would otherwise be duplicated.
1 parent 8e7e67c commit c51b2ce

File tree

3 files changed

+63
-44
lines changed

3 files changed

+63
-44
lines changed

neo/core/spiketrainlist.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ def __getitem__(self, i):
115115
else:
116116
return SpikeTrainList(items=items)
117117

118+
def __setitem__(self, i, value):
119+
if self._items is None:
120+
self._spiketrains_from_array()
121+
self._items[i] = value
122+
118123
def __str__(self):
119124
"""Return str(self)"""
120125
if self._items is None:

neo/io/neomatlabio.py

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def read_block(self, lazy=False):
227227
bl_struct = d['block']
228228
bl = self.create_ob_from_struct(
229229
bl_struct, 'Block')
230+
self._resolve_references(bl)
230231
bl.check_relationships()
231232
return bl
232233

@@ -242,38 +243,23 @@ def write_block(self, bl, **kargs):
242243
seg_struct = self.create_struct_from_obj(seg)
243244
bl_struct['segments'].append(seg_struct)
244245

245-
for anasig in seg.analogsignals:
246-
anasig_struct = self.create_struct_from_obj(anasig)
247-
seg_struct['analogsignals'].append(anasig_struct)
248-
249-
for irrsig in seg.irregularlysampledsignals:
250-
irrsig_struct = self.create_struct_from_obj(irrsig)
251-
seg_struct['irregularlysampledsignals'].append(irrsig_struct)
252-
253-
for ea in seg.events:
254-
ea_struct = self.create_struct_from_obj(ea)
255-
seg_struct['events'].append(ea_struct)
256-
257-
for ea in seg.epochs:
258-
ea_struct = self.create_struct_from_obj(ea)
259-
seg_struct['epochs'].append(ea_struct)
260-
261-
for sptr in seg.spiketrains:
262-
sptr_struct = self.create_struct_from_obj(sptr)
263-
seg_struct['spiketrains'].append(sptr_struct)
264-
265-
for image_sq in seg.imagesequences:
266-
image_sq_structure = self.create_struct_from_obj(image_sq)
267-
seg_struct['imagesequences'].append(image_sq_structure)
246+
for container_name in seg._child_containers:
247+
for child_obj in getattr(seg, container_name):
248+
child_struct = self.create_struct_from_obj(child_obj)
249+
seg_struct[container_name].append(child_struct)
268250

269251
for group in bl.groups:
270252
group_structure = self.create_struct_from_obj(group)
271253
bl_struct['groups'].append(group_structure)
272254

255+
for container_name in group._child_containers:
256+
for child_obj in getattr(group, container_name):
257+
group_structure[container_name].append(id(child_obj))
258+
273259
scipy.io.savemat(self.filename, {'block': bl_struct}, oned_as='row')
274260

275261
def create_struct_from_obj(self, ob):
276-
struct = {}
262+
struct = {"neo_id": id(ob)}
277263

278264
# relationship
279265
for childname in getattr(ob, '_child_containers', []):
@@ -290,11 +276,6 @@ def create_struct_from_obj(self, ob):
290276
for i, attr in enumerate(all_attrs):
291277
attrname, attrtype = attr[0], attr[1]
292278

293-
# ~ if attrname =='':
294-
# ~ struct['array'] = ob.magnitude
295-
# ~ struct['units'] = ob.dimensionality.string
296-
# ~ continue
297-
298279
if (hasattr(ob, '_quantity_attr') and
299280
ob._quantity_attr == attrname):
300281
struct[attrname] = ob.magnitude
@@ -320,13 +301,6 @@ def create_struct_from_obj(self, ob):
320301

321302
def create_ob_from_struct(self, struct, classname):
322303
cl = class_by_name[classname]
323-
# check if inherits Quantity
324-
# ~ is_quantity = False
325-
# ~ for attr in cl._necessary_attrs:
326-
# ~ if attr[0] == '' and attr[1] == pq.Quantity:
327-
# ~ is_quantity = True
328-
# ~ break
329-
# ~ is_quantiy = hasattr(cl, '_quantity_attr')
330304

331305
# ~ if is_quantity:
332306
if hasattr(cl, '_quantity_attr'):
@@ -374,20 +348,27 @@ def create_ob_from_struct(self, struct, classname):
374348
# check children
375349
if attrname in getattr(ob, '_child_containers', []):
376350
child_struct = getattr(struct, attrname)
351+
child_class_name = classname_lower_to_upper[attrname[:-1]]
377352
try:
378353
# try must only surround len() or other errors are captured
379354
child_len = len(child_struct)
380355
except TypeError:
381356
# strange scipy.io behavior: if len is 1 there is no len()
382-
child = self.create_ob_from_struct(
383-
child_struct,
384-
classname_lower_to_upper[attrname[:-1]])
357+
if classname == "Group":
358+
child = _Ref(child_struct, child_class_name)
359+
else:
360+
child = self.create_ob_from_struct(
361+
child_struct,
362+
child_class_name)
385363
getattr(ob, attrname.lower()).append(child)
386364
else:
387365
for c in range(child_len):
388-
child = self.create_ob_from_struct(
389-
child_struct[c],
390-
classname_lower_to_upper[attrname[:-1]])
366+
if classname == "Group":
367+
child = _Ref(child_struct[c], child_class_name)
368+
else:
369+
child = self.create_ob_from_struct(
370+
child_struct[c],
371+
child_class_name)
391372
getattr(ob, attrname.lower()).append(child)
392373
continue
393374

@@ -432,4 +413,31 @@ def create_ob_from_struct(self, struct, classname):
432413

433414
setattr(ob, attrname, item)
434415

416+
neo_id = getattr(struct, "neo_id", None)
417+
if neo_id:
418+
setattr(ob, "_id", neo_id)
435419
return ob
420+
421+
def _resolve_references(self, bl):
422+
if bl.groups:
423+
obj_lookup = {}
424+
for ob in bl.children_recur:
425+
if hasattr(ob, "_id"):
426+
obj_lookup[ob._id] = ob
427+
for grp in bl.groups:
428+
for container_name in grp._child_containers:
429+
container = getattr(grp, container_name)
430+
for i, ref in enumerate(container):
431+
assert isinstance(ref, _Ref)
432+
container[i] = obj_lookup[ref.identifier]
433+
434+
435+
class _Ref:
436+
437+
def __init__(self, identifier, target_class_name):
438+
self.identifier = identifier
439+
self.target_cls = class_by_name[target_class_name]
440+
441+
@property
442+
def proxy_for(self):
443+
return self.target_cls

neo/test/iotest/test_neomatlabio.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from neo.core.analogsignal import AnalogSignal
1010
from neo.core.irregularlysampledsignal import IrregularlySampledSignal
11-
from neo import Block, Segment, SpikeTrain, ImageSequence
11+
from neo import Block, Segment, SpikeTrain, ImageSequence, Group
1212
from neo.test.iotest.common_io_test import BaseTestIO
1313
from neo.io.neomatlabio import NeoMatlabIO
1414

@@ -26,7 +26,7 @@ class TestNeoMatlabIO(BaseTestIO, unittest.TestCase):
2626
files_to_download = []
2727

2828
def test_write_read_single_spike(self):
29-
block1 = Block()
29+
block1 = Block(name="test_neomatlabio")
3030
seg = Segment('segment1')
3131
spiketrain1 = SpikeTrain([1] * pq.s, t_stop=10 * pq.s, sampling_rate=1 * pq.Hz)
3232
spiketrain1.annotate(yep='yop')
@@ -43,6 +43,8 @@ def test_write_read_single_spike(self):
4343
seg.irregularlysampledsignals.append(irrsig1)
4444
seg.imagesequences.append(image_sequence)
4545

46+
group1 = Group([spiketrain1, sig1])
47+
block1.groups.append(group1)
4648

4749
# write block
4850
filename = self.get_local_path('matlabiotestfile.mat')
@@ -72,6 +74,10 @@ def test_write_read_single_spike(self):
7274
assert 'yep' in spiketrain2.annotations
7375
assert spiketrain2.annotations['yep'] == 'yop'
7476

77+
# test group retrieval
78+
group2 = block2.groups[0]
79+
assert_array_equal(group1.analogsignals[0], group2.analogsignals[0])
80+
7581

7682
if __name__ == "__main__":
7783
unittest.main()

0 commit comments

Comments
 (0)