|
1 | 1 | # -*- coding: utf-8 -*- |
| 2 | +from copy import copy |
| 3 | +from math import sqrt |
| 4 | + |
2 | 5 | import click |
3 | 6 | from msgspec.yaml import decode as yaml_decode |
4 | 7 | from pioreactor import structs |
5 | 8 | from pioreactor import types as pt |
6 | 9 | from pioreactor.calibrations import get_calibration_protocols |
| 10 | +from pioreactor.calibrations.utils import curve_to_callable |
| 11 | +from pioreactor.calibrations.utils import curve_to_functional_form |
7 | 12 | from pioreactor.estimators import ESTIMATOR_PATH |
8 | 13 | from pioreactor.estimators import list_estimator_devices |
9 | 14 | from pioreactor.estimators import list_of_estimators_by_device |
10 | 15 | from pioreactor.estimators import load_active_estimator |
11 | 16 | from pioreactor.estimators import load_estimator |
| 17 | +from pioreactor.utils.akimas import akima_eval |
| 18 | +from pioreactor.utils.akimas import akima_fit |
12 | 19 |
|
13 | 20 |
|
14 | 21 | def green(string: str) -> str: |
@@ -142,3 +149,137 @@ def delete_estimator(device: str, estimator_name: str) -> None: |
142 | 149 | target_file.unlink() |
143 | 150 |
|
144 | 151 | click.echo(f"Deleted estimator '{estimator_name}' of device '{device}'.") |
| 152 | + |
| 153 | + |
| 154 | +def _extract_fusion_by_angle_records( |
| 155 | + estimator: structs.ODFusionEstimator, |
| 156 | +) -> dict[pt.PdAngle, dict[str, list[float]]]: |
| 157 | + recorded_data = estimator.recorded_data |
| 158 | + if isinstance(recorded_data, dict) and "by_angle" in recorded_data: |
| 159 | + by_angle = recorded_data.get("by_angle") |
| 160 | + if isinstance(by_angle, dict): |
| 161 | + return {angle: value for angle, value in by_angle.items() if isinstance(value, dict)} |
| 162 | + |
| 163 | + if isinstance(recorded_data, dict) and "base_recorded_data" in recorded_data: |
| 164 | + base_recorded_data = recorded_data.get("base_recorded_data") |
| 165 | + if isinstance(base_recorded_data, dict) and "by_angle" in base_recorded_data: |
| 166 | + by_angle = base_recorded_data.get("by_angle") |
| 167 | + if isinstance(by_angle, dict): |
| 168 | + return {angle: value for angle, value in by_angle.items() if isinstance(value, dict)} |
| 169 | + |
| 170 | + return {} |
| 171 | + |
| 172 | + |
| 173 | +def _fit_curve_data_from_points( |
| 174 | + *, |
| 175 | + fit: str, |
| 176 | + x: list[float], |
| 177 | + y: list[float], |
| 178 | +) -> structs.AkimaFitData: |
| 179 | + if len(x) < 2 or len(y) < 2: |
| 180 | + raise ValueError("Need at least two points to fit a curve.") |
| 181 | + |
| 182 | + if fit == "akima": |
| 183 | + return akima_fit(x, y) |
| 184 | + raise ValueError(f"Unsupported fit type: {fit}") |
| 185 | + |
| 186 | + |
| 187 | +def _rmse_for_fit(curve_data: structs.CalibrationCurveData, x: list[float], y: list[float]) -> float: |
| 188 | + curve_callable = curve_to_callable(curve_data) |
| 189 | + residuals = [(curve_callable(x_val) - y_val) ** 2 for x_val, y_val in zip(x, y)] |
| 190 | + if not residuals: |
| 191 | + return 0.0 |
| 192 | + return sqrt(sum(residuals) / len(residuals)) |
| 193 | + |
| 194 | + |
| 195 | +@estimators.command(name="analyze") |
| 196 | +@click.option("--device", required=True, help="Which estimator device to analyze.") |
| 197 | +@click.option("--name", "estimator_name", required=True, help="Which estimator name to analyze.") |
| 198 | +@click.option( |
| 199 | + "--fit", |
| 200 | + "fit", |
| 201 | + default="akima", |
| 202 | + type=click.Choice(["poly", "spline", "akima"]), |
| 203 | + show_default=True, |
| 204 | + help="Curve fit type to use when analyzing.", |
| 205 | +) |
| 206 | +def analyze_estimator(device: str, estimator_name: str, fit: str) -> None: |
| 207 | + """ |
| 208 | + Analyze an estimator file from local storage. |
| 209 | + """ |
| 210 | + target_file = ESTIMATOR_PATH / device / f"{estimator_name}.yaml" |
| 211 | + if not target_file.exists(): |
| 212 | + click.echo(f"No such estimator file: {target_file}", err=True) |
| 213 | + raise SystemExit(1) |
| 214 | + |
| 215 | + try: |
| 216 | + estimator = load_estimator(device, estimator_name) |
| 217 | + except Exception as exc: |
| 218 | + click.echo(f"Unable to load estimator: {exc}", err=True) |
| 219 | + raise SystemExit(1) from exc |
| 220 | + |
| 221 | + if not isinstance(estimator, structs.ODFusionEstimator): |
| 222 | + click.echo("Only od_fused estimators are supported for analyze.", err=True) |
| 223 | + raise SystemExit(1) |
| 224 | + |
| 225 | + if fit != "akima": |
| 226 | + click.echo("Only akima fits are supported for od_fused estimators.", err=True) |
| 227 | + raise SystemExit(1) |
| 228 | + |
| 229 | + by_angle = _extract_fusion_by_angle_records(estimator) |
| 230 | + if not by_angle: |
| 231 | + click.echo("No recorded fusion data available to analyze.", err=True) |
| 232 | + raise SystemExit(1) |
| 233 | + |
| 234 | + click.echo(f"Estimator: {estimator.estimator_name}") |
| 235 | + click.echo(f"Device: {device}") |
| 236 | + click.echo(f"Fit: {fit}") |
| 237 | + click.echo("") |
| 238 | + |
| 239 | + new_estimator = copy(estimator) |
| 240 | + mu_splines: dict[pt.PdAngle, structs.AkimaFitData] = {} |
| 241 | + sigma_splines_log: dict[pt.PdAngle, structs.AkimaFitData] = {} |
| 242 | + |
| 243 | + for angle in estimator.angles: |
| 244 | + points = by_angle.get(angle) |
| 245 | + if not points: |
| 246 | + click.echo(f"{angle}°: no recorded data found.", err=True) |
| 247 | + continue |
| 248 | + |
| 249 | + x_vals = points.get("x") |
| 250 | + y_vals = points.get("y") |
| 251 | + if not isinstance(x_vals, list) or not isinstance(y_vals, list): |
| 252 | + click.echo(f"{angle}°: recorded data malformed.", err=True) |
| 253 | + continue |
| 254 | + |
| 255 | + try: |
| 256 | + mu_curve = _fit_curve_data_from_points(fit=fit, x=x_vals, y=y_vals) |
| 257 | + mu_rmse = _rmse_for_fit(mu_curve, x_vals, y_vals) |
| 258 | + except Exception as exc: |
| 259 | + click.echo(f"{angle}°: unable to fit mu curve: {exc}", err=True) |
| 260 | + raise SystemExit(1) from exc |
| 261 | + |
| 262 | + sigma_reference = [akima_eval(estimator.sigma_splines_log[angle], float(x_val)) for x_val in x_vals] |
| 263 | + try: |
| 264 | + sigma_curve = _fit_curve_data_from_points(fit=fit, x=x_vals, y=sigma_reference) |
| 265 | + sigma_rmse = _rmse_for_fit(sigma_curve, x_vals, sigma_reference) |
| 266 | + except Exception as exc: |
| 267 | + click.echo(f"{angle}°: unable to fit sigma curve: {exc}", err=True) |
| 268 | + raise SystemExit(1) from exc |
| 269 | + |
| 270 | + click.echo(f"{angle}° mu: {curve_to_functional_form(mu_curve)}") |
| 271 | + click.echo(f"{angle}° mu rmse: {mu_rmse:0.4f}") |
| 272 | + click.echo(f"{angle}° sigma(log): {curve_to_functional_form(sigma_curve)}") |
| 273 | + click.echo(f"{angle}° sigma(log) rmse: {sigma_rmse:0.4f}") |
| 274 | + click.echo("") |
| 275 | + |
| 276 | + mu_splines[angle] = mu_curve |
| 277 | + sigma_splines_log[angle] = sigma_curve |
| 278 | + |
| 279 | + confirm = click.confirm(green("Save updated estimator fit?"), default=False) |
| 280 | + if not confirm: |
| 281 | + raise SystemExit(0) |
| 282 | + |
| 283 | + new_estimator.mu_splines = mu_splines |
| 284 | + new_estimator.sigma_splines_log = sigma_splines_log |
| 285 | + new_estimator.save_to_disk_for_device(device) |
0 commit comments