Skip to content

Commit 13d3fe3

Browse files
Enable NWBIO to write lazy data objects (#1130)
* Enable NWBIO to write lazy data objects * pep8
1 parent ffde176 commit 13d3fe3

File tree

2 files changed

+76
-5
lines changed

2 files changed

+76
-5
lines changed

neo/io/nwbio.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,9 @@ def _write_signal(self, nwbfile, signal, electrodes):
561561
additional_metadata["conversion"] = conversion
562562
else:
563563
units = signal.units
564+
if hasattr(signal, 'proxy_for') and signal.proxy_for in [AnalogSignal,
565+
IrregularlySampledSignal]:
566+
signal = signal.load()
564567
if isinstance(signal, AnalogSignal):
565568
sampling_rate = signal.sampling_rate.rescale("Hz")
566569
tS = timeseries_class(
@@ -597,20 +600,26 @@ def _write_signal(self, nwbfile, signal, electrodes):
597600
return tS
598601

599602
def _write_spiketrain(self, nwbfile, spiketrain):
603+
segment = spiketrain.segment
604+
if hasattr(spiketrain, 'proxy_for') and spiketrain.proxy_for is SpikeTrain:
605+
spiketrain = spiketrain.load()
600606
nwbfile.add_unit(spike_times=spiketrain.rescale('s').magnitude,
601607
obs_intervals=[[float(spiketrain.t_start.rescale('s')),
602608
float(spiketrain.t_stop.rescale('s'))]],
603609
_name=spiketrain.name,
604610
# _description=spiketrain.description,
605-
segment=spiketrain.segment.name,
606-
block=spiketrain.segment.block.name)
611+
segment=segment.name,
612+
block=segment.block.name)
607613
# todo: handle annotations (using add_unit_column()?)
608614
# todo: handle Neo Units
609615
# todo: handle spike waveforms, if any (see SpikeEventSeries)
610616
return nwbfile.units
611617

612618
def _write_event(self, nwbfile, event):
613-
hierarchy = {'block': event.segment.block.name, 'segment': event.segment.name}
619+
segment = event.segment
620+
if hasattr(event, 'proxy_for') and event.proxy_for == Event:
621+
event = event.load()
622+
hierarchy = {'block': segment.block.name, 'segment': segment.name}
614623
tS_evt = AnnotationSeries(
615624
name=event.name,
616625
data=event.labels,
@@ -621,13 +630,16 @@ def _write_event(self, nwbfile, event):
621630
return tS_evt
622631

623632
def _write_epoch(self, nwbfile, epoch):
633+
segment = epoch.segment
634+
if hasattr(epoch, 'proxy_for') and epoch.proxy_for == Epoch:
635+
epoch = epoch.load()
624636
for t_start, duration, label in zip(epoch.rescale('s').magnitude,
625637
epoch.durations.rescale('s').magnitude,
626638
epoch.labels):
627639
nwbfile.add_epoch(t_start, t_start + duration, [label], [],
628640
_name=epoch.name,
629-
segment=epoch.segment.name,
630-
block=epoch.segment.block.name)
641+
segment=segment.name,
642+
block=segment.block.name)
631643
return nwbfile.epochs
632644

633645

neo/test/iotest/test_nwbio.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from neo.core import AnalogSignal, SpikeTrain, Event, Epoch, IrregularlySampledSignal, Segment, \
1818
Block
1919

20+
from neo.rawio.examplerawio import ExampleRawIO
21+
from neo.io.proxyobjects import (AnalogSignalProxy, SpikeTrainProxy, EventProxy, EpochProxy)
22+
2023
try:
2124
import pynwb
2225
from neo.io.nwbio import NWBIO
@@ -250,6 +253,62 @@ def test_roundtrip_with_annotations(self):
250253

251254
os.remove(test_file_name)
252255

256+
def test_write_proxy_objects(self):
257+
test_file_name = self.local_test_dir / "test_round_trip_with_annotations.nwb"
258+
259+
# generate dummy IO as basis for ProxyObjects
260+
self.proxy_reader = ExampleRawIO(filename='my_filename.fake')
261+
self.proxy_reader.parse_header()
262+
263+
# generate test structure with proxy objects
264+
original_block = Block(name='myblock', session_start_time=datetime.now().astimezone(),
265+
session_description=str(test_file_name),
266+
identifier=str(test_file_name))
267+
seg = Segment(name='mysegment')
268+
original_block.segments.append(seg)
269+
270+
# create proxy objects
271+
proxy_anasig = AnalogSignalProxy(rawio=self.proxy_reader, stream_index=0,
272+
inner_stream_channels=None, block_index=0, seg_index=0,)
273+
proxy_anasig.segment = seg
274+
seg.analogsignals.append(proxy_anasig)
275+
276+
proxy_sptr = SpikeTrainProxy(rawio=self.proxy_reader, spike_channel_index=0, block_index=0,
277+
seg_index=0)
278+
proxy_sptr.segment = seg
279+
seg.spiketrains.append(proxy_sptr)
280+
281+
proxy_event = EventProxy(rawio=self.proxy_reader, event_channel_index=0, block_index=0,
282+
seg_index=0)
283+
proxy_event.segment = seg
284+
seg.events.append(proxy_event)
285+
286+
proxy_epoch = EpochProxy(rawio=self.proxy_reader, event_channel_index=1, block_index=0,
287+
seg_index=0)
288+
proxy_epoch.segment = seg
289+
seg.epochs.append(proxy_epoch)
290+
291+
original_block.create_relationship()
292+
293+
iow = NWBIO(filename=test_file_name, mode='w')
294+
295+
# writing data via proxyobjects
296+
iow.write_all_blocks([original_block])
297+
298+
# checking written data
299+
ior = NWBIO(filename=test_file_name, mode='r')
300+
retrieved_block = ior.read_all_blocks()[0]
301+
302+
for original_segment, retrieved_segment in zip(original_block.segments,
303+
retrieved_block.segments):
304+
assert_array_equal(original_segment.analogsignals[0].load().magnitude,
305+
retrieved_segment.analogsignals[0].magnitude)
306+
assert_array_equal(original_segment.spiketrains[0].load().magnitude,
307+
retrieved_segment.spiketrains[0].magnitude)
308+
assert_array_equal(original_segment.events[0].load().magnitude,
309+
retrieved_segment.events[0].magnitude)
310+
assert_array_equal(original_segment.epochs[0].load().magnitude,
311+
retrieved_segment.epochs[0].magnitude)
253312

254313
if __name__ == "__main__":
255314
if HAVE_PYNWB:

0 commit comments

Comments
 (0)