Skip to content

Commit 471e3ad

Browse files
authored
Merge pull request #1714 from dandi/add-session-duration
feat: add session end time extraction from NWB files
2 parents 129e521 + 5793656 commit 471e3ad

File tree

3 files changed

+345
-0
lines changed

3 files changed

+345
-0
lines changed

dandi/metadata/util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ def extract_session(metadata: dict) -> list[models.Session] | None:
588588
name=session_id or "Acquisition session",
589589
description=metadata.get("session_description"),
590590
startDate=metadata.get("session_start_time"),
591+
endDate=metadata.get("session_end_time"),
591592
used=probes,
592593
)
593594
]

dandi/pynwb_utils.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from collections import Counter
44
from collections.abc import Callable
5+
from datetime import timedelta
56
import inspect
67
import os
78
import os.path as op
@@ -14,6 +15,7 @@
1415
from fscacher import PersistentCache
1516
import h5py
1617
import hdmf
18+
import numpy as np
1719
from packaging.version import Version
1820
import pynwb
1921
from pynwb import NWBHDF5IO
@@ -262,9 +264,115 @@ def _get_pynwb_metadata(path: str | Path | Readable) -> dict[str, Any]:
262264
# get external_file data:
263265
out["external_file_objects"] = _get_image_series(nwb)
264266

267+
# Calculate session duration for metadata
268+
session_duration = _get_session_duration(nwb)
269+
if session_duration is not None and out.get("session_start_time") is not None:
270+
# Convert to absolute datetime by adding duration to session_start_time
271+
start_time = out["session_start_time"]
272+
out["session_end_time"] = start_time + timedelta(seconds=session_duration)
273+
265274
return out
266275

267276

