Skip to content

Commit e2d9650

Browse files
committed
MockMigrogrid/Resampler: Fail on unhandled exceptions
The tasks created by the MockMigrogrid/Resampler were not handling exceptions, so if anything goes wrong when streaming (or when setting up the streaming), the task will silently die, which means no values are sent to the channels, so tests will probably hang forever waiting for values. We now make sure that tasks are cleaned up from the lists when they are done and also to exit the tests if there was an unhandled exception, so at least the test will fail and show the origin of the raised exception. Signed-off-by: Leandro Lucarella <[email protected]>
1 parent 077612b commit e2d9650

File tree

2 files changed

+78
-36
lines changed

2 files changed

+78
-36
lines changed

tests/timeseries/mock_microgrid.py

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,18 @@ def inverters(comp_type: InverterType) -> list[int]:
155155
self.evc_component_states: dict[int, EVChargerComponentState] = {}
156156
self.evc_cable_states: dict[int, EVChargerCableState] = {}
157157

158-
self._streaming_coros: list[Coroutine[None, None, None]] = []
159-
self._streaming_tasks: list[asyncio.Task[None]] = []
158+
self._streaming_coros: list[tuple[int, Coroutine[None, None, None]]] = []
159+
"""The streaming coroutines for each component.
160+
161+
The tuple stores the component id we are streaming for as the first item and the
162+
coroutine as the second item.
163+
"""
164+
165+
self._streaming_tasks: dict[int, asyncio.Task[None]] = {}
166+
"""The streaming tasks for each component.
167+
168+
The key is the component id we are streaming for in this task.
169+
"""
160170

