Skip to content

Commit 00e604b

Browse files
authored
fix: populate source_data properly (#1656)
1 parent 874b398 commit 00e604b

File tree

2 files changed

+22
-43
lines changed

2 files changed

+22
-43
lines changed

src/aind_data_schema/core/data_description.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,6 @@ def get_or_default(field_name: str) -> Any:
226226
if not re.match(DataRegex.DERIVED.value, derived_name): # pragma: no cover
227227
raise ValueError(f"Derived name({derived_name}) does not match allowed Regex pattern")
228228

229-
# Upgrade source_data
230-
current_source_data = data_description.source_data or []
231-
if source_data is not None:
232-
new_source_data = current_source_data + source_data
233-
else:
234-
new_source_data = current_source_data + [original_name]
235-
236229
return cls(
237230
subject_id=get_or_default("subject_id"),
238231
creation_time=creation_time,
@@ -247,7 +240,7 @@ def get_or_default(field_name: str) -> Any:
247240
restrictions=get_or_default("restrictions"),
248241
modalities=get_or_default("modalities"),
249242
data_summary=get_or_default("data_summary"),
250-
source_data=new_source_data,
243+
source_data=source_data if source_data else [original_name],
251244
)
252245

253246
@classmethod
@@ -320,13 +313,6 @@ def get_or_default(field_name: str) -> Any:
320313
if not re.match(DataRegex.DERIVED.value, derived_name): # pragma: no cover
321314
raise ValueError(f"Derived name({derived_name}) does not match allowed Regex pattern")
322315

323-
# Upgrade source_data
324-
current_source_data = data_description.source_data or []
325-
if source_data is not None:
326-
new_source_data = current_source_data + source_data
327-
else:
328-
new_source_data = current_source_data + [data_description.name]
329-
330316
return cls(
331317
subject_id=get_or_default("subject_id"),
332318
creation_time=creation_time,
@@ -341,7 +327,7 @@ def get_or_default(field_name: str) -> Any:
341327
restrictions=get_or_default("restrictions"),
342328
modalities=get_or_default("modalities"),
343329
data_summary=get_or_default("data_summary"),
344-
source_data=new_source_data,
330+
source_data=source_data if source_data else [data_description.name],
345331
)
346332

347333
@classmethod

tests/test_data_description.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -336,15 +336,14 @@ def test_from_raw_with_explicit_source_data(self):
336336
self.assertEqual(r1.source_data, explicit_source)
337337
self.assertNotIn(da.name, r1.source_data)
338338

339-
# Test scenario 4: DERIVED data → DERIVED with additional source_data
339+
# Test scenario 4: DERIVED data → DERIVED with explicit source_data
340340
additional_source = ["another_external_dataset"]
341341
r2 = DataDescription.from_derived(r1, "clustering", source_data=additional_source, creation_time=dt)
342342

343-
# Should combine existing source_data with new source_data
343+
# Should use the explicit source_data (not combine with existing)
344344
self.assertIsNotNone(r2.source_data)
345-
self.assertEqual(len(r2.source_data), 3) # 2 from r1 + 1 additional
346-
self.assertEqual(r2.source_data[:2], explicit_source) # First two should be from r1
347-
self.assertEqual(r2.source_data[2:], additional_source) # Last one should be additional
345+
self.assertEqual(len(r2.source_data), 1) # Just the new source_data
346+
self.assertEqual(r2.source_data, additional_source)
348347

349348
def test_from_raw_chained_source_data_behavior(self):
350349
"""Test source_data behavior in chained derived data without explicit source_data"""
@@ -367,18 +366,15 @@ def test_from_raw_chained_source_data_behavior(self):
367366
r1 = DataDescription.from_raw(da, "spikesort-ks25", creation_time=dt)
368367
self.assertEqual(r1.source_data, [da.name])
369368

370-
# Second derivation: DERIVED → DERIVED (should append to existing source_data)
369+
# Second derivation: DERIVED → DERIVED (should use only the immediate predecessor)
371370
r2 = DataDescription.from_derived(r1, "clustering", creation_time=dt)
372-
self.assertEqual(len(r2.source_data), 2)
373-
self.assertEqual(r2.source_data[0], da.name) # Original source
374-
self.assertEqual(r2.source_data[1], r1.name) # Previous derived data
371+
self.assertEqual(len(r2.source_data), 1)
372+
self.assertEqual(r2.source_data[0], r1.name) # Only the immediate predecessor
375373

376-
# Third derivation: should continue the chain
374+
# Third derivation: should only reference the immediate predecessor
377375
r3 = DataDescription.from_derived(r2, "analysis", creation_time=dt)
378-
self.assertEqual(len(r3.source_data), 3)
379-
self.assertEqual(r3.source_data[0], da.name) # Original source
380-
self.assertEqual(r3.source_data[1], r1.name) # First derived
381-
self.assertEqual(r3.source_data[2], r2.name) # Second derived
376+
self.assertEqual(len(r3.source_data), 1)
377+
self.assertEqual(r3.source_data[0], r2.name) # Only the immediate predecessor
382378

383379
def test_from_derived_basic_functionality(self):
384380
"""Test from_derived creates derived data using original input name"""
@@ -402,9 +398,8 @@ def test_from_derived_basic_functionality(self):
402398
self.assertIn("quality_control", derived2.name)
403399
self.assertNotIn("spike_sorting", derived2.name) # Should not chain process names
404400
self.assertEqual(derived2.data_level, DataLevel.DERIVED)
405-
self.assertEqual(len(derived2.source_data), 2)
406-
self.assertEqual(derived2.source_data[0], example_data_description.name) # Original
407-
self.assertEqual(derived2.source_data[1], derived1.name) # Previous derived
401+
self.assertEqual(len(derived2.source_data), 1)
402+
self.assertEqual(derived2.source_data[0], derived1.name) # Only immediate predecessor
408403

409404
# Verify the names have the expected structure
410405
expected_derived1_prefix = f"{example_data_description.name}_spike_sorting_"
@@ -433,10 +428,9 @@ def test_from_derived_with_explicit_source_data(self):
433428
explicit_source = ["external_dataset_1", "external_dataset_2"]
434429
derived2 = DataDescription.from_derived(derived1, "analysis", source_data=explicit_source, creation_time=dt2)
435430

436-
# Should combine existing source_data with new source_data
437-
self.assertEqual(len(derived2.source_data), 3) # 1 from derived1 + 2 explicit
438-
self.assertEqual(derived2.source_data[0], example_data_description.name) # From derived1
439-
self.assertEqual(derived2.source_data[1:], explicit_source) # Explicit source_data
431+
# Should use the explicit source_data (not combine with existing)
432+
self.assertEqual(len(derived2.source_data), 2) # Just the explicit source_data
433+
self.assertEqual(derived2.source_data, explicit_source) # Explicit source_data only
440434

441435
def test_from_derived_chained_behavior(self):
442436
"""Test chained from_derived calls maintain original input name"""
@@ -455,10 +449,10 @@ def test_from_derived_chained_behavior(self):
455449
self.assertTrue(derived2.name.startswith(f"{original_prefix}_process2_"))
456450
self.assertTrue(derived3.name.startswith(f"{original_prefix}_process3_"))
457451

458-
# Verify source_data chains correctly
452+
# Verify source_data only contains immediate predecessor
459453
self.assertEqual(derived1.source_data, [example_data_description.name])
460-
self.assertEqual(derived2.source_data, [example_data_description.name, derived1.name])
461-
self.assertEqual(derived3.source_data, [example_data_description.name, derived1.name, derived2.name])
454+
self.assertEqual(derived2.source_data, [derived1.name])
455+
self.assertEqual(derived3.source_data, [derived2.name])
462456

463457
def test_from_derived_name_parsing(self):
464458
"""Test from_derived correctly parses complex derived names"""
@@ -552,9 +546,8 @@ def test_from_data_description_with_kwargs_and_source_data(self):
552546
)
553547

554548
self.assertEqual(result_derived.tags, custom_tags)
555-
# Should combine existing source_data with explicit source_data
556-
expected_source_data = [example_data_description.name] + explicit_source
557-
self.assertEqual(result_derived.source_data, expected_source_data)
549+
# Should use the explicit source_data (not combine with existing)
550+
self.assertEqual(result_derived.source_data, explicit_source)
558551

559552
def test_from_derived_with_invalid_creation_time(self):
560553
"""Test from_derived error when creation_time is not a datetime object"""

0 commit comments

Comments
 (0)