277+
def _get_session_duration(nwb: pynwb.NWBFile) -> float | None:
278+
"""Calculate the duration of a recording session from NWB file contents.
279+
280+
This function finds the minimum and maximum timestamps across all TimeSeries
281+
and DynamicTable objects with time information, then returns the duration as
282+
max - min.
283+
284+
Parameters
285+
----------
286+
nwb: pynwb.NWBFile
287+
An open NWB file object
288+
289+
Returns
290+
-------
291+
float or None
292+
The session duration in seconds (max_time - min_time),
293+
or None if no time information could be extracted
294+
"""
295+
start_times: list[float] = []
296+
end_times: list[float] = []
297+
298+
# Iterate through all objects in the NWB file
299+
for obj in nwb.objects.values():
300+
# Handle TimeSeries objects
301+
if isinstance(obj, pynwb.base.TimeSeries):
302+
if obj.timestamps is not None and len(obj.timestamps) > 0:
303+
# Use first and last timestamps
304+
start_times.append(float(obj.timestamps[0]))
305+
end_times.append(float(obj.timestamps[-1]))
306+
elif (
307+
obj.starting_time is not None
308+
and obj.rate is not None
309+
and obj.data is not None
310+
):
311+
# Calculate start and end time
312+
start_times.append(float(obj.starting_time))
313+
num_samples = len(obj.data)
314+
if obj.rate == 0:
315+
continue
316+
end_times.append(float(obj.starting_time + (num_samples / obj.rate)))
317+
318+
# Handle DynamicTable objects with time columns
319+
elif isinstance(obj, hdmf.common.DynamicTable):
320+
# Handle start_time and stop_time columns (e.g., trials)
321+
if "start_time" in obj.colnames and len(obj["start_time"]):
322+
start_times.append(float(obj["start_time"][0]))
323+
if "stop_time" in obj.colnames and len(obj["stop_time"]):
324+
end_times.append(float(obj["stop_time"][-1]))
325+
326+
# Handle spike_times column (e.g., Units table)
327+
# Assume spike times are ordered within each unit
328+
# Read only the first and last spike time from each unit
329+
if "spike_times" in obj.colnames and len(obj["spike_times"]):
330+
idxs = obj["spike_times"].data[:]
331+
332+
# handle bug if the first unit has no spikes
333+
if idxs[0] == 0:
334+
idxs = idxs[1:]
335+
336+
st_data = obj["spike_times"].target
337+
338+
if len(idxs) > 1:
339+
start = float(np.min(np.r_[st_data[0], st_data[idxs[:-1]]]))
340+
else:
341+
start = float(st_data[0])
342+
343+
end = float(np.max(st_data[idxs - 1]))
344+
start_times.append(float(start))
345+
end_times.append(float(end))
346+
347+
# Handle timestamp column (e.g., EventsTable)
348+
if "timestamp" in obj.colnames and len(obj["timestamp"]):
349+
timestamp_data = obj["timestamp"]
350+
start_times.append(float(timestamp_data[0]))
351+
# Check if duration column exists to calculate end times
352+
if "duration" in obj.colnames:
353+
duration_data = obj["duration"]
354+
end_times.append(float(timestamp_data[-1] + duration_data[-1]))
355+
else:
356+
# No duration, use max timestamp as end
357+
end_times.append(float(timestamp_data[-1]))
358+
359+
# Return duration as max - min
360+
if start_times and end_times:
361+
duration = max(end_times) - min(start_times)
362+
if (
363+
duration < 3600 * 24 * 365 * 5
364+
): # if duration is over 5 years, something went wrong
365+
return duration
366+
else:
367+
lgr.warning(
368+
"Session duration of %.2f seconds (%.2f years) exceeds 5-year limit; "
369+
"returning None as this likely indicates an error in timestamps",
370+
duration,
371+
duration / (3600 * 24 * 365),
372+
)
373+
return None
374+
375+
268376
def _get_image_series(nwb: pynwb.NWBFile) -> list[dict]:
269377
"""Retrieves all ImageSeries related metadata from an open nwb file.
270378

dandi/tests/test_metadata.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@
2828
)
2929
from dandischema.models import Dandiset as DandisetMeta
3030
from dateutil.tz import tzutc
31+
from hdmf.common import DynamicTable
32+
import numpy as np
3133
from pydantic import ByteSize
34+
from pynwb import NWBHDF5IO, NWBFile, TimeSeries
3235
import pytest
3336
import requests
3437
from semantic_version import Version
@@ -471,6 +474,239 @@ def test_time_extract_gest() -> None:
471474
)
472475

473476

477+
@pytest.mark.ai_generated
478+
def test_session_duration_extraction(tmp_path: Path) -> None:
479+
"""Test that session duration is extracted and included in Session activity"""
480+
# Create a test NWB file with TimeSeries data
481+
nwb_path = tmp_path / "test_duration.nwb"
482+
session_start = datetime(2020, 1, 1, 12, 0, 0, tzinfo=tzutc())
483+
484+
nwbfile = NWBFile(
485+
session_description="test session for duration",
486+
identifier="test_duration_123",
487+
session_start_time=session_start,
488+
)
489+
490+
# Add a TimeSeries that spans 100 seconds (timestamps from 0 to 100)
491+
data = np.random.rand(1000)
492+
timestamps = np.linspace(0, 100, 1000)
493+
ts1 = TimeSeries(name="timeseries1", data=data, unit="volts", timestamps=timestamps)
494+
nwbfile.add_acquisition(ts1)
495+
496+
# Add another TimeSeries using starting_time and rate
497+
# This one goes from 50s to 150s (100 samples at 1 Hz)
498+
data2 = np.random.rand(100)
499+
ts2 = TimeSeries(
500+
name="timeseries2", data=data2, unit="volts", starting_time=50.0, rate=1.0
501+
)
502+
nwbfile.add_acquisition(ts2)
503+
504+
# Write the file
505+
with NWBHDF5IO(str(nwb_path), "w") as io:
506+
io.write(nwbfile)
507+
508+
# Extract metadata
509+
from ..metadata.nwb import get_metadata, nwb2asset
510+
511+
metadata = get_metadata(nwb_path)
512+
513+
# Check that session_end_time was calculated
514+
assert "session_start_time" in metadata
515+
assert "session_end_time" in metadata
516+
517+
# Calculate duration - should be 150 seconds (max) - 0 seconds (min)
518+
duration = (
519+
metadata["session_end_time"] - metadata["session_start_time"]
520+
).total_seconds()
521+
assert abs(duration - 150.0) < 1.0 # Allow small floating point errors
522+
523+
# Check that Session activity includes endDate
524+
asset = nwb2asset(nwb_path, digest=DUMMY_DANDI_ETAG)
525+
assert asset.wasGeneratedBy is not None
526+
527+
# Find Session activities
528+
sessions = [act for act in asset.wasGeneratedBy if act.schemaKey == "Session"]
529+
assert len(sessions) > 0
530+
531+
session = sessions[0]
532+
assert session.startDate is not None
533+
assert session.endDate is not None
534+
assert session.startDate == metadata["session_start_time"]
535+
assert session.endDate == metadata["session_end_time"]
536+
537+
538+
@pytest.mark.ai_generated
539+
def test_session_duration_with_trials(tmp_path: Path) -> None:
540+
"""Test that session duration includes trials table timestamps"""
541+
# Create a test NWB file with trials
542+
nwb_path = tmp_path / "test_duration_trials.nwb"
543+
session_start = datetime(2020, 1, 1, 12, 0, 0, tzinfo=tzutc())
544+
545+
nwbfile = NWBFile(
546+
session_description="test session with trials",
547+
identifier="test_trials_123",
548+
session_start_time=session_start,
549+
)
550+
551+
# Add a TimeSeries that spans from 10 to 50 seconds
552+
data = np.random.rand(400)
553+
timestamps = np.linspace(10, 50, 400)
554+
ts = TimeSeries(name="timeseries1", data=data, unit="volts", timestamps=timestamps)
555+
nwbfile.add_acquisition(ts)
556+
557+
# Add trials that extend the session to 200 seconds
558+
nwbfile.add_trial_column(
559+
name="correct", description="whether the trial was correct"
560+
)
561+
nwbfile.add_trial(start_time=5.0, stop_time=15.0, correct=True)
562+
nwbfile.add_trial(start_time=20.0, stop_time=30.0, correct=False)
563+
nwbfile.add_trial(start_time=100.0, stop_time=200.0, correct=True)
564+
565+
# Write the file
566+
with NWBHDF5IO(str(nwb_path), "w") as io:
567+
io.write(nwbfile)
568+
569+
# Extract metadata
570+
from ..metadata.nwb import get_metadata, nwb2asset
571+
572+
metadata = get_metadata(nwb_path)
573+
574+
# Check that session_end_time was calculated
575+
assert "session_start_time" in metadata
576+
assert "session_end_time" in metadata
577+
578+
# Calculate duration - should be 200 (max from trials) - 5 (min from trials) = 195 seconds
579+
duration = (
580+
metadata["session_end_time"] - metadata["session_start_time"]
581+
).total_seconds()
582+
assert abs(duration - 195.0) < 1.0 # Allow small floating point errors
583+
584+
# Check that Session activity includes endDate
585+
asset = nwb2asset(nwb_path, digest=DUMMY_DANDI_ETAG)
586+
assert asset.wasGeneratedBy is not None
587+
588+
# Find Session activities
589+
sessions = [act for act in asset.wasGeneratedBy if act.schemaKey == "Session"]
590+
assert len(sessions) > 0
591+
592+
session = sessions[0]
593+
assert session.startDate is not None
594+
assert session.endDate is not None
595+
assert session.startDate == metadata["session_start_time"]
596+
assert session.endDate == metadata["session_end_time"]
597+
598+
599+
@pytest.mark.ai_generated
600+
def test_session_duration_with_units(tmp_path: Path) -> None:
601+
"""Test that session duration includes spike_times from Units table"""
602+
# Create a test NWB file with Units table
603+
nwb_path = tmp_path / "test_duration_units.nwb"
604+
session_start = datetime(2020, 1, 1, 12, 0, 0, tzinfo=tzutc())
605+
606+
nwbfile = NWBFile(
607+
session_description="test session with units",
608+
identifier="test_units_123",
609+
session_start_time=session_start,
610+
)
611+
612+
# Add a simple TimeSeries that spans from 10 to 30 seconds
613+
data = np.random.rand(200)
614+
timestamps = np.linspace(10, 30, 200)
615+
ts = TimeSeries(name="timeseries1", data=data, unit="volts", timestamps=timestamps)
616+
nwbfile.add_acquisition(ts)
617+
618+
# Add Units with spike_times that extend session to 250 seconds
619+
# Unit 1: spikes from 5s to 100s
620+
# Unit 2: spikes from 50s to 250s
621+
nwbfile.add_unit(spike_times=np.array([5.0, 10.0, 20.0, 50.0, 100.0]))
622+
nwbfile.add_unit(spike_times=np.array([50.0, 100.0, 150.0, 200.0, 250.0]))
623+
624+
# Write the file
625+
with NWBHDF5IO(str(nwb_path), "w") as io:
626+
io.write(nwbfile)
627+
628+
# Extract metadata
629+
from ..metadata.nwb import get_metadata
630+
631+
metadata = get_metadata(nwb_path)
632+
633+
# Check that session_end_time was calculated
634+
assert "session_start_time" in metadata
635+
assert "session_end_time" in metadata
636+
637+
# Duration should be 250 (max spike) - 5 (min spike) = 245 seconds
638+
duration = (
639+
metadata["session_end_time"] - metadata["session_start_time"]
640+
).total_seconds()
641+
assert abs(duration - 245.0) < 1.0 # Allow small floating point errors
642+
643+
644+
@pytest.mark.ai_generated
645+
def test_session_duration_with_events(tmp_path: Path) -> None:
646+
"""Test that session duration includes timestamp/duration from DynamicTable"""
647+
# Create a test NWB file with a DynamicTable containing timestamp and duration
648+
nwb_path = tmp_path / "test_duration_events.nwb"
649+
session_start = datetime(2020, 1, 1, 12, 0, 0, tzinfo=tzutc())
650+
651+
nwbfile = NWBFile(
652+
session_description="test session with events",
653+
identifier="test_events_123",
654+
session_start_time=session_start,
655+
)
656+
657+
# Add a simple TimeSeries that spans from 5 to 20 seconds
658+
data = np.random.rand(150)
659+
timestamps = np.linspace(5, 20, 150)
660+
ts = TimeSeries(name="timeseries1", data=data, unit="volts", timestamps=timestamps)
661+
nwbfile.add_acquisition(ts)
662+
663+
# Create a DynamicTable with timestamp and duration columns (similar to EventsTable)
664+
665+
events_table = DynamicTable(
666+
name="events",
667+
description="test events with timestamps and durations",
668+
)
669+
events_table.add_column(
670+
name="timestamp",
671+
description="event timestamps",
672+
)
673+
events_table.add_column(
674+
name="duration",
675+
description="event durations",
676+
)
677+
678+
# Add events: event at 3s lasting 2s (ends at 5s)
679+
# event at 100s lasting 80s (ends at 180s)
680+
events_table.add_row(timestamp=3.0, duration=2.0)
681+
events_table.add_row(timestamp=100.0, duration=30.0)
682+
events_table.add_row(timestamp=150.0, duration=10.0)
683+
684+
# Add the table to a processing module
685+
processing_module = nwbfile.create_processing_module(
686+
name="behavior", description="behavioral data"
687+
)
688+
processing_module.add(events_table)
689+
690+
# Write the file
691+
with NWBHDF5IO(str(nwb_path), "w") as io:
692+
io.write(nwbfile)
693+
694+
# Extract metadata
695+
from ..metadata.nwb import get_metadata
696+
697+
metadata = get_metadata(nwb_path)
698+
699+
# Check that session_end_time was calculated
700+
assert "session_start_time" in metadata
701+
assert "session_end_time" in metadata
702+
703+
# Duration should be 180 (100 + 80, max end) - 3 (min timestamp) = 177 seconds
704+
duration = (
705+
metadata["session_end_time"] - metadata["session_start_time"]
706+
).total_seconds()
707+
assert abs(duration - 157.0) < 1.0 # Allow small floating point errors
708+
709+
474710
@mark_xfail_ontobee
475711
@mark.skipif_no_network
476712
@pytest.mark.obolibrary

0 commit comments

Comments
 (0)