Skip to content

Commit 97055e4

Browse files
committed
convert assert to errors I
1 parent f401627 commit 97055e4

File tree

7 files changed

+94
-66
lines changed

7 files changed

+94
-66
lines changed

neo/core/container.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,8 @@ def check_relationships(self, recursive=True):
450450
else:
451451
container = getattr(self, _container_name(child.__class__.__name__))
452452
if container.parent is not None:
453-
assert getattr(child, parent_name, None) is self
453+
if getattr(child, parent_name, None) is not self:
454+
raise AttributeError("Child should have its parent as an attribute")
454455
if recursive:
455456
for child in self.container_children:
456457
child.check_relationships(recursive=True)

neo/core/objectlist.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ def __init__(self, allowed_contents, parent=None):
1919
self.allowed_contents = (allowed_contents,)
2020
else:
2121
for item in allowed_contents:
22-
assert issubclass(item, BaseNeo)
22+
if not issubclass(item, BaseNeo):
23+
raise TypeError("Each item in allowed_contents must be a subclass of BaseNeo")
2324
self.allowed_contents = tuple(allowed_contents)
2425
self._items = []
2526
self.parent = parent

neo/core/spiketrain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -891,8 +891,8 @@ def _merge_array_annotations(self, others, sorting=None):
891891
the spikes.
892892
:return Merged array_annotations
893893
"""
894-
895-
assert sorting is not None, "The order of the merged spikes must be known"
894+
if sorting is None:
895+
raise ValueError("The order of the merged spikes must be known")
896896

897897
merged_array_annotations = {}
898898

neo/core/spiketrainlist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def __init__(self, items=None, parent=None):
8787
self._channel_id_array = None
8888
self._all_channel_ids = None
8989
self._spiketrain_metadata = {}
90-
if parent is not None:
91-
assert parent.__class__.__name__ == "Segment"
90+
if parent is not None and parent.__class__.__name__ != "Segment":
91+
raise AttributeError("The parent class must be a Segment")
9292
self.segment = parent
9393

9494
@property

neo/io/basefromrawio.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -241,13 +241,10 @@ def read_segment(
241241
"""
242242

243243
if lazy:
244-
assert (
245-
time_slice is None
246-
), "For lazy=True you must specify a time_slice when LazyObject.load(time_slice=...)"
247-
248-
assert (
249-
not load_waveforms
250-
), "For lazy=True you must specify load_waveforms when SpikeTrain.load(load_waveforms=...)"
244+
if time_slice is not None:
245+
raise ValueError("For lazy=True you must specify a time_slice when LazyObject.load(time_slice=...)")
246+
if load_waveforms:
247+
raise ValueError("For lazy=True you must specify load_waveforms when SpikeTrain.load(load_waveforms=...)")
251248

252249
if signal_group_mode is None:
253250
signal_group_mode = self._prefered_signal_group_mode

