Skip to content

Commit d7733c1

Browse files
first spin at new ekf algorithm to handle dilutions
1 parent e18e741 commit d7733c1

File tree

11 files changed

+239
-199
lines changed

11 files changed

+239
-199
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
### Upcoming
2+
3+
- adding dirs to exported data zips
4+
15
### 25.5.1
26

37
#### Enhancements

config.dev.ini

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,9 @@ ws_protocol=ws
149149
use_tls=0
150150

151151
[growth_rate_kalman]
152-
acc_std=0.0008
153152
obs_std=1.5
154-
od_std=0.005
155-
rate_std=0.1
153+
od_std=0.0025
154+
rate_std=0.25
156155

157156

158157
[dosing_automation.config]
@@ -165,7 +164,6 @@ max_subdose=1.0
165164
[growth_rate_calculating.config]
166165
# these next two parameters control the length and magnitude
167166
# of the variance shift that our Kalman filter performs after a dosing event
168-
ekf_variance_shift_post_dosing_minutes=0.40
169167
ekf_variance_shift_post_dosing_factor=2500
170168
ekf_outlier_std_threshold=3.0
171169
samples_for_od_statistics=35

pioreactor/actions/leader/export_experiment_data.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,10 @@ def export_experiment_data(
261261
if count == 0:
262262
logger.warning(f"No data present in {dataset_name}. Check database?")
263263

264+
zf.mkdir(dataset_name)
264265
for filename in filenames:
265-
path_to_file = Path(Path(output).parent / filename)
266-
zf.write(path_to_file, arcname=filename)
266+
path_to_file = Path(output, filename)
267+
zf.write(path_to_file, arcname=f"{dataset_name}/{filename}")
267268
Path(path_to_file).unlink()
268269

269270
logger.info("Finished export.")

pioreactor/background_jobs/growth_rate_calculating.py

Lines changed: 25 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,14 @@ def __init__(
9898
self.source_obs_from_mqtt = source_obs_from_mqtt
9999
self.ignore_cache = ignore_cache
100100
self.time_of_previous_observation: datetime | None = None
101-
self.expected_dt = 1 / (60 * 60 * config.getfloat("od_reading.config", "samples_per_second"))
101+
self.expected_dt = 1 / (
102+
60 * 60 * config.getfloat("od_reading.config", "samples_per_second")
103+
) # in hours
104+
105+
# ekf parameters for when a dosing event occurs
106+
self._obs_since_last_dose: int | None = None
107+
self._obs_required_to_reset: int | None = None
108+
self._recent_dilution = False
102109

103110
def on_ready(self) -> None:
104111
# this is here since the below is long running, and if kept in the init(), there is a large window where
@@ -133,7 +140,6 @@ def on_ready(self) -> None:
133140
self.logger.debug(f"od_normalization_mean={self.od_normalization_factors}")
134141
self.logger.debug(f"od_normalization_variance={self.od_variances}")
135142
self.ekf = self.initialize_extended_kalman_filter(
136-
acc_std=config.getfloat("growth_rate_kalman", "acc_std"),
137143
od_std=config.getfloat("growth_rate_kalman", "od_std"),
138144
rate_std=config.getfloat("growth_rate_kalman", "rate_std"),
139145
obs_std=config.getfloat("growth_rate_kalman", "obs_std"),
@@ -143,32 +149,29 @@ def on_ready(self) -> None:
143149
self.start_passive_listeners()
144150

145151
def initialize_extended_kalman_filter(
146-
self, acc_std: float, od_std: float, rate_std: float, obs_std: float
152+
self, od_std: float, rate_std: float, obs_std: float
147153
) -> CultureGrowthEKF:
148154
import numpy as np
149155

150156
initial_state = np.array(
151157
[
152158
self.initial_nOD,
153159
self.initial_growth_rate,
154-
self.initial_acc,
155160
]
156161
)
157162
self.logger.debug(f"Initial state: {repr(initial_state)}")
158163

159164
initial_covariance = 1e-4 * np.eye(
160-
3
165+
2
161166
) # empirically selected - TODO: this should probably scale with `expected_dt`
162167
self.logger.debug(f"Initial covariance matrix:\n{repr(initial_covariance)}")
163168

164-
acc_process_variance = (acc_std * self.expected_dt) ** 2
165169
od_process_variance = (od_std * self.expected_dt) ** 2
166170
rate_process_variance = (rate_std * self.expected_dt) ** 2
167171

168-
process_noise_covariance = np.zeros((3, 3))
172+
process_noise_covariance = np.zeros((2, 2))
169173
process_noise_covariance[0, 0] = od_process_variance
170174
process_noise_covariance[1, 1] = rate_process_variance
171-
process_noise_covariance[2, 2] = acc_process_variance
172175
self.logger.debug(f"Process noise covariance matrix:\n{repr(process_noise_covariance)}")
173176

174177
observation_noise_covariance = self.create_obs_noise_covariance(obs_std)
@@ -371,21 +374,6 @@ def get_od_variances_from_cache(self) -> dict[pt.PdChannel, float]:
371374

372375
return variances
373376

374-
def update_ekf_variance_after_event(self, minutes: float, factor: float) -> None:
375-
if whoami.is_testing_env():
376-
# TODO: replace with jobmanager
377-
msg = subscribe( # needs to be pubsub.subscribe (ie not sub_client.subscribe) since this is called in a callback
378-
f"pioreactor/{self.unit}/{self.experiment}/od_reading/interval",
379-
timeout=1.0,
380-
)
381-
if msg:
382-
interval = float(msg.payload)
383-
else:
384-
interval = 5
385-
self.ekf.scale_OD_variance_for_next_n_seconds(factor, minutes * (12 * interval))
386-
else:
387-
self.ekf.scale_OD_variance_for_next_n_seconds(factor, minutes * 60)
388-
389377
def scale_raw_observations(self, observations: dict[pt.PdChannel, float]) -> dict[pt.PdChannel, float]:
390378
def _scale_and_shift(obs, shift, scale) -> float:
391379
return (obs - shift) / (scale - shift)
@@ -474,9 +462,19 @@ def _update_state_from_observation(
474462

475463
self.time_of_previous_observation = timestamp
476464

477-
updated_state_, covariance_ = self.ekf.update(list(scaled_observations.values()), dt)
465+
updated_state_, covariance_ = self.ekf.update(
466+
list(scaled_observations.values()), dt, self._recent_dilution
467+
)
478468
latest_od_filtered, latest_growth_rate = float(updated_state_[0]), float(updated_state_[1])
479469

470+
if self._obs_since_last_dose is not None and self._obs_required_to_reset is not None:
471+
self._obs_since_last_dose += 1
472+
473+
if self._obs_since_last_dose >= self._obs_required_to_reset:
474+
self._obs_since_last_dose = None
475+
self._obs_required_to_reset = None
476+
self._recent_dilution = False
477+
480478
growth_rate = structs.GrowthRate(
481479
growth_rate=latest_growth_rate,
482480
timestamp=timestamp,
@@ -499,21 +497,9 @@ def respond_to_dosing_event_from_mqtt(self, message: pt.MQTTMessage) -> None:
499497
return self.respond_to_dosing_event(dosing_event)
500498

501499
def respond_to_dosing_event(self, dosing_event: structs.DosingEvent) -> None:
502-
# here we can add custom logic to handle dosing events.
503-
# an improvement to this: the variance factor is proportional to the amount exchanged.
504-
if dosing_event.event != "remove_waste":
505-
self.update_ekf_variance_after_event(
506-
minutes=config.getfloat(
507-
"growth_rate_calculating.config",
508-
"ekf_variance_shift_post_dosing_minutes",
509-
fallback=0.40,
510-
),
511-
factor=config.getfloat(
512-
"growth_rate_calculating.config",
513-
"ekf_variance_shift_post_dosing_factor",
514-
fallback=2500,
515-
),
516-
)
500+
self._obs_since_last_dose = 0
501+
self._obs_required_to_reset = 1
502+
self._recent_dilution = True
517503

518504
def start_passive_listeners(self) -> None:
519505
# process incoming data

pioreactor/background_jobs/od_reading.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -926,13 +926,13 @@ def _determine_best_ir_led_intensity(
926926

927927
_, REF_on_signal = on_reading.popitem()
928928

929-
ir_intensity_argmax_REF_can_be = initial_ir_intensity / REF_on_signal.reading * 0.240
929+
ir_intensity_argmax_REF_can_be = initial_ir_intensity / REF_on_signal.reading * 0.250
930930

931931
ir_intensity_argmax_ANGLE_can_be = (
932932
initial_ir_intensity / culture_on_signal.reading * 3.0
933933
) / 50 # divide by N since the culture is unlikely to Nx.
934934

935-
ir_intensity_max = 80.0
935+
ir_intensity_max = 85.0
936936

937937
return round(
938938
max(

pioreactor/pubsub.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ def conform_and_validate_api_endpoint(endpoint: str) -> str:
359359

360360

361361
def create_webserver_path(address: str, endpoint: str) -> str:
362+
# pioreactor cluster specific (note the use of protocol and ports from our config!)
362363
# Most commonly, address can be an mdns name (test.local), or an IP address.
363364
port = config.getint("ui", "port", fallback=80)
364365
proto = config.get("ui", "proto", fallback="http")
@@ -367,6 +368,7 @@ def create_webserver_path(address: str, endpoint: str) -> str:
367368

368369

369370
def get_from(address: str, endpoint: str, **kwargs) -> mureq.Response:
371+
# pioreactor cluster specific
370372
return mureq.get(create_webserver_path(address, endpoint), **kwargs)
371373

372374

@@ -377,6 +379,7 @@ def get_from_leader(endpoint: str, **kwargs) -> mureq.Response:
377379
def put_into(
378380
address: str, endpoint: str, body: bytes | None = None, json: dict | Struct | None = None, **kwargs
379381
) -> mureq.Response:
382+
# pioreactor cluster specific
380383
return mureq.put(create_webserver_path(address, endpoint), body=body, json=json, **kwargs)
381384

382385

@@ -389,6 +392,7 @@ def put_into_leader(
389392
def patch_into(
390393
address: str, endpoint: str, body: bytes | None = None, json: dict | Struct | None = None, **kwargs
391394
) -> mureq.Response:
395+
# pioreactor cluster specific
392396
return mureq.patch(create_webserver_path(address, endpoint), body=body, json=json, **kwargs)
393397

394398

@@ -401,6 +405,7 @@ def patch_into_leader(
401405
def post_into(
402406
address: str, endpoint: str, body: bytes | None = None, json: dict | Struct | None = None, **kwargs
403407
) -> mureq.Response:
408+
# pioreactor cluster specific
404409
return mureq.post(create_webserver_path(address, endpoint), body=body, json=json, **kwargs)
405410

406411

@@ -411,6 +416,7 @@ def post_into_leader(
411416

412417

413418
def delete_from(address: str, endpoint: str, **kwargs) -> mureq.Response:
419+
# pioreactor cluster specific
414420
return mureq.delete(create_webserver_path(address, endpoint), **kwargs)
415421

416422

pioreactor/tests/conftest.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111

1212
from pioreactor.mureq import Response
1313
from pioreactor.pubsub import publish
14-
from pioreactor.structs import ODReadings
15-
from pioreactor.structs import RawODReading
16-
from pioreactor.utils.timing import to_datetime
1714

1815

1916
@pytest.fixture(autouse=True)
@@ -123,30 +120,3 @@ def mock_request(method, url, **kwargs):
123120
# Patch the mureq.request method
124121
with patch("pioreactor.mureq.request", side_effect=mock_request):
125122
yield bucket
126-
127-
128-
class StreamODReadingsFromExport:
129-
def __init__(self, filename: str, skip_first_n_rows=0):
130-
self.filename = filename
131-
self.skip_first_n_rows = skip_first_n_rows
132-
133-
def __enter__(self, *args, **kwargs):
134-
import csv
135-
136-
self.file_instance = open(self.filename, "r")
137-
self.csv_reader = csv.DictReader(self.file_instance, quoting=csv.QUOTE_MINIMAL)
138-
return self
139-
140-
def __exit__(self, *args, **kwargs):
141-
self.file_instance.close()
142-
143-
def __iter__(self):
144-
for i, line in enumerate(self.csv_reader):
145-
if i <= self.skip_first_n_rows:
146-
continue
147-
dt = to_datetime(line["timestamp"])
148-
od = RawODReading(
149-
angle=line["angle"], channel=line["channel"], timestamp=dt, od=float(line["od_reading"])
150-
)
151-
ods = ODReadings(timestamp=dt, ods={"2": od})
152-
yield ods

pioreactor/tests/test_growth_rate_calculating.py

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import numpy as np
88
from msgspec.json import encode
9-
from numpy.testing import assert_array_equal
109

1110
from pioreactor import structs
1211
from pioreactor.background_jobs.growth_rate_calculating import GrowthRateCalculator
@@ -16,7 +15,7 @@
1615
from pioreactor.config import temporary_config_changes
1716
from pioreactor.pubsub import collect_all_logs_of_level
1817
from pioreactor.pubsub import publish
19-
from pioreactor.tests.conftest import StreamODReadingsFromExport
18+
from pioreactor.tests.utils import StreamODReadingsFromExport
2019
from pioreactor.utils import local_persistent_storage
2120
from pioreactor.utils.timing import current_utc_timestamp
2221
from pioreactor.utils.timing import default_datetime_for_pioreactor
@@ -357,27 +356,11 @@ def test_shock_from_dosing_works(self) -> None:
357356
)
358357
pause()
359358

360-
previous_covariance_matrix = calc.ekf.covariance_.copy()
361-
362-
# trigger dosing events, which change the "regime"
363-
publish(
364-
f"pioreactor/{unit}/{experiment}/dosing_events",
365-
encode(
366-
structs.DosingEvent(
367-
volume_change=1.0,
368-
event="add_media",
369-
source_of_event="algo",
370-
timestamp=to_datetime("2010-01-01T12:00:48.000Z"),
371-
)
372-
),
373-
)
374-
pause()
375-
376359
publish(
377360
f"pioreactor/{unit}/{experiment}/od_reading/ods",
378361
create_encoded_od_raw_batched(
379362
["1"],
380-
[0.49],
363+
[0.52],
381364
["90"],
382365
timestamp="2010-01-01T12:00:50.000Z",
383366
),
@@ -387,15 +370,13 @@ def test_shock_from_dosing_works(self) -> None:
387370
f"pioreactor/{unit}/{experiment}/od_reading/ods",
388371
create_encoded_od_raw_batched(
389372
["1"],
390-
[0.48],
373+
[0.52],
391374
["90"],
392375
timestamp="2010-01-01T12:00:55.000Z",
393376
),
394377
)
395378
pause()
396379

397-
assert not np.array_equal(previous_covariance_matrix, calc.ekf.covariance_)
398-
399380
publish(
400381
f"pioreactor/{unit}/{experiment}/dosing_events",
401382
encode(
@@ -408,6 +389,8 @@ def test_shock_from_dosing_works(self) -> None:
408389
),
409390
)
410391
pause()
392+
assert calc._recent_dilution
393+
411394
publish(
412395
f"pioreactor/{unit}/{experiment}/od_reading/ods",
413396
create_encoded_od_raw_batched(
@@ -418,20 +401,7 @@ def test_shock_from_dosing_works(self) -> None:
418401
),
419402
)
420403
pause()
421-
422-
time.sleep(8)
423-
assert calc.ekf._currently_scaling_covariance
424-
assert not np.array_equal(previous_covariance_matrix, calc.ekf.covariance_)
425-
426-
time.sleep(10)
427-
pause()
428-
429-
# should revert back
430-
while calc.ekf._currently_scaling_covariance:
431-
pass
432-
433-
assert_array_equal(calc.ekf.covariance_, previous_covariance_matrix)
434-
calc.clean_up()
404+
assert not calc._recent_dilution
435405

436406
def test_end_to_end(self) -> None:
437407
with temporary_config_changes(

0 commit comments

Comments
 (0)