Skip to content

Commit af1302d

Browse files
committed
refactor: allow dictionary tags
1 parent eefa5e0 commit af1302d

File tree

5 files changed

+83
-112
lines changed

5 files changed

+83
-112
lines changed

.github/copilot-instructions.md

Lines changed: 0 additions & 4 deletions
This file was deleted.

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,4 @@ ephys_instrument.json
145145
examples/*.json
146146
*.json
147147
docs/base/models/*
148+
.github/copilot-instructions.md

examples/quality_control.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
status_history=[sp],
4848
tags={
4949
"probe": "Probe A",
50-
"type": "drift map",
5150
}
5251
),
5352
QCMetric(
@@ -60,7 +59,6 @@
6059
status_history=[sp],
6160
tags={
6261
"probe": "Probe B",
63-
"type": "drift map",
6462
}
6563
),
6664
QCMetric(
@@ -73,33 +71,8 @@
7371
status_history=[s],
7472
tags={
7573
"probe": "Probe C",
76-
"type": "drift map",
7774
}
7875
),
79-
QCMetric(
80-
name="Video 1 frame count",
81-
modality=Modality.BEHAVIOR_VIDEOS,
82-
stage=Stage.RAW,
83-
description="Pass when frame count matches expected",
84-
value=662,
85-
status_history=[s],
86-
tags={
87-
"video": "Video 1",
88-
"type": "Frame count checks",
89-
},
90-
),
91-
QCMetric(
92-
name="Video 2 num frames",
93-
modality=Modality.BEHAVIOR_VIDEOS,
94-
stage=Stage.RAW,
95-
description="Pass when frame count matches expected",
96-
value=662,
97-
status_history=[s],
98-
tags={
99-
"video": "Video 2",
100-
"type": "Frame count checks",
101-
},
102-
),
10376
QCMetric(
10477
name="ProbeA",
10578
modality=Modality.ECEPHYS,
@@ -109,7 +82,6 @@
10982
status_history=[s],
11083
tags={
11184
"probe": "Probe A",
112-
"type": "Probes present",
11385
},
11486
),
11587
QCMetric(
@@ -121,7 +93,6 @@
12193
status_history=[s],
12294
tags={
12395
"probe": "Probe B",
124-
"type": "Probes present",
12596
},
12697
),
12798
QCMetric(
@@ -133,15 +104,38 @@
133104
status_history=[s],
134105
tags={
135106
"probe": "Probe C",
136-
"type": "Probes present",
107+
},
108+
),
109+
QCMetric(
110+
name="Video 1 frame count",
111+
modality=Modality.BEHAVIOR_VIDEOS,
112+
stage=Stage.RAW,
113+
description="Pass when frame count matches expected",
114+
value=662,
115+
status_history=[s],
116+
tags={
117+
"video": "Video 1",
118+
},
119+
),
120+
QCMetric(
121+
name="Video 2 num frames",
122+
modality=Modality.BEHAVIOR_VIDEOS,
123+
stage=Stage.RAW,
124+
description="Pass when frame count matches expected",
125+
value=662,
126+
status_history=[s],
127+
tags={
128+
"video": "Video 2",
137129
},
138130
),
139131
]
140132

141133
q = QualityControl(
142134
metrics=metrics,
143-
default_grouping=[["probe", "video"], ["type"]], # in visualizations group probes together and videos together, then group metrics by type
144-
allow_tag_failures=["Video 2"], # allow any metrics with tag video: Video 2 to fail without failing overall QC
135+
# in visualizations split first by modality, then by probe / video tags
136+
default_grouping=[["modality"], ["probe", "video"]],
137+
# allow any metrics with tag video: Video 2 to fail without failing overall QC
138+
allow_tag_failures=["Video 2"],
145139
)
146140

147141
if __name__ == "__main__":

src/aind_data_schema/core/quality_control.py

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from datetime import datetime, timezone
44
from enum import Enum
55
from typing import Any, List, Literal, Optional, Union
6+
import warnings
67

78
from aind_data_schema_models.modalities import Modality
89
from pydantic import Field, SkipValidation, field_validator, model_validator
@@ -56,9 +57,7 @@ class QCMetric(DataModel):
5657
evaluated_assets: Optional[List[str]] = Field(
5758
default=None,
5859
title="List of asset names that this metric depends on",
59-
description=(
60-
"Set to None except when a metric's calculation required data " "coming from a different data asset."
61-
),
60+
description="Set to None except when a metric's calculation required data coming from a different data asset.",
6261
)
6362

6463
@property
@@ -90,17 +89,12 @@ def fix_tag_lists(cls, self):
9089
9190
Remove this function in aind-data-schema v3.X
9291
"""
92+
if "tags" not in self:
93+
return self
9394
tags = self["tags"]
9495
if isinstance(tags, list):
95-
# Convert list of strings to dict with string keys
96-
if len(tags) == 1:
97-
self["tags"] = {
98-
"tag": tags[0],
99-
"name": self["name"],
100-
}
101-
else:
102-
# Unfortunately there is no reasonable way to handle multiple tags, these assets should be re-generated
103-
self["tags"] = {f"tag_{i+1}": tag for i, tag in enumerate(tags)}
96+
warnings.warn("QCMetric 'tags' field is now a dict. Converting from list to dict", DeprecationWarning)
97+
self["tags"] = {f"tag_{i+1}": tag for i, tag in enumerate(tags)}
10498
return self
10599

106100

@@ -156,11 +150,11 @@ def tags(self) -> List[str]:
156150
Returns
157151
-------
158152
List[str]
159-
List of all unique tags across all metrics
153+
List of all unique tag values across all metrics
160154
"""
161155
all_tags = []
162156
for metric in self.metrics:
163-
all_tags.extend(metric.tags)
157+
all_tags.extend(metric.tags.values())
164158
return list(set(all_tags))
165159

166160
@property
@@ -280,15 +274,18 @@ def __add__(self, other: "QualityControl") -> "QualityControl":
280274
allow_tag_failures=combined_allow_tag_failures,
281275
)
282276

283-
@field_validator("default_grouping", mode="before")
277+
@model_validator(mode="before")
284278
def fix_default_grouping_list(cls, value: dict) -> dict:
285279
"""Convert default grouping from list of strings to list of list of strings if necessary
286280
This function is for backwards compatibility with v2.2.X where default_grouping was stored as a list of strings.
287281
Remove this function in aind-data-schema v3.X
288282
"""
289-
if value and len(value) > 0 and isinstance(value[0], str):
283+
if "default_grouping" not in value:
284+
return value
285+
if value["default_grouping"] and isinstance(value["default_grouping"][0], str):
286+
# Add the modality as the top-level grouping
290287
# Convert list of strings to list of list of strings
291-
value = [[tag] for tag in value]
288+
value["default_grouping"] = [[value["modality"]["abbreviation"]]] + ["tag_0"]
292289
return value
293290

294291

@@ -331,7 +328,7 @@ def _get_filtered_statuses(
331328
modality_filter: Optional[List[Modality.ONE_OF]] = None,
332329
stage_filter: Optional[List[Stage]] = None,
333330
tag_filter: Optional[List[str]] = None,
334-
allow_tag_failures: List[str | tuple] = [],
331+
allow_tag_failures: List[str] = [],
335332
):
336333
"""Get the status of metrics filtered by modality, stage, tag, and date."""
337334
filtered_statuses = []
@@ -341,22 +338,16 @@ def _get_filtered_statuses(
341338
continue
342339
if stage_filter and metric.stage not in stage_filter:
343340
continue
344-
if tag_filter and not (metric.tags and any(t in metric.tags for t in tag_filter)):
341+
if tag_filter and not (metric.tags and any(t in metric.tags.values() for t in tag_filter)):
345342
continue
346343

347344
# Get status at the specified date using the helper function
348345
status = _get_status_by_date(metric, date)
349-
# Check if any of our tags are in the allow_tag_failures list
346+
# Check if any of our tag values are in the allow_tag_failures list
350347
if status == Status.FAIL and metric.tags:
351-
for fail2pass_tags in allow_tag_failures:
352-
if isinstance(fail2pass_tags, tuple):
353-
# If it's a tuple, check if all of the tags match
354-
if all(t in metric.tags for t in fail2pass_tags):
355-
status = Status.PASS
356-
break
357-
elif fail2pass_tags in metric.tags:
358-
status = Status.PASS
359-
break
348+
metric_tag_values = set(metric.tags.values())
349+
if any(tag_value in allow_tag_failures for tag_value in metric_tag_values):
350+
status = Status.PASS
360351
filtered_statuses.append(status)
361352

362353
return filtered_statuses

0 commit comments

Comments
 (0)