@@ -50,8 +50,8 @@ def get_events(container, **properties):
5050 Example:
5151 --------
5252 >>> event = neo.Event(times=[0.5, 10.0, 25.2] * pq.s)
53- >>> event.annotate(event_type='trial start',
54- trial_id=[1, 2, 3])
53+ >>> event.annotate(event_type='trial start')
54+ >>> event.array_annotate( trial_id=[1, 2, 3])
5555 >>> seg = neo.Segment()
5656 >>> seg.events = [event]
5757
@@ -218,38 +218,8 @@ def _event_epoch_slice_by_valid_ids(obj, valid_ids):
218218 """
219219 Internal function
220220 """
221- # modify annotations
222- sparse_annotations = _get_valid_annotations (obj , valid_ids )
223-
224- # modify array annotations
225- sparse_array_annotations = {key : value [valid_ids ]
226- for key , value in obj .array_annotations .items () if len (value )}
227-
228- if obj .labels is not None and obj .labels .size > 0 :
229- labels = obj .labels [valid_ids ]
230- else :
231- labels = obj .labels
232- if type (obj ) is neo .Event :
233- sparse_obj = neo .Event (
234- times = copy .deepcopy (obj .times [valid_ids ]),
235- labels = copy .deepcopy (labels ),
236- units = copy .deepcopy (obj .units ),
237- name = copy .deepcopy (obj .name ),
238- description = copy .deepcopy (obj .description ),
239- file_origin = copy .deepcopy (obj .file_origin ),
240- array_annotations = sparse_array_annotations ,
241- ** sparse_annotations )
242- elif type (obj ) is neo .Epoch :
243- sparse_obj = neo .Epoch (
244- times = copy .deepcopy (obj .times [valid_ids ]),
245- durations = copy .deepcopy (obj .durations [valid_ids ]),
246- labels = copy .deepcopy (labels ),
247- units = copy .deepcopy (obj .units ),
248- name = copy .deepcopy (obj .name ),
249- description = copy .deepcopy (obj .description ),
250- file_origin = copy .deepcopy (obj .file_origin ),
251- array_annotations = sparse_array_annotations ,
252- ** sparse_annotations )
221+ if type (obj ) is neo .Event or type (obj ) is neo .Epoch :
222+ sparse_obj = copy .deepcopy (obj [valid_ids ])
253223 else :
254224 raise TypeError ('Can only slice Event and Epoch objects by valid IDs.' )
255225
@@ -260,77 +230,24 @@ def _get_valid_ids(obj, annotation_key, annotation_value):
260230 """
261231 Internal function
262232 """
263- # wrap annotation value to be list
264- if not type (annotation_value ) in [list , np .ndarray ]:
265- annotation_value = [annotation_value ]
266-
267- # get all real attributes of object
268- attributes = inspect .getmembers (obj )
269- attributes_names = [t [0 ] for t in attributes if not (
270- t [0 ].startswith ('__' ) and t [0 ].endswith ('__' ))]
271- attributes_ids = [i for i , t in enumerate (attributes ) if not (
272- t [0 ].startswith ('__' ) and t [0 ].endswith ('__' ))]
273-
274- # check if annotation is present
275- value_avail = False
276- if annotation_key in obj .annotations :
277- check_value = obj .annotations [annotation_key ]
278- value_avail = True
279- elif annotation_key in obj .array_annotations :
280- check_value = obj .array_annotations [annotation_key ]
281- value_avail = True
282- elif annotation_key in attributes_names :
283- check_value = attributes [attributes_ids [
284- attributes_names .index (annotation_key )]][1 ]
285- value_avail = True
286-
287- if value_avail :
288- # check if annotation is list and fits to length of object list
289- if not _is_annotation_list (check_value , len (obj )):
290- # check if annotation is single value and fits to requested value
291- if check_value in annotation_value :
292- valid_mask = np .ones (obj .shape )
293- else :
294- valid_mask = np .zeros (obj .shape )
295- if type (check_value ) != str :
296- warnings .warn (
297- 'Length of annotation "%s" (%s) does not fit '
298- 'to length of object list (%s)' % (
299- annotation_key , len (check_value ), len (obj )))
300-
301- # extract object entries, which match requested annotation
302- else :
303- valid_mask = np .zeros (obj .shape )
304- for obj_id in range (len (obj )):
305- if check_value [obj_id ] in annotation_value :
306- valid_mask [obj_id ] = True
307- else :
308- valid_mask = np .zeros (obj .shape )
309233
310- valid_ids = np .where ( valid_mask )[ 0 ]
234+ valid_mask = np .zeros ( obj . shape )
311235
312- return valid_ids
236+ if annotation_key in obj .annotations and obj .annotations [annotation_key ] == annotation_value :
237+ valid_mask = np .ones (obj .shape )
313238
239+ elif annotation_key in obj .array_annotations :
240+ # wrap annotation value to be list
241+ if not type (annotation_value ) in [list , np .ndarray ]:
242+ annotation_value = [annotation_value ]
243+ valid_mask = np .in1d (obj .array_annotations [annotation_key ], annotation_value )
314244
315- def _get_valid_annotations (obj , valid_ids ):
316- """
317- Internal function
318- """
319- sparse_annotations = copy .deepcopy (obj .annotations )
320- for key in sparse_annotations :
321- if _is_annotation_list (sparse_annotations [key ], len (obj )):
322- sparse_annotations [key ] = list (np .array (sparse_annotations [key ])[
323- valid_ids ])
324- return sparse_annotations
245+ elif hasattr (obj , annotation_key ) and getattr (obj , annotation_key ) == annotation_value :
246+ valid_mask = np .ones (obj .shape )
325247
248+ valid_ids = np .where (valid_mask )[0 ]
326249
327- def _is_annotation_list (value , exp_length ):
328- """
329- Internal function
330- """
331- return (
332- (isinstance (value , list ) or (
333- isinstance (value , np .ndarray ) and value .ndim > 0 )) and (len (value ) == exp_length ))
250+ return valid_ids
334251
335252
336253def add_epoch (
@@ -421,6 +338,7 @@ def add_epoch(
421338 ep = neo .Epoch (times = times , durations = durations , ** kwargs )
422339
423340 ep .annotate (** event1 .annotations )
341+ ep .array_annotate (** event1 .array_annotations )
424342
425343 if attach_result :
426344 segment .epochs .append (ep )
@@ -516,10 +434,10 @@ def cut_block_by_epochs(block, properties=None, reset_time=False):
516434 Contains the Segments to cut according to the Epoch criteria provided
517435 properties: dictionary
518436 A dictionary that contains the Epoch keys and values to filter for.
519- Each key of the dictionary is matched to an attribute or an an
520- annotation of the Event. The value of each dictionary entry corresponds
521- to a valid entry or a list of valid entries of the attribute or
522- annotation.
437+ Each key of the dictionary is matched to an attribute or an
438+ annotation or an array_annotation of the Event.
439+ The value of each dictionary entry corresponds to a valid entry or a
440+ list of valid entries of the attribute or (array) annotation.
523441
524442 If the value belonging to the key is a list of entries of the same
525443 length as the number of epochs in the Epoch object, the list entries
@@ -619,13 +537,7 @@ def cut_segment_by_epoch(seg, epoch, reset_time=False):
619537 epoch .times [ep_id ] + epoch .durations [ep_id ],
620538 reset_time = reset_time )
621539
622- # Add annotations of Epoch
623- for a in epoch .annotations :
624- if type (epoch .annotations [a ]) is list \
625- and len (epoch .annotations [a ]) == len (epoch ):
626- subseg .annotations [a ] = copy .copy (epoch .annotations [a ][ep_id ])
627- else :
628- subseg .annotations [a ] = copy .copy (epoch .annotations [a ])
540+ subseg .annotate (** copy .copy (epoch .annotations ))
629541
630542 # Add array-annotations of Epoch
631543 for key , val in epoch .array_annotations .items ():
0 commit comments