Skip to content

Commit abe4b15

Browse files
committed
fix some bugs that were found by running iotests
1 parent 5adbe0a commit abe4b15

File tree

7 files changed

+98
-46
lines changed

7 files changed

+98
-46
lines changed

neo/core/container.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from copy import deepcopy
99
from neo.core.baseneo import BaseNeo, _reference_name, _container_name
10+
from neo.core.objectlist import ObjectList
1011
from neo.core.spiketrain import SpikeTrain
1112
from neo.core.spiketrainlist import SpikeTrainList
1213

@@ -211,11 +212,14 @@ def _set_object_list(self, name, value):
211212
Example:
212213
>>> segment._set_object_list("_analogsignals", [sig1, sig2])
213214
"""
214-
assert isinstance(value, list)
215-
object_list = getattr(self, name)
216-
if len(object_list) > 0:
217-
raise Exception("Object list not empty")
218-
object_list.extend(value)
215+
if isinstance(value, list):
216+
object_list = getattr(self, name)
217+
object_list.clear()
218+
object_list.extend(value)
219+
elif isinstance(value, ObjectList): # from __iadd__
220+
setattr(self, name, value)
221+
else:
222+
TypeError("value must be a list or an ObjectList")
219223

220224
@property
221225
def _child_objects(self):
@@ -365,11 +369,7 @@ def filter(self, targdict=None, data=True, container=False, recursive=True,
365369
data = True
366370
container = True
367371

368-
if objects == SpikeTrain:
369-
children = SpikeTrainList()
370-
else:
371-
children = []
372-
372+
children = []
373373
# get the objects we want
374374
if data:
375375
if recursive:
@@ -382,8 +382,12 @@ def filter(self, targdict=None, data=True, container=False, recursive=True,
382382
else:
383383
children.extend(self.container_children)
384384

385-
return filterdata(children, objects=objects,
386-
targdict=targdict, **kwargs)
385+
filtered = filterdata(children, objects=objects,
386+
targdict=targdict, **kwargs)
387+
if objects == SpikeTrain:
388+
return SpikeTrainList(items=filtered)
389+
else:
390+
return filtered
387391

388392
def list_children_by_class(self, cls):
389393
"""
@@ -406,7 +410,12 @@ def check_relationships(self, recursive=True):
406410
"""
407411
parent_name = _reference_name(self.__class__.__name__)
408412
for child in self._single_children:
409-
assert getattr(child, parent_name, None) is self
413+
if hasattr(child, "proxy_for"):
414+
container = getattr(self, _container_name(child.proxy_for.__name__))
415+
else:
416+
container = getattr(self, _container_name(child.__class__.__name__))
417+
if container.parent is not None:
418+
assert getattr(child, parent_name, None) is self
410419
if recursive:
411420
for child in self.container_children:
412421
child.check_relationships(recursive=True)

neo/core/group.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(self, objects=None, name=None, description=None, file_origin=None,
6464
# they are contained in.
6565
self._analogsignals = ObjectList(AnalogSignal)
6666
self._irregularlysampledsignals = ObjectList(IrregularlySampledSignal)
67-
self._spiketrains = SpikeTrainList(parent=self)
67+
self._spiketrains = SpikeTrainList()
6868
self._events = ObjectList(Event)
6969
self._epochs = ObjectList(Epoch)
7070
self._channelviews = ObjectList(ChannelView)

neo/core/objectlist.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
"""
2-
3-
2+
This module implements the ObjectList class, which is used to peform type checks
3+
and handle relationships within the Neo Block-Segment-Data hierarchy.
44
"""
55

6+
import sys
67
from neo.core.baseneo import BaseNeo
78

89

910
class ObjectList:
1011
"""
11-
handle relationships within Neo hierarchy
12+
This class behaves like a list, but has additional functionality
13+
to handle relationships within Neo hierarchy, and perform type checks.
1214
"""
1315

1416
def __init__(self, allowed_contents, parent=None):
@@ -33,6 +35,8 @@ def _handle_append(self, obj):
3335
# set the child-parent relationship
3436
if self.parent:
3537
relationship_name = self.parent.__class__.__name__.lower()
38+
if relationship_name == "group":
39+
raise Exception("Objects in groups should not link to the group as their parent")
3640
current_parent = getattr(obj, relationship_name)
3741
if current_parent != self.parent:
3842
setattr(obj, relationship_name, self.parent)
@@ -63,6 +67,7 @@ def __iadd__(self, objects):
6367
for obj in objects:
6468
self._handle_append(obj)
6569
self.contents.extend(objects)
70+
return self
6671

6772
def __iter__(self):
6873
return iter(self.contents)
@@ -73,11 +78,8 @@ def __getitem__(self, i):
7378
def __len__(self):
7479
return len(self.contents)
7580

76-
def __reversed__(self):
77-
raise NotImplementedError
78-
79-
def __setitem__(self, i):
80-
raise NotImplementedError
81+
def __setitem__(self, key, value):
82+
self.contents[key] = value
8183

8284
def append(self, obj):
8385
self._handle_append(obj)
@@ -89,28 +91,26 @@ def extend(self, objects):
8991
self.contents.extend(objects)
9092

9193
def clear(self):
92-
raise NotImplementedError
93-
94-
def copy(self):
95-
raise NotImplementedError
94+
self.contents = []
9695

97-
def count(self):
98-
raise NotImplementedError
96+
def count(self, value):
97+
return self.contents.count(value)
9998

100-
def index(self):
101-
raise NotImplementedError
99+
def index(self, value, start=0, stop=sys.maxsize):
100+
return self.contents.index(value, start, stop)
102101

103-
def insert(self):
104-
raise NotImplementedError
102+
def insert(self, index, obj):
103+
self._handle_append(obj)
104+
self.contents[index] = obj
105105

106-
def pop(self):
107-
raise NotImplementedError
106+
def pop(self, index=-1):
107+
return self.contents.pop(index)
108108

109-
def remove(self):
110-
raise NotImplementedError
109+
def remove(self, value):
110+
return self.contents.remove(value)
111111

112112
def reverse(self):
113-
raise NotImplementedError
113+
raise self.contents.reverse()
114114

115-
def sort(self):
116-
raise NotImplementedError
115+
def sort(self, *args, key=None, reverse=False):
116+
self.contents.sort(*args, key=key, reverse=reverse)

neo/core/spiketrainlist.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class SpikeTrainList(ObjectList):
7474
<SpikeTrain(array([], dtype=float64) * ms, [0.0 ms, 100.0 ms])>]
7575
7676
"""
77+
allowed_contents = (SpikeTrain,)
7778

