Skip to content

Commit 8baaeb3

Browse files
committed
Replace simple lists with a list-like object that can manage relationships
1 parent 98c658f commit 8baaeb3

File tree

8 files changed

+275
-26
lines changed

8 files changed

+275
-26
lines changed

neo/core/block.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
from datetime import datetime
1111

1212
from neo.core.container import Container, unique_objs
13+
from neo.core.group import Group
14+
from neo.core.objectlist import ObjectList
15+
from neo.core.regionofinterest import RegionOfInterest
16+
from neo.core.segment import Segment
1317

1418

1519
class Block(Container):
@@ -85,11 +89,27 @@ def __init__(self, name=None, description=None, file_origin=None,
8589
self.file_datetime = file_datetime
8690
self.rec_datetime = rec_datetime
8791
self.index = index
88-
self.segments = []
89-
self.groups = []
90-
self.regionsofinterest = [] # temporary workaround.
91-
# the goal is to store all sub-classes of RegionOfInterest in a single list
92-
# but this will need substantial changes to container handling
92+
self._segments = ObjectList(Segment)
93+
self._groups = ObjectList(Group)
94+
self._regionsofinterest = ObjectList(RegionOfInterest)
95+
96+
segments = property(
97+
fget=lambda self: self._get_object_list("_segments"),
98+
fset=lambda self, value: self._set_object_list("_segments", value),
99+
doc="todo"
100+
)
101+
102+
groups = property(
103+
fget=lambda self: self._get_object_list("_groups"),
104+
fset=lambda self, value: self._set_object_list("_groups", value),
105+
doc="todo"
106+
)
107+
108+
regionsofinterest = property(
109+
fget=lambda self: self._get_object_list("_regionsofinterest"),
110+
fset=lambda self, value: self._set_object_list("_regionsofinterest", value),
111+
doc="todo"
112+
)
93113

94114
@property
95115
def data_children_recur(self):

neo/core/container.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,22 @@ def __init__(self, name=None, description=None, file_origin=None,
198198
super().__init__(name=name, description=description,
199199
file_origin=file_origin, **annotations)
200200

201+
def _get_object_list(self, name):
202+
"""
203+
204+
"""
205+
return getattr(self, name)
206+
207+
def _set_object_list(self, name, value):
208+
"""
209+
210+
"""
211+
assert isinstance(value, list)
212+
object_list = getattr(self, name)
213+
if len(object_list) > 0:
214+
raise Exception("Object list not empty")
215+
object_list.extend(value)
216+
201217
@property
202218
def _child_objects(self):
203219
"""

neo/core/group.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@
88

99
from os import close
1010
from neo.core.container import Container
11+
from neo.core.analogsignal import AnalogSignal
12+
from neo.core.container import Container
13+
from neo.core.objectlist import ObjectList
14+
from neo.core.epoch import Epoch
15+
from neo.core.event import Event
16+
from neo.core.imagesequence import ImageSequence
17+
from neo.core.irregularlysampledsignal import IrregularlySampledSignal
18+
from neo.core.segment import Segment
19+
from neo.core.spiketrainlist import SpikeTrainList
20+
from neo.core.view import ChannelView
1121

1222

1323
class Group(Container):
@@ -49,16 +59,15 @@ def __init__(self, objects=None, name=None, description=None, file_origin=None,
4959
super().__init__(name=name, description=description,
5060
file_origin=file_origin, **annotations)
5161

52-
self.analogsignals = []
53-
self.irregularlysampledsignals = []
54-
self.spiketrains = []
55-
self.events = []
56-
self.epochs = []
57-
self.channelviews = []
58-
self.imagesequences = []
59-
self.segments = [] # to remove?
60-
self.groups = []
61-
self.block = None
62+
self._analogsignals = ObjectList(AnalogSignal)
63+
self._irregularlysampledsignals = ObjectList(IrregularlySampledSignal)
64+
self.spiketrains = SpikeTrainList(segment=self)
65+
self._events = ObjectList(Event)
66+
self._epochs = ObjectList(Epoch)
67+
self._channelviews = ObjectList(ChannelView)
68+
self._imagesequences = ObjectList(ImageSequence)
69+
self.segments = ObjectList(Segment) # to remove?
70+
self.groups = ObjectList(Group)
6271

6372
if allowed_types is None:
6473
self.allowed_types = None
@@ -68,6 +77,42 @@ def __init__(self, objects=None, name=None, description=None, file_origin=None,
6877
if objects:
6978
self.add(*objects)
7079

80+
analogsignals = property(
81+
fget=lambda self: self._get_object_list("_analogsignals"),
82+
fset=lambda self, value: self._set_object_list("_analogsignals", value),
83+
doc="todo"
84+
)
85+
86+
irregularlysampledsignals = property(
87+
fget=lambda self: self._get_object_list("_irregularlysampledsignals"),
88+
fset=lambda self, value: self._set_object_list("_irregularlysampledsignals", value),
89+
doc="todo"
90+
)
91+
92+
events = property(
93+
fget=lambda self: self._get_object_list("_events"),
94+
fset=lambda self, value: self._set_object_list("_events", value),
95+
doc="todo"
96+
)
97+
98+
epochs = property(
99+
fget=lambda self: self._get_object_list("_epochs"),
100+
fset=lambda self, value: self._set_object_list("_epochs", value),
101+
doc="todo"
102+
)
103+
104+
channelviews = property(
105+
fget=lambda self: self._get_object_list("_channelviews"),
106+
fset=lambda self, value: self._set_object_list("_channelviews", value),
107+
doc="todo"
108+
)
109+
110+
imagesequences = property(
111+
fget=lambda self: self._get_object_list("_imagesequences"),
112+
fset=lambda self, value: self._set_object_list("_imagesequences", value),
113+
doc="todo"
114+
)
115+
71116
@property
72117
def _container_lookup(self):
73118
return {

neo/core/objectlist.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""
2+
3+
4+
"""
5+
6+
from neo.core.baseneo import BaseNeo
7+
8+
9+
class ObjectList:
10+
"""
11+
handle relationships within Neo hierarchy
12+
"""
13+
14+
def __init__(self, allowed_contents):
15+
# validate allowed_contents and normalize it to a tuple
16+
if isinstance(allowed_contents, type) and issubclass(allowed_contents, BaseNeo):
17+
self.allowed_contents = (allowed_contents,)
18+
else:
19+
for item in allowed_contents:
20+
assert issubclass(item, BaseNeo)
21+
self.allowed_contents = tuple(allowed_contents)
22+
self.contents = []
23+
24+
def _handle_append(self, obj):
25+
if not (
26+
isinstance(obj, self.allowed_contents)
27+
or ( # also allow proxy objects of the correct type
28+
hasattr(obj, "proxy_for") and obj.proxy_for in self.allowed_contents
29+
)
30+
):
31+
raise TypeError(f"Object is a {type(obj)}. It should be one of {self.allowed_contents}.")
32+
33+
def __str__(self):
34+
return str(self.contents)
35+
36+
def __repr__(self):
37+
return repr(self.contents)
38+
39+
def __add__(self, objects):
40+
# todo: decision: return a list, or a new DataObjectList?
41+
if isinstance(objects, ObjectList):
42+
return self.contents + objects.contents
43+
else:
44+
return self.contents + objects
45+
46+
def __radd__(self, objects):
47+
if isinstance(objects, ObjectList):
48+
return objects.contents + self.contents
49+
else:
50+
return objects + self.contents
51+
52+
def __contains__(self, key):
53+
return key in self.contents
54+
55+
def __iadd__(self, objects):
56+
for obj in objects:
57+
self._handle_append(obj)
58+
self.contents.extend(objects)
59+
60+
def __iter__(self):
61+
return iter(self.contents)
62+
63+
def __getitem__(self, i):
64+
return self.contents[i]
65+
66+
def __len__(self):
67+
return len(self.contents)
68+
69+
def __reversed__(self):
70+
raise NotImplementedError
71+
72+
def __setitem__(self, i):
73+
raise NotImplementedError
74+
75+
def append(self, obj):
76+
self._handle_append(obj)
77+
self.contents.append(obj)
78+
79+
def extend(self, objects):
80+
for obj in objects:
81+
self._handle_append(obj)
82+
self.contents.extend(objects)
83+
84+
def clear(self):
85+
raise NotImplementedError
86+
87+
def copy(self):
88+
raise NotImplementedError
89+
90+
def count(self):
91+
raise NotImplementedError
92+
93+
def index(self):
94+
raise NotImplementedError
95+
96+
def insert(self):
97+
raise NotImplementedError
98+
99+
def pop(self):
100+
raise NotImplementedError
101+
102+
def remove(self):
103+
raise NotImplementedError
104+
105+
def reverse(self):
106+
raise NotImplementedError
107+
108+
def sort(self):
109+
raise NotImplementedError

neo/core/regionofinterest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from math import floor, ceil
22

3+
from neo.core.baseneo import BaseNeo
34

4-
class RegionOfInterest:
5+
6+
class RegionOfInterest(BaseNeo):
57
"""Abstract base class"""
68
pass
79

neo/core/segment.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,20 @@
77
'''
88

99
from datetime import datetime
10+
from copy import deepcopy
1011

1112
import numpy as np
1213

13-
from copy import deepcopy
1414

15+
from neo.core.analogsignal import AnalogSignal
1516
from neo.core.container import Container
17+
from neo.core.objectlist import ObjectList
18+
from neo.core.epoch import Epoch
19+
from neo.core.event import Event
20+
from neo.core.imagesequence import ImageSequence
21+
from neo.core.irregularlysampledsignal import IrregularlySampledSignal
1622
from neo.core.spiketrainlist import SpikeTrainList
23+
from neo.core.view import ChannelView
1724

1825

1926
class Segment(Container):
@@ -92,19 +99,55 @@ def __init__(self, name=None, description=None, file_origin=None,
9299
super().__init__(name=name, description=description,
93100
file_origin=file_origin, **annotations)
94101

95-
self.analogsignals = []
96-
self.irregularlysampledsignals = []
102+
self._analogsignals = ObjectList(AnalogSignal)
103+
self._irregularlysampledsignals = ObjectList(IrregularlySampledSignal)
97104
self.spiketrains = SpikeTrainList(segment=self)
98-
self.events = []
99-
self.epochs = []
100-
self.channelviews = []
101-
self.imagesequences = []
105+
self._events = ObjectList(Event)
106+
self._epochs = ObjectList(Epoch)
107+
self._channelviews = ObjectList(ChannelView)
108+
self._imagesequences = ObjectList(ImageSequence)
102109
self.block = None
103110

104111
self.file_datetime = file_datetime
105112
self.rec_datetime = rec_datetime
106113
self.index = index
107114

115+
analogsignals = property(
116+
fget=lambda self: self._get_object_list("_analogsignals"),
117+
fset=lambda self, value: self._set_object_list("_analogsignals", value),
118+
doc="todo"
119+
)
120+
121+
irregularlysampledsignals = property(
122+
fget=lambda self: self._get_object_list("_irregularlysampledsignals"),
123+
fset=lambda self, value: self._set_object_list("_irregularlysampledsignals", value),
124+
doc="todo"
125+
)
126+
127+
events = property(
128+
fget=lambda self: self._get_object_list("_events"),
129+
fset=lambda self, value: self._set_object_list("_events", value),
130+
doc="todo"
131+
)
132+
133+
epochs = property(
134+
fget=lambda self: self._get_object_list("_epochs"),
135+
fset=lambda self, value: self._set_object_list("_epochs", value),
136+
doc="todo"
137+
)
138+
139+
channelviews = property(
140+
fget=lambda self: self._get_object_list("_channelviews"),
141+
fset=lambda self, value: self._set_object_list("_channelviews", value),
142+
doc="todo"
143+
)
144+
145+
imagesequences = property(
146+
fget=lambda self: self._get_object_list("_imagesequences"),
147+
fset=lambda self, value: self._set_object_list("_imagesequences", value),
148+
doc="todo"
149+
)
150+
108151
# t_start attribute is handled as a property so type checking can be done
109152
@property
110153
def t_start(self):

neo/core/spiketrainlist.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111
import quantities as pq
1212
from .spiketrain import SpikeTrain, normalize_times_array
13+
from .objectlist import ObjectList
1314

1415

1516
def is_spiketrain_or_proxy(obj):
@@ -32,7 +33,7 @@ def unique(quantities):
3233

3334

3435

35-
class SpikeTrainList(object):
36+
class SpikeTrainList(ObjectList):
3637
"""
3738
This class contains multiple spike trains, and can represent them
3839
either as a list of SpikeTrain objects or as a pair of arrays

0 commit comments

Comments
 (0)