Skip to content

Commit e301f68

Browse files
fix some race conditions in PWM and caches
1 parent 5df9444 commit e301f68

File tree

9 files changed

+179
-14
lines changed

9 files changed

+179
-14
lines changed

core/pioreactor/actions/led_intensity.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
from contextlib import contextmanager
44
from contextlib import nullcontext
5+
from time import time_ns
56
from typing import Any
67
from typing import Iterator
78

@@ -55,15 +56,19 @@ def change_leds_intensities_temporarily(
5556

5657
@contextmanager
5758
def lock_leds_temporarily(channels: list[LedChannel]) -> Iterator[None]:
59+
lock_id = f"{os.getpid()}:{time_ns()}"
60+
acquired_channels: list[LedChannel] = []
5861
try:
5962
with local_intermittent_storage("led_locks") as cache:
6063
for c in channels:
61-
cache[c] = os.getpid()
64+
if cache.set_if_absent(c, lock_id):
65+
acquired_channels.append(c)
6266
yield
6367
finally:
6468
with local_intermittent_storage("led_locks") as cache:
65-
for c in channels:
66-
cache.pop(c)
69+
for c in acquired_channels:
70+
if cache.get(c) == lock_id:
71+
cache.pop(c)
6772

6873

6974
def is_led_channel_locked(channel: LedChannel) -> bool:

core/pioreactor/actions/pump.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,9 +314,15 @@ def _get_pump_action(pump_device: PumpCalibrationDevices) -> str:
314314
)
315315
return 0.0
316316

317-
with PWMPump(
318-
unit, experiment, pin, calibration=calibration, mqtt_client=mqtt_client, logger=logger
319-
) as pump:
317+
try:
318+
pump_instance = PWMPump(
319+
unit, experiment, pin, calibration=calibration, mqtt_client=mqtt_client, logger=logger
320+
)
321+
except exc.PWMError as e:
322+
logger.error(str(e))
323+
return 0.0
324+
325+
with pump_instance as pump:
320326
sub_duration = 0.5
321327
volume_moved_ml = 0.0
322328

core/pioreactor/utils/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,16 @@ def __setitem__(self, key, value):
380380
def set(self, key, value):
381381
return self.__setitem__(key, value)
382382

383+
def set_if_absent(self, key, value) -> bool:
384+
self.cursor.execute(
385+
f"""
386+
INSERT OR IGNORE INTO {self.table_name} (key, value)
387+
VALUES (?, ?)
388+
""",
389+
(key, value),
390+
)
391+
return self.cursor.rowcount == 1
392+
383393
def get(self, key, default=None):
384394
self.cursor.execute(f"SELECT value FROM {self.table_name} WHERE key = ?", (key,))
385395
result = self.cursor.fetchone()

