|
41 | 41 | from threading import Thread |
42 | 42 | from time import sleep |
43 | 43 | from typing import Generator |
| 44 | +from typing import Iterator |
44 | 45 |
|
45 | 46 | import click |
46 | 47 | from msgspec.json import decode as loads |
|
55 | 56 | from pioreactor.config import config |
56 | 57 | from pioreactor.utils import local_persistent_storage |
57 | 58 | from pioreactor.utils.streaming import DosingObservationSource |
| 59 | +from pioreactor.utils.streaming import IteratorBackedStream |
58 | 60 | from pioreactor.utils.streaming import merge_historical_streams |
59 | 61 | from pioreactor.utils.streaming import merge_live_streams |
60 | 62 | from pioreactor.utils.streaming import MqttDosingSource |
@@ -108,13 +110,13 @@ def __init__( |
108 | 110 | self._recent_dilution = False |
109 | 111 |
|
110 | 112 | def initialize_extended_kalman_filter( |
111 | | - self, od_std: float, rate_std: float, obs_std: float, od_stream: ODObservationSource |
| 113 | + self, od_std: float, rate_std: float, obs_std: float, od_iter: Iterator[structs.ODReadings] |
112 | 114 | ) -> CultureGrowthEKF: |
113 | 115 | import numpy as np |
114 | 116 |
|
115 | 117 | self.logger.debug(f"{od_std=}, {rate_std=}, {obs_std=}") |
116 | 118 |
|
117 | | - initial_nOD, initial_growth_rate = self.get_initial_values(od_stream) |
| 119 | + initial_nOD, initial_growth_rate = self.get_initial_values(od_iter) |
118 | 120 |
|
119 | 121 | initial_state = np.array([initial_nOD, initial_growth_rate]) |
120 | 122 | self.logger.debug(f"Initial state: {repr(initial_state)}") |
@@ -240,13 +242,13 @@ def _compute_and_cache_od_statistics( |
240 | 242 |
|
241 | 243 | return means, variances |
242 | 244 |
|
243 | | - def get_initial_values(self, od_stream: ODObservationSource) -> tuple[float, float]: |
| 245 | + def get_initial_values(self, od_iter: Iterator[structs.ODReadings]) -> tuple[float, float]: |
244 | 246 | if self.ignore_cache: |
245 | 247 | initial_growth_rate = 0.0 |
246 | | - initial_nOD = self.get_filtered_od_from_stream(od_stream) |
| 248 | + initial_nOD = self.get_filtered_od_from_iterator(od_iter) |
247 | 249 | else: |
248 | 250 | initial_growth_rate = self.get_growth_rate_from_cache() |
249 | | - initial_nOD = self.get_filtered_od_from_cache_or_stream(od_stream) |
| 251 | + initial_nOD = self.get_filtered_od_from_cache_or_iterator(od_iter) |
250 | 252 | return (initial_nOD, initial_growth_rate) |
251 | 253 |
|
252 | 254 | def get_precomputed_values( |
@@ -304,16 +306,16 @@ def get_growth_rate_from_cache(self) -> float: |
304 | 306 | with local_persistent_storage("growth_rate") as cache: |
305 | 307 | return cache.get(self.cache_key, 0.0) |
306 | 308 |
|
307 | | - def get_filtered_od_from_cache_or_stream(self, od_stream: ODObservationSource) -> float: |
| 309 | + def get_filtered_od_from_cache_or_iterator(self, od_iter: Iterator[structs.ODReadings]) -> float: |
308 | 310 | with local_persistent_storage("od_filtered") as cache: |
309 | 311 | value = cache.get(self.cache_key) |
310 | 312 | if value: |
311 | 313 | return value |
312 | 314 | else: |
313 | | - return self.get_filtered_od_from_stream(od_stream) |
| 315 | + return self.get_filtered_od_from_iterator(od_iter) |
314 | 316 |
|
315 | | - def get_filtered_od_from_stream(self, od_stream: ODObservationSource) -> float: |
316 | | - scaled_od_readings = self.scale_raw_observations(next(iter(od_stream))) |
| 317 | + def get_filtered_od_from_iterator(self, od_iter: Iterator[structs.ODReadings]) -> float: |
| 318 | + scaled_od_readings = self.scale_raw_observations(next(od_iter)) |
317 | 319 | return mean(scaled_od_readings[channel] for channel in scaled_od_readings.keys()) |
318 | 320 |
|
319 | 321 | def get_od_normalization_from_cache(self) -> dict[pt.PdChannel, float]: |
@@ -454,18 +456,20 @@ def process_until_disconnected_or_exhausted( |
454 | 456 | self.logger.debug(f"od_blank={dict(self.od_blank)}") |
455 | 457 |
|
456 | 458 | # create kalman filter |
| 459 | + od_iter = iter(od_stream) |
457 | 460 | self.ekf = self.initialize_extended_kalman_filter( |
458 | 461 | od_std=config.getfloat("growth_rate_kalman", "od_std"), |
459 | 462 | rate_std=config.getfloat("growth_rate_kalman", "rate_std"), |
460 | 463 | obs_std=config.getfloat("growth_rate_kalman", "obs_std"), |
461 | | - od_stream=od_stream, |
| 464 | + od_iter=od_iter, |
462 | 465 | ) |
463 | 466 |
|
464 | 467 | # how should we merge streams? |
465 | 468 | if od_stream.is_live and dosing_stream.is_live: |
466 | | - merged_streams = merge_live_streams(od_stream, dosing_stream, stop_event=self.stopping_event) |
| 469 | + od_iter_stream = IteratorBackedStream(od_iter, od_stream.set_stop_event) |
| 470 | + merged_streams = merge_live_streams(od_iter_stream, dosing_stream, stop_event=self.stopping_event) |
467 | 471 | elif not od_stream.is_live and not dosing_stream.is_live: |
468 | | - merged_streams = merge_historical_streams(od_stream, dosing_stream, key=lambda t: t.timestamp) |
| 472 | + merged_streams = merge_historical_streams(od_iter, dosing_stream, key=lambda t: t.timestamp) |
469 | 473 | else: |
470 | 474 | raise ValueError("Both streams must be live or both must be historical.") |
471 | 475 |
|
|
0 commit comments