2424
2525import numpy as np
2626import quantities as pq
27-
2827from neo .core import (Segment , SpikeTrain , Epoch , Event , AnalogSignal ,
2928 IrregularlySampledSignal , Block , ImageSequence )
3029from neo .io .baseio import BaseIO
4342 "experiment_description" , "session_id" , "institution" , "keywords" , "notes" ,
4443 "pharmacology" , "protocol" , "related_publications" , "slices" , "source_script" ,
4544 "source_script_file_name" , "data_collection" , "surgery" , "virus" , "stimulus_notes" ,
46- "lab" , "session_description" ,
47- "rec_datetime" ,
45+ "lab" , "session_description" , "rec_datetime" ,
4846)
4947
5048POSSIBLE_JSON_FIELDS = (
@@ -207,7 +205,7 @@ class NWBIO(BaseIO):
207205 is_writable = True
208206 is_streameable = False
209207
210- def __init__ (self , filename , mode = 'r' ):
208+ def __init__ (self , filename , mode = 'r' , ** annotations ):
211209 """
212210 Arguments:
213211 filename : the filename
@@ -218,6 +216,9 @@ def __init__(self, filename, mode='r'):
218216 self .filename = filename
219217 self .blocks_written = 0
220218 self .nwb_file_mode = mode
219+ self ._blocks = {}
220+ self .annotations = annotations
221+ self ._io_nwb = None
221222
222223 def read_all_blocks (self , lazy = False , ** kwargs ):
223224 """
@@ -383,17 +384,11 @@ def _read_acquisition_group(self, lazy):
383384 def _read_stimulus_group (self , lazy ):
384385 self ._read_timeseries_group ("stimulus" , lazy )
385386
386- def write_all_blocks (self , blocks , ** kwargs ):
387- """
388- Write list of blocks to the file
389- """
390- import pynwb
391-
392- # todo: allow metadata in NWBFile constructor to be taken from kwargs
387+ def _build_global_annotations (self , blocks ):
393388 annotations = defaultdict (set )
394389 for annotation_name in GLOBAL_ANNOTATIONS :
395- if annotation_name in kwargs :
396- annotations [annotation_name ] = kwargs [annotation_name ]
390+ if annotation_name in self . annotations :
391+ annotations [annotation_name ] = self . annotations [annotation_name ]
397392 else :
398393 for block in blocks :
399394 if annotation_name in block .annotations :
@@ -411,65 +406,86 @@ def write_all_blocks(self, blocks, **kwargs):
411406 "We don't yet support multiple values for {}" .format (annotation_name ))
412407 # take single value from set
413408 annotations [annotation_name ], = annotations [annotation_name ]
409+
414410 if "identifier" not in annotations :
415- annotations ["identifier" ] = self .filename
411+ annotations ["identifier" ] = str ( self .filename )
416412 if "session_description" not in annotations :
417- annotations ["session_description" ] = blocks [0 ].description or self .filename
413+ annotations ["session_description" ] = blocks [0 ].description or str ( self .filename )
418414 # todo: concatenate descriptions of multiple blocks if different
419- if "session_start_time" not in annotations :
420- annotations ["session_start_time" ] = blocks [0 ].rec_datetime
421- if annotations ["session_start_time" ] is None :
415+ if annotations .get ("session_start_time" , None ) is None :
416+ if "rec_datetime" in annotations :
417+ annotations ["session_start_time" ] = annotations ["rec_datetime" ]
418+ else :
422419 raise Exception ("Writing to NWB requires an annotation 'session_start_time'" )
423- self .annotations = {"rec_datetime" : "rec_datetime" }
424- self .annotations ["rec_datetime" ] = blocks [0 ].rec_datetime
425- # todo: handle subject
426- nwbfile = pynwb .NWBFile (** annotations )
427- assert self .nwb_file_mode in ('w' ,) # possibly expand to 'a'ppend later
428- if self .nwb_file_mode == "w" and os .path .exists (self .filename ):
429- os .remove (self .filename )
430- io_nwb = pynwb .NWBHDF5IO (self .filename , mode = self .nwb_file_mode )
420+ return annotations
421+
422+ def write_all_blocks (self , blocks , validate = True , ** kwargs ):
423+ """
424+ Write list of blocks to the file
425+ """
426+ import pynwb
427+
428+ global_annotations = self ._build_global_annotations (blocks )
429+ self ._nwbfile = pynwb .NWBFile (** global_annotations )
431430
432431 if sum (statistics (block )["SpikeTrain" ]["count" ] for block in blocks ) > 0 :
433- nwbfile .add_unit_column ('_name' , 'the name attribute of the SpikeTrain' )
432+ self . _nwbfile .add_unit_column ('_name' , 'the name attribute of the SpikeTrain' )
434433 # nwbfile.add_unit_column('_description',
435434 # 'the description attribute of the SpikeTrain')
436- nwbfile .add_unit_column (
435+ self . _nwbfile .add_unit_column (
437436 'segment' , 'the name of the Neo Segment to which the SpikeTrain belongs' )
438- nwbfile .add_unit_column (
437+ self . _nwbfile .add_unit_column (
439438 'block' , 'the name of the Neo Block to which the SpikeTrain belongs' )
440439
441440 if sum (statistics (block )["Epoch" ]["count" ] for block in blocks ) > 0 :
442- nwbfile .add_epoch_column ('_name' , 'the name attribute of the Epoch' )
441+ self . _nwbfile .add_epoch_column ('_name' , 'the name attribute of the Epoch' )
443442 # nwbfile.add_epoch_column('_description', 'the description attribute of the Epoch')
444- nwbfile .add_epoch_column (
443+ self . _nwbfile .add_epoch_column (
445444 'segment' , 'the name of the Neo Segment to which the Epoch belongs' )
446- nwbfile .add_epoch_column ('block' ,
445+ self . _nwbfile .add_epoch_column ('block' ,
447446 'the name of the Neo Block to which the Epoch belongs' )
448447
449448 for i , block in enumerate (blocks ):
450- self .write_block (nwbfile , block )
451- io_nwb .write (nwbfile )
449+ self ._write_block (block )
450+
451+ assert self .nwb_file_mode in ('w' ,) # possibly expand to 'a'ppend later
452+ if self .nwb_file_mode == "w" and os .path .exists (self .filename ):
453+ os .remove (self .filename )
454+ io_nwb = pynwb .NWBHDF5IO (self .filename , mode = self .nwb_file_mode )
455+ io_nwb .write (self ._nwbfile )
452456 io_nwb .close ()
453457
458+ if validate :
459+ self .validate_file ()
460+
461+ def validate_file (self ):
462+ import pynwb
463+
454464 with pynwb .NWBHDF5IO (self .filename , "r" ) as io_validate :
455465 errors = pynwb .validate (io_validate , namespace = "core" )
456466 if errors :
457467 raise Exception (f"Errors found when validating { self .filename } " )
458468
459- def write_block (self , nwbfile , block , ** kwargs ):
469+ def write_block (self , block , ** kwargs ):
470+ """
471+ Write a single Block to the file
472+ :param block: Block to be written
473+ """
474+ return self .write_all_blocks ([block ], ** kwargs )
475+
476+ def _write_block (self , block ):
460477 """
461478 Write a Block to the file
462479 :param block: Block to be written
463- :param nwbfile: Representation of an NWB file
464480 """
465- electrodes = self ._write_electrodes (nwbfile , block )
481+ electrodes = self ._write_electrodes (self . _nwbfile , block )
466482 if not block .name :
467483 block .name = "block%d" % self .blocks_written
468484 for i , segment in enumerate (block .segments ):
469485 assert segment .block is block
470486 if not segment .name :
471487 segment .name = "%s : segment%d" % (block .name , i )
472- self ._write_segment (nwbfile , segment , electrodes )
488+ self ._write_segment (self . _nwbfile , segment , electrodes )
473489 self .blocks_written += 1
474490
475491 def _write_electrodes (self , nwbfile , block ):
@@ -485,10 +501,10 @@ def _write_electrodes(self, nwbfile, block):
485501 if elec_meta ["device" ]["name" ] in devices :
486502 device = devices [elec_meta ["device" ]["name" ]]
487503 else :
488- device = nwbfile .create_device (** elec_meta ["device" ])
504+ device = self . _nwbfile .create_device (** elec_meta ["device" ])
489505 devices [elec_meta ["device" ]["name" ]] = device
490506 elec_meta .pop ("device" )
491- electrodes [elec_meta ["name" ]] = nwbfile .create_icephys_electrode (
507+ electrodes [elec_meta ["name" ]] = self . _nwbfile .create_icephys_electrode (
492508 device = device , ** elec_meta
493509 )
494510 return electrodes
@@ -503,13 +519,13 @@ def _write_segment(self, nwbfile, segment, electrodes):
503519 logging .warning ("Warning signal name exists. New name: %s" % (signal .name ))
504520 else :
505521 signal .name = "%s : analogsignal%s %i" % (segment .name , signal .name , i )
506- self ._write_signal (nwbfile , signal , electrodes )
522+ self ._write_signal (self . _nwbfile , signal , electrodes )
507523
508524 for i , train in enumerate (segment .spiketrains ):
509525 assert train .segment is segment
510526 if not train .name :
511527 train .name = "%s : spiketrain%d" % (segment .name , i )
512- self ._write_spiketrain (nwbfile , train )
528+ self ._write_spiketrain (self . _nwbfile , train )
513529
514530 for i , event in enumerate (segment .events ):
515531 assert event .segment is segment
@@ -518,12 +534,12 @@ def _write_segment(self, nwbfile, segment, electrodes):
518534 logging .warning ("Warning event name exists. New name: %s" % (event .name ))
519535 else :
520536 event .name = "%s : event%s %d" % (segment .name , event .name , i )
521- self ._write_event (nwbfile , event )
537+ self ._write_event (self . _nwbfile , event )
522538
523539 for i , epoch in enumerate (segment .epochs ):
524540 if not epoch .name :
525541 epoch .name = "%s : epoch%d" % (segment .name , i )
526- self ._write_epoch (nwbfile , epoch )
542+ self ._write_epoch (self . _nwbfile , epoch )
527543
528544 def _write_signal (self , nwbfile , signal , electrodes ):
529545 import pynwb
@@ -572,8 +588,8 @@ def _write_signal(self, nwbfile, signal, electrodes):
572588 signal .__class__ .__name__ ))
573589 nwb_group = signal .annotations .get ("nwb_group" , "acquisition" )
574590 add_method_map = {
575- "acquisition" : nwbfile .add_acquisition ,
576- "stimulus" : nwbfile .add_stimulus
591+ "acquisition" : self . _nwbfile .add_acquisition ,
592+ "stimulus" : self . _nwbfile .add_stimulus
577593 }
578594 if nwb_group in add_method_map :
579595 add_time_series = add_method_map [nwb_group ]
@@ -586,7 +602,8 @@ def _write_spiketrain(self, nwbfile, spiketrain):
586602 segment = spiketrain .segment
587603 if hasattr (spiketrain , 'proxy_for' ) and spiketrain .proxy_for is SpikeTrain :
588604 spiketrain = spiketrain .load ()
589- nwbfile .add_unit (spike_times = spiketrain .rescale ('s' ).magnitude ,
605+ self ._nwbfile .add_unit (
606+ spike_times = spiketrain .rescale ('s' ).magnitude ,
590607 obs_intervals = [[float (spiketrain .t_start .rescale ('s' )),
591608 float (spiketrain .t_stop .rescale ('s' ))]],
592609 _name = spiketrain .name ,
@@ -596,7 +613,7 @@ def _write_spiketrain(self, nwbfile, spiketrain):
596613 # todo: handle annotations (using add_unit_column()?)
597614 # todo: handle Neo Units
598615 # todo: handle spike waveforms, if any (see SpikeEventSeries)
599- return nwbfile .units
616+ return self . _nwbfile .units
600617
601618 def _write_event (self , nwbfile , event ):
602619 import pynwb
@@ -611,7 +628,7 @@ def _write_event(self, nwbfile, event):
611628 timestamps = event .times .rescale ('second' ).magnitude ,
612629 description = event .description or "" ,
613630 comments = json .dumps (hierarchy ))
614- nwbfile .add_acquisition (tS_evt )
631+ self . _nwbfile .add_acquisition (tS_evt )
615632 return tS_evt
616633
617634 def _write_epoch (self , nwbfile , epoch ):
@@ -621,11 +638,11 @@ def _write_epoch(self, nwbfile, epoch):
621638 for t_start , duration , label in zip (epoch .rescale ('s' ).magnitude ,
622639 epoch .durations .rescale ('s' ).magnitude ,
623640 epoch .labels ):
624- nwbfile .add_epoch (t_start , t_start + duration , [label ], [],
641+ self . _nwbfile .add_epoch (t_start , t_start + duration , [label ], [],
625642 _name = epoch .name ,
626643 segment = segment .name ,
627644 block = segment .block .name )
628- return nwbfile .epochs
645+ return self . _nwbfile .epochs
629646
630647
631648class AnalogSignalProxy (BaseAnalogSignalProxy ):
0 commit comments