Skip to content

Commit 9c07a97

Browse files
legoueeapdavison
andauthored
Consistent parameters of NWBIO.write_block (#1166)
* Consistent write_block * try to fix failing test * fix another source of test failure * Move building of global annotations to a method, since it needs to be performed when `write_all_blocks()` is called, not on class initialization. Co-authored-by: Andrew Davison <[email protected]>
1 parent 217b299 commit 9c07a97

File tree

1 file changed

+68
-51
lines changed

1 file changed

+68
-51
lines changed

neo/io/nwbio.py

Lines changed: 68 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
import numpy as np
2626
import quantities as pq
27-
2827
from neo.core import (Segment, SpikeTrain, Epoch, Event, AnalogSignal,
2928
IrregularlySampledSignal, Block, ImageSequence)
3029
from neo.io.baseio import BaseIO
@@ -43,8 +42,7 @@
4342
"experiment_description", "session_id", "institution", "keywords", "notes",
4443
"pharmacology", "protocol", "related_publications", "slices", "source_script",
4544
"source_script_file_name", "data_collection", "surgery", "virus", "stimulus_notes",
46-
"lab", "session_description",
47-
"rec_datetime",
45+
"lab", "session_description", "rec_datetime",
4846
)
4947

5048
POSSIBLE_JSON_FIELDS = (
@@ -207,7 +205,7 @@ class NWBIO(BaseIO):
207205
is_writable = True
208206
is_streameable = False
209207

210-
def __init__(self, filename, mode='r'):
208+
def __init__(self, filename, mode='r', **annotations):
211209
"""
212210
Arguments:
213211
filename : the filename
@@ -218,6 +216,9 @@ def __init__(self, filename, mode='r'):
218216
self.filename = filename
219217
self.blocks_written = 0
220218
self.nwb_file_mode = mode
219+
self._blocks = {}
220+
self.annotations = annotations
221+
self._io_nwb = None
221222

222223
def read_all_blocks(self, lazy=False, **kwargs):
223224
"""
@@ -383,17 +384,11 @@ def _read_acquisition_group(self, lazy):
383384
def _read_stimulus_group(self, lazy):
384385
self._read_timeseries_group("stimulus", lazy)
385386

386-
def write_all_blocks(self, blocks, **kwargs):
387-
"""
388-
Write list of blocks to the file
389-
"""
390-
import pynwb
391-
392-
# todo: allow metadata in NWBFile constructor to be taken from kwargs
387+
def _build_global_annotations(self, blocks):
393388
annotations = defaultdict(set)
394389
for annotation_name in GLOBAL_ANNOTATIONS:
395-
if annotation_name in kwargs:
396-
annotations[annotation_name] = kwargs[annotation_name]
390+
if annotation_name in self.annotations:
391+
annotations[annotation_name] = self.annotations[annotation_name]
397392
else:
398393
for block in blocks:
399394
if annotation_name in block.annotations:
@@ -411,65 +406,86 @@ def write_all_blocks(self, blocks, **kwargs):
411406
"We don't yet support multiple values for {}".format(annotation_name))
412407
# take single value from set
413408
annotations[annotation_name], = annotations[annotation_name]
409+
414410
if "identifier" not in annotations:
415-
annotations["identifier"] = self.filename
411+
annotations["identifier"] = str(self.filename)
416412
if "session_description" not in annotations:
417-
annotations["session_description"] = blocks[0].description or self.filename
413+
annotations["session_description"] = blocks[0].description or str(self.filename)
418414
# todo: concatenate descriptions of multiple blocks if different
419-
if "session_start_time" not in annotations:
420-
annotations["session_start_time"] = blocks[0].rec_datetime
421-
if annotations["session_start_time"] is None:
415+
if annotations.get("session_start_time", None) is None:
416+
if "rec_datetime" in annotations:
417+
annotations["session_start_time"] = annotations["rec_datetime"]
418+
else:
422419
raise Exception("Writing to NWB requires an annotation 'session_start_time'")
423-
self.annotations = {"rec_datetime": "rec_datetime"}
424-
self.annotations["rec_datetime"] = blocks[0].rec_datetime
425-
# todo: handle subject
426-
nwbfile = pynwb.NWBFile(**annotations)
427-
assert self.nwb_file_mode in ('w',) # possibly expand to 'a'ppend later
428-
if self.nwb_file_mode == "w" and os.path.exists(self.filename):
429-
os.remove(self.filename)
430-
io_nwb = pynwb.NWBHDF5IO(self.filename, mode=self.nwb_file_mode)
420+
return annotations
421+
422+
def write_all_blocks(self, blocks, validate=True, **kwargs):
423+
"""
424+
Write list of blocks to the file
425+
"""
426+
import pynwb
427+
428+
global_annotations = self._build_global_annotations(blocks)
429+
self._nwbfile = pynwb.NWBFile(**global_annotations)
431430

