Skip to content

Commit db01e4c

Browse files
committed
Implement async guarding to prevent accidental blocking IO
1 parent 0a0157d commit db01e4c

File tree

13 files changed

+196
-64
lines changed

13 files changed

+196
-64
lines changed

canopen/async_guard.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
""" Utils for async """
2+
3+
import functools
4+
from typing import Optional, Callable
5+
6+
TSentinel = Callable[[], bool]
7+
8+
# NOTE: Global, but needed to be able to use ensure_not_async() in
9+
# decorator context.
10+
_ASYNC_SENTINEL: Optional[TSentinel] = None
11+
12+
13+
def set_async_sentinel(fn: TSentinel):
14+
""" Register a function to validate if async is running """
15+
global _ASYNC_SENTINEL
16+
_ASYNC_SENTINEL = fn
17+
18+
19+
def ensure_not_async(fn):
20+
""" Decorator that will ensure that the function is not called if async
21+
is running.
22+
"""
23+
24+
@functools.wraps(fn)
25+
def async_guard(*args, **kwargs):
26+
global _ASYNC_SENTINEL
27+
if _ASYNC_SENTINEL:
28+
if _ASYNC_SENTINEL():
29+
raise RuntimeError("Calling a blocking function while running async")
30+
return fn(*args, **kwargs)
31+
return async_guard

canopen/emcy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import asyncio
66
import time
77
from typing import Callable, List, Optional, TYPE_CHECKING
8+
from .async_guard import ensure_not_async
89

910
if TYPE_CHECKING:
1011
from .network import Network
@@ -26,6 +27,7 @@ def __init__(self):
2627
self.emcy_received = threading.Condition()
2728
self.aemcy_received = asyncio.Condition()
2829

30+
@ensure_not_async # NOTE: Safeguard for accidental async use
2931
def on_emcy(self, can_id, data, timestamp):
3032
# NOTE: Callback. Called from another thread unless async
3133
code, register, data = EMCY_STRUCT.unpack(data)
@@ -76,6 +78,8 @@ def reset(self):
7678
self.log = []
7779
self.active = []
7880

81+
# FIXME: Make async implementation
82+
@ensure_not_async # NOTE: Safeguard for accidental async use
7983
def wait(
8084
self, emcy_code: Optional[int] = None, timeout: float = 10
8185
) -> "EmcyError":

canopen/lss.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import queue
99
except ImportError:
1010
import Queue as queue
11+
from .async_guard import ensure_not_async
1112

1213
if TYPE_CHECKING:
1314
from .network import Network
@@ -248,6 +249,8 @@ def send_identify_non_configured_remote_slave(self):
248249
message[0] = CS_IDENTIFY_NON_CONFIGURED_REMOTE_SLAVE
249250
self.__send_command(message)
250251

252+
# FIXME: Make async implementation
253+
@ensure_not_async # NOTE: Safeguard for accidental async use
251254
def fast_scan(self):
252255
"""This command sends a series of fastscan message
253256
to find unconfigured slave with lowest number of LSS idenities
@@ -279,7 +282,7 @@ def fast_scan(self):
279282
if not self.__send_fast_scan_message(lss_id[lss_sub], lss_bit_check, lss_sub, lss_next):
280283
return False, None
281284

282-
time.sleep(0.01) # NOTE: Blocking
285+
time.sleep(0.01) # NOTE: Blocking call
283286

284287
# Now the next 32 bits will be scanned
285288
lss_sub += 1
@@ -303,6 +306,8 @@ def __send_fast_scan_message(self, id_number, bit_checker, lss_sub, lss_next):
303306

304307
return False
305308

309+
# FIXME: Make async implementation
310+
@ensure_not_async # NOTE: Safeguard for accidental async use
306311
def __send_lss_address(self, req_cs, number):
307312
message = bytearray(8)
308313

@@ -366,6 +371,8 @@ def __send_configure(self, req_cs, value1=0, value2=0):
366371
error_msg = "LSS Error: %d" % error_code
367372
raise LssError(error_msg)
368373

374+
# FIXME: Make async implementation
375+
@ensure_not_async # NOTE: Safeguard for accidental async use
369376
def __send_command(self, message):
370377
"""Send a LSS operation code to the network
371378
@@ -385,7 +392,7 @@ def __send_command(self, message):
385392
response = None
386393
if not self.responses.empty():
387394
logger.info("There were unexpected messages in the queue")
388-
self.responses = queue.Queue() # FIXME: Recreating the queue. Async too?
395+
self.responses = queue.Queue() # FIXME: Recreating the queue
389396