core/pioreactor/utils/pwm.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,10 @@ def dc(self, dc: pt.FloatBetween0and100) -> None:
123123
if self._started:
124124
try:
125125
lgpio.tx_pwm(self._handle, self.pin, self.frequency, self.dc)
126-
except lgpio.error:
127-
pass
126+
except lgpio.error as e:
127+
raise PWMError(
128+
f"Failed to set software PWM on GPIO-{self.pin} to {dc:.5g}% duty cycle at {self.frequency:.5g} Hz."
129+
) from e
128130
elif dc == 0:
129131
pass
130132
else:
@@ -213,6 +215,7 @@ def __init__(
213215
self.pin: GpioPin = pin
214216
self.hz = hz
215217
self.duty_cycle = 0.0
218+
self._lock_id = f"{getpid()}:{id(self)}"
216219

217220
if self.is_locked():
218221
msg = f"GPIO-{self.pin} is currently locked but a task is overwriting it. Either too many jobs are trying to access this pin, or a job didn't clean up properly. If your confident you can release it, use `pio cache clear pwm_locks {self.pin} --as-int` on the command line for {self.unit}."
@@ -318,11 +321,21 @@ def is_locked(self) -> bool:
318321

319322
def lock(self) -> None:
320323
with local_intermittent_storage("pwm_locks") as pwm_locks:
321-
pwm_locks[self.pin] = getpid()
324+
if pwm_locks.set_if_absent(self.pin, self._lock_id):
325+
return
326+
327+
raise PWMError(
328+
f"GPIO-{self.pin} is currently locked but a task is overwriting it. Either too many jobs are trying to access this pin, or a job didn't clean up properly. If your confident you can release it, use `pio cache clear pwm_locks {self.pin} --as-int` on the command line for {self.unit}."
329+
)
322330

323331
def unlock(self) -> None:
324332
with local_intermittent_storage("pwm_locks") as pwm_locks:
325-
pwm_locks.pop(self.pin)
333+
owner_lock_id = pwm_locks.get(self.pin)
334+
if owner_lock_id is None:
335+
return
336+
337+
if owner_lock_id == self._lock_id:
338+
pwm_locks.pop(self.pin)
326339

327340
@contextmanager
328341
def lock_temporarily(self) -> Iterator[None]:

core/pioreactor/web/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,13 @@ def is_rate_limited(job: str, expire_time_seconds=1.0) -> bool:
156156
Check if the user has made a request within the debounce duration.
157157
"""
158158
with local_intermittent_storage("debounce") as cache:
159-
if cache.get(job) and (time() - cache.get(job)) < expire_time_seconds:
160-
return True
161-
else:
162-
cache.set(job, time())
159+
now = time()
160+
if cache.set_if_absent(job, now):
163161
return False
162+
163+
last_request_time = cache.get(job)
164+
if (last_request_time is not None) and ((now - float(last_request_time)) < expire_time_seconds):
165+
return True
166+
167+
cache.set(job, now)
168+
return False

core/tests/test_led_intensity.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,18 @@ def test_is_led_channel_locked_directly() -> None:
149149
assert not is_led_channel_locked("A")
150150

151151

152+
def test_nested_lock_leds_temporarily_does_not_release_outer_lock() -> None:
153+
assert not is_led_channel_locked("A")
154+
155+
with lock_leds_temporarily(["A"]):
156+
assert is_led_channel_locked("A")
157+
with lock_leds_temporarily(["A"]):
158+
assert is_led_channel_locked("A")
159+
assert is_led_channel_locked("A")
160+
161+
assert not is_led_channel_locked("A")
162+
163+
152164
def test_change_leds_intensities_temporarily_invalid_raises_and_state_unchanged() -> None:
153165
unit = get_unit_name()
154166
exp = "test_change_leds_invalid"

core/tests/test_pwms.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
# -*- coding: utf-8 -*-
22
# test_pwms
33
import json
4+
import sys
45
import time
6+
import types
57

68
import pytest
79
from pioreactor import pubsub
10+
from pioreactor.exc import PWMError
811
from pioreactor.utils import local_intermittent_storage
912
from pioreactor.utils.pwm import PWM
13+
from pioreactor.utils.pwm import SoftwarePWMOutputDevice
1014
from pioreactor.whoami import get_unit_name
1115

1216

@@ -98,3 +102,76 @@ def collect(msg) -> None:
98102

99103
assert json.loads(mqtt_items[-1]).get("17", 0.0) == 0.0
100104
assert json.loads(mqtt_items[-1]).get("12", 0.0) == 0.0
105+
106+
107+
def _install_fake_lgpio(monkeypatch: pytest.MonkeyPatch) -> dict[str, float | list[float] | None]:
108+
fake_lgpio = types.ModuleType("lgpio")
109+
110+
class FakeLgpioError(Exception):
111+
pass
112+
113+
state: dict[str, float | list[float] | None] = {"raise_when_dc": None, "tx_calls": []}
114+
115+
def gpiochip_open(_chip: int) -> int:
116+
return 1
117+
118+
def gpio_claim_output(_handle: int, _pin: int) -> None:
119+
return
120+
121+
def tx_pwm(_handle: int, _pin: int, _frequency: float, duty_cycle: float) -> None:
122+
tx_calls = state["tx_calls"]
123+
assert isinstance(tx_calls, list)
124+
tx_calls.append(duty_cycle)
125+
126+
raise_when_dc = state["raise_when_dc"]
127+
if (raise_when_dc is not None) and (duty_cycle == raise_when_dc):
128+
raise FakeLgpioError("simulated lgpio tx_pwm failure")
129+
130+
def gpiochip_close(_handle: int) -> None:
131+
return
132+
133+
fake_lgpio.error = FakeLgpioError
134+
fake_lgpio.gpiochip_open = gpiochip_open
135+
fake_lgpio.gpio_claim_output = gpio_claim_output
136+
fake_lgpio.tx_pwm = tx_pwm
137+
fake_lgpio.gpiochip_close = gpiochip_close
138+
139+
monkeypatch.setitem(sys.modules, "lgpio", fake_lgpio)
140+
return state
141+
142+
143+
def test_software_pwm_dc_errors_raise_pwm_error(monkeypatch: pytest.MonkeyPatch) -> None:
144+
state = _install_fake_lgpio(monkeypatch)
145+
pwm = SoftwarePWMOutputDevice(pin=17, frequency=100)
146+
pwm.start(0)
147+
148+
state["raise_when_dc"] = 25.0
149+
150+
with pytest.raises(PWMError, match="Failed to set software PWM"):
151+
pwm.dc = 25.0
152+
153+
154+
def test_software_pwm_stop_errors_raise_pwm_error(monkeypatch: pytest.MonkeyPatch) -> None:
155+
state = _install_fake_lgpio(monkeypatch)
156+
pwm = SoftwarePWMOutputDevice(pin=17, frequency=100)
157+
pwm.start(20)
158+
159+
state["raise_when_dc"] = 0.0
160+
161+
with pytest.raises(PWMError, match="Failed to set software PWM"):
162+
pwm.off()
163+
164+
165+
def test_lock_is_exclusive_after_creation() -> None:
166+
exp = "test_lock_is_exclusive_after_creation"
167+
unit = get_unit_name()
168+
169+
pwm_first = PWM(12, 10, experiment=exp, unit=unit)
170+
pwm_second = PWM(12, 10, experiment=exp, unit=unit)
171+
172+
pwm_first.lock()
173+
with pytest.raises(PWMError, match="GPIO-12 is currently locked"):
174+
pwm_second.lock()
175+
176+
pwm_first.clean_up()
177+
pwm_second.clean_up()

core/tests/test_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,18 @@ def test_caches_pop() -> None:
9696
assert cache.pop("C", default=3) == 3
9797

9898

99+
def test_cache_set_if_absent() -> None:
100+
with local_intermittent_storage("test") as cache:
101+
cache.empty()
102+
103+
with local_intermittent_storage("test") as cache:
104+
assert cache.set_if_absent("A", "1")
105+
assert not cache.set_if_absent("A", "2")
106+
107+
with local_intermittent_storage("test") as cache:
108+
assert cache["A"] == "1"
109+
110+
99111
def test_caches_can_have_tuple_or_singleton_keys() -> None:
100112
with local_persistent_storage("test_caches_can_have_tuple_keys") as c:
101113
c[(1, 2)] = 1

core/tests/web/test_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# -*- coding: utf-8 -*-
22
import pytest
3+
from pioreactor.utils import local_intermittent_storage
4+
from pioreactor.web import utils as web_utils
5+
from pioreactor.web.utils import is_rate_limited
36
from pioreactor.web.utils import is_valid_unix_filename
47
from pioreactor.web.utils import scrub_to_valid
58

@@ -53,3 +56,25 @@ def test_valid_unix_filenames(name) -> None:
5356
)
5457
def test_invalid_unix_filenames(name) -> None:
5558
assert not is_valid_unix_filename(name)
59+
60+
61+
def test_is_rate_limited_blocks_second_request_within_window() -> None:
62+
job_name = "test_rate_limit_second_blocked"
63+
with local_intermittent_storage("debounce") as cache:
64+
cache.pop(job_name)
65+
66+
assert not is_rate_limited(job_name, expire_time_seconds=10.0)
67+
assert is_rate_limited(job_name, expire_time_seconds=10.0)
68+
69+
70+
def test_is_rate_limited_allows_after_expiry(monkeypatch: pytest.MonkeyPatch) -> None:
71+
job_name = "test_rate_limit_allows_after_expiry"
72+
with local_intermittent_storage("debounce") as cache:
73+
cache.pop(job_name)
74+
75+
timeline = iter([1000.0, 1002.0, 1002.2])
76+
monkeypatch.setattr(web_utils, "time", lambda: next(timeline))
77+
78+
assert not is_rate_limited(job_name, expire_time_seconds=1.0)
79+
assert not is_rate_limited(job_name, expire_time_seconds=1.0)
80+
assert is_rate_limited(job_name, expire_time_seconds=1.0)

0 commit comments

Comments
 (0)