neo/io/baseio.py

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -167,97 +167,121 @@ def write(self, bl, **kargs):
167167
"""
168168
if Block in self.writeable_objects:
169169
if isinstance(bl, Sequence):
170-
assert hasattr(
171-
self, "write_all_blocks"
172-
), f"{self.__class__.__name__} does not offer to store a sequence of blocks"
170+
if not hasattr(self, "write_all_blocks"):
171+
raise AttributeError(f"{self.__class__.__name__} does not offer to store a sequence of blocks")
173172
self.write_all_blocks(bl, **kargs)
174173
else:
175174
self.write_block(bl, **kargs)
176175
elif Segment in self.writeable_objects:
177-
assert len(bl.segments) == 1, (
178-
f"{self.__class__.__name__} is based on segment so if you try to write a block it "
179-
+ "must contain only one Segment"
180-
)
176+
if len(bl.segments) != 1:
177+
raise ValueError(f"{self.__class__.__name__} is based on segment so if you try to write a block it "
178+
+ "must contain only one Segment")
181179
self.write_segment(bl.segments[0], **kargs)
182180
else:
183181
raise NotImplementedError
184182

185183
######## All individual read methods #######################
186184
def read_block(self, **kargs):
187-
assert Block in self.readable_objects, read_error
185+
if Block not in self.readable_objects:
186+
raise RuntimeError(read_error)
188187

189188
def read_segment(self, **kargs):
190-
assert Segment in self.readable_objects, read_error
189+
if Segment not in self.readable_objects:
190+
raise RuntimeError(read_error)
191191

192192
def read_spiketrain(self, **kargs):
193-
assert SpikeTrain in self.readable_objects, read_error
193+
if SpikeTrain not in self.readable_objects:
194+
raise RuntimeError(read_error)
194195

195196
def read_analogsignal(self, **kargs):
196-
assert AnalogSignal in self.readable_objects, read_error
197+
if AnalogSignal not in self.readable_objects:
198+
raise RuntimeError(read_error)
197199

198200
def read_imagesequence(self, **kargs):
199-
assert ImageSequence in self.readable_objects, read_error
201+
if ImageSequence not in self.readable_objects:
202+
raise RuntimeError(read_error)
200203

201204
def read_rectangularregionofinterest(self, **kargs):
202-
assert RectangularRegionOfInterest in self.readable_objects, read_error
205+
if RectangularRegionOfInterest not in self.readable_objects:
206+
raise RuntimeError(read_error)
203207

204208
def read_circularregionofinterest(self, **kargs):
205-
assert CircularRegionOfInterest in self.readable_objects, read_error
209+
if CircularRegionOfInterest not in self.readable_objects:
210+
raise RuntimeError(read_error)
206211

207212
def read_polygonregionofinterest(self, **kargs):
208-
assert PolygonRegionOfInterest in self.readable_objects, read_error
213+
if PolygonRegionOfInterest not in self.readable_objects:
214+
raise RuntimeError(read_error)
209215

210216
def read_irregularlysampledsignal(self, **kargs):
211-
assert IrregularlySampledSignal in self.readable_objects, read_error
217+
if IrregularlySampledSignal not in self.readable_objects:
218+
raise RuntimeError(read_error)
212219

213220
def read_channelview(self, **kargs):
214-
assert ChannelView in self.readable_objects, read_error
221+
if ChannelView not in self.readable_objects:
222+
raise RuntimeError(read_error)
215223

216224
def read_event(self, **kargs):
217-
assert Event in self.readable_objects, read_error
225+
if Event not in self.readable_objects:
226+
raise RuntimeError(read_error)
218227

219228
def read_epoch(self, **kargs):
220-
assert Epoch in self.readable_objects, read_error
229+
if Epoch not in self.readable_objects:
230+
raise RuntimeError(read_error)
221231

222232
def read_group(self, **kargs):
223-
assert Group in self.readable_objects, read_error
233+
if Group not in self.readable_objects:
234+
raise RuntimeError(read_error)
224235

225236
######## All individual write methods #######################
226237
def write_block(self, bl, **kargs):
227-
assert Block in self.writeable_objects, write_error
238+
if Block not in self.writeable_objects:
239+
raise RuntimeError(write_error)
228240

229241
def write_segment(self, seg, **kargs):
230-
assert Segment in self.writeable_objects, write_error
242+
if Segment not in self.writeable_objects:
243+
raise RuntimeError(write_error)
231244

232245
def write_spiketrain(self, sptr, **kargs):
233-
assert SpikeTrain in self.writeable_objects, write_error
246+
if SpikeTrain not in self.writeable_objects:
247+
raise RuntimeError(write_error)
234248

235249
def write_analogsignal(self, anasig, **kargs):
236-
assert AnalogSignal in self.writeable_objects, write_error
250+
if AnalogSignal not in self.writeable_objects:
251+
raise RuntimeError(write_error)
237252

238253
def write_imagesequence(self, imseq, **kargs):
239-
assert ImageSequence in self.writeable_objects, write_error
254+
if ImageSequence not in self.writeable_objects:
255+
raise RuntimeError(write_error)
240256

241257
def write_rectangularregionofinterest(self, rectroi, **kargs):
242-
assert RectangularRegionOfInterest in self.writeable_objects, read_error
258+
if RectangularRegionOfInterest not in self.writeable_objects:
259+
raise RuntimeError(write_error)
243260

244261
def write_circularregionofinterest(self, circroi, **kargs):
245-
assert CircularRegionOfInterest in self.writeable_objects, read_error
262+
if CircularRegionOfInterest not in self.writeable_objects:
263+
raise RuntimeError(write_error)
246264

247265
def write_polygonregionofinterest(self, polyroi, **kargs):
248-
assert PolygonRegionOfInterest in self.writeable_objects, read_error
266+
if PolygonRegionOfInterest not in self.writeable_objects:
267+
raise RuntimeError(write_error)
249268

250269
def write_irregularlysampledsignal(self, irsig, **kargs):
251-
assert IrregularlySampledSignal in self.writeable_objects, write_error
270+
if IrregularlySampledSignal not in self.writeable_objects:
271+
raise RuntimeError(write_error)
252272

253273
def write_channelview(self, chv, **kargs):
254-
assert ChannelView in self.writeable_objects, write_error
274+
if ChannelView not in self.writeable_objects:
275+
raise RuntimeError(write_error)
255276

256277
def write_event(self, ev, **kargs):
257-
assert Event in self.writeable_objects, write_error
278+
if Event not in self.writeable_objects:
279+
raise RuntimeError(write_error)
258280

259281
def write_epoch(self, ep, **kargs):
260-
assert Epoch in self.writeable_objects, write_error
282+
if Epoch not in self.writeable_objects:
283+
raise RuntimeError(write_error)
261284

262285
def write_group(self, group, **kargs):
263-
assert Group in self.writeable_objects, write_error
286+
if Group not in self.writeable_objects:
287+
raise RuntimeError(write_error)

neo/rawio/baserawio.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -497,23 +497,23 @@ def _check_stream_signal_channel_characteristics(self):
497497
signal_streams = self.header["signal_streams"]
498498
signal_channels = self.header["signal_channels"]
499499
if signal_streams.size > 0:
500-
assert signal_channels.size > 0, "Signal stream but no signal_channels!!!"
500+
if signal_channels.size < 1:
501+
raise ValueError("Signal stream exists but there are no signal channels")
502+
501503

502504
for stream_index in range(signal_streams.size):
503505
stream_id = signal_streams[stream_index]["id"]
504506
mask = signal_channels["stream_id"] == stream_id
505507
characteristics = signal_channels[mask][_common_sig_characteristics]
506508
unique_characteristics = np.unique(characteristics)
507-
assert unique_characteristics.size == 1, (
508-
f"Some channels in stream_id {stream_id} "
509-
f"do not have same {_common_sig_characteristics} {unique_characteristics}"
510-
)
509+
if unique_characteristics.size != 1:
510+
raise ValueError(f"Some channels in stream_id {stream_id} "
511+
f"do not have the same {_common_sig_characteristics} {unique_characteristics}")
511512

512513
# also check that channel_id is unique inside a stream
513514
channel_ids = signal_channels[mask]["id"]
514-
assert (
515-
np.unique(channel_ids).size == channel_ids.size
516-
), f"signal_channels do not have unique ids for stream {stream_index}"
515+
if np.unique(channel_ids) != channel_ids.size:
516+
raise ValueError(f"signal_channels do not have unique ids for stream {stream_index}")
517517

518518
self._several_channel_groups = signal_streams.size > 1
519519

@@ -540,7 +540,8 @@ def channel_name_to_index(self, stream_index: int, channel_names: list[str]):
540540
mask = self.header["signal_channels"]["stream_id"] == stream_id
541541
signal_channels = self.header["signal_channels"][mask]
542542
chan_names = list(signal_channels["name"])
543-
assert signal_channels.size == np.unique(chan_names).size, "Channel names not unique"
543+
if signal_channels.size != np.unique(chan_names).size:
544+
raise ValueError("Channel names are not unique")
544545
channel_indexes = np.array([chan_names.index(name) for name in channel_names])
545546
return channel_indexes
546547

@@ -621,12 +622,12 @@ def _get_stream_index_from_arg(self, stream_index_arg: int | None):
621622
622623
"""
623624
if stream_index_arg is None:
624-
assert self.header["signal_streams"].size == 1, "stream_index must be given for multiple stream files"
625+
if self.header["signal_streams"].size != 1:
626+
raise ValueError("stream_index must be given for files with multiple streams")
625627
stream_index = 0
626628
else:
627-
assert (
628-
0 <= stream_index_arg < self.header["signal_streams"].size
629-
), f"stream_index must be between 0 and {self.header['signal_streams'].size}"
629+
if stream_index_arg < 0 or stream_index_arg >= self.header["signal_streams"].size:
630+
raise ValueError(f"stream_index must be between 0 and {self.header['signal_streams'].size}")
630631
stream_index = stream_index_arg
631632
return stream_index
632633