390397
self.network.send_message(self.LSS_TX_COBID, message)
391398

@@ -402,6 +409,7 @@ def __send_command(self, message):
402409

403410
return response
404411

412+
@ensure_not_async # NOTE: Safeguard for accidental async use
405413
def on_message_received(self, can_id, data, timestamp):
406414
# NOTE: Callback. Called from another thread
407415
self.responses.put(bytes(data)) # NOTE: Blocking call

canopen/network.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .lss import LssMaster
3030
from .objectdictionary.eds import import_from_node
3131
from .objectdictionary import ObjectDictionary
32+
from .async_guard import set_async_sentinel
3233

3334
logger = logging.getLogger(__name__)
3435

@@ -132,6 +133,10 @@ def connect(self, *args, **kwargs) -> "Network":
132133
kwargs_notifier["loop"] = kwargs["loop"]
133134
self.loop = kwargs["loop"]
134135
del kwargs["loop"]
136+
# Register this function as the means to check if canopen is run in
137+
# async mode. This enables the @ensure_not_async() decorator to
138+
# work. See async_guard.py
139+
set_async_sentinel(self.is_async)
135140
self.bus = can.Bus(*args, **kwargs)
136141
logger.info("Connected to '%s'", self.bus.channel_info)
137142
self.notifier = can.Notifier(self.bus, self.listeners, 1, **kwargs_notifier)
@@ -359,8 +364,7 @@ def update(self, data: bytes) -> None:
359364
:param data:
360365
New data to transmit
361366
"""
362-
# NOTE: Called from callback, which is another thread on non-async use.
363-
# Make sure this is thread-safe.
367+
# NOTE: Callback. Called from another thread unless async
364368
new_data = bytearray(data)
365369
old_data = self.msg.data
366370
self.msg.data = new_data
@@ -436,6 +440,7 @@ def search(self, limit: int = 127) -> None:
436440
"""Search for nodes by sending SDO requests to all node IDs."""
437441
if self.network is None:
438442
raise RuntimeError("A Network is required to do active scanning")
443+
# SDO upload request, parameter 0x1000:0x00
439444
sdo_req = b"\x40\x00\x10\x00\x00\x00\x00\x00"
440445
for node_id in range(1, limit + 1):
441446
self.network.send_message(0x600 + node_id, sdo_req)

canopen/nmt.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Callable, Optional, TYPE_CHECKING
88

99
from .network import CanError
10+
from .async_guard import ensure_not_async
1011

1112
if TYPE_CHECKING:
1213
from .network import Network
@@ -96,15 +97,14 @@ def state(self) -> str:
9697
- 'RESET'
9798
- 'RESET COMMUNICATION'
9899
"""
99-
logger.warning("Accessing NmtBase.state attribute is deprecated")
100100
if self._state in NMT_STATES:
101101
return NMT_STATES[self._state]
102102
else:
103103
return self._state
104104

105105
@state.setter
106106
def state(self, new_state: str):
107-
logger.warning("Accessing NmtBase.state setter is deprecated")
107+
logger.warning("Accessing NmtBase.state setter is deprecated, use set_state()")
108108
self.set_state(new_state)
109109

110110
def set_state(self, new_state: str):
@@ -129,6 +129,7 @@ def __init__(self, node_id: int):
129129
self.astate_update = asyncio.Condition()
130130
self._callbacks = []
131131

132+
@ensure_not_async # NOTE: Safeguard for accidental async use
132133
def on_heartbeat(self, can_id, data, timestamp):
133134
# NOTE: Callback. Called from another thread unless async
134135
with self.state_update: # NOTE: Blocking call
@@ -177,6 +178,7 @@ def send_command(self, code: int):
177178
"Sending NMT command 0x%X to node %d", code, self.id)
178179
self.network.send_message(0, [code, self.id])
179180

