Skip to content

Commit edc0444

Browse files
committed
Refurbish async guard system
* Thread dependent sentinel guard
1 parent fdb6414 commit edc0444

File tree

7 files changed

+58
-22
lines changed

7 files changed

+58
-22
lines changed

canopen/async_guard.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,24 @@
11
""" Utils for async """
2-
32
import functools
4-
from typing import Optional, Callable
5-
6-
TSentinel = Callable[[], bool]
3+
import threading
74

85
# NOTE: Global, but needed to be able to use ensure_not_async() in
96
# decorator context.
10-
_ASYNC_SENTINEL: Optional[TSentinel] = None
7+
_ASYNC_SENTINELS: dict[int, bool] = {}
118

129

13-
def set_async_sentinel(fn: TSentinel):
10+
def set_async_sentinel(enable: bool):
1411
""" Register a function to validate if async is running """
15-
global _ASYNC_SENTINEL
16-
_ASYNC_SENTINEL = fn
12+
_ASYNC_SENTINELS[threading.get_ident()] = enable
1713

1814

1915
def ensure_not_async(fn):
2016
""" Decorator that will ensure that the function is not called if async
2117
is running.
2218
"""
23-
2419
@functools.wraps(fn)
25-
def async_guard(*args, **kwargs):
26-
global _ASYNC_SENTINEL
27-
if _ASYNC_SENTINEL:
28-
if _ASYNC_SENTINEL():
29-
raise RuntimeError("Calling a blocking function while running async")
20+
def async_guard_wrap(*args, **kwargs):
21+
if _ASYNC_SENTINELS.get(threading.get_ident(), False):
22+
raise RuntimeError(f"Calling a blocking function in async. {fn.__qualname__}() in {fn.__code__.co_filename}:{fn.__code__.co_firstlineno}, while running async")
3023
return fn(*args, **kwargs)
31-
return async_guard
24+
return async_guard_wrap

canopen/emcy.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import annotations
12
import asyncio
23
import logging
34
import struct
@@ -83,7 +84,7 @@ def reset(self):
8384
@ensure_not_async # NOTE: Safeguard for accidental async use
8485
def wait(
8586
self, emcy_code: Optional[int] = None, timeout: float = 10
86-
) -> "EmcyError":
87+
) -> EmcyError:
8788
"""Wait for a new EMCY to arrive.
8889
8990
:param emcy_code: EMCY code to wait for
@@ -111,6 +112,14 @@ def wait(
111112
# This is the one we're interested in
112113
return emcy
113114

115+
def async_wait(
116+
self, emcy_code: Optional[int] = None, timeout: float = 10
117+
) -> EmcyError:
118+
# FIXME: Implement this function
119+
raise NotImplementedError(
120+
"async_wait is not implemented."
121+
)
122+
114123

115124
class EmcyProducer:
116125

canopen/network.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self, bus: Optional[can.BusABC] = None, notifier: Optional[can.Noti
6060
# Register this function as the means to check if canopen is run in
6161
# async mode. This enables the @ensure_not_async() decorator to
6262
# work. See async_guard.py
63-
set_async_sentinel(self.is_async)
63+
set_async_sentinel(self.is_async())
6464

6565
if self.is_async():
6666
self.subscribe(self.lss.LSS_RX_COBID, self.lss.aon_message_received)
@@ -142,6 +142,9 @@ def disconnect(self) -> None:
142142
self.bus = None
143143
self.check()
144144

145+
# Remove the async sentinel
146+
set_async_sentinel(False)
147+
145148
def __enter__(self):
146149
return self
147150

canopen/objectdictionary/eds.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from configparser import NoOptionError, NoSectionError, RawConfigParser
55

66
from canopen import objectdictionary
7-
from canopen.async_guard import ensure_not_async
87
from canopen.objectdictionary import ObjectDictionary, datatypes
98
from canopen.sdo import SdoClient
109

@@ -175,7 +174,6 @@ def import_eds(source, node_id):
175174

176175

177176
# FIXME: Make async variant "aimport_from_node"
178-
@ensure_not_async # NOTE: Safeguard for accidental async use
179177
def import_from_node(node_id, network):
180178
""" Download the configuration from the remote node
181179
:param int node_id: Identifier of the node

canopen/pdo/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __getitem__(self, key) -> PdoBase:
5858
def __len__(self):
5959
return len(self.map)
6060

61+
@ensure_not_async # NOTE: Safeguard for accidental async use
6162
def read(self, from_od=False):
6263
"""Read PDO configuration from node using SDO."""
6364
for pdo_map in self.map.values():
@@ -68,6 +69,7 @@ async def aread(self, from_od=False):
6869
for pdo_map in self.map.values():
6970
await pdo_map.aread(from_od=from_od)
7071

72+
@ensure_not_async # NOTE: Safeguard for accidental async use
7173
def save(self):
7274
"""Save PDO configuration to node using SDO."""
7375
for pdo_map in self.map.values():

canopen/sdo/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from canopen import objectdictionary
99
from canopen import variable
1010
from canopen.utils import pretty_index
11+
from canopen.async_guard import ensure_not_async
1112

1213

1314
class CrcXmodem:
@@ -183,12 +184,14 @@ def __init__(self, sdo_node: SdoBase, od: objectdictionary.ODVariable):
183184
def __await__(self):
184185
return self.aget_raw().__await__()
185186