7879
def __init__(self, items=None, parent=None):
7980
"""Initialize self"""
@@ -89,8 +90,14 @@ def __init__(self, items=None, parent=None):
8990
self._channel_id_array = None
9091
self._all_channel_ids = None
9192
self._spiketrain_metadata = {}
93+
if parent is not None:
94+
assert parent.__class__.__name__ == "Segment"
9295
self.segment = parent
9396

97+
@property
98+
def parent(self):
99+
return self.segment
100+
94101
def __iter__(self):
95102
"""Implement iter(self)"""
96103
if self._items is None:
@@ -120,6 +127,9 @@ def __str__(self):
120127
else:
121128
return str(self._items)
122129

130+
def __repr__(self):
131+
return "<SpikeTrainList>"
132+
123133
def __len__(self):
124134
"""Return len(self)"""
125135
if self._items is None:
@@ -196,7 +206,7 @@ def __iadd__(self, other):
196206
return self._add_spiketrainlists(other, in_place=True)
197207
elif other and is_spiketrain_or_proxy(other[0]):
198208
for obj in other:
199-
obj.segment = self.segment
209+
self._handle_append(obj)
200210
if self._items is None:
201211
self._spiketrains_from_array()
202212
self._items.extend(other)
@@ -228,15 +238,15 @@ def append(self, obj):
228238
raise ValueError("Can only append SpikeTrain objects")
229239
if self._items is None:
230240
self._spiketrains_from_array()
231-
obj.segment = self.segment
241+
self._handle_append(obj)
232242
self._items.append(obj)
233243

