Skip to content

Commit 9c8050d

Browse files
auto select knots using aicc
1 parent e9e6a7f commit 9c8050d

File tree

9 files changed

+109
-21
lines changed

9 files changed

+109
-21
lines changed

core/pioreactor/calibrations/protocols/od_reference_standard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def get_reference_standard_step(
256256
return get_session_step(_REFERENCE_STANDARD_STEPS, session, executor)
257257

258258

259-
def get_valid_od_devices_for_this_unit() -> list[str]:
259+
def get_valid_od_devices_for_this_unit() -> list[pt.ODCalibrationDevices]:
260260

261261
pd_channels = config["od_config.photodiode_channel"]
262262
valid_devices: list[pt.ODCalibrationDevices] = []

core/pioreactor/calibrations/protocols/od_single_vial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def run_od_calibration(target_device: pt.ODCalibrationDevices) -> structs.OD600C
484484

485485

486486
class SingleVialODProtocol(CalibrationProtocol[pt.ODCalibrationDevices]):
487-
target_device = pt.OD_DEVICES
487+
target_device = pt.OD_DEVICES # type: ignore
488488
protocol_name = "single_vial"
489489
description = "Calibrate OD using a single vial"
490490

core/pioreactor/calibrations/protocols/od_standards.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,7 @@ def _calculate_curve_data(
187187
if len(od600_values) >= 3:
188188
from pioreactor.utils.splines import spline_fit
189189

190-
knots = min(4, len(od600_values))
191-
return "spline", spline_fit(od600_values, voltages, knots=knots, weights=weights)
190+
return "spline", spline_fit(od600_values, voltages, knots="auto", weights=weights)
192191

193192
degree = min(3, max(1, len(od600_values) - 1))
194193
return "poly", utils.calculate_poly_curve_of_best_fit(od600_values, voltages, degree, weights)
@@ -598,7 +597,7 @@ def get_standards_step(
598597
return get_session_step(_OD_STANDARDS_STEPS, session, executor)
599598

600599

601-
def get_valid_od_devices_for_this_unit() -> list[str]:
600+
def get_valid_od_devices_for_this_unit() -> list[pt.ODCalibrationDevices]:
602601
pd_channels = config["od_config.photodiode_channel"]
603602
valid_devices: list[pt.ODCalibrationDevices] = []
604603

core/pioreactor/calibrations/protocols/pump_duration_based.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ def run_pump_calibration(
580580

581581

582582
class DurationBasedPumpProtocol(CalibrationProtocol[pt.PumpCalibrationDevices]):
583-
target_device = pt.PUMP_DEVICES
583+
target_device = cast(list[pt.PumpCalibrationDevices], pt.PUMP_DEVICES)
584584
protocol_name = "duration_based"
585585
title = "Duration-based pump calibration"
586586
description = "Build a duration-to-volume curve for the {device} pump using a simple multi-step flow."

core/pioreactor/calibrations/protocols/stirring_dc_based.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,13 @@ def _build_stirring_calibration_from_measurements(
140140

141141
from pioreactor.utils.splines import spline_fit
142142

143-
knots = min(3, len(dcs))
144143
return SimpleStirringCalibration(
145144
pwm_hz=config.getfloat("stirring.config", "pwm_hz"),
146145
voltage=voltage,
147146
calibration_name=f"stirring-calibration-{current_utc_datetime().strftime('%Y-%m-%d_%H-%M')}",
148147
calibrated_on_pioreactor_unit=unit,
149148
created_at=current_utc_datetime(),
150-
curve_data_=spline_fit(dcs, rpms, knots=knots),
149+
curve_data_=spline_fit(dcs, rpms, knots="auto"),
151150
curve_type="spline",
152151
recorded_data={"x": dcs, "y": rpms},
153152
)

core/pioreactor/calibrations/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
class CalibrationProtocol(Generic[Device]):
1818
protocol_name: ClassVar[ProtocolName]
19-
target_device: ClassVar[str | list[str]]
19+
target_device: ClassVar[str | list[Device]]
2020
title: ClassVar[str] = ""
2121
description: ClassVar[str] = ""
2222
requirements: ClassVar[tuple[str, ...]] = ()

core/pioreactor/utils/polys.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
import numpy as np
77

88

9+
def _to_pyfloat(seq: list[float]) -> list[float]:
10+
# we have trouble serializing numpy floats
11+
return [float(_) for _ in seq]
12+
13+
914
def poly_fit(
1015
x: Sequence[float],
1116
y: Sequence[float],
@@ -33,7 +38,7 @@ def poly_fit(
3338
raise ValueError("weights must be non-negative.")
3439

3540
coefs = np.polyfit(x_values, y_values, deg=degree, w=weight_values)
36-
return coefs.tolist()
41+
return _to_pyfloat(coefs.tolist())
3742

3843

3944
def poly_eval(poly_data: list[float], x: float) -> float:

core/pioreactor/utils/splines.py

Lines changed: 73 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,15 @@
77
import numpy as np
88

99

10+
def _to_pyfloat(seq: list[float]) -> list[float]:
11+
# we have trouble serializing numpy floats
12+
return [float(_) for _ in seq]
13+
14+
1015
def spline_fit(
1116
x: Sequence[float],
1217
y: Sequence[float],
13-
knots: int | Sequence[float],
18+
knots: int | Sequence[float] | str | None = "auto",
1419
weights: Sequence[float] | None = None,
1520
) -> list:
1621
"""
@@ -21,7 +26,8 @@ def spline_fit(
2126
x, y
2227
Observations.
2328
knots
24-
Either the number of knots to use (including boundaries) or explicit knot positions.
29+
Either the number of knots to use (including boundaries), explicit knot positions, or "auto".
30+
When "auto" (default), knot count is selected by AICc over a small candidate range.
2531
weights
2632
Optional weights for each observation.
2733
@@ -41,10 +47,6 @@ def spline_fit(
4147
if np.allclose(x_values, x_values[0]):
4248
raise ValueError("x values must not all be the same.")
4349

44-
knot_positions = _normalize_knots(x_values, knots)
45-
if len(knot_positions) < 2:
46-
raise ValueError("At least two knots are required.")
47-
4850
if weights is None:
4951
weight_values = np.ones_like(x_values)
5052
else:
@@ -54,14 +56,19 @@ def spline_fit(
5456
if np.any(weight_values < 0):
5557
raise ValueError("weights must be non-negative.")
5658

57-
design_matrix = _build_spline_design_matrix(knot_positions, x_values)
58-
weighted_design = design_matrix * np.sqrt(weight_values)[:, None]
59-
weighted_y = y_values * np.sqrt(weight_values)
59+
if knots is None or knots == "auto":
60+
knot_positions = _auto_select_knots(x_values, y_values, weight_values)
61+
elif isinstance(knots, str):
62+
raise ValueError('knots must be an int, a sequence of floats, or "auto".')
63+
else:
64+
knot_positions = _normalize_knots(x_values, knots)
65+
if len(knot_positions) < 2:
66+
raise ValueError("At least two knots are required.")
6067

61-
knot_values, *_ = np.linalg.lstsq(weighted_design, weighted_y, rcond=None)
68+
knot_values, _ = _fit_knot_values(knot_positions, x_values, y_values, weight_values)
6269
coefficients = _natural_cubic_spline_coefficients(knot_positions, knot_values)
6370

64-
return [knot_positions.tolist(), [coeff.tolist() for coeff in coefficients]]
71+
return [_to_pyfloat(knot_positions.tolist()), [_to_pyfloat(coeff.tolist()) for coeff in coefficients]]
6572

6673

6774
def spline_eval(spline_data: list, x: float) -> float:
@@ -133,6 +140,61 @@ def _normalize_knots(x_values: np.ndarray, knots: int | Sequence[float]) -> np.n
133140
return knot_positions
134141

135142

143+
def _fit_knot_values(
144+
knot_positions: np.ndarray,
145+
x_values: np.ndarray,
146+
y_values: np.ndarray,
147+
weight_values: np.ndarray,
148+
) -> tuple[np.ndarray, np.ndarray]:
149+
design_matrix = _build_spline_design_matrix(knot_positions, x_values)
150+
sqrt_weights = np.sqrt(weight_values)
151+
weighted_design = design_matrix * sqrt_weights[:, None]
152+
weighted_y = y_values * sqrt_weights
153+
knot_values, *_ = np.linalg.lstsq(weighted_design, weighted_y, rcond=None)
154+
return knot_values, design_matrix
155+
156+
157+
def _aicc_score(weighted_sse: float, n_obs: int, n_params: int) -> float:
158+
if n_obs <= n_params + 1:
159+
return float("inf")
160+
sse = max(weighted_sse, np.finfo(float).tiny)
161+
correction = (2 * n_params * (n_params + 1)) / (n_obs - n_params - 1)
162+
return n_obs * np.log(sse / n_obs) + 2 * n_params + correction
163+
164+
165+
def _auto_select_knots(
166+
x_values: np.ndarray,
167+
y_values: np.ndarray,
168+
weight_values: np.ndarray,
169+
*,
170+
max_knots: int | None = None,
171+
) -> np.ndarray:
172+
n_obs = x_values.size
173+
unique_x = np.unique(x_values).size
174+
if max_knots is None:
175+
max_knots = min(6, n_obs)
176+
max_knots = min(max_knots, unique_x)
177+
max_knots = max(2, max_knots)
178+
179+
best_score = float("inf")
180+
best_knots: np.ndarray | None = None
181+
182+
for count in range(2, max_knots + 1):
183+
knot_positions = _normalize_knots(x_values, count)
184+
knot_values, design_matrix = _fit_knot_values(knot_positions, x_values, y_values, weight_values)
185+
y_pred = design_matrix @ knot_values
186+
residual = y_values - y_pred
187+
weighted_sse = float(np.sum(weight_values * residual**2))
188+
score = _aicc_score(weighted_sse, n_obs, knot_positions.size)
189+
if score < best_score:
190+
best_score = score
191+
best_knots = knot_positions
192+
193+
if best_knots is None:
194+
return _normalize_knots(x_values, 2)
195+
return best_knots
196+
197+
136198
def _build_spline_design_matrix(knots: np.ndarray, x_values: np.ndarray) -> np.ndarray:
137199
n = x_values.size
138200
m = knots.size

core/tests/test_splines.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@ def test_spline_fit_and_eval_linear() -> None:
1515
assert spline_eval(spline_data, -1.0) == pytest.approx(-1.0, rel=1e-6)
1616

1717

18+
def test_spline_fit_auto_selects_knots() -> None:
19+
x = [0.0, 1.0, 2.0, 3.0]
20+
y = [1.0, 3.0, 5.0, 7.0]
21+
spline_data = spline_fit(x, y, knots="auto")
22+
23+
assert spline_eval(spline_data, 2.5) == pytest.approx(6.0, rel=1e-6)
24+
assert spline_data[0][0] == pytest.approx(min(x), rel=1e-6)
25+
assert spline_data[0][-1] == pytest.approx(max(x), rel=1e-6)
26+
27+
1828
def test_spline_fit_explicit_knots_interpolate_at_knots() -> None:
1929
x = [0.0, 1.0, 2.0]
2030
y = [0.0, 1.0, 0.0]
@@ -133,6 +143,19 @@ def test_spline_fit_reduces_residuals_on_noisy_linear_data() -> None:
133143
assert mse < 0.2
134144

135145

146+
def test_spline_fit_auto_selects_linear_curve_for_noisy_linear_data() -> None:
147+
rng = np.random.default_rng(123)
148+
x = np.linspace(0.0, 10.0, 30)
149+
y = 2.5 * x - 1.0 + rng.normal(0.0, 0.1, size=x.size)
150+
spline_data = spline_fit(x.tolist(), y.tolist(), knots="auto")
151+
152+
knots, coefficients = spline_data
153+
assert len(knots) == 2
154+
assert len(coefficients) == 1
155+
assert coefficients[0][2] == pytest.approx(0.0, abs=1e-12)
156+
assert coefficients[0][3] == pytest.approx(0.0, abs=1e-12)
157+
158+
136159
def test_spline_fit_respects_sorted_or_unsorted_input() -> None:
137160
x = [0.0, 1.0, 2.0, 3.0]
138161
y = [1.0, 2.0, 0.0, 3.0]

0 commit comments

Comments
 (0)