Skip to content

Commit 39c14da

Browse files
authored
Merge pull request #789 from INM-6/enh/array_anno_utils
Improve utilities
2 parents 7491749 + acb13ce commit 39c14da

File tree

1 file changed

+22
-110
lines changed

1 file changed

+22
-110
lines changed

neo/utils.py

Lines changed: 22 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def get_events(container, **properties):
5050
Example:
5151
--------
5252
>>> event = neo.Event(times=[0.5, 10.0, 25.2] * pq.s)
53-
>>> event.annotate(event_type='trial start',
54-
trial_id=[1, 2, 3])
53+
>>> event.annotate(event_type='trial start')
54+
>>> event.array_annotate(trial_id=[1, 2, 3])
5555
>>> seg = neo.Segment()
5656
>>> seg.events = [event]
5757
@@ -218,38 +218,8 @@ def _event_epoch_slice_by_valid_ids(obj, valid_ids):
218218
"""
219219
Internal function
220220
"""
221-
# modify annotations
222-
sparse_annotations = _get_valid_annotations(obj, valid_ids)
223-
224-
# modify array annotations
225-
sparse_array_annotations = {key: value[valid_ids]
226-
for key, value in obj.array_annotations.items() if len(value)}
227-
228-
if obj.labels is not None and obj.labels.size > 0:
229-
labels = obj.labels[valid_ids]
230-
else:
231-
labels = obj.labels
232-
if type(obj) is neo.Event:
233-
sparse_obj = neo.Event(
234-
times=copy.deepcopy(obj.times[valid_ids]),
235-
labels=copy.deepcopy(labels),
236-
units=copy.deepcopy(obj.units),
237-
name=copy.deepcopy(obj.name),
238-
description=copy.deepcopy(obj.description),
239-
file_origin=copy.deepcopy(obj.file_origin),
240-
array_annotations=sparse_array_annotations,
241-
**sparse_annotations)
242-
elif type(obj) is neo.Epoch:
243-
sparse_obj = neo.Epoch(
244-
times=copy.deepcopy(obj.times[valid_ids]),
245-
durations=copy.deepcopy(obj.durations[valid_ids]),
246-
labels=copy.deepcopy(labels),
247-
units=copy.deepcopy(obj.units),
248-
name=copy.deepcopy(obj.name),
249-
description=copy.deepcopy(obj.description),
250-
file_origin=copy.deepcopy(obj.file_origin),
251-
array_annotations=sparse_array_annotations,
252-
**sparse_annotations)
221+
if type(obj) is neo.Event or type(obj) is neo.Epoch:
222+
sparse_obj = copy.deepcopy(obj[valid_ids])
253223
else:
254224
raise TypeError('Can only slice Event and Epoch objects by valid IDs.')
255225

@@ -260,77 +230,24 @@ def _get_valid_ids(obj, annotation_key, annotation_value):
260230
"""
261231
Internal function
262232
"""
263-
# wrap annotation value to be list
264-
if not type(annotation_value) in [list, np.ndarray]:
265-
annotation_value = [annotation_value]
266-
267-
# get all real attributes of object
268-
attributes = inspect.getmembers(obj)
269-
attributes_names = [t[0] for t in attributes if not(
270-
t[0].startswith('__') and t[0].endswith('__'))]
271-
attributes_ids = [i for i, t in enumerate(attributes) if not(
272-
t[0].startswith('__') and t[0].endswith('__'))]
273-
274-
# check if annotation is present
275-
value_avail = False
276-
if annotation_key in obj.annotations:
277-
check_value = obj.annotations[annotation_key]
278-
value_avail = True
279-
elif annotation_key in obj.array_annotations:
280-
check_value = obj.array_annotations[annotation_key]
281-
value_avail = True
282-
elif annotation_key in attributes_names:
283-
check_value = attributes[attributes_ids[
284-
attributes_names.index(annotation_key)]][1]
285-
value_avail = True
286-
287-
if value_avail:
288-
# check if annotation is list and fits to length of object list
289-
if not _is_annotation_list(check_value, len(obj)):
290-
# check if annotation is single value and fits to requested value
291-
if check_value in annotation_value:
292-
valid_mask = np.ones(obj.shape)
293-
else:
294-
valid_mask = np.zeros(obj.shape)
295-
if type(check_value) != str:
296-
warnings.warn(
297-
'Length of annotation "%s" (%s) does not fit '
298-
'to length of object list (%s)' % (
299-
annotation_key, len(check_value), len(obj)))
300-
301-
# extract object entries, which match requested annotation
302-
else:
303-
valid_mask = np.zeros(obj.shape)
304-
for obj_id in range(len(obj)):
305-
if check_value[obj_id] in annotation_value:
306-
valid_mask[obj_id] = True
307-
else:
308-
valid_mask = np.zeros(obj.shape)
309233

