Skip to content

Commit 5be43dc

Browse files
author
kleinjohann
committed
Add array annotations
write as properties, read as array annotation if the length is correct
1 parent b4c5950 commit 5be43dc

File tree

1 file changed

+155
-5
lines changed

1 file changed

+155
-5
lines changed

neo/io/nixio.py

Lines changed: 155 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -397,9 +397,35 @@ def _nix_to_neo_analogsignal(self, nix_da_group):
397397
else:
398398
t_start = create_quantity(timedim.offset, timedim.unit)
399399

400+
# find annotations with length equal to signal length and turn them into array-annotations
401+
try:
402+
sig_length = signaldata.shape[-1]
403+
# Note: This is because __getitem__[int] returns a scalar Epoch/Event/SpikeTrain
404+
# To be removed if __getitem__[int] is changed
405+
except IndexError:
406+
sig_length = 1
407+
408+
array_annotations = {}
409+
410+
if sig_length > 1:
411+
for attr_key, attr_val in neo_attrs.items():
412+
if isinstance(attr_val, (list, np.ndarray)) or (
413+
isinstance(attr_val, pq.Quantity) and attr_val.shape == ()):
414+
if len(attr_val) == sig_length:
415+
if isinstance(attr_val, list) or (isinstance(attr_val, np.ndarray) and not (
416+
isinstance(attr_val, pq.Quantity) and (
417+
attr_val.shape == () or attr_val.shape == (1,)))):
418+
# Array annotations should only be 1-dimensional
419+
continue
420+
if isinstance(attr_val, dict):
421+
# Dictionaries are not supported as array annotations
422+
continue
423+
array_annotations[attr_key] = attr_val
424+
del neo_attrs[attr_key]
425+
400426
neo_signal = AnalogSignal(
401427
signal=signaldata, sampling_period=sampling_period,
402-
t_start=t_start, **neo_attrs
428+
t_start=t_start, array_annotations=array_annotations, **neo_attrs
403429
)
404430
self._neo_map[neo_attrs["nix_name"]] = neo_signal
405431
# all DAs reference the same sources
@@ -428,8 +454,35 @@ def _nix_to_neo_irregularlysampledsignal(self, nix_da_group):
428454
signaldata = create_quantity(signaldata, unit)
429455
timedim = self._get_time_dimension(nix_da_group[0])
430456
times = create_quantity(timedim.ticks, timedim.unit)
457+
458+
# find annotations with length equal to signal length and turn them into array-annotations
459+
try:
460+
sig_length = signaldata.shape[-1]
461+
# Note: This is because __getitem__[int] returns a scalar Epoch/Event/SpikeTrain
462+
# To be removed if __getitem__[int] is changed
463+
except IndexError:
464+
sig_length = 1
465+
466+
array_annotations = {}
467+
468+
if sig_length > 1:
469+
for attr_key, attr_val in neo_attrs.items():
470+
if isinstance(attr_val, (list, np.ndarray)) or (
471+
isinstance(attr_val, pq.Quantity) and attr_val.shape == ()):
472+
if len(attr_val) == sig_length:
473+
if isinstance(attr_val, list) or (isinstance(attr_val, np.ndarray) and not (
474+
isinstance(attr_val, pq.Quantity) and (
475+
attr_val.shape == () or attr_val.shape == (1,)))):
476+
# Array annotations should only be 1-dimensional
477+
continue
478+
if isinstance(attr_val, dict):
479+
# Dictionaries are not supported as array annotations
480+
continue
481+
array_annotations[attr_key] = attr_val
482+
del neo_attrs[attr_key]
483+
431484
neo_signal = IrregularlySampledSignal(
432-
signal=signaldata, times=times, **neo_attrs
485+
signal=signaldata, times=times, array_annotations=array_annotations, **neo_attrs
433486
)
434487
self._neo_map[neo_attrs["nix_name"]] = neo_signal
435488
# all DAs reference the same sources
@@ -446,7 +499,35 @@ def _nix_to_neo_event(self, nix_mtag):
446499
times = create_quantity(nix_mtag.positions, time_unit)
447500
labels = np.array(nix_mtag.positions.dimensions[0].labels,
448501
dtype="S")
449-
neo_event = Event(times=times, labels=labels, **neo_attrs)
502+
503+
# find annotations with length equal to event length and turn them into array-annotations
504+
try:
505+
sig_length = times.shape[-1]
506+
# Note: This is because __getitem__[int] returns a scalar Epoch/Event/SpikeTrain
507+
# To be removed if __getitem__[int] is changed
508+
except IndexError:
509+
sig_length = 1
510+
511+
array_annotations = {}
512+
513+
if sig_length > 1:
514+
for attr_key, attr_val in neo_attrs.items():
515+
if isinstance(attr_val, (list, np.ndarray)) or (
516+
isinstance(attr_val, pq.Quantity) and attr_val.shape == ()):
517+
if len(attr_val) == sig_length:
518+
if isinstance(attr_val, list) or (isinstance(attr_val, np.ndarray) and not (
519+
isinstance(attr_val, pq.Quantity) and (
520+
attr_val.shape == () or attr_val.shape == (1,)))):
521+
# Array annotations should only be 1-dimensional
522+
continue
523+
if isinstance(attr_val, dict):
524+
# Dictionaries are not supported as array annotations
525+
continue
526+
array_annotations[attr_key] = attr_val
527+
del neo_attrs[attr_key]
528+
529+
neo_event = Event(times=times, labels=labels, array_annotations=array_annotations,
530+
**neo_attrs)
450531
self._neo_map[nix_mtag.name] = neo_event
451532
return neo_event
452533