161171
if grid_meter:
162172
self._connect_to = self._grid_meter_id
@@ -218,9 +228,16 @@ def start_mock_client(
218228
A MockMicrogridClient instance.
219229
"""
220230
self.init_mock_client(initialize_cb)
221-
self._streaming_tasks = [
222-
asyncio.create_task(coro) for coro in self._streaming_coros
223-
]
231+
232+
def _done_callback(task: asyncio.Task[None]) -> None:
233+
if exc := task.exception():
234+
raise SystemExit(f"Streaming task {task.get_name()!r} failed: {exc}")
235+
236+
for component_id, coro in self._streaming_coros:
237+
task = asyncio.create_task(coro, name=f"component-id:{component_id}")
238+
self._streaming_tasks[component_id] = task
239+
task.add_done_callback(_done_callback)
240+
224241
return self.mock_client
225242

226243
async def _comp_data_send_task(
@@ -243,14 +260,17 @@ def _start_meter_streaming(self, meter_id: int) -> None:
243260
if not self._api_client_streaming:
244261
return
245262
self._streaming_coros.append(
246-
self._comp_data_send_task(
263+
(
247264
meter_id,
248-
lambda value, ts: MeterDataWrapper(
249-
component_id=meter_id,
250-
timestamp=ts,
251-
active_power=value,
252-
current_per_phase=(value + 100.0, value + 101.0, value + 102.0),
253-
voltage_per_phase=(value + 200.0, value + 199.8, value + 200.2),
265+
self._comp_data_send_task(
266+
meter_id,
267+
lambda value, ts: MeterDataWrapper(
268+
component_id=meter_id,
269+
timestamp=ts,
270+
active_power=value,
271+
current_per_phase=(value + 100.0, value + 101.0, value + 102.0),
272+
voltage_per_phase=(value + 200.0, value + 199.8, value + 200.2),
273+
),
254274
),
255275
)
256276
)
@@ -259,10 +279,13 @@ def _start_battery_streaming(self, bat_id: int) -> None:
259279
if not self._api_client_streaming:
260280
return
261281
self._streaming_coros.append(
262-
self._comp_data_send_task(
282+
(
263283
bat_id,
264-
lambda value, ts: BatteryDataWrapper(
265-
component_id=bat_id, timestamp=ts, soc=value
284+
self._comp_data_send_task(
285+
bat_id,
286+
lambda value, ts: BatteryDataWrapper(
287+
component_id=bat_id, timestamp=ts, soc=value
288+
),
266289
),
267290
)
268291
)
@@ -271,10 +294,13 @@ def _start_inverter_streaming(self, inv_id: int) -> None:
271294
if not self._api_client_streaming:
272295
return
273296
self._streaming_coros.append(
274-
self._comp_data_send_task(
297+
(
275298
inv_id,
276-
lambda value, ts: InverterDataWrapper(
277-
component_id=inv_id, timestamp=ts, active_power=value
299+
self._comp_data_send_task(
300+
inv_id,
301+
lambda value, ts: InverterDataWrapper(
302+
component_id=inv_id, timestamp=ts, active_power=value
303+
),
278304
),
279305
)
280306
)
@@ -283,17 +309,20 @@ def _start_ev_charger_streaming(self, evc_id: int) -> None:
283309
if not self._api_client_streaming:
284310
return
285311
self._streaming_coros.append(
286-
self._comp_data_send_task(
312+
(
287313
evc_id,
288-
lambda value, ts: EvChargerDataWrapper(
289-
component_id=evc_id,
290-
timestamp=ts,
291-
active_power=value,
292-
current_per_phase=(value + 10.0, value + 11.0, value + 12.0),
293-
component_state=self.evc_component_states[evc_id],
294-
cable_state=self.evc_cable_states[evc_id],
314+
self._comp_data_send_task(
315+
evc_id,
316+
lambda value, ts: EvChargerDataWrapper(
317+
component_id=evc_id,
318+
timestamp=ts,
319+
active_power=value,
320+
current_per_phase=(value + 10.0, value + 11.0, value + 12.0),
321+
component_state=self.evc_component_states[evc_id],
322+
cable_state=self.evc_cable_states[evc_id],
323+
),
295324
),
296-
),
325+
)
297326
)
298327

299328
def add_consumer_meters(self, count: int = 1) -> None:
@@ -563,10 +592,10 @@ async def cleanup(self) -> None:
563592

564593
await self.mock_resampler._stop()
565594

566-
for coro in self._streaming_coros:
595+
for _, coro in self._streaming_coros:
567596
coro.close()
568597

569-
for task in self._streaming_tasks:
598+
for task in self._streaming_tasks.values():
570599
await cancel_and_await(task)
571600
microgrid.connection_manager._CONNECTION_MANAGER = None
572601
# pylint: enable=protected-access

tests/timeseries/mock_resampler.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,13 @@ def voltage_senders(ids: list[int]) -> list[list[Sender[Sample[Quantity]]]]:
179179
)
180180

181181
self._forward_tasks: dict[str, asyncio.Task[None]] = {}
182-
self._request_handler_task = asyncio.create_task(
183-
self._handle_resampling_requests()
184-
)
182+
task = asyncio.create_task(self._handle_resampling_requests())
183+
task.add_done_callback(self._handle_task_done)
184+
self._request_handler_task = task
185+
186+
def _handle_task_done(self, task: asyncio.Task[None]) -> None:
187+
if exc := task.exception():
188+
raise SystemExit(f"Task {task.get_name()!r} failed: {exc}") from exc
185189

186190
async def _stop(self) -> None:
187191
tasks_to_stop = [
@@ -202,20 +206,29 @@ async def _channel_forward_messages(
202206

203207
async def _handle_resampling_requests(self) -> None:
204208
async for request in self._resampler_request_channel.new_receiver():
205-
if request.get_channel_name() in self._forward_tasks:
209+
name = request.get_channel_name()
210+
if name in self._forward_tasks:
206211
continue
207212
basic_recv_name = f"{request.component_id}:{request.metric_id}"
208213
recv = self._basic_receivers[basic_recv_name].pop()
209214
assert recv is not None
210-
self._forward_tasks[request.get_channel_name()] = asyncio.create_task(
215+
task = asyncio.create_task(
211216
self._channel_forward_messages(
212217
recv,
213218
self._channel_registry.get_or_create(
214-
Sample[Quantity], request.get_channel_name()
219+
Sample[Quantity], name
215220
).new_sender(),
216-
)
221+
),
222+
name=name,
217223
)
218224

225+
def _done_callback(task: asyncio.Task[None]) -> None:
226+
del self._forward_tasks[task.get_name()]
227+
self._handle_task_done(task)
228+
229+
task.add_done_callback(_done_callback)
230+
self._forward_tasks[name] = task
231+
219232
def make_sample(self, value: float | None) -> Sample[Quantity]:
220233
"""Create a sample with the given value."""
221234
return Sample(

0 commit comments

Comments
 (0)