310-
valid_ids = np.where(valid_mask)[0]
234+
valid_mask = np.zeros(obj.shape)
311235

312-
return valid_ids
236+
if annotation_key in obj.annotations and obj.annotations[annotation_key] == annotation_value:
237+
valid_mask = np.ones(obj.shape)
313238

239+
elif annotation_key in obj.array_annotations:
240+
# wrap annotation value to be list
241+
if not type(annotation_value) in [list, np.ndarray]:
242+
annotation_value = [annotation_value]
243+
valid_mask = np.in1d(obj.array_annotations[annotation_key], annotation_value)
314244

315-
def _get_valid_annotations(obj, valid_ids):
316-
"""
317-
Internal function
318-
"""
319-
sparse_annotations = copy.deepcopy(obj.annotations)
320-
for key in sparse_annotations:
321-
if _is_annotation_list(sparse_annotations[key], len(obj)):
322-
sparse_annotations[key] = list(np.array(sparse_annotations[key])[
323-
valid_ids])
324-
return sparse_annotations
245+
elif hasattr(obj, annotation_key) and getattr(obj, annotation_key) == annotation_value:
246+
valid_mask = np.ones(obj.shape)
325247

248+
valid_ids = np.where(valid_mask)[0]
326249

327-
def _is_annotation_list(value, exp_length):
328-
"""
329-
Internal function
330-
"""
331-
return (
332-
(isinstance(value, list) or (
333-
isinstance(value, np.ndarray) and value.ndim > 0)) and (len(value) == exp_length))
250+
return valid_ids
334251

335252

336253
def add_epoch(
@@ -421,6 +338,7 @@ def add_epoch(
421338
ep = neo.Epoch(times=times, durations=durations, **kwargs)
422339

423340
ep.annotate(**event1.annotations)
341+
ep.array_annotate(**event1.array_annotations)
424342

425343
if attach_result:
426344
segment.epochs.append(ep)
@@ -516,10 +434,10 @@ def cut_block_by_epochs(block, properties=None, reset_time=False):
516434
Contains the Segments to cut according to the Epoch criteria provided
517435
properties: dictionary
518436
A dictionary that contains the Epoch keys and values to filter for.
519-
Each key of the dictionary is matched to an attribute or an an
520-
annotation of the Event. The value of each dictionary entry corresponds
521-
to a valid entry or a list of valid entries of the attribute or
522-
annotation.
437+
Each key of the dictionary is matched to an attribute or an
438+
annotation or an array_annotation of the Event.
439+
The value of each dictionary entry corresponds to a valid entry or a
440+
list of valid entries of the attribute or (array) annotation.
523441
524442
If the value belonging to the key is a list of entries of the same
525443
length as the number of epochs in the Epoch object, the list entries
@@ -619,13 +537,7 @@ def cut_segment_by_epoch(seg, epoch, reset_time=False):
619537
epoch.times[ep_id] + epoch.durations[ep_id],
620538
reset_time=reset_time)
621539

622-
# Add annotations of Epoch
623-
for a in epoch.annotations:
624-
if type(epoch.annotations[a]) is list \
625-
and len(epoch.annotations[a]) == len(epoch):
626-
subseg.annotations[a] = copy.copy(epoch.annotations[a][ep_id])
627-
else:
628-
subseg.annotations[a] = copy.copy(epoch.annotations[a])
540+
subseg.annotate(**copy.copy(epoch.annotations))
629541

630542
# Add array-annotations of Epoch
631543
for key, val in epoch.array_annotations.items():

0 commit comments

Comments
 (0)