432431
if sum(statistics(block)["SpikeTrain"]["count"] for block in blocks) > 0:
433-
nwbfile.add_unit_column('_name', 'the name attribute of the SpikeTrain')
432+
self._nwbfile.add_unit_column('_name', 'the name attribute of the SpikeTrain')
434433
# nwbfile.add_unit_column('_description',
435434
# 'the description attribute of the SpikeTrain')
436-
nwbfile.add_unit_column(
435+
self._nwbfile.add_unit_column(
437436
'segment', 'the name of the Neo Segment to which the SpikeTrain belongs')
438-
nwbfile.add_unit_column(
437+
self._nwbfile.add_unit_column(
439438
'block', 'the name of the Neo Block to which the SpikeTrain belongs')
440439

441440
if sum(statistics(block)["Epoch"]["count"] for block in blocks) > 0:
442-
nwbfile.add_epoch_column('_name', 'the name attribute of the Epoch')
441+
self._nwbfile.add_epoch_column('_name', 'the name attribute of the Epoch')
443442
# nwbfile.add_epoch_column('_description', 'the description attribute of the Epoch')
444-
nwbfile.add_epoch_column(
443+
self._nwbfile.add_epoch_column(
445444
'segment', 'the name of the Neo Segment to which the Epoch belongs')
446-
nwbfile.add_epoch_column('block',
445+
self._nwbfile.add_epoch_column('block',
447446
'the name of the Neo Block to which the Epoch belongs')
448447

449448
for i, block in enumerate(blocks):
450-
self.write_block(nwbfile, block)
451-
io_nwb.write(nwbfile)
449+
self._write_block(block)
450+
451+
assert self.nwb_file_mode in ('w',) # possibly expand to 'a'ppend later
452+
if self.nwb_file_mode == "w" and os.path.exists(self.filename):
453+
os.remove(self.filename)
454+
io_nwb = pynwb.NWBHDF5IO(self.filename, mode=self.nwb_file_mode)
455+
io_nwb.write(self._nwbfile)
452456
io_nwb.close()
453457

458+
if validate:
459+
self.validate_file()
460+
461+
def validate_file(self):
462+
import pynwb
463+
454464
with pynwb.NWBHDF5IO(self.filename, "r") as io_validate:
455465
errors = pynwb.validate(io_validate, namespace="core")
456466
if errors:
457467
raise Exception(f"Errors found when validating {self.filename}")
458468

459-
def write_block(self, nwbfile, block, **kwargs):
469+
def write_block(self, block, **kwargs):
470+
"""
471+
Write a single Block to the file
472+
:param block: Block to be written
473+
"""
474+
return self.write_all_blocks([block], **kwargs)
475+
476+
def _write_block(self, block):
460477
"""
461478
Write a Block to the file
462479
:param block: Block to be written
463-
:param nwbfile: Representation of an NWB file
464480
"""
465-
electrodes = self._write_electrodes(nwbfile, block)
481+
electrodes = self._write_electrodes(self._nwbfile, block)
466482
if not block.name:
467483
block.name = "block%d" % self.blocks_written
468484
for i, segment in enumerate(block.segments):
469485
assert segment.block is block
470486
if not segment.name:
471487
segment.name = "%s : segment%d" % (block.name, i)
472-
self._write_segment(nwbfile, segment, electrodes)
488+
self._write_segment(self._nwbfile, segment, electrodes)
473489
self.blocks_written += 1
474490

475491
def _write_electrodes(self, nwbfile, block):
@@ -485,10 +501,10 @@ def _write_electrodes(self, nwbfile, block):
485501
if elec_meta["device"]["name"] in devices:
486502
device = devices[elec_meta["device"]["name"]]
487503
else:
488-
device = nwbfile.create_device(**elec_meta["device"])
504+
device = self._nwbfile.create_device(**elec_meta["device"])
489505
devices[elec_meta["device"]["name"]] = device
490506
elec_meta.pop("device")
491-
electrodes[elec_meta["name"]] = nwbfile.create_icephys_electrode(
507+
electrodes[elec_meta["name"]] = self._nwbfile.create_icephys_electrode(
492508
device=device, **elec_meta
493509
)
494510
return electrodes
@@ -503,13 +519,13 @@ def _write_segment(self, nwbfile, segment, electrodes):
503519
logging.warning("Warning signal name exists. New name: %s" % (signal.name))
504520
else:
505521
signal.name = "%s : analogsignal%s %i" % (segment.name, signal.name, i)
506-
self._write_signal(nwbfile, signal, electrodes)
522+
self._write_signal(self._nwbfile, signal, electrodes)
507523

508524
for i, train in enumerate(segment.spiketrains):
509525
assert train.segment is segment
510526
if not train.name:
511527
train.name = "%s : spiketrain%d" % (segment.name, i)
512-
self._write_spiketrain(nwbfile, train)
528+
self._write_spiketrain(self._nwbfile, train)
513529

514530
for i, event in enumerate(segment.events):
515531
assert event.segment is segment
@@ -518,12 +534,12 @@ def _write_segment(self, nwbfile, segment, electrodes):
518534
logging.warning("Warning event name exists. New name: %s" % (event.name))
519535
else:
520536
event.name = "%s : event%s %d" % (segment.name, event.name, i)
521-
self._write_event(nwbfile, event)
537+
self._write_event(self._nwbfile, event)
522538

523539
for i, epoch in enumerate(segment.epochs):
524540
if not epoch.name:
525541
epoch.name = "%s : epoch%d" % (segment.name, i)
526-
self._write_epoch(nwbfile, epoch)
542+
self._write_epoch(self._nwbfile, epoch)
527543

528544
def _write_signal(self, nwbfile, signal, electrodes):
529545
import pynwb
@@ -572,8 +588,8 @@ def _write_signal(self, nwbfile, signal, electrodes):
572588
signal.__class__.__name__))
573589
nwb_group = signal.annotations.get("nwb_group", "acquisition")
574590
add_method_map = {
575-
"acquisition": nwbfile.add_acquisition,
576-
"stimulus": nwbfile.add_stimulus
591+
"acquisition": self._nwbfile.add_acquisition,
592+
"stimulus": self._nwbfile.add_stimulus
577593
}
578594
if nwb_group in add_method_map:
579595
add_time_series = add_method_map[nwb_group]
@@ -586,7 +602,8 @@ def _write_spiketrain(self, nwbfile, spiketrain):
586602
segment = spiketrain.segment
587603
if hasattr(spiketrain, 'proxy_for') and spiketrain.proxy_for is SpikeTrain:
588604
spiketrain = spiketrain.load()
589-
nwbfile.add_unit(spike_times=spiketrain.rescale('s').magnitude,
605+
self._nwbfile.add_unit(
606+
spike_times=spiketrain.rescale('s').magnitude,
590607
obs_intervals=[[float(spiketrain.t_start.rescale('s')),
591608
float(spiketrain.t_stop.rescale('s'))]],
592609
_name=spiketrain.name,
@@ -596,7 +613,7 @@ def _write_spiketrain(self, nwbfile, spiketrain):
596613
# todo: handle annotations (using add_unit_column()?)
597614
# todo: handle Neo Units
598615
# todo: handle spike waveforms, if any (see SpikeEventSeries)
599-
return nwbfile.units
616+
return self._nwbfile.units
600617

601618
def _write_event(self, nwbfile, event):
602619
import pynwb
@@ -611,7 +628,7 @@ def _write_event(self, nwbfile, event):
611628
timestamps=event.times.rescale('second').magnitude,
612629
description=event.description or "",
613630
comments=json.dumps(hierarchy))
614-
nwbfile.add_acquisition(tS_evt)
631+
self._nwbfile.add_acquisition(tS_evt)
615632
return tS_evt
616633

617634
def _write_epoch(self, nwbfile, epoch):
@@ -621,11 +638,11 @@ def _write_epoch(self, nwbfile, epoch):
621638
for t_start, duration, label in zip(epoch.rescale('s').magnitude,
622639
epoch.durations.rescale('s').magnitude,
623640
epoch.labels):
624-
nwbfile.add_epoch(t_start, t_start + duration, [label], [],
641+
self._nwbfile.add_epoch(t_start, t_start + duration, [label], [],
625642
_name=epoch.name,
626643
segment=segment.name,
627644
block=segment.block.name)
628-
return nwbfile.epochs
645+
return self._nwbfile.epochs
629646

630647

631648
class AnalogSignalProxy(BaseAnalogSignalProxy):

0 commit comments

Comments
 (0)