Skip to content

Commit 972901f

Browse files
Merge pull request #1311 from apdavison/container-add--check-types
Check the types of objects added to a container with the new `add()` method
2 parents 182a7e9 + 74b5aa4 commit 972901f

File tree

6 files changed

+39
-5
lines changed

6 files changed

+39
-5
lines changed

doc/source/authors.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ and may not be the current affiliation of a contributor.
6262
* Elodie Legouée [21]
6363
* Heberto Mayorquin [24]
6464
* Thomas Perret [25]
65+
* Kyle Johnsen [26, 27]
6566

6667
1. Centre de Recherche en Neuroscience de Lyon, CNRS UMR5292 - INSERM U1028 - Universite Claude Bernard Lyon 1
6768
2. Unité de Neuroscience, Information et Complexité, CNRS UPR 3293, Gif-sur-Yvette, France
@@ -88,6 +89,8 @@ and may not be the current affiliation of a contributor.
8889
23. Bio Engineering Laboratory, DBSSE, ETH, Basel, Switzerland
8990
24. CatalystNeuro
9091
25. Institut des Sciences Cognitives Marc Jeannerod, CNRS UMR5229, Lyon, France
92+
26. Georgia Institute of Technology
93+
27. Emory University
9194

9295
If we've somehow missed you off the list we're very sorry - please let us know.
9396

neo/core/container.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,22 @@ def _get_container(self, cls):
342342
def add(self, *objects):
343343
"""Add a new Neo object to the Container"""
344344
for obj in objects:
345-
container = self._get_container(obj.__class__)
346-
container.append(obj)
345+
if (
346+
obj.__class__.__name__ in self._child_objects
347+
or (
348+
hasattr(obj, "proxy_for")
349+
and obj.proxy_for.__name__ in self._child_objects
350+
)
351+
):
352+
container = self._get_container(obj.__class__)
353+
container.append(obj)
354+
else:
355+
raise TypeError(
356+
f"Cannot add object of type {obj.__class__.__name__} "
357+
f"to a {self.__class__.__name__}, can only add objects of the "
358+
f"following types: {self._child_objects}"
359+
)
360+
347361

348362

349363
def filter(self, targdict=None, data=True, container=False, recursive=True,

neo/core/group.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ def __init__(self, objects=None, name=None, description=None, file_origin=None,
7676
self.allowed_types = None
7777
else:
7878
self.allowed_types = tuple(allowed_types)
79+
for type_ in self.allowed_types:
80+
if type_.__name__ not in self._child_objects:
81+
raise TypeError(f"Groups can not contain objects of type {type_.__name__}")
7982

8083
if objects:
8184
self.add(*objects)
@@ -140,8 +143,7 @@ def add(self, *objects):
140143
if self.allowed_types and not isinstance(obj, self.allowed_types):
141144
raise TypeError("This Group can only contain {}, but not {}"
142145
"".format(self.allowed_types, type(obj)))
143-
container = self._get_container(obj.__class__)
144-
container.append(obj)
146+
super().add(*objects)
145147

146148
def walk(self):
147149
"""

neo/test/coretest/test_block.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from neo.core import SpikeTrain, AnalogSignal, Event
2424
from neo.test.tools import (assert_neo_object_is_compliant,
2525
assert_same_sub_schema)
26-
from neo.test.generate_datasets import random_block, simple_block
26+
from neo.test.generate_datasets import random_block, simple_block, random_signal
2727

2828

2929
N_EXAMPLES = 5
@@ -493,6 +493,10 @@ def test_add(self):
493493
new_blk.add(*blk.segments)
494494
assert len(new_blk.segments) == n_segs_start + len(blk.segments)
495495

496+
def test_add_invalid_type_raises_Exception(self):
497+
new_blk = Block()
498+
self.assertRaises(TypeError, new_blk.add, random_signal())
499+
496500

497501
if __name__ == "__main__":
498502
unittest.main()

neo/test/coretest/test_group.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from neo.core.segment import Segment
1616
from neo.core.view import ChannelView
1717
from neo.core.group import Group
18+
from neo.core.block import Block
1819

1920

2021
class TestGroup(unittest.TestCase):
@@ -91,3 +92,9 @@ def test_walk(self):
9192
target.extend([children[1], children[2], *grandchildren[2]])
9293
self.assertEqual(flattened,
9394
target)
95+
96+
def test_add_invalid_type_raises_Exception(self):
97+
group = Group()
98+
self.assertRaises(TypeError, group.add, Block())
99+
100+
self.assertRaises(TypeError, Group, allowed_types=[Block])

neo/test/coretest/test_segment.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,10 @@ def test_add(self):
669669
seg.add(proxy_epoch)
670670
assert len(seg.epochs) == 1
671671

672+
def test_add_invalid_type_raises_Exception(self):
673+
seg = Segment()
674+
self.assertRaises(TypeError, seg.add, Block())
675+
672676

673677
if __name__ == "__main__":
674678
unittest.main()

0 commit comments

Comments
 (0)