Skip to content

Commit fed439d

Browse files
authored
feat: DataStream.__add__ and Acquistion._merge_data_stream_lists (#1634)
* feat: DataStream.__add__ and Acquistion._merge_data_stream_lists * test: merging now merges data streams! * tests: coverage for un-mergeable streams * chore: overzealous codespell * refactor: allow configurable overlap range, defaults to 2 minutes * chore: lint * refactor: merge_streams instead of two different lists * tests: match new _merge_data_streams input type
1 parent 003150c commit fed439d

File tree

4 files changed

+243
-3
lines changed

4 files changed

+243
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,4 @@ fail-under = 100
103103

104104
[tool.codespell]
105105
skip = '.git,*.pdf,*.svg'
106-
ignore-words-list = 'nd,assertIn'
106+
ignore-words-list = 'nd,assertIn,DeviceC'

src/aind_data_schema/core/acquisition.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Schema describing data acquisition metadata and configurations"""
22

33
from decimal import Decimal
4+
import logging
45
from typing import Annotated, List, Literal, Optional
56

67
from aind_data_schema_models.modalities import Modality
@@ -188,6 +189,54 @@ def check_connections(self):
188189

189190
return self
190191

192+
@classmethod
193+
def overlapping(cls, stream1: "DataStream", stream2: "DataStream", overlap_s: int) -> bool:
194+
"""Check if two DataStream objects have overlapping start and end times"""
195+
start_diff = abs((stream1.stream_start_time - stream2.stream_start_time).total_seconds())
196+
end_diff = abs((stream1.stream_end_time - stream2.stream_end_time).total_seconds())
197+
return start_diff <= overlap_s and end_diff <= overlap_s
198+
199+
def __add__(self, other: "DataStream", overlap_s: int = 120) -> "DataStream":
200+
"""Combine two DataStream objects"""
201+
202+
if not DataStream.overlapping(self, other, overlap_s=overlap_s):
203+
raise ValueError("Cannot combine DataStreams with non-overlapping start and end times.")
204+
205+
min_start_time = min(self.stream_start_time, other.stream_start_time)
206+
max_end_time = max(self.stream_end_time, other.stream_end_time)
207+
208+
# Combine modalities
209+
modalities = self.modalities + other.modalities
210+
modalities = remove_duplicates(modalities)
211+
212+
# Combine active devices
213+
active_devices = self.active_devices + other.active_devices
214+
len_orig_devices = len(active_devices)
215+
active_devices = remove_duplicates(active_devices)
216+
if len(active_devices) < len_orig_devices:
217+
logging.warning(
218+
"Duplicate active devices were removed. Only DAQ devices should be shared in overlapped " "DataStreams."
219+
)
220+
221+
# Combine configurations
222+
configurations = self.configurations + other.configurations
223+
224+
# Combine connections
225+
connections = self.connections + other.connections
226+
227+
# Combine notes
228+
notes = merge_notes(self.notes, other.notes)
229+
230+
return DataStream(
231+
stream_start_time=min_start_time,
232+
stream_end_time=max_end_time,
233+
modalities=modalities,
234+
active_devices=active_devices,
235+
configurations=configurations,
236+
connections=connections,
237+
notes=notes,
238+
)
239+
191240

192241
class StimulusEpoch(DataModel):
193242
"""All stimuli being presented to the subject. starting and stopping at approximately the
@@ -361,6 +410,37 @@ def specimen_required(self):
361410

362411
return self
363412

413+
@classmethod
414+
def _merge_data_streams(
415+
cls, streams: List[DataStream], overlap_s: int = 120
416+
) -> List[DataStream]:
417+
"""Merge two lists of data streams"""
418+
groups = []
419+
visited = set()
420+
for i in range(len(streams)):
421+
if i in visited:
422+
continue
423+
group = [streams[i]]
424+
visited.add(i)
425+
for j in range(i + 1, len(streams)):
426+
if j not in visited and DataStream.overlapping(streams[i], streams[j], overlap_s=overlap_s):
427+
group.append(streams[j])
428+
visited.add(j)
429+
groups.append(group)
430+
431+
# Construct the final set of streams, including merged streams where applicable
432+
merged_streams = []
433+
for group in groups:
434+
if len(group) == 1:
435+
merged_streams.append(group[0])
436+
else:
437+
merged_stream = group[0]
438+
for stream in group[1:]:
439+
merged_stream = merged_stream + stream
440+
merged_streams.append(merged_stream)
441+
442+
return merged_streams
443+
364444
def __add__(self, other: "Acquisition") -> "Acquisition":
365445
"""Combine two Acquisition objects"""
366446

@@ -400,7 +480,7 @@ def __add__(self, other: "Acquisition") -> "Acquisition":
400480
ethics_review_id = merge_optional_list(self.ethics_review_id, other.ethics_review_id)
401481
calibrations = self.calibrations + other.calibrations
402482
maintenance = self.maintenance + other.maintenance
403-
data_streams = self.data_streams + other.data_streams
483+
data_streams = Acquisition._merge_data_streams(self.data_streams + other.data_streams)
404484
stimulus_epochs = self.stimulus_epochs + other.stimulus_epochs
405485

406486
# Remove duplicates

tests/test_acquisition.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,166 @@ def extract_union_types(annotation):
392392
# If we get here, all device config subclasses are properly covered
393393
self.assertTrue(True, "All DeviceConfig subclasses are properly covered in discriminated unions")
394394

395+
def test_datastream_add_basic(self):
396+
"""Test combining two DataStream objects"""
397+
acq = ephys_acquisition.model_copy()
398+
399+
stream1 = acq.data_streams[0].model_copy()
400+
stream2 = acq.data_streams[0].model_copy()
401+
402+
stream1.active_devices = ["Device1", "Device2"]
403+
stream2.active_devices = ["Device3", "Device4"]
404+
405+
combined_stream = stream1 + stream2
406+
407+
self.assertIsNotNone(combined_stream)
408+
self.assertEqual(combined_stream.stream_start_time, stream1.stream_start_time)
409+
self.assertEqual(combined_stream.stream_end_time, stream1.stream_end_time)
410+
self.assertEqual(len(combined_stream.modalities), 1)
411+
self.assertEqual(len(combined_stream.active_devices), 4)
412+
self.assertEqual(len(combined_stream.configurations), 4)
413+
414+
# Also check that an error is raised if the streams cannot be combined
415+
416+
stream2.stream_end_time = stream2.stream_end_time.replace(year=2100)
417+
418+
with self.assertRaises(ValueError):
419+
_ = stream1 + stream2
420+
421+
def test_datastream_add_combines_notes(self):
422+
"""Test that notes are properly merged when combining DataStreams"""
423+
acq = ephys_acquisition.model_copy()
424+
425+
stream1 = acq.data_streams[0].model_copy()
426+
stream2 = acq.data_streams[0].model_copy()
427+
428+
stream1.active_devices = ["Device1", "Device2"]
429+
stream2.active_devices = ["Device3", "Device4"]
430+
stream1.notes = "Note 1"
431+
stream2.notes = "Note 2"
432+
433+
combined_stream = stream1 + stream2
434+
435+
self.assertIn("Note 1", combined_stream.notes)
436+
self.assertIn("Note 2", combined_stream.notes)
437+
438+
def test_datastream_add_with_duplicate_devices(self):
439+
"""Test that overlapping active devices are logged as warning when combining"""
440+
acq = ephys_acquisition.model_copy()
441+
stream1 = acq.data_streams[0]
442+
stream2 = acq.data_streams[0].model_copy()
443+
444+
combined_stream = stream1 + stream2
445+
446+
self.assertIsNotNone(combined_stream)
447+
448+
def test_datastream_add_combines_connections(self):
449+
"""Test that connections are properly combined"""
450+
acq = ephys_acquisition.model_copy()
451+
452+
stream1 = acq.data_streams[0].model_copy()
453+
stream2 = acq.data_streams[0].model_copy()
454+
455+
stream1.active_devices = ["Device1", "Device2"]
456+
stream2.active_devices = ["Device3", "Device4"]
457+
stream1.connections = [Connection(source_device="Device1", target_device="Device2")]
458+
stream2.connections = [Connection(source_device="Device3", target_device="Device4")]
459+
460+
combined_stream = stream1 + stream2
461+
462+
self.assertEqual(len(combined_stream.connections), 2)
463+
464+
def test_merge_data_stream_lists_single_streams(self):
465+
"""Test merging lists with single streams"""
466+
acq = ephys_acquisition.model_copy()
467+
stream1 = acq.data_streams[0].model_copy()
468+
stream2 = acq.data_streams[1].model_copy()
469+
470+
stream1.active_devices = ["Device1"]
471+
stream2.active_devices = ["Device2"]
472+
473+
merged = Acquisition._merge_data_streams([stream1] + [stream2])
474+
475+
self.assertEqual(len(merged), 2)
476+
477+
def test_merge_data_stream_lists_overlapping_streams(self):
478+
"""Test merging streams with overlapping start/end times"""
479+
acq = ephys_acquisition.model_copy()
480+
481+
stream1 = acq.data_streams[0].model_copy()
482+
stream2 = acq.data_streams[0].model_copy()
483+
484+
stream1.active_devices = ["Device1", "Device2"]
485+
stream2.active_devices = ["Device3", "Device4"]
486+
487+
merged = Acquisition._merge_data_streams([stream1] + [stream2])
488+
489+
self.assertEqual(len(merged), 1)
490+
self.assertEqual(len(merged[0].active_devices), 4)
491+
492+
def test_merge_data_stream_lists_non_overlapping_streams(self):
493+
"""Test merging streams with different start/end times"""
494+
acq = ephys_acquisition.model_copy()
495+
496+
stream1 = acq.data_streams[1].model_copy()
497+
stream2 = acq.data_streams[0].model_copy()
498+
499+
stream1.active_devices = ["Device1"]
500+
stream2.active_devices = ["Device2"]
501+
502+
merged = Acquisition._merge_data_streams([stream1] + [stream2])
503+
504+
self.assertEqual(len(merged), 2)
505+
506+
def test_merge_data_stream_lists_multiple_overlapping_groups(self):
507+
"""Test merging multiple streams with multiple overlapping groups"""
508+
acq = ephys_acquisition.model_copy()
509+
510+
stream1 = acq.data_streams[0].model_copy()
511+
stream2 = acq.data_streams[0].model_copy()
512+
stream3 = acq.data_streams[0].model_copy()
513+
514+
stream1.active_devices = ["DeviceA", "DeviceB"]
515+
stream2.active_devices = ["DeviceC"]
516+
stream3.active_devices = ["DeviceD", "DeviceE"]
517+
518+
start1 = datetime(year=2023, month=4, day=25, hour=2, minute=0, second=0, tzinfo=timezone.utc)
519+
end1 = datetime(year=2023, month=4, day=25, hour=2, minute=30, second=0, tzinfo=timezone.utc)
520+
stream1.stream_start_time = start1
521+
stream1.stream_end_time = end1
522+
523+
start2 = datetime(year=2023, month=4, day=25, hour=2, minute=0, second=30, tzinfo=timezone.utc)
524+
end2 = datetime(year=2023, month=4, day=25, hour=2, minute=29, second=30, tzinfo=timezone.utc)
525+
stream2.stream_start_time = start2
526+
stream2.stream_end_time = end2
527+
528+
start3 = datetime(year=2023, month=4, day=25, hour=3, minute=0, second=0, tzinfo=timezone.utc)
529+
end3 = datetime(year=2023, month=4, day=25, hour=3, minute=30, second=0, tzinfo=timezone.utc)
530+
stream3.stream_start_time = start3
531+
stream3.stream_end_time = end3
532+
533+
merged = Acquisition._merge_data_streams([stream1] + [stream2, stream3])
534+
535+
for m in merged:
536+
print(m.stream_start_time, m.stream_end_time, m.active_devices)
537+
538+
self.assertEqual(len(merged), 2)
539+
self.assertEqual(len(merged[0].active_devices), 3)
540+
541+
def test_datastream_add_with_exaspim_example(self):
542+
"""Test combining DataStreams using ExaSPIM example"""
543+
acq = exaspim_acquisition.model_copy()
544+
stream1 = acq.data_streams[0].model_copy()
545+
546+
stream2 = acq.data_streams[0].model_copy()
547+
stream2.active_devices = ["Device99"]
548+
549+
combined_stream = stream1 + stream2
550+
551+
self.assertIsNotNone(combined_stream)
552+
self.assertIn(Modality.SPIM, combined_stream.modalities)
553+
self.assertEqual(len(combined_stream.active_devices), 3)
554+
395555

396556
if __name__ == "__main__":
397557
unittest.main()

tests/test_composability_merge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def test_merge_acquisition(self):
143143
self.assertEqual(len(merged_acq.experimenters), 1)
144144
self.assertEqual(len(merged_acq.maintenance), 2)
145145
self.assertEqual(len(merged_acq.calibrations), 2)
146-
self.assertEqual(len(merged_acq.data_streams), 2)
146+
self.assertEqual(len(merged_acq.data_streams), 1)
147147
self.assertEqual(merged_acq.acquisition_start_time, t)
148148
self.assertEqual(merged_acq.acquisition_end_time, t)
149149
self.assertEqual(merged_acq.acquisition_type, "ExaSPIM")

0 commit comments

Comments
 (0)