181+
@ensure_not_async # NOTE: Safeguard for accidental async use
180182
def wait_for_heartbeat(self, timeout: float = 10):
181183
"""Wait until a heartbeat message is received."""
182184
with self.state_update: # NOTE: Blocking call
@@ -186,6 +188,17 @@ def wait_for_heartbeat(self, timeout: float = 10):
186188
raise NmtError("No boot-up or heartbeat received")
187189
return self.state
188190

191+
async def await_for_heartbeat(self, timeout: float = 10):
192+
"""Wait until a heartbeat message is received."""
193+
async with self.astate_update:
194+
self._state_received = None
195+
try:
196+
await asyncio.wait_for(self.astate_update.wait(), timeout=timeout)
197+
except asyncio.TimeoutError:
198+
raise NmtError("No boot-up or heartbeat received")
199+
return self.state
200+
201+
@ensure_not_async # NOTE: Safeguard for accidental async use
189202
def wait_for_bootup(self, timeout: float = 10) -> None:
190203
"""Wait until a boot-up message is received."""
191204
end_time = time.time() + timeout
@@ -199,6 +212,20 @@ def wait_for_bootup(self, timeout: float = 10) -> None:
199212
if self._state_received == 0:
200213
break
201214

215+
async def await_for_bootup(self, timeout: float = 10) -> None:
216+
"""Wait until a boot-up message is received."""
217+
async def wait_for_bootup():
218+
while True:
219+
async with self.astate_update:
220+
self._state_received = None
221+
await self.astate_update.wait()
222+
if self._state_received == 0:
223+
return
224+
try:
225+
await asyncio.wait_for(wait_for_bootup(), timeout=timeout)
226+
except asyncio.TimeoutError:
227+
raise NmtError("Timeout waiting for boot-up message")
228+
202229
def add_hearbeat_callback(self, callback: Callable[[int], None]):
203230
"""Add function to be called on heartbeat reception.
204231
@@ -255,7 +282,7 @@ def send_command(self, code: int) -> None:
255282
# The heartbeat service should start on the transition
256283
# between INITIALIZING and PRE-OPERATIONAL state
257284
if old_state == 0 and self._state == 127:
258-
heartbeat_time_ms = self._local_node.sdo[0x1017].get_raw()
285+
heartbeat_time_ms = self._local_node.sdo[0x1017].get_raw() # FIXME: Blocking?
259286
self.start_heartbeat(heartbeat_time_ms)
260287
else:
261288
self.update_heartbeat()

canopen/node/remote.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ def __load_configuration_helper(self, index, subindex, name, value):
151151
subindex=subindex,
152152
name=name,
153153
value=value)))
154-
self.sdo[index][subindex].set_raw(value)
154+
self.sdo[index][subindex].set_raw(value) # FIXME: Blocking?
155155
else:
156-
self.sdo[index].set_raw(value)
156+
self.sdo[index].set_raw(value) # FIXME: Blocking?
157157
logger.info(str('SDO [{index:#06x}]: {name}: {value:#06x}'.format(
158158
index=index,
159159
name=name,

canopen/objectdictionary/eds.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from ConfigParser import RawConfigParser, NoOptionError, NoSectionError
99
from canopen import objectdictionary
1010
from canopen.sdo import SdoClient
11+
from canopen.async_guard import ensure_not_async
1112

1213
logger = logging.getLogger(__name__)
1314

@@ -167,12 +168,13 @@ def import_eds(source, node_id):
167168
return od
168169

169170

171+
# FIXME: Make async variant
172+
@ensure_not_async # NOTE: Safeguard for accidental async use
170173
def import_from_node(node_id, network):
171174
""" Download the configuration from the remote node
172175
:param int node_id: Identifier of the node
173176
:param network: network object
174177
"""
175-
# FIXME: Implement async variant
176178
# Create temporary SDO client
177179
sdo_client = SdoClient(0x600 + node_id, 0x580 + node_id, objectdictionary.ObjectDictionary())
178180
sdo_client.network = network

canopen/pdo/base.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ..sdo import SdoAbortedError
1717
from .. import objectdictionary
1818
from .. import variable
19+
from ..async_guard import ensure_not_async
1920

2021
PDO_NOT_VALID = 1 << 31
2122
RTR_NOT_ALLOWED = 1 << 30
@@ -205,7 +206,7 @@ def __init__(self, pdo_node: PdoBase, com_record, map_array):
205206
#: Set explicitly or using the :meth:`start()` method.
206207
self.period: Optional[float] = None
207208
self.callbacks = []
208-
self.receive_condition = threading.Condition() # FIXME Async
209+
self.receive_condition = threading.Condition()
209210
self.areceive_condition = asyncio.Condition()
210211
self.is_received: bool = False
211212
self._task = None
@@ -307,11 +308,12 @@ def is_periodic(self) -> bool:
307308
# Unknown transmission type, assume non-periodic
308309
return False
309310

311+
@ensure_not_async # NOTE: Safeguard for accidental async use
310312
def on_message(self, can_id, data, timestamp):
311313
# NOTE: Callback. Called from another thread unless async
312314
is_transmitting = self._task is not None
313315
if can_id == self.cob_id and not is_transmitting:
314-
with self.receive_condition: # FIXME: Blocking
316+
with self.receive_condition: # NOTE: Blocking call
315317
self.is_received = True
316318
self.data = data
317319
if self.timestamp is not None:
@@ -394,6 +396,7 @@ def read_generator(self):
394396

395397
self.subscribe()
396398

399+
@ensure_not_async # NOTE: Safeguard for accidental async use
397400
def read(self, from_od=False) -> None:
398401
"""Read PDO configuration for this map using SDO or from OD."""
399402
gen = self.read_generator()
@@ -404,7 +407,7 @@ def read(self, from_od=False) -> None:
404407
value = var.od.default
405408
else:
406409
# Get value from SDO
407-
value = var.get_raw()
410+
value = var.get_raw() # FIXME: Blocking?
408411
try:
409412
# Deliver value into read_generator and wait for next object
410413
var = gen.send(value)
@@ -454,8 +457,7 @@ def save_generator(self):
454457
# mappings for an invalid object 0x0000:00 to overwrite any
455458
# excess entries with all-zeros.
456459

457-
# FIXME: This is a blocking call which might be called from async
458-
self._fill_map(self.map_array[0].get_raw())
460+
self._fill_map(self.map_array[0].get_raw()) # FIXME: Blocking?
459461
subindex = 1
460462
for var in self.map:
461463
logger.info("Writing %s (0x%X:%d, %d bits) to PDO map",
@@ -485,10 +487,11 @@ def save_generator(self):
485487
yield self.com_record[1], self.cob_id | (RTR_NOT_ALLOWED if not self.rtr_allowed else 0x0)
486488
self.subscribe()
487489

490+
@ensure_not_async # NOTE: Safeguard for accidental async use
488491
def save(self) -> None:
489492
"""Read PDO configuration for this map using SDO."""
490493
for sdo, value in self.save_generator():
491-
sdo.set_raw(value)
494+
sdo.set_raw(value) # FIXME: Blocking?
492495

493496
async def asave(self) -> None:
494497
"""Read PDO configuration for this map using SDO, async variant."""
@@ -596,15 +599,16 @@ def remote_request(self) -> None:
596599
if self.enabled and self.rtr_allowed:
597600
self.pdo_node.network.send_message(self.cob_id, None, remote=True)
598601

602+
@ensure_not_async # NOTE: Safeguard for accidental async use
599603
def wait_for_reception(self, timeout: float = 10) -> float:
600604
"""Wait for the next transmit PDO.
601605
602606
:param float timeout: Max time to wait in seconds.
603607
:return: Timestamp of message received or None if timeout.
604608
"""
605-
with self.receive_condition: # FIXME: Blocking
609+
with self.receive_condition: # NOTE: Blocking call
606610
self.is_received = False
607-
self.receive_condition.wait(timeout) # FIXME: Blocking
611+
self.receive_condition.wait(timeout) # NOTE: Blocking call
608612
return self.timestamp if self.is_received else None
609613

610614
async def await_for_reception(self, timeout: float = 10) -> float:

0 commit comments

Comments
 (0)