Skip to content

Commit f6eaf33

Browse files
authored
hotfix: recover source_data field, including validator (#1534)
* hotfix: recover source_data field, including validator * tests: additional coverage for condition where source_data is passed in * chore: lint
1 parent 79d3807 commit f6eaf33

File tree

2 files changed

+146
-17
lines changed

2 files changed

+146
-17
lines changed

src/aind_data_schema/core/data_description.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ class DataDescription(DataCoreModel):
105105
" of any technology or formal procedure to generate data for a study",
106106
title="Modalities",
107107
)
108+
source_data: Optional[List[str]] = Field(
109+
default=None,
110+
description="For derived assets, list the source data asset names used to create this data",
111+
title="Source data",
112+
)
108113
data_summary: Optional[str] = Field(
109114
default=None, title="Data summary", description="Semantic summary of experimental goal"
110115
)
@@ -156,8 +161,17 @@ def build_name(self):
156161

157162
return self
158163

164+
@model_validator(mode="after")
165+
def source_data_when_raw(self):
166+
"""Ensure that source_data is not provided when data_level is RAW"""
167+
if self.data_level == DataLevel.RAW and self.source_data is not None:
168+
raise ValueError("source_data must not be set when data_level is 'raw'")
169+
return self
170+
159171
@classmethod
160-
def from_raw(cls, data_description: "DataDescription", process_name: str, **kwargs) -> "DataDescription":
172+
def from_raw(
173+
cls, data_description: "DataDescription", process_name: str, source_data: Optional[List[str]] = None, **kwargs
174+
) -> "DataDescription":
161175
"""
162176
Create a DataLevel.DERIVED DataDescription from a DataLevel.RAW DataDescription object.
163177
@@ -205,10 +219,21 @@ def get_or_default(field_name: str) -> Any:
205219
raise ValueError(f"creation_time({creation_time}) must be a datetime object")
206220

207221
# Upgrade name
208-
derived_name = f"{data_description.name}_{process_name}_{datetime_to_name_string(creation_time)}"
222+
original_name = data_description.name
223+
derived_name = f"{original_name}_{process_name}_{datetime_to_name_string(creation_time)}"
209224
if not re.match(DataRegex.DERIVED.value, derived_name): # pragma: no cover
210225
raise ValueError(f"Derived name({derived_name}) does not match allowed Regex pattern")
211226

227+
# Upgrade source_data
228+
if source_data is not None:
229+
new_source_data = (
230+
source_data if not data_description.source_data else data_description.source_data + source_data
231+
)
232+
else:
233+
new_source_data = (
234+
[original_name] if not data_description.source_data else data_description.source_data + [original_name]
235+
)
236+
212237
return cls(
213238
subject_id=get_or_default("subject_id"),
214239
creation_time=creation_time,
@@ -223,4 +248,5 @@ def get_or_default(field_name: str) -> Any:
223248
restrictions=get_or_default("restrictions"),
224249
modalities=get_or_default("modalities"),
225250
data_summary=get_or_default("data_summary"),
251+
source_data=new_source_data,
226252
)

tests/test_data_description.py

Lines changed: 118 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_raw_data_description_construction(self):
3535
da = DataDescription(
3636
creation_time=dt,
3737
institution=Organization.AIND,
38-
data_level="raw",
38+
data_level=DataLevel.RAW,
3939
funding_source=[f],
4040
modalities=[Modality.ECEPHYS],
4141
subject_id="12345",
@@ -75,7 +75,7 @@ def test_derived_data_description_construction(self):
7575
da = DataDescription(
7676
creation_time=dt,
7777
institution=Organization.AIND,
78-
data_level="raw",
78+
data_level=DataLevel.RAW,
7979
funding_source=[f],
8080
modalities=[Modality.ECEPHYS],
8181
subject_id="12345",
@@ -92,7 +92,7 @@ def test_nested_derived_data_description_construction(self):
9292
da = DataDescription(
9393
creation_time=dt,
9494
institution=Organization.AIND,
95-
data_level="raw",
95+
data_level=DataLevel.RAW,
9696
funding_source=[f],
9797
modalities=[Modality.ECEPHYS],
9898
subject_id="12345",
@@ -111,7 +111,7 @@ def test_data_description_construction(self):
111111
dd = DataDescription(
112112
modalities=[Modality.SPIM],
113113
subject_id="1234",
114-
data_level="raw",
114+
data_level=DataLevel.RAW,
115115
creation_time=dt,
116116
institution=Organization.AIND,
117117
funding_source=[f],
@@ -128,7 +128,7 @@ def test_data_description_construction_failure(self):
128128
DataDescription(
129129
modalities=[Modality.SPIM],
130130
subject_id="",
131-
data_level="raw",
131+
data_level=DataLevel.RAW,
132132
creation_time=dt,
133133
institution=Organization.AIND,
134134
funding_source=[f],
@@ -141,7 +141,6 @@ def test_parse_name_invalid(self):
141141

142142
with self.assertRaises(ValueError) as context:
143143
DataDescription.parse_name("name", "invalid_data_level")
144-
145144
self.assertIn("DataLevel", str(context.exception))
146145

147146
def test_derived_valid(self):
@@ -152,7 +151,7 @@ def test_derived_valid(self):
152151
dr = DataDescription(
153152
modalities=[Modality.SPIM],
154153
subject_id="1234",
155-
data_level="raw",
154+
data_level=DataLevel.RAW,
156155
creation_time=dt,
157156
institution=Organization.AIND,
158157
funding_source=[f],
@@ -172,7 +171,7 @@ def test_raw_no_subject_id(self):
172171
DataDescription(
173172
creation_time=dt,
174173
institution=Organization.AIND,
175-
data_level="raw",
174+
data_level=DataLevel.RAW,
176175
funding_source=[Funding(funder=Organization.NINDS, grant_number="grant001")],
177176
modalities=[Modality.ECEPHYS],
178177
investigators=[Person(name="Jane Smith")],
@@ -188,7 +187,7 @@ def test_derived_bad_creation_time(self):
188187
da = DataDescription(
189188
creation_time=dt,
190189
institution=Organization.AIND,
191-
data_level="raw",
190+
data_level=DataLevel.RAW,
192191
funding_source=[Funding(funder=Organization.NINDS, grant_number="grant001")],
193192
modalities=[Modality.ECEPHYS],
194193
subject_id="12345",
@@ -203,19 +202,16 @@ def test_derived_bad_creation_time(self):
203202

204203
def test_data_description_missing_fields(self):
205204
"""Test DataDescription missing fields"""
206-
dt = datetime.datetime.now()
207-
with self.assertRaises(ValueError):
205+
with self.assertRaises(ValidationError):
208206
DataDescription()
209-
with self.assertRaises(ValueError):
210-
DataDescription(creation_time=dt)
211207

212208
def test_pattern_errors(self):
213209
"""Tests that errors are raised if malformed strings are input"""
214210
with self.assertRaises(ValidationError) as e:
215211
DataDescription(
216212
modalities=[Modality.SPIM],
217213
subject_id="1234",
218-
data_level="raw",
214+
data_level=DataLevel.RAW,
219215
project_name="a_32r&!#R$&#",
220216
creation_time=datetime.datetime(2020, 10, 10, 10, 10, 10),
221217
institution=Organization.AIND,
@@ -240,7 +236,7 @@ def test_round_trip(self):
240236
da1 = DataDescription(
241237
creation_time=dt,
242238
institution=Organization.AIND,
243-
data_level="raw",
239+
data_level=DataLevel.RAW,
244240
funding_source=[Funding(funder=Organization.NINDS, grant_number="grant001")],
245241
modalities=[Modality.SPIM],
246242
subject_id="12345",
@@ -275,6 +271,113 @@ def test_unique_abbreviations(self):
275271
modality_abbreviations = [m().abbreviation for m in Modality.ALL]
276272
self.assertEqual(len(set(modality_abbreviations)), len(modality_abbreviations))
277273

274+
def test_source_data_field(self):
275+
"""Tests the source_data field behavior"""
276+
277+
# source_data should not be set for raw data
278+
with self.assertRaises(ValueError) as context:
279+
DataDescription(
280+
modalities=[Modality.SPIM],
281+
subject_id="1234",
282+
data_level=DataLevel.RAW,
283+
creation_time=datetime.datetime.now(),
284+
institution=Organization.AIND,
285+
funding_source=[Funding(funder=Organization.NINDS, grant_number="grant001")],
286+
investigators=[Person(name="Jane Smith")],
287+
project_name="Test",
288+
source_data=["some_source_data"],
289+
)
290+
self.assertIn("source_data must not be set when data_level is 'raw'", str(context.exception))
291+
292+
# source_data should be set correctly for derived data
293+
dt = datetime.datetime.now()
294+
f = Funding(funder=Organization.NINDS, grant_number="grant001")
295+
da = DataDescription(
296+
creation_time=dt,
297+
institution=Organization.AIND,
298+
data_level=DataLevel.RAW,
299+
funding_source=[f],
300+
modalities=[Modality.ECEPHYS],
301+
subject_id="12345",
302+
investigators=[Person(name="Jane Smith")],
303+
project_name="Test",
304+
)
305+
r1 = DataDescription.from_raw(da, "spikesort-ks25", creation_time=dt)
306+
self.assertIsNotNone(r1.source_data)
307+
self.assertEqual(len(r1.source_data), 1)
308+
self.assertEqual(r1.source_data[0], da.name)
309+
310+
def test_from_raw_with_explicit_source_data(self):
311+
"""Test from_raw with explicitly provided source_data parameter"""
312+
dt = datetime.datetime.now()
313+
f = Funding(funder=Organization.NINDS, grant_number="grant001")
314+
315+
# Create a raw DataDescription
316+
da = DataDescription(
317+
creation_time=dt,
318+
institution=Organization.AIND,
319+
data_level=DataLevel.RAW,
320+
funding_source=[f],
321+
modalities=[Modality.ECEPHYS],
322+
subject_id="12345",
323+
investigators=[Person(name="Jane Smith")],
324+
project_name="Test",
325+
)
326+
327+
# Test scenario 3: RAW data → DERIVED with explicit source_data
328+
explicit_source = ["external_dataset_1", "external_dataset_2"]
329+
r1 = DataDescription.from_raw(da, "spikesort-ks25", source_data=explicit_source, creation_time=dt)
330+
331+
# Should use the explicit source_data instead of the original name
332+
self.assertIsNotNone(r1.source_data)
333+
self.assertEqual(len(r1.source_data), 2)
334+
self.assertEqual(r1.source_data, explicit_source)
335+
self.assertNotIn(da.name, r1.source_data)
336+
337+
# Test scenario 4: DERIVED data → DERIVED with additional source_data
338+
additional_source = ["another_external_dataset"]
339+
r2 = DataDescription.from_raw(r1, "clustering", source_data=additional_source, creation_time=dt)
340+
341+
# Should combine existing source_data with new source_data
342+
self.assertIsNotNone(r2.source_data)
343+
self.assertEqual(len(r2.source_data), 3) # 2 from r1 + 1 additional
344+
self.assertEqual(r2.source_data[:2], explicit_source) # First two should be from r1
345+
self.assertEqual(r2.source_data[2:], additional_source) # Last one should be additional
346+
347+
def test_from_raw_chained_source_data_behavior(self):
348+
"""Test source_data behavior in chained derived data without explicit source_data"""
349+
dt = datetime.datetime.now()
350+
f = Funding(funder=Organization.NINDS, grant_number="grant001")
351+
352+
# Create a raw DataDescription
353+
da = DataDescription(
354+
creation_time=dt,
355+
institution=Organization.AIND,
356+
data_level=DataLevel.RAW,
357+
funding_source=[f],
358+
modalities=[Modality.ECEPHYS],
359+
subject_id="12345",
360+
investigators=[Person(name="Jane Smith")],
361+
project_name="Test",
362+
)
363+
364+
# First derivation: RAW → DERIVED (should set source_data to original name)
365+
r1 = DataDescription.from_raw(da, "spikesort-ks25", creation_time=dt)
366+
self.assertEqual(r1.source_data, [da.name])
367+
368+
# Second derivation: DERIVED → DERIVED (should append to existing source_data)
369+
r2 = DataDescription.from_raw(r1, "clustering", creation_time=dt)
370+
self.assertEqual(len(r2.source_data), 2)
371+
self.assertEqual(r2.source_data[0], da.name) # Original source
372+
self.assertEqual(r2.source_data[1], r1.name) # Previous derived data
373+
374+
# Third derivation: should continue the chain
375+
r3 = DataDescription.from_raw(r2, "analysis", creation_time=dt)
376+
self.assertEqual(len(r3.source_data), 3)
377+
self.assertEqual(r3.source_data[0], da.name) # Original source
378+
self.assertEqual(r3.source_data[1], r1.name) # First derived
379+
self.assertEqual(r3.source_data[2], r2.name) # Second derived
380+
278381

279382
if __name__ == "__main__":
280383
unittest.main()

0 commit comments

Comments
 (0)