Skip to content

Commit 7e9d863

Browse files
authored
feat: add optional calibration fit to Calibration (#1494)
* feat: working on CalibrationFit * chore: lint * chore: docstring int he wrong place
1 parent 76da9d2 commit 7e9d863

File tree

2 files changed

+117
-8
lines changed

2 files changed

+117
-8
lines changed

src/aind_data_schema/components/measurements.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,52 @@
11
"""Calibration data models"""
22

3+
from enum import Enum
34
from typing import List, Literal, Optional
45

5-
from aind_data_schema_models.units import UNITS, PowerUnit, TimeUnit, VolumeUnit
6+
from aind_data_schema_models.units import UNITS, PowerUnit, TimeUnit, VolumeUnit, VoltageUnit
7+
from pydantic import model_validator
68

7-
from aind_data_schema.base import AwareDatetimeWithDefault, Discriminated, Field
9+
from aind_data_schema.base import AwareDatetimeWithDefault, DataModel, Discriminated, Field, GenericModel
810
from aind_data_schema.components.configs import DeviceConfig
911
from aind_data_schema.components.reagent import Reagent
1012

1113

14+
class FitType(Enum):
15+
"""Type of fit for calibration data"""
16+
17+
LINEAR_INTERPOLATION = "linear_interpolation"
18+
LINEAR = "linear"
19+
OTHER = "other"
20+
21+
22+
class CalibrationFit(DataModel):
23+
"""Fit equation for calibration data"""
24+
25+
fit_type: FitType = Field(
26+
...,
27+
title="Fit type",
28+
)
29+
fit_parameters: Optional[GenericModel] = Field(
30+
default=None,
31+
title="Fit parameters",
32+
description="Parameters of the fit equation, e.g. slope and intercept for linear fit",
33+
)
34+
35+
@model_validator(mode="before")
36+
def validate_fit_type(cls, values):
37+
"""Ensure that parameters are provided for linear and other fits"""
38+
fit_type = values.get("fit_type")
39+
fit_parameters = values.get("fit_parameters")
40+
41+
if fit_type in {FitType.LINEAR, FitType.OTHER} and not fit_parameters:
42+
raise ValueError(f"Fit parameters must be provided for {fit_type.value} fit type")
43+
44+
if fit_type == FitType.LINEAR_INTERPOLATION and fit_parameters is not None:
45+
raise ValueError("Fit parameters should not be provided for linear interpolation fit type")
46+
47+
return values
48+
49+
1250
class Calibration(DeviceConfig):
1351
"""Generic calibration class"""
1452

@@ -25,17 +63,21 @@ class Calibration(DeviceConfig):
2563
..., description="Calibration output (provide the average if repeated)", title="Outputs"
2664
)
2765
output_unit: UNITS = Field(..., title="Output unit")
66+
fit: Optional[CalibrationFit] = Field(
67+
default=None,
68+
title="Fit",
69+
description="Fit equation for the calibration data used during data acquisition",
70+
)
2871
notes: Optional[str] = Field(
2972
default=None,
3073
title="Notes",
31-
description="Fit equation, etc",
3274
)
3375

3476

3577
class VolumeCalibration(Calibration):
36-
"""Calibration of a liquid delivery device"""
78+
"""Calibration of a liquid delivery device based on solenoid/valve opening times"""
3779