234244
def extend(self, iterable):
235245
"""L.extend(iterable) -> None -- extend list by appending elements from the iterable"""
236246
if self._items is None:
237247
self._spiketrains_from_array()
238248
for obj in iterable:
239-
obj.segment = self.segment
249+
self._handle_append(obj)
240250
self._items.extend(iterable)
241251

242252
@classmethod

neo/io/neomatlabio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def create_struct_from_obj(self, ob):
265265
struct = {}
266266

267267
# relationship
268-
for childname in getattr(ob, '_single_child_containers', []):
268+
for childname in getattr(ob, '_child_containers', []):
269269
supported_containers = [subob.__name__.lower() + 's' for subob in
270270
self.supported_objects]
271271
if childname in supported_containers:
@@ -356,7 +356,7 @@ def create_ob_from_struct(self, struct, classname):
356356

357357
for attrname in struct._fieldnames:
358358
# check children
359-
if attrname in getattr(ob, '_single_child_containers', []):
359+
if attrname in getattr(ob, '_child_containers', []):
360360
child_struct = getattr(struct, attrname)
361361
try:
362362
# try must only surround len() or other errors are captured

neo/test/coretest/test_block.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
HAVE_IPYTHON = True
1919

2020
from neo.core.block import Block
21+
from neo.core.segment import Segment
2122
from neo.core.container import filterdata
2223
from neo.core import SpikeTrain, AnalogSignal, Event
2324
from neo.test.tools import (assert_neo_object_is_compliant,
@@ -451,6 +452,20 @@ def test__deepcopy(self):
451452
for sptr in segment.spiketrains:
452453
self.assertEqual(id(sptr.segment), id(segment))
453454

455+
def test_segment_list(self):
456+
blk = Block()
457+
assert len(blk.segments) == 0
458+
blk.segments.append(Segment())
459+
assert len(blk.segments) == 1
460+
blk.segments.extend([Segment(), Segment()])
461+
assert len(blk.segments) == 3
462+
blk.segments = []
463+
assert len(blk.segments) == 0
464+
blk.segments = [Segment()]
465+
assert len(blk.segments) == 1
466+
blk.segments += [Segment(), Segment()]
467+
assert len(blk.segments) == 3
468+
454469

455470
if __name__ == "__main__":
456471
unittest.main()

neo/test/iotest/test_cedio.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,27 @@
11
import unittest
2+
from platform import system
3+
from sys import maxsize
24

3-
from neo.io import CedIO
4-
from neo.test.iotest.common_io_test import BaseTestIO
5+
try:
6+
if system() == 'Windows':
7+
if maxsize > 2**32:
8+
import sonpy.amd64.sonpy
9+
else:
10+
import sonpy.win32.sonpy
11+
elif system() == 'Darwin':
12+
import sonpy.darwin.sonpy
13+
elif system() == 'Linux':
14+
import sonpy.linux.sonpy
15+
from neo.io import CedIO
16+
except ImportError:
17+
HAVE_SONPY = False
18+
CedIO = None
19+
else:
20+
HAVE_SONPY = True
521

22+
from neo.test.iotest.common_io_test import BaseTestIO
623

24+
@unittest.skipUnless(HAVE_SONPY, "sonpy")
725
class TestCedIO(BaseTestIO, unittest.TestCase, ):
826
ioclass = CedIO
927
entities_to_test = [

0 commit comments

Comments
 (0)