Skip to content

Commit 2d203d7

Browse files
generalize this
1 parent 80acc20 commit 2d203d7

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

core/pioreactor/estimators/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def _estimator_path_for(device: str, name: str) -> Path:
2121
return ESTIMATOR_PATH / device / f"{name}.yaml"
2222

2323

24-
def load_active_estimator(device: Device) -> structs.ODFusionEstimator | None:
24+
def load_active_estimator(device: Device) -> structs.AnyEstimator | None:
2525
with local_persistent_storage("active_estimators") as storage:
2626
active_name = storage.get(device)
2727

@@ -30,15 +30,15 @@ def load_active_estimator(device: Device) -> structs.ODFusionEstimator | None:
3030
return load_estimator(device, active_name)
3131

3232

33-
def load_estimator(device: Device, estimator_name: str) -> structs.ODFusionEstimator:
33+
def load_estimator(device: Device, estimator_name: str) -> structs.AnyEstimator:
3434
target_file = _estimator_path_for(device, estimator_name)
3535
if not target_file.is_file():
3636
raise FileNotFoundError(f"Estimator {estimator_name} was not found in {ESTIMATOR_PATH / device}")
3737
if target_file.stat().st_size == 0:
3838
raise FileNotFoundError(f"Estimator {estimator_name} is empty")
3939

4040
try:
41-
return yaml_decode(target_file.read_bytes(), type=structs.ODFusionEstimator)
41+
return yaml_decode(target_file.read_bytes(), type=structs.subclass_union(structs.EstimatorBase))
4242
except ValidationError as exc:
4343
raise ValidationError(f"Error reading {target_file.stem}: {exc}") from exc
4444

core/pioreactor/web/tasks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,14 +551,15 @@ def calibration_save_calibration(device: str, calibration_payload: dict[str, obj
551551
def estimator_save_estimator(device: str, estimator_payload: dict[str, object]) -> dict[str, str]:
552552
from msgspec.json import decode as json_decode
553553
from msgspec.json import encode as json_encode
554-
from pioreactor.structs import ODFusionEstimator
554+
from pioreactor.structs import EstimatorBase
555+
from pioreactor.structs import subclass_union
555556

556557
logger.debug(
557558
"Starting estimator save: device=%s payload_keys=%s",
558559
device,
559560
sorted(estimator_payload.keys()),
560561
)
561-
estimator = json_decode(json_encode(estimator_payload), type=ODFusionEstimator)
562+
estimator = json_decode(json_encode(estimator_payload), type=subclass_union(EstimatorBase))
562563
path = estimator.save_to_disk_for_device(device)
563564
estimator.set_as_active_calibration_for_device(device)
564565
logger.debug(

0 commit comments

Comments
 (0)