Skip to content

Commit 5bb52cd

Browse files
wip
1 parent cf468fd commit 5bb52cd

File tree

6 files changed

+58
-17
lines changed

6 files changed

+58
-17
lines changed

core/pioreactor/background_jobs/growth_rate_calculating.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from threading import Thread
4242
from time import sleep
4343
from typing import Generator
44+
from typing import Iterator
4445

4546
import click
4647
from msgspec.json import decode as loads
@@ -55,6 +56,7 @@
5556
from pioreactor.config import config
5657
from pioreactor.utils import local_persistent_storage
5758
from pioreactor.utils.streaming import DosingObservationSource
59+
from pioreactor.utils.streaming import IteratorBackedStream
5860
from pioreactor.utils.streaming import merge_historical_streams
5961
from pioreactor.utils.streaming import merge_live_streams
6062
from pioreactor.utils.streaming import MqttDosingSource
@@ -108,13 +110,13 @@ def __init__(
108110
self._recent_dilution = False
109111

110112
def initialize_extended_kalman_filter(
111-
self, od_std: float, rate_std: float, obs_std: float, od_stream: ODObservationSource
113+
self, od_std: float, rate_std: float, obs_std: float, od_iter: Iterator[structs.ODReadings]
112114
) -> CultureGrowthEKF:
113115
import numpy as np
114116

115117
self.logger.debug(f"{od_std=}, {rate_std=}, {obs_std=}")
116118

117-
initial_nOD, initial_growth_rate = self.get_initial_values(od_stream)
119+
initial_nOD, initial_growth_rate = self.get_initial_values(od_iter)
118120

119121
initial_state = np.array([initial_nOD, initial_growth_rate])
120122
self.logger.debug(f"Initial state: {repr(initial_state)}")
@@ -240,13 +242,13 @@ def _compute_and_cache_od_statistics(
240242

241243
return means, variances
242244

243-
def get_initial_values(self, od_stream: ODObservationSource) -> tuple[float, float]:
245+
def get_initial_values(self, od_iter: Iterator[structs.ODReadings]) -> tuple[float, float]:
244246
if self.ignore_cache:
245247
initial_growth_rate = 0.0
246-
initial_nOD = self.get_filtered_od_from_stream(od_stream)
248+
initial_nOD = self.get_filtered_od_from_iterator(od_iter)
247249
else:
248250
initial_growth_rate = self.get_growth_rate_from_cache()
249-
initial_nOD = self.get_filtered_od_from_cache_or_stream(od_stream)
251+
initial_nOD = self.get_filtered_od_from_cache_or_iterator(od_iter)
250252
return (initial_nOD, initial_growth_rate)
251253

252254
def get_precomputed_values(
@@ -304,16 +306,16 @@ def get_growth_rate_from_cache(self) -> float:
304306
with local_persistent_storage("growth_rate") as cache:
305307
return cache.get(self.cache_key, 0.0)
306308

307-
def get_filtered_od_from_cache_or_stream(self, od_stream: ODObservationSource) -> float:
309+
def get_filtered_od_from_cache_or_iterator(self, od_iter: Iterator[structs.ODReadings]) -> float:
308310
with local_persistent_storage("od_filtered") as cache:
309311
value = cache.get(self.cache_key)
310312
if value:
311313
return value
312314
else:
313-
return self.get_filtered_od_from_stream(od_stream)
315+
return self.get_filtered_od_from_iterator(od_iter)
314316

315-
def get_filtered_od_from_stream(self, od_stream: ODObservationSource) -> float:
316-
scaled_od_readings = self.scale_raw_observations(next(iter(od_stream)))
317+
def get_filtered_od_from_iterator(self, od_iter: Iterator[structs.ODReadings]) -> float:
318+
scaled_od_readings = self.scale_raw_observations(next(od_iter))
317319
return mean(scaled_od_readings[channel] for channel in scaled_od_readings.keys())
318320

319321
def get_od_normalization_from_cache(self) -> dict[pt.PdChannel, float]:
@@ -454,18 +456,20 @@ def process_until_disconnected_or_exhausted(
454456
self.logger.debug(f"od_blank={dict(self.od_blank)}")
455457

456458
# create kalman filter
459+
od_iter = iter(od_stream)
457460
self.ekf = self.initialize_extended_kalman_filter(
458461
od_std=config.getfloat("growth_rate_kalman", "od_std"),
459462
rate_std=config.getfloat("growth_rate_kalman", "rate_std"),
460463
obs_std=config.getfloat("growth_rate_kalman", "obs_std"),
461-
od_stream=od_stream,
464+
od_iter=od_iter,
462465
)
463466

464467
# how should we merge streams?
465468
if od_stream.is_live and dosing_stream.is_live:
466-
merged_streams = merge_live_streams(od_stream, dosing_stream, stop_event=self.stopping_event)
469+
od_iter_stream = IteratorBackedStream(od_iter, od_stream.set_stop_event)
470+
merged_streams = merge_live_streams(od_iter_stream, dosing_stream, stop_event=self.stopping_event)
467471
elif not od_stream.is_live and not dosing_stream.is_live:
468-
merged_streams = merge_historical_streams(od_stream, dosing_stream, key=lambda t: t.timestamp)
472+
merged_streams = merge_historical_streams(od_iter, dosing_stream, key=lambda t: t.timestamp)
469473
else:
470474
raise ValueError("Both streams must be live or both must be historical.")
471475

core/pioreactor/background_jobs/od_reading.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -980,12 +980,11 @@ def __call__(self, raw_od_readings: structs.ODReadings) -> structs.ODFused | Non
980980
od_fused_value = compute_fused_od(self.estimator, fused_inputs)
981981
if self._should_warn_about_bounds(od_fused_value):
982982
self.logger.warning(
983-
"Fused OD estimate hit estimator bounds: estimator=%s min_logc=%s max_logc=%s od_fused=%s ref_normalization=%s",
983+
"Fused OD estimate hit estimator bounds: estimator=%s min_logc=%s max_logc=%s od_fused=%s",
984984
self.estimator.estimator_name,
985985
self.estimator.min_logc,
986986
self.estimator.max_logc,
987987
od_fused_value,
988-
config.get("od_reading.config", "ref_normalization", fallback="classic"),
989988
)
990989
except Exception as e:
991990
self.logger.debug(f"Failed to compute fused OD: {e}", exc_info=True)

core/pioreactor/calibrations/protocols/stirring_dc_based.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ def start_dc_based_session(
168168
min_dc: float | None = None,
169169
max_dc: float | None = None,
170170
) -> CalibrationSession:
171+
if any(is_pio_job_running(["stirring"])):
172+
raise ValueError("Stirring must be off before starting.")
171173
session_id = str(uuid.uuid4())
172174
now = utc_iso_timestamp()
173175
return CalibrationSession(

core/pioreactor/pubsub.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,12 @@ def loop_stop(self) -> MQTTErrorCode:
4747
return MQTTErrorCode.MQTT_ERR_INVAL
4848
self._thread_terminate = True
4949
# Wake the network loop (select) so it can observe _thread_terminate promptly.
50-
# Paho uses a sockpair to interrupt the loop; writing a byte is enough and avoids closing fds.
51-
self._reset_sockets(sockpair_only=True)
50+
# Avoid closing sockpair fds while the loop thread might be blocked on recv.
51+
if self._sockpairW is not None:
52+
try:
53+
self._sockpairW.send(b"x")
54+
except OSError:
55+
pass
5256

5357
if threading.current_thread() != thread:
5458
thread.join()

core/pioreactor/utils/streaming.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,20 @@ def __iter__(self):
256256
S = ODObservationSource | DosingObservationSource
257257

258258

259+
class IteratorBackedStream(ODObservationSource):
260+
is_live = True
261+
262+
def __init__(self, iterator: Iterator[ODReadings], stop_event_setter: Callable[[Event], None]) -> None:
263+
self._iterator = iterator
264+
self._stop_event_setter = stop_event_setter
265+
266+
def set_stop_event(self, ev: Event) -> None:
267+
self._stop_event_setter(ev)
268+
269+
def __iter__(self) -> Iterator[ODReadings]:
270+
return self._iterator
271+
272+
259273
def merge_live_streams(
260274
*iterables: S,
261275
stop_event: Event | None = None,

core/tests/test_growth_rate_calculating.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,25 @@ def test_restart(self) -> None:
241241
timestamp="2010-01-01T12:00:35.000Z",
242242
),
243243
)
244-
assert wait_for(lambda: float(calc1.processor.ekf.state_[-1]) != 0, timeout=3.0)
244+
publish(
245+
f"pioreactor/{unit}/{experiment}/od_reading/ods",
246+
create_encoded_od_raw_batched(
247+
["1", "2"],
248+
[1.155, 0.935],
249+
["90", "135"],
250+
timestamp="2010-01-01T12:00:35.000Z",
251+
),
252+
)
253+
publish(
254+
f"pioreactor/{unit}/{experiment}/od_reading/ods",
255+
create_encoded_od_raw_batched(
256+
["1", "2"],
257+
[1.156, 0.936],
258+
["90", "135"],
259+
timestamp="2010-01-01T12:00:35.000Z",
260+
),
261+
)
262+
assert wait_for(lambda: float(calc1.processor.ekf.state_[-1]) != 0, timeout=10.0)
245263

246264
with GrowthRateCalculator(unit=unit, experiment=experiment) as calc2:
247265
od_stream, dosing_stream = create_od_stream_from_mqtt(

0 commit comments

Comments
 (0)