187+
@ensure_not_async # NOTE: Safeguard for accidental async use
186188
def get_data(self) -> bytes:
187189
return self.sdo_node.upload(self.od.index, self.od.subindex)
188190

189191
async def aget_data(self) -> bytes:
190192
return await self.sdo_node.aupload(self.od.index, self.od.subindex)
191193

194+
@ensure_not_async # NOTE: Safeguard for accidental async use
192195
def set_data(self, data: bytes):
193196
force_segment = self.od.data_type == objectdictionary.DOMAIN
194197
self.sdo_node.download(self.od.index, self.od.subindex, data, force_segment)
@@ -205,6 +208,7 @@ def writable(self) -> bool:
205208
def readable(self) -> bool:
206209
return self.od.readable
207210

211+
@ensure_not_async # NOTE: Safeguard for accidental async use
208212
def open(self, mode="rb", encoding="ascii", buffering=1024, size=None,
209213
block_transfer=False, request_crc_support=True):
210214
"""Open the data stream as a file like object.

canopen/sdo/client.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from canopen.sdo.constants import *
1313
from canopen.sdo.exceptions import *
1414
from canopen.utils import pretty_index
15+
from canopen.async_guard import ensure_not_async
1516

1617

1718
logger = logging.getLogger(__name__)
@@ -49,9 +50,11 @@ def __init__(self, rx_cobid, tx_cobid, od):
4950
def on_response(self, can_id, data, timestamp):
5051
self.responses.put_nowait(bytes(data))
5152

53+
@ensure_not_async # NOTE: Safeguard for accidental async use
5254
def send_request(self, request):
5355
retries_left = self.MAX_RETRIES
5456
if self.PAUSE_BEFORE_SEND:
57+
# NOTE: Blocking
5558
time.sleep(self.PAUSE_BEFORE_SEND)
5659
while True:
5760
try:
@@ -63,6 +66,7 @@ def send_request(self, request):
6366
raise
6467
logger.info(str(e))
6568
if self.RETRY_DELAY:
69+
# NOTE: Blocking
6670
time.sleep(self.RETRY_DELAY)
6771
else:
6872
break
@@ -108,6 +112,7 @@ def abort(self, abort_code=0x08000000):
108112
self.send_request(request)
109113
logger.error("Transfer aborted by client with code 0x%08X", abort_code)
110114

115+
@ensure_not_async # NOTE: Safeguard for accidental async use
111116
def upload(self, index: int, subindex: int) -> bytes:
112117
"""May be called to make a read operation without an Object Dictionary.
113118
@@ -126,7 +131,9 @@ def upload(self, index: int, subindex: int) -> bytes:
126131
with self.open(index, subindex, buffering=0) as fp:
127132
response_size = fp.size
128133
data = fp.read()
134+
return self.truncate_data(index, subindex, data, response_size)
129135

136+
def truncate_data(self, index: int, subindex: int, data: bytes, size: int) -> bytes:
130137
# If size is available through variable in OD, then use the smaller of the two sizes.
131138
# Some devices send U32/I32 even if variable is smaller in OD
132139
var = self.od.get_variable(index, subindex)
@@ -137,7 +144,7 @@ def upload(self, index: int, subindex: int) -> bytes:
137144
if var.data_type not in objectdictionary.DATA_TYPES:
138145
# Get the size in bytes for this variable
139146
var_size = len(var) // 8
140-
if response_size is None or var_size < response_size:
147+
if size is None or var_size < size:
141148
# Truncate the data to specified size
142149
data = data[0:var_size]
143150
return data
@@ -152,8 +159,16 @@ async def aupload(self, index: int, subindex: int) -> bytes:
152159
# upload -> open -> ReadableStream -> request_reponse -> send_request -> network.send_message
153160
# recv -> on_reponse -> queue.put
154161
# request_reponse -> read_response -> queue.get
155-
return await asyncio.to_thread(self.upload, index, subindex)
162+
def _upload():
163+
with self._open(index, subindex, buffering=0) as fp:
164+
response_size = fp.size
165+
data = fp.read()
166+
return data, response_size
156167

168+
data, response_size = await asyncio.to_thread(_upload)
169+
return self.truncate_data(index, subindex, data, response_size)
170+
171+
@ensure_not_async # NOTE: Safeguard for accidental async use
157172
def download(
158173
self,
159174
index: int,
@@ -193,10 +208,22 @@ async def adownload(
193208
"""
194209
async with self.lock: # Ensure only one active SDO request per channel
195210
# Deferring to thread because there are sleeps in the call chain
196-
return await asyncio.to_thread(self.download, index, subindex, data, force_segment)
197211

212+
def _download():
213+
with self._open(index, subindex, "wb", buffering=7, size=len(data),
214+
force_segment=force_segment) as fp:
215+
fp.write(data)
216+
217+
return await asyncio.to_thread(_download)
218+
219+
@ensure_not_async # NOTE: Safeguard for accidental async use
198220
def open(self, index, subindex=0, mode="rb", encoding="ascii",
199221
buffering=1024, size=None, block_transfer=False, force_segment=False, request_crc_support=True):
222+
return self._open(index, subindex, mode, encoding, buffering,
223+
size, block_transfer, force_segment, request_crc_support)
224+
225+
def _open(self, index, subindex=0, mode="rb", encoding="ascii",
226+
buffering=1024, size=None, block_transfer=False, force_segment=False, request_crc_support=True):
200227
"""Open the data stream as a file like object.
201228
202229
:param int index:

0 commit comments

Comments
 (0)