@@ -456,21 +537,75 @@ def _nix_to_neo_epoch(self, nix_mtag):
456537
times = create_quantity(nix_mtag.positions, time_unit)
457538
durations = create_quantity(nix_mtag.extents,
458539
nix_mtag.extents.unit)
540+
541+
# find annotations with length equal to event length and turn them into array-annotations
542+
try:
543+
sig_length = times.shape[-1]
544+
# Note: This is because __getitem__[int] returns a scalar Epoch/Event/SpikeTrain
545+
# To be removed if __getitem__[int] is changed
546+
except IndexError:
547+
sig_length = 1
548+
549+
array_annotations = {}
550+
551+
if sig_length > 1:
552+
for attr_key, attr_val in neo_attrs.items():
553+
if isinstance(attr_val, (list, np.ndarray)) or (
554+
isinstance(attr_val, pq.Quantity) and attr_val.shape == ()):
555+
if len(attr_val) == sig_length:
556+
if isinstance(attr_val, list) or (isinstance(attr_val, np.ndarray) and not (
557+
isinstance(attr_val, pq.Quantity) and (
558+
attr_val.shape == () or attr_val.shape == (1,)))):
559+
# Array annotations should only be 1-dimensional
560+
continue
561+
if isinstance(attr_val, dict):
562+
# Dictionaries are not supported as array annotations
563+
continue
564+
array_annotations[attr_key] = attr_val
565+
del neo_attrs[attr_key]
566+
459567
if len(nix_mtag.positions.dimensions[0].labels) > 0:
460568
labels = np.array(nix_mtag.positions.dimensions[0].labels,
461569
dtype="S")
462570
else:
463571
labels = None
464572
neo_epoch = Epoch(times=times, durations=durations, labels=labels,
465-
**neo_attrs)
573+
array_annotations=array_annotations, **neo_attrs)
466574
self._neo_map[nix_mtag.name] = neo_epoch
467575
return neo_epoch
468576