38-
input: List[float] = Field(..., title="Input times", description="Length of time solenoid is open")
80+
input: List[float] = Field(..., title="Input times", description="Length of time solenoid/valve is open")
3981
input_unit: TimeUnit = Field(..., title="Input unit")
4082
repeats: Optional[int] = Field(
4183
default=None,
@@ -51,10 +93,10 @@ class VolumeCalibration(Calibration):
5193

5294

5395
class PowerCalibration(Calibration):
54-
"""Calibration of a laser device"""
96+
"""Calibration of a device that outputs power based on input strength"""
5597

56-
input: List[float] = Field(..., title="Input", description="Power or percentage input strength")
57-
input_unit: PowerUnit = Field(..., title="Input unit")
98+
input: List[float] = Field(..., title="Input", description="Power, voltage, or percentage input strength")
99+
input_unit: PowerUnit | VoltageUnit = Field(..., title="Input unit")
58100
output: List[float] = Field(..., title="Output", description="Power output (provide the average if repeated)")
59101
output_unit: PowerUnit = Field(..., title="Output unit")
60102

tests/test_measurements.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Tests for CalibrationFit from measurements module"""
2+
3+
import unittest
4+
from pydantic import ValidationError
5+
from aind_data_schema.components.measurements import CalibrationFit, FitType
6+
7+
8+
class TestCalibrationFit(unittest.TestCase):
9+
"""Tests for CalibrationFit class"""
10+
11+
def test_linear_interpolation_without_parameters(self):
12+
"""Test that linear interpolation fit type works without parameters"""
13+
fit = CalibrationFit(fit_type=FitType.LINEAR_INTERPOLATION)
14+
self.assertEqual(fit.fit_type, FitType.LINEAR_INTERPOLATION.value)
15+
self.assertIsNone(fit.fit_parameters)
16+
17+
def test_linear_interpolation_with_parameters_raises_error(self):
18+
"""Test that linear interpolation fit type raises error with parameters"""
19+
with self.assertRaises(ValidationError) as context:
20+
CalibrationFit(fit_type=FitType.LINEAR_INTERPOLATION, fit_parameters={"slope": 1.0, "intercept": 0.0})
21+
self.assertIn("Fit parameters should not be provided for linear interpolation fit type", str(context.exception))
22+
23+
def test_linear_fit_with_parameters(self):
24+
"""Test that linear fit type works with parameters"""
25+
parameters = {"slope": 1.5, "intercept": 2.0}
26+
fit = CalibrationFit(fit_type=FitType.LINEAR, fit_parameters=parameters)
27+
self.assertEqual(fit.fit_type, FitType.LINEAR.value)
28+
# Compare the fit_parameters as a dict using model_dump()
29+
self.assertIsNotNone(fit.fit_parameters)
30+
self.assertEqual(fit.fit_parameters.model_dump(), parameters)
31+
32+
def test_linear_fit_without_parameters_raises_error(self):
33+
"""Test that linear fit type raises error without parameters"""
34+
with self.assertRaises(ValidationError) as context:
35+
CalibrationFit(fit_type=FitType.LINEAR)
36+
self.assertIn("Fit parameters must be provided for linear fit type", str(context.exception))
37+
38+
def test_other_fit_with_parameters(self):
39+
"""Test that other fit type works with parameters"""
40+
parameters = {"a": 1.0, "b": 2.0, "c": 3.0}
41+
fit = CalibrationFit(fit_type=FitType.OTHER, fit_parameters=parameters)
42+
self.assertEqual(fit.fit_type, FitType.OTHER.value)
43+
# Compare the fit_parameters as a dict using model_dump()
44+
self.assertIsNotNone(fit.fit_parameters)
45+
self.assertEqual(fit.fit_parameters.model_dump(), parameters)
46+
47+
def test_other_fit_without_parameters_raises_error(self):
48+
"""Test that other fit type raises error without parameters"""
49+
with self.assertRaises(ValidationError) as context:
50+
CalibrationFit(fit_type=FitType.OTHER)
51+
self.assertIn("Fit parameters must be provided for other fit type", str(context.exception))
52+
53+
def test_linear_fit_with_none_parameters_raises_error(self):
54+
"""Test that linear fit type raises error with None parameters"""
55+
with self.assertRaises(ValidationError) as context:
56+
CalibrationFit(fit_type=FitType.LINEAR, fit_parameters=None)
57+
self.assertIn("Fit parameters must be provided for linear fit type", str(context.exception))
58+
59+
def test_other_fit_with_none_parameters_raises_error(self):
60+
"""Test that other fit type raises error with None parameters"""
61+
with self.assertRaises(ValidationError) as context:
62+
CalibrationFit(fit_type=FitType.OTHER, fit_parameters=None)
63+
self.assertIn("Fit parameters must be provided for other fit type", str(context.exception))
64+
65+
66+
if __name__ == "__main__":
67+
unittest.main()

0 commit comments

Comments
 (0)