Skip to content

Commit 74b5aa4

Browse files
committed
raise a TypeError is an object of the wrong type is added to a container.
1 parent da4d0ec commit 74b5aa4

File tree

5 files changed

+36
-5
lines changed

5 files changed

+36
-5
lines changed

neo/core/container.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,22 @@ def _get_container(self, cls):
342342
def add(self, *objects):
343343
"""Add a new Neo object to the Container"""
344344
for obj in objects:
345-
container = self._get_container(obj.__class__)
346-
container.append(obj)
345+
if (
346+
obj.__class__.__name__ in self._child_objects
347+
or (
348+
hasattr(obj, "proxy_for")
349+
and obj.proxy_for.__name__ in self._child_objects
350+
)
351+
):
352+
container = self._get_container(obj.__class__)
353+
container.append(obj)
354+
else:
355+
raise TypeError(
356+
f"Cannot add object of type {obj.__class__.__name__} "
357+
f"to a {self.__class__.__name__}, can only add objects of the "
358+
f"following types: {self._child_objects}"
359+
)
360+
347361

348362

349363
def filter(self, targdict=None, data=True, container=False, recursive=True,

neo/core/group.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ def __init__(self, objects=None, name=None, description=None, file_origin=None,
7676
self.allowed_types = None
7777
else:
7878
self.allowed_types = tuple(allowed_types)
79+
for type_ in self.allowed_types:
80+
if type_.__name__ not in self._child_objects:
81+
raise TypeError(f"Groups can not contain objects of type {type_.__name__}")
7982

8083
if objects:
8184
self.add(*objects)
@@ -140,8 +143,7 @@ def add(self, *objects):
140143
if self.allowed_types and not isinstance(obj, self.allowed_types):
141144
raise TypeError("This Group can only contain {}, but not {}"
142145
"".format(self.allowed_types, type(obj)))
143-
container = self._get_container(obj.__class__)
144-
container.append(obj)
146+
super().add(*objects)
145147

146148
def walk(self):
147149
"""

neo/test/coretest/test_block.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from neo.core import SpikeTrain, AnalogSignal, Event
2424
from neo.test.tools import (assert_neo_object_is_compliant,
2525
assert_same_sub_schema)
26-
from neo.test.generate_datasets import random_block, simple_block
26+
from neo.test.generate_datasets import random_block, simple_block, random_signal
2727

2828

2929
N_EXAMPLES = 5
@@ -493,6 +493,10 @@ def test_add(self):
493493
new_blk.add(*blk.segments)
494494
assert len(new_blk.segments) == n_segs_start + len(blk.segments)
495495

496+
def test_add_invalid_type_raises_Exception(self):
497+
new_blk = Block()
498+
self.assertRaises(TypeError, new_blk.add, random_signal())
499+
496500

497501
if __name__ == "__main__":
498502
unittest.main()

neo/test/coretest/test_group.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from neo.core.segment import Segment
1616
from neo.core.view import ChannelView
1717
from neo.core.group import Group
18+
from neo.core.block import Block
1819

1920

2021
class TestGroup(unittest.TestCase):
@@ -91,3 +92,9 @@ def test_walk(self):
9192
target.extend([children[1], children[2], *grandchildren[2]])
9293
self.assertEqual(flattened,
9394
target)
95+
96+
def test_add_invalid_type_raises_Exception(self):
97+
group = Group()
98+
self.assertRaises(TypeError, group.add, Block())
99+
100+
self.assertRaises(TypeError, Group, allowed_types=[Block])

neo/test/coretest/test_segment.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,10 @@ def test_add(self):
669669
seg.add(proxy_epoch)
670670
assert len(seg.epochs) == 1
671671

672+
def test_add_invalid_type_raises_Exception(self):
673+
seg = Segment()
674+
self.assertRaises(TypeError, seg.add, Block())
675+
672676

673677
if __name__ == "__main__":
674678
unittest.main()

0 commit comments

Comments
 (0)