469577
def _nix_to_neo_spiketrain(self, nix_mtag):
470578
neo_attrs = self._nix_attr_to_neo(nix_mtag)
471579
time_unit = nix_mtag.positions.unit
472580
times = create_quantity(nix_mtag.positions, time_unit)
473-
neo_spiketrain = SpikeTrain(times=times, **neo_attrs)
581+
582+
# find annotations with length equal to event length and turn them into array-annotations
583+
try:
584+
sig_length = times.shape[-1]
585+
# Note: This is because __getitem__[int] returns a scalar Epoch/Event/SpikeTrain
586+
# To be removed if __getitem__[int] is changed
587+
except IndexError:
588+
sig_length = 1
589+
590+
array_annotations = {}
591+
592+
if sig_length > 1:
593+
for attr_key, attr_val in neo_attrs.items():
594+
if isinstance(attr_val, (list, np.ndarray)) or (
595+
isinstance(attr_val, pq.Quantity) and attr_val.shape == ()):
596+
if len(attr_val) == sig_length:
597+
if isinstance(attr_val, list) or (isinstance(attr_val, np.ndarray) and not (
598+
isinstance(attr_val, pq.Quantity) and (
599+
attr_val.shape == () or attr_val.shape == (1,)))):
600+
# Array annotations should only be 1-dimensional
601+
continue
602+
if isinstance(attr_val, dict):
603+
# Dictionaries are not supported as array annotations
604+
continue
605+
array_annotations[attr_key] = attr_val
606+
del neo_attrs[attr_key]
607+
608+
neo_spiketrain = SpikeTrain(times=times, array_annotations=array_annotations, **neo_attrs)
474609
if nix_mtag.features:
475610
wfda = nix_mtag.features[0].data
476611
wftime = self._get_time_dimension(wfda)
@@ -724,6 +859,9 @@ def _write_analogsignal(self, anasig, nixblock, nixgroup):
724859
if anasig.annotations:
725860
for k, v in anasig.annotations.items():
726861
self._write_property(metadata, k, v)
862+
if anasig.array_annotations:
863+
for k, v in anasig.array_annotations.items():
864+
self._write_property(metadata, k, v)
727865

728866
self._signal_map[nix_name] = nixdas
729867

@@ -787,6 +925,9 @@ def _write_irregularlysampledsignal(self, irsig, nixblock, nixgroup):
787925
if irsig.annotations:
788926
for k, v in irsig.annotations.items():
789927
self._write_property(metadata, k, v)
928+
if irsig.array_annotations:
929+
for k, v in irsig.array_annotations.items():
930+
self._write_property(metadata, k, v)
790931

791932
self._signal_map[nix_name] = nixdas
792933

@@ -834,6 +975,9 @@ def _write_event(self, event, nixblock, nixgroup):
834975
if event.annotations:
835976
for k, v in event.annotations.items():
836977
self._write_property(metadata, k, v)
978+
if event.array_annotations:
979+
for k, v in event.array_annotations.items():
980+
self._write_property(metadata, k, v)
837981

838982
nixgroup.multi_tags.append(nixmt)
839983

@@ -896,6 +1040,9 @@ def _write_epoch(self, epoch, nixblock, nixgroup):
8961040
if epoch.annotations:
8971041
for k, v in epoch.annotations.items():
8981042
self._write_property(metadata, k, v)
1043+
if epoch.array_annotations:
1044+
for k, v in epoch.array_annotations.items():
1045+
self._write_property(metadata, k, v)
8991046

9001047
nixgroup.multi_tags.append(nixmt)
9011048

@@ -950,6 +1097,9 @@ def _write_spiketrain(self, spiketrain, nixblock, nixgroup):
9501097
if spiketrain.annotations:
9511098
for k, v in spiketrain.annotations.items():
9521099
self._write_property(metadata, k, v)
1100+
if spiketrain.array_annotations:
1101+
for k, v in spiketrain.array_annotations.items():
1102+
self._write_property(metadata, k, v)
9531103

9541104
if nixgroup:
9551105
nixgroup.multi_tags.append(nixmt)

0 commit comments

Comments
 (0)