@@ -786,7 +787,9 @@ def get_analogsignal_chunk(
786787

787788
if isinstance(channel_indexes, np.ndarray):
788789
if channel_indexes.dtype == "bool":
789-
assert self.signal_channels_count(stream_index) == channel_indexes.size
790+
if self.signal_channels_count(stream_index) != channel_indexes.size:
791+
raise ValueError("If channel_indexes is a boolean it must have be the same length as the "
792+
f"number of channels {self.signal_channels_count(stream_index)}")
790793
(channel_indexes,) = np.nonzero(channel_indexes)
791794

792795
if prefer_slice and isinstance(channel_indexes, np.ndarray):
@@ -1210,9 +1213,8 @@ def setup_cache(self, cache_path: "home" | "same_as_resource", **init_kargs):
12101213
elif cache_path == "same_as_resource":
12111214
dirname = os.path.dirname(resource_name)
12121215
else:
1213-
assert os.path.exists(
1214-
cache_path
1215-
), 'cache_path does not exists use "home" or "same_as_resource" to make this auto'
1216+
if not os.path.exists(cache_path):
1217+
raise ValueError("cache_path does not exist use 'home' or 'same_as_resource' to make this auto")
12161218

12171219
# the hash of the resource (dir of file) is done with filename+datetime
12181220
# TODO make something more sophisticated when rawmode='one-dir' that use all
@@ -1233,12 +1235,14 @@ def setup_cache(self, cache_path: "home" | "same_as_resource", **init_kargs):
12331235
self.dump_cache()
12341236

12351237
def add_in_cache(self, **kargs):
1236-
assert self.use_cache
1238+
if not self.use_cache:
1239+
raise ValueError("Can not use add_in_cache if not using cache")
12371240
self._cache.update(kargs)
12381241
self.dump_cache()
12391242

12401243
def dump_cache(self):
1241-
assert self.use_cache
1244+
if not self.use_cache:
1245+
raise ValueError("Can not use dump_cache if not using cache")
12421246
joblib.dump(self._cache, self.cache_filename)
12431247

12441248
##################
@@ -1336,7 +1340,8 @@ def _rescale_epoch_duration(self, raw_duration: np.ndarray, dtype: np.dtype):
13361340

13371341
def pprint_vector(vector, lim: int = 8):
13381342
vector = np.asarray(vector)
1339-
assert vector.ndim == 1
1343+
if vector.ndim != 1:
1344+
raise ValueError(f"`vector` must have a dimension of 1 and not {vector.ndim}")
13401345
if len(vector) > lim:
13411346
part1 = ", ".join(e for e in vector[: lim // 2])
13421347
part2 = " , ".join(e for e in vector[-lim // 2 :])

0 commit comments

Comments
 (0)