diff --git a/LICENSE.txt b/LICENSE.txt index bfce6f06..db5094ed 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,6 +1,7 @@ MIT License Copyright (c) 2016 Christian Sandberg +Copyright (c) 2025 Svein Seldal Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.rst b/README.rst index 0014cd67..947f3bb4 100644 --- a/README.rst +++ b/README.rst @@ -1,5 +1,5 @@ -CANopen for Python -================== +CANopen for Python, asyncio port +================================ A Python implementation of the CANopen_ standard. The aim of the project is to support the most common parts of the CiA 301 @@ -8,6 +8,84 @@ automation tasks rather than a standard compliant master implementation. The library supports Python 3.8 or newer. +This library is the asyncio port of CANopen. See below for code example. + + +Asyncio port +------------ + +The objective of the library is to provide a canopen implementation in +either async or non-async environment, with suitable API for both. + +To minimize the impact of the async changes, this port is designed to use the +existing synchronous backend of the library. This means that the library +uses :code:`asyncio.to_thread()` for many asynchronous operations. + +This port remains compatible with using it in a regular non-asyncio +environment. This is selected with the `loop` parameter in the +:code:`Network` constructor. If you pass a valid asyncio event loop, the +library will run in async mode. If you pass `loop=None`, it will run in +regular blocking mode. It cannot be used in both modes at the same time. + + +Difference between async and non-async version +---------------------------------------------- + +This port have some differences with the upstream non-async version of canopen. + +* Minimum python version is 3.9, while the upstream version supports 3.8. + +* The :code:`Network` accepts additional parameters than upstream. It accepts + :code:`loop` which selects the mode of operation. If :code:`None` it will + run in blocking mode, otherwise it will run in async mode. It supports + providing a custom CAN :code:`notifier` if the CAN bus will be shared by + multiple protocols. + +* The :code:`Network` class can be (and should be) used in an async context + manager. This will ensure the network will be automatically disconnected when + exiting the context. See the example below. + +* Most async functions follow an "a" prefix naming scheme. + E.g. the async variant for :code:`SdoClient.download()` is available + as :code:`SdoClient.adownload()`. + +* Variables in the regular canopen library uses properties for getting and + setting. This is replaced with awaitable methods in the async version. + + var = sdo['Variable'].raw # synchronous + sdo['Variable'].raw = 12 # synchronous + + var = await sdo['Variable'].get_raw() # async + await sdo['Variable'].set_raw(12) # async + +* Installed :code:`ensure_not_async()` sentinel guard in functions which + prevents calling blocking functions in async context. It will raise the + exception :code:`RuntimeError` "Calling a blocking function" when this + happen. If this is encountered, it is likely that the code is not using the + async variants of the library. + +* The mechanism for CAN bus callbacks have been changed. Callbacks might be + async, which means they cannot be called immediately. This affects how + error handling is done in the library. + +* The callbacks to the message handlers have been changed to be handled by + :code:`Network.dispatch_callbacks()`. They are no longer called with any + locks held, as this would not work with async. This affects: + * :code:`PdoMaps.on_message` + * :code:`EmcyConsumer.on_emcy` + * :code:`NtmMaster.on_heartbaet` + +* SDO block upload and download is not yet supported in async mode. + +* :code:`ODVariable.__len__()` returns 64 bits instead of 8 bits to support + truncated 24-bits integers, see #436 + +* :code:`BaseNode402` does not work with async + +* :code:`LssMaster` does not work with async, except :code:`LssMaster.fast_scan()` + +* :code:`Bits` is not working in async + Features -------- @@ -156,6 +234,70 @@ The :code:`n` is the PDO index (normally 1 to 4). The second form of access is f network.disconnect() +Asyncio +------- + +This is the same example as above, but using asyncio + +.. code-block:: python + + import asyncio + import canopen + import can + + async def my_node(network, nodeid, od): + + # Create the node object and load the OD + node = network.add_node(nodeid, od) + + # Read the PDOs from the remote + await node.tpdo.aread() + await node.rpdo.aread() + + # Set the module state + node.nmt.set_state('OPERATIONAL') + + # Set motor speed via SDO + await node.sdo['MotorSpeed'].aset_raw(2) + + while True: + + # Wait for TPDO 1 + t = await node.tpdo[1].await_for_reception(1) + if not t: + continue + + # Get the TPDO 1 value + rpm = node.tpdo[1]['MotorSpeed Actual'].get_raw() + print(f'SPEED on motor {nodeid}:', rpm) + + # Sleep a little + await asyncio.sleep(0.2) + + # Send RPDO 1 with some data + node.rpdo[1]['Some variable'].set_phys(42) + node.rpdo[1].transmit() + + async def main(): + + # Connect to the CAN bus + # Arguments are passed to python-can's can.Bus() constructor + # (see https://python-can.readthedocs.io/en/latest/bus.html). + # Note the loop parameter to enable asyncio operation + loop = asyncio.get_running_loop() + async with canopen.Network(loop=loop).connect( + interface='pcan', bitrate=1000000) as network: + + # Create two independent tasks for two nodes 51 and 52 which will run concurrently + task1 = asyncio.create_task(my_node(network, 51, '/path/to/object_dictionary.eds')) + task2 = asyncio.create_task(my_node(network, 52, '/path/to/object_dictionary.eds')) + + # Wait for both to complete (which will never happen) + await asyncio.gather((task1, task2)) + + asyncio.run(main()) + + Debugging --------- diff --git a/canopen/async_guard.py b/canopen/async_guard.py new file mode 100644 index 00000000..405c4dbf --- /dev/null +++ b/canopen/async_guard.py @@ -0,0 +1,43 @@ +""" Utils for async """ +import functools +import logging +import threading +import traceback + +# NOTE: Global, but needed to be able to use ensure_not_async() in +# decorator context. +_ASYNC_SENTINELS: dict[int, bool] = {} + +logger = logging.getLogger(__name__) + + +def set_async_sentinel(enable: bool): + """ Register a function to validate if async is running """ + _ASYNC_SENTINELS[threading.get_ident()] = enable + + +def ensure_not_async(fn): + """ Decorator that will ensure that the function is not called if async + is running. + """ + @functools.wraps(fn) + def async_guard_wrap(*args, **kwargs): + if _ASYNC_SENTINELS.get(threading.get_ident(), False): + st = "".join(traceback.format_stack()) + logger.debug("Traceback:\n%s", st.rstrip()) + raise RuntimeError(f"Calling a blocking function, {fn.__qualname__}() in {fn.__code__.co_filename}:{fn.__code__.co_firstlineno}, while running async") + return fn(*args, **kwargs) + return async_guard_wrap + + +class AllowBlocking: + """ Context manager to pause async guard """ + def __init__(self): + self._enabled = _ASYNC_SENTINELS.get(threading.get_ident(), False) + + def __enter__(self): + set_async_sentinel(False) + return self + + def __exit__(self, exc_type, exc_value, traceback): + set_async_sentinel(self._enabled) diff --git a/canopen/emcy.py b/canopen/emcy.py index ec2c489f..f4a555fe 100644 --- a/canopen/emcy.py +++ b/canopen/emcy.py @@ -1,9 +1,12 @@ +from __future__ import annotations +import asyncio import logging import struct import threading import time from typing import Callable, List, Optional +from canopen.async_guard import ensure_not_async import canopen.network @@ -22,11 +25,15 @@ def __init__(self): self.active: List["EmcyError"] = [] self.callbacks = [] self.emcy_received = threading.Condition() + self.network: canopen.network.Network = canopen.network._UNINITIALIZED_NETWORK + # @callback # NOTE: called from another thread + @ensure_not_async # NOTE: Safeguard for accidental async use def on_emcy(self, can_id, data, timestamp): code, register, data = EMCY_STRUCT.unpack(data) entry = EmcyError(code, register, data, timestamp) + # NOTE: Blocking lock with self.emcy_received: if code & 0xFF00 == 0: # Error reset @@ -36,8 +43,8 @@ def on_emcy(self, can_id, data, timestamp): self.log.append(entry) self.emcy_received.notify_all() - for callback in self.callbacks: - callback(entry) + # Call all registered callbacks + self.network.dispatch_callbacks(self.callbacks, entry) def add_callback(self, callback: Callable[["EmcyError"], None]): """Get notified on EMCY messages from this node. @@ -53,6 +60,7 @@ def reset(self): self.log = [] self.active = [] + @ensure_not_async # NOTE: Safeguard for accidental async use def wait( self, emcy_code: Optional[int] = None, timeout: float = 10 ) -> "EmcyError": @@ -65,8 +73,10 @@ def wait( """ end_time = time.time() + timeout while True: + # NOTE: Blocking lock with self.emcy_received: prev_log_size = len(self.log) + # NOTE: Blocking call self.emcy_received.wait(timeout) if len(self.log) == prev_log_size: # Resumed due to timeout @@ -81,6 +91,18 @@ def wait( # This is the one we're interested in return emcy + async def async_wait( + self, emcy_code: Optional[int] = None, timeout: float = 10 + ) -> EmcyError: + """Wait for a new EMCY to arrive. + + :param emcy_code: EMCY code to wait for + :param timeout: Max time in seconds to wait + + :return: The EMCY exception object or None if timeout + """ + return await asyncio.to_thread(self.wait, emcy_code, timeout) + class EmcyProducer: diff --git a/canopen/lss.py b/canopen/lss.py index 7c0b92a6..38e0b61a 100644 --- a/canopen/lss.py +++ b/canopen/lss.py @@ -1,8 +1,10 @@ +import asyncio import logging import queue import struct import time +from canopen.async_guard import ensure_not_async import canopen.network @@ -87,6 +89,8 @@ def __init__(self) -> None: self._data = None self.responses = queue.Queue() + # FIXME: Async implementation of the public methods in this class + def send_switch_state_global(self, mode): """switch mode to CONFIGURATION_STATE or WAITING_STATE in the all slaves on CAN bus. @@ -241,6 +245,7 @@ def send_identify_non_configured_remote_slave(self): message[0] = CS_IDENTIFY_NON_CONFIGURED_REMOTE_SLAVE self.__send_command(message) + @ensure_not_async # NOTE: Safeguard for accidental async use def fast_scan(self): """This command sends a series of fastscan message to find unconfigured slave with lowest number of LSS idenities @@ -257,6 +262,7 @@ def fast_scan(self): lss_next = 0 if self.__send_fast_scan_message(lss_id[0], lss_bit_check, lss_sub, lss_next): + # NOTE: Blocking call time.sleep(0.01) while lss_sub < 4: lss_bit_check = 32 @@ -266,12 +272,14 @@ def fast_scan(self): if not self.__send_fast_scan_message(lss_id[lss_sub], lss_bit_check, lss_sub, lss_next): lss_id[lss_sub] |= 1< None: @@ -108,7 +120,16 @@ def connect(self, *args, **kwargs) -> Network: if self.bus is None: self.bus = can.Bus(*args, **kwargs) logger.info("Connected to '%s'", self.bus.channel_info) - self.notifier = can.Notifier(self.bus, self.listeners, self.NOTIFIER_CYCLE) + if self.notifier is None: + # Do not start a can notifier with the async loop. It changes the + # behavior of the notifier callbacks. Instead of running the + # callbacks from a separate thread, it runs the callbacks in the + # same thread as the event loop where blocking calls are not allowed. + # This library needs to support both async and sync, so we need to + # use the notifier in a separate thread. + self.notifier = can.Notifier(self.bus, [], self.NOTIFIER_CYCLE) + for listener in self.listeners: + self.notifier.add_listener(listener) return self def disconnect(self) -> None: @@ -126,12 +147,25 @@ def disconnect(self) -> None: self.bus = None self.check() + # Remove the async sentinel + set_async_sentinel(False) + def __enter__(self): return self def __exit__(self, type, value, traceback): self.disconnect() + async def __aenter__(self): + # FIXME: When TaskGroup are available, we should use them to manage the + # tasks. The user must use the `async with` statement with the Network + # to ensure its created. + return self + + async def __aexit__(self, type, value, traceback): + self.disconnect() + + @ensure_not_async # NOTE: Safeguard for accidental async use def add_node( self, node: Union[int, RemoteNode, LocalNode], @@ -161,6 +195,20 @@ def add_node( self[node.id] = node return node + async def aadd_node( + self, + node: Union[int, RemoteNode, LocalNode], + object_dictionary: Union[str, ObjectDictionary, None] = None, + upload_eds: bool = False, + ) -> RemoteNode: + """Add a remote node to the network, async variant. + + See add_node() for description + """ + # NOTE: The async variant exists because import_from_node might block + return await asyncio.to_thread(self.add_node, node, + object_dictionary, upload_eds) + def create_node( self, node: int, @@ -206,6 +254,8 @@ def send_message(self, can_id: int, data: bytes, remote: bool = False) -> None: arbitration_id=can_id, data=data, is_remote_frame=remote) + # NOTE: Blocking lock. This is probably ok for async, because async + # only use one thread. with self.send_lock: self.bus.send(msg) self.check() @@ -229,6 +279,7 @@ def send_periodic( """ return PeriodicMessageTask(can_id, data, period, self.bus, remote) + # @callback # NOTE: called from another thread def notify(self, can_id: int, data: bytearray, timestamp: float) -> None: """Feed incoming message to this library. @@ -243,11 +294,44 @@ def notify(self, can_id: int, data: bytearray, timestamp: float) -> None: Timestamp of the message, preferably as a Unix timestamp """ if can_id in self.subscribers: - callbacks = self.subscribers[can_id] - for callback in callbacks: - callback(can_id, data, timestamp) + self.dispatch_callbacks(self.subscribers[can_id], can_id, data, timestamp) self.scanner.on_message_received(can_id) + def on_error(self, exc: BaseException) -> None: + """This method is called to handle any exception in the callbacks.""" + + # Exceptions in any callbaks should not affect CAN processing + logger.exception("Exception in callback: %s", exc_info=exc) + + def dispatch_callbacks(self, callbacks: List[Callback], *args) -> None: + """Dispatch a list of callbacks with the given arguments. + + :param callbacks: + List of callbacks to call + :param args: + Arguments to pass to the callbacks + """ + def task_done(task: asyncio.Task) -> None: + """Callback to be called when a task is done.""" + self._tasks.discard(task) + + # FIXME: This section should probably be migrated to a TaskGroup. + # However, this is not available yet in Python 3.8 - 3.10. + try: + if (exc := task.exception()) is not None: + self.on_error(exc) + except (asyncio.CancelledError, asyncio.InvalidStateError) as exc: + # Handle cancelled tasks and unfinished tasks gracefully + self.on_error(exc) + + # Run the callbacks + for callback in callbacks: + result = callback(*args) + if result is not None and asyncio.iscoroutine(result): + task = asyncio.create_task(result) + self._tasks.add(task) + task.add_done_callback(task_done) + def check(self) -> None: """Check that no fatal error has occurred in the receiving thread. @@ -260,6 +344,10 @@ def check(self) -> None: logger.error("An error has caused receiving of messages to stop") raise exc + def is_async(self) -> bool: + """Check if canopen has been connected with async""" + return self.loop is not None + def __getitem__(self, node_id: int) -> Union[RemoteNode, LocalNode]: return self.nodes[node_id] @@ -335,6 +423,7 @@ def stop(self): """Stop transmission""" self._task.stop() + # @callback # NOTE: Indirectly called from another thread via other callbacks def update(self, data: bytes) -> None: """Update data of message @@ -362,6 +451,7 @@ class MessageListener(Listener): def __init__(self, network: Network): self.network = network + # @callback # NOTE: called from another thread def on_message_received(self, msg): if msg.is_error_frame or msg.is_remote_frame: return @@ -370,7 +460,7 @@ def on_message_received(self, msg): self.network.notify(msg.arbitration_id, msg.data, msg.timestamp) except Exception as e: # Exceptions in any callbaks should not affect CAN processing - logger.error(str(e)) + self.network.on_error(e) def stop(self) -> None: """Override abstract base method to release any resources.""" @@ -398,6 +488,7 @@ def __init__(self, network: Optional[Network] = None): #: A :class:`list` of nodes discovered self.nodes: List[int] = [] + # @callback # NOTE: called from another thread def on_message_received(self, can_id: int): service = can_id & 0x780 node_id = can_id & 0x7F diff --git a/canopen/nmt.py b/canopen/nmt.py index c13d0779..ed4ec02d 100644 --- a/canopen/nmt.py +++ b/canopen/nmt.py @@ -1,9 +1,11 @@ +import asyncio import logging import struct import threading import time from typing import Callable, Dict, Final, List, Optional, TYPE_CHECKING +from canopen.async_guard import ensure_not_async import canopen.network if TYPE_CHECKING: @@ -54,6 +56,7 @@ def __init__(self, node_id: int): self.network: canopen.network.Network = canopen.network._UNINITIALIZED_NETWORK self._state = 0 + # @callback # NOTE: called from another thread def on_command(self, can_id, data, timestamp): cmd, node_id = struct.unpack_from("BB", data) if node_id in (self.id, 0): @@ -63,6 +66,7 @@ def on_command(self, can_id, data, timestamp): if new_state != self._state: logger.info("New NMT state %s, old state %s", NMT_STATES[new_state], NMT_STATES[self._state]) + # FIXME: Is this thread-safe? self._state = new_state def send_command(self, code: int): @@ -119,15 +123,17 @@ def __init__(self, node_id: int): self.state_update = threading.Condition() self._callbacks: List[Callable[[int], None]] = [] + # @callback # NOTE: called from another thread + @ensure_not_async # NOTE: Safeguard for accidental async use def on_heartbeat(self, can_id, data, timestamp): + new_state, = struct.unpack_from("B", data) + # Mask out toggle bit + new_state &= 0x7F + logger.debug("Received heartbeat can-id %d, state is %d", can_id, new_state) + + # NOTE: Blocking lock with self.state_update: self.timestamp = timestamp - new_state, = struct.unpack_from("B", data) - # Mask out toggle bit - new_state &= 0x7F - logger.debug("Received heartbeat can-id %d, state is %d", can_id, new_state) - for callback in self._callbacks: - callback(new_state) if new_state == 0: # Boot-up, will go to PRE-OPERATIONAL automatically self._state = 127 @@ -136,6 +142,9 @@ def on_heartbeat(self, can_id, data, timestamp): self._state_received = new_state self.state_update.notify_all() + # Call all registered callbacks + self.network.dispatch_callbacks(self._callbacks, new_state) + def send_command(self, code: int): """Send an NMT command code to the node. @@ -147,28 +156,42 @@ def send_command(self, code: int): "Sending NMT command 0x%X to node %d", code, self.id) self.network.send_message(0, [code, self.id]) + @ensure_not_async # NOTE: Safeguard for accidental async use def wait_for_heartbeat(self, timeout: float = 10): """Wait until a heartbeat message is received.""" + # NOTE: Blocking lock with self.state_update: self._state_received = None + # NOTE: Blocking call self.state_update.wait(timeout) if self._state_received is None: raise NmtError("No boot-up or heartbeat received") return self.state + async def await_for_heartbeat(self, timeout: float = 10): + """Wait until a heartbeat message is received.""" + return await asyncio.to_thread(self.wait_for_heartbeat, timeout) + + @ensure_not_async # NOTE: Safeguard for accidental async use def wait_for_bootup(self, timeout: float = 10) -> None: """Wait until a boot-up message is received.""" end_time = time.time() + timeout while True: now = time.time() + # NOTE: Blocking lock with self.state_update: self._state_received = None + # NOTE: Blocking call self.state_update.wait(end_time - now + 0.1) if now > end_time: raise NmtError("Timeout waiting for boot-up message") if self._state_received == 0: break + async def await_for_bootup(self, timeout: float = 10) -> None: + """Wait until a boot-up message is received.""" + return await asyncio.to_thread(self.wait_for_bootup, timeout) + def add_heartbeat_callback(self, callback: Callable[[int], None]): """Add function to be called on heartbeat reception. @@ -208,6 +231,7 @@ def __init__(self, node_id: int, local_node): self._heartbeat_time_ms = 0 self._local_node = local_node + # @callback # NOTE: called from another thread def on_command(self, can_id, data, timestamp): super(NmtSlave, self).on_command(can_id, data, timestamp) self.update_heartbeat() @@ -228,7 +252,12 @@ def send_command(self, code: int) -> None: # The heartbeat service should start on the transition # between INITIALIZING and PRE-OPERATIONAL state if old_state == 0 and self._state == 127: - heartbeat_time_ms = self._local_node.sdo[0x1017].raw + # FIXME: Document why this was fixed + if self._heartbeat_time_ms == 0: + # NOTE: Blocking - protected in SdoClient + heartbeat_time_ms = self._local_node.sdo[0x1017].raw + else: + heartbeat_time_ms = self._heartbeat_time_ms self.start_heartbeat(heartbeat_time_ms) else: self.update_heartbeat() @@ -263,8 +292,10 @@ def stop_heartbeat(self): self._send_task.stop() self._send_task = None + # @callback # NOTE: Indirectly called from another thread via on_command def update_heartbeat(self): if self._send_task is not None: + # FIXME: Make this thread-safe self._send_task.update([self._state]) diff --git a/canopen/node/remote.py b/canopen/node/remote.py index 371f784c..3b6ef3e8 100644 --- a/canopen/node/remote.py +++ b/canopen/node/remote.py @@ -59,6 +59,7 @@ def associate_network(self, network: canopen.network.Network): self.tpdo.network = network self.rpdo.network = network self.nmt.network = network + self.emcy.network = network for sdo in self.sdo_channels: network.subscribe(sdo.tx_cobid, sdo.on_response) network.subscribe(0x700 + self.id, self.nmt.on_heartbeat) @@ -79,6 +80,7 @@ def remove_network(self) -> None: self.tpdo.network = canopen.network._UNINITIALIZED_NETWORK self.rpdo.network = canopen.network._UNINITIALIZED_NETWORK self.nmt.network = canopen.network._UNINITIALIZED_NETWORK + self.emcy.network = canopen.network._UNINITIALIZED_NETWORK def add_sdo(self, rx_cobid, tx_cobid): """Add an additional SDO channel. @@ -132,8 +134,10 @@ def __load_configuration_helper(self, index, subindex, name, value): if subindex is not None: logger.info('SDO [0x%04X][0x%02X]: %s: %#06x', index, subindex, name, value) + # NOTE: Blocking - protected in SdoClient self.sdo[index][subindex].raw = value else: + # NOTE: Blocking - protected in SdoClient self.sdo[index].raw = value logger.info('SDO [0x%04X]: %s: %#06x', index, name, value) diff --git a/canopen/objectdictionary/__init__.py b/canopen/objectdictionary/__init__.py index f394da23..d8f5e6cc 100644 --- a/canopen/objectdictionary/__init__.py +++ b/canopen/objectdictionary/__init__.py @@ -395,7 +395,8 @@ def __len__(self) -> int: if self.data_type in self.STRUCT_TYPES: return self.STRUCT_TYPES[self.data_type].size * 8 else: - return 8 + # FIXME: Temporary fix for trucated 24-bit integers, see #436 + return 64 @property def writable(self) -> bool: diff --git a/canopen/objectdictionary/eds.py b/canopen/objectdictionary/eds.py index 986d2a37..8f350400 100644 --- a/canopen/objectdictionary/eds.py +++ b/canopen/objectdictionary/eds.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import copy import logging import re @@ -7,6 +8,7 @@ from typing import TYPE_CHECKING from canopen import objectdictionary +from canopen.async_guard import ensure_not_async from canopen.objectdictionary import ObjectDictionary, datatypes from canopen.sdo import SdoClient @@ -33,6 +35,7 @@ def import_eds(source, node_id): else: fp = open(source) opened_here = True + # NOTE: Blocking call if fp is a file eds.read_file(fp) finally: # Only close object if opened in this fn @@ -179,6 +182,7 @@ def import_eds(source, node_id): return od +@ensure_not_async # NOTE: Safeguard for accidental async use def import_from_node(node_id: int, network: canopen.network.Network): """ Download the configuration from the remote node :param int node_id: Identifier of the node @@ -191,6 +195,7 @@ def import_from_node(node_id: int, network: canopen.network.Network): network.subscribe(0x580 + node_id, sdo_client.on_response) # Create file like object for Store EDS variable try: + # NOTE: This results in a blocking call with sdo_client.open(0x1021, 0, "rt") as eds_fp: od = import_eds(eds_fp, node_id) except Exception as e: @@ -202,6 +207,14 @@ def import_from_node(node_id: int, network: canopen.network.Network): return od +async def aimport_from_node(node_id: int, network: canopen.network.Network): + """ Download the configuration from the remote node + :param int node_id: Identifier of the node + :param network: network object + """ + return await asyncio.to_thread(import_from_node, node_id, network) + + def _calc_bit_length(data_type): if data_type == datatypes.INTEGER8: return 8 diff --git a/canopen/pdo/base.py b/canopen/pdo/base.py index 0ba65199..52e734ba 100644 --- a/canopen/pdo/base.py +++ b/canopen/pdo/base.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import binascii import logging import math @@ -10,6 +11,7 @@ import canopen.network from canopen import objectdictionary from canopen import variable +from canopen.async_guard import ensure_not_async from canopen.sdo import SdoAbortedError if TYPE_CHECKING: @@ -56,16 +58,28 @@ def __getitem__(self, key): def __len__(self): return len(self.map) + @ensure_not_async # NOTE: Safeguard for accidental async use def read(self, from_od=False): """Read PDO configuration from node using SDO.""" for pdo_map in self.map.values(): pdo_map.read(from_od=from_od) + async def aread(self, from_od=False): + """Read PDO configuration from node using SDO, async variant.""" + for pdo_map in self.map.values(): + await pdo_map.aread(from_od=from_od) + + @ensure_not_async # NOTE: Safeguard for accidental async use def save(self): """Save PDO configuration to node using SDO.""" for pdo_map in self.map.values(): pdo_map.save() + async def asave(self): + """Save PDO configuration to node using SDO, async variant.""" + for pdo_map in self.map.values(): + await pdo_map.asave() + def subscribe(self): """Register the node's PDOs for reception on the network. @@ -309,9 +323,12 @@ def is_periodic(self) -> bool: # Unknown transmission type, assume non-periodic return False + # @callback # NOTE: called from another thread + @ensure_not_async # NOTE: Safeguard for accidental async use def on_message(self, can_id, data, timestamp): is_transmitting = self._task is not None if can_id == self.cob_id and not is_transmitting: + # NOTE: Blocking lock with self.receive_condition: self.is_received = True self.data = data @@ -319,8 +336,9 @@ def on_message(self, can_id, data, timestamp): self.period = timestamp - self.timestamp self.timestamp = timestamp self.receive_condition.notify_all() - for callback in self.callbacks: - callback(self) + + # Call all registered callbacks + self.pdo_node.network.dispatch_callbacks(self.callbacks, self) def add_callback(self, callback: Callable[[PdoMap], None]) -> None: """Add a callback which will be called on receive. @@ -331,58 +349,48 @@ def add_callback(self, callback: Callable[[PdoMap], None]) -> None: """ self.callbacks.append(callback) - def read(self, from_od=False) -> None: - """Read PDO configuration for this map. - - :param from_od: - Read using SDO if False, read from object dictionary if True. - When reading from object dictionary, if DCF populated a value, the - DCF value will be used, otherwise the EDS default will be used instead. - """ + def read_generator(self): + """Generator to run through steps for reading the PDO configuration + for this map. - def _raw_from(param): - if from_od: - if param.od.value is not None: - return param.od.value - else: - return param.od.default - return param.raw + This function does not do any io. This must be done by the caller. - cob_id = _raw_from(self.com_record[1]) + """ + cob_id = yield self.com_record[1] self.cob_id = cob_id & 0x1FFFFFFF logger.info("COB-ID is 0x%X", self.cob_id) self.enabled = cob_id & PDO_NOT_VALID == 0 logger.info("PDO is %s", "enabled" if self.enabled else "disabled") self.rtr_allowed = cob_id & RTR_NOT_ALLOWED == 0 logger.info("RTR is %s", "allowed" if self.rtr_allowed else "not allowed") - self.trans_type = _raw_from(self.com_record[2]) + self.trans_type = yield self.com_record[2] logger.info("Transmission type is %d", self.trans_type) if self.trans_type >= 254: try: - self.inhibit_time = _raw_from(self.com_record[3]) + self.inhibit_time = yield self.com_record[3] except (KeyError, SdoAbortedError) as e: logger.info("Could not read inhibit time (%s)", e) else: logger.info("Inhibit time is set to %d ms", self.inhibit_time) try: - self.event_timer = _raw_from(self.com_record[5]) + self.event_timer = yield self.com_record[5] except (KeyError, SdoAbortedError) as e: logger.info("Could not read event timer (%s)", e) else: logger.info("Event timer is set to %d ms", self.event_timer) try: - self.sync_start_value = _raw_from(self.com_record[6]) + self.sync_start_value = yield self.com_record[6] except (KeyError, SdoAbortedError) as e: logger.info("Could not read SYNC start value (%s)", e) else: logger.info("SYNC start value is set to %d ms", self.sync_start_value) self.clear() - nof_entries = _raw_from(self.map_array[0]) + nof_entries = yield self.map_array[0] for subindex in range(1, nof_entries + 1): - value = _raw_from(self.map_array[subindex]) + value = yield self.map_array[subindex] index = value >> 16 subindex = (value >> 8) & 0xFF # Ignore the highest bit, it is never valid for <= 64 PDO length @@ -397,66 +405,147 @@ def _raw_from(param): self.subscribe() - def save(self) -> None: - """Save PDO configuration for this map using SDO.""" + @ensure_not_async # NOTE: Safeguard for accidental async use + def read(self, from_od=False) -> None: + """Read PDO configuration for this map. + + :param from_od: + Read using SDO if False, read from object dictionary if True. + When reading from object dictionary, if DCF populated a value, the + DCF value will be used, otherwise the EDS default will be used instead. + """ + gen = self.read_generator() + param = next(gen) + while param: + if from_od: + # Use value from OD + if param.od.value is not None: + value = param.od.value + else: + value = param.od.default + else: + # Get value from SDO + # NOTE: Blocking - protected in SdoClient + value = param.raw + try: + # Deliver value into read_generator and wait for next object + param = gen.send(value) + except StopIteration: + break + + async def aread(self, from_od=False) -> None: + """Read PDO configuration for this map. Async variant. + + :param from_od: + Read using SDO if False, read from object dictionary if True. + When reading from object dictionary, if DCF populated a value, the + DCF value will be used, otherwise the EDS default will be used instead. + """ + gen = self.read_generator() + param = next(gen) + while param: + if from_od: + # Use value from OD + if param.od.value is not None: + value = param.od.value + else: + value = param.od.default + else: + # Get value from SDO + value = await param.aget_raw() + try: + param = gen.send(value) + except StopIteration: + break + + def save_generator(self): + """Generator to run through steps for saving the PDO configuration + using SDO. + + This function does not do any io. This must be done by the caller. + + """ if self.cob_id is None: logger.info("Skip saving %s: COB-ID was never set", self.com_record.od.name) return logger.info("Setting COB-ID 0x%X and temporarily disabling PDO", self.cob_id) - self.com_record[1].raw = self.cob_id | PDO_NOT_VALID | (RTR_NOT_ALLOWED if not self.rtr_allowed else 0x0) + yield self.com_record[1], self.cob_id | PDO_NOT_VALID | (RTR_NOT_ALLOWED if not self.rtr_allowed else 0x0) if self.trans_type is not None: logger.info("Setting transmission type to %d", self.trans_type) - self.com_record[2].raw = self.trans_type + yield self.com_record[2], self.trans_type if self.inhibit_time is not None: logger.info("Setting inhibit time to %d us", (self.inhibit_time * 100)) - self.com_record[3].raw = self.inhibit_time + yield self.com_record[3], self.inhibit_time if self.event_timer is not None: logger.info("Setting event timer to %d ms", self.event_timer) - self.com_record[5].raw = self.event_timer + yield self.com_record[5], self.event_timer if self.sync_start_value is not None: logger.info("Setting SYNC start value to %d", self.sync_start_value) - self.com_record[6].raw = self.sync_start_value + yield self.com_record[6], self.sync_start_value - try: - self.map_array[0].raw = 0 - except SdoAbortedError: - # WORKAROUND for broken implementations: If the array has a - # fixed number of entries (count not writable), generate dummy - # mappings for an invalid object 0x0000:00 to overwrite any - # excess entries with all-zeros. - self._fill_map(self.map_array[0].raw) - subindex = 1 - for var in self.map: - logger.info("Writing %s (0x%04X:%02X, %d bits) to PDO map", - var.name, var.index, var.subindex, var.length) - if getattr(self.pdo_node.node, "curtis_hack", False): - # Curtis HACK: mixed up field order - self.map_array[subindex].raw = (var.index | - var.subindex << 16 | - var.length << 24) - else: - self.map_array[subindex].raw = (var.index << 16 | - var.subindex << 8 | - var.length) - subindex += 1 - try: - self.map_array[0].raw = len(self.map) - except SdoAbortedError as e: - # WORKAROUND for broken implementations: If the array - # number-of-entries parameter is not writable, we have already - # generated the required number of mappings above. - if e.code != 0x06010002: - # Abort codes other than "Attempt to write a read-only - # object" should still be reported. - raise - self._update_data_size() + if self.map is not None: + try: + yield self.map_array[0], 0 + except SdoAbortedError: + # WORKAROUND for broken implementations: If the array has a + # fixed number of entries (count not writable), generate dummy + # mappings for an invalid object 0x0000:00 to overwrite any + # excess entries with all-zeros. + # + # The '@@fill_map' yield will run + # self._fill_map(self.map_array[0].raw()) + yield self.map_array[0], '@@fill_map' + subindex = 1 + for var in self.map: + logger.info("Writing %s (0x%04X:%02X, %d bits) to PDO map", + var.name, var.index, var.subindex, var.length) + if hasattr(self.pdo_node.node, "curtis_hack", False): + # Curtis HACK: mixed up field order + yield self.map_array[subindex], (var.index | + var.subindex << 16 | + var.length << 24) + else: + yield self.map_array[subindex], (var.index << 16 | + var.subindex << 8 | + var.length) + subindex += 1 + try: + yield self.map_array[0], len(self.map) + except SdoAbortedError as e: + # WORKAROUND for broken implementations: If the array + # number-of-entries parameter is not writable, we have already + # generated the required number of mappings above. + if e.code != 0x06010002: + # Abort codes other than "Attempt to write a read-only + # object" should still be reported. + raise + self._update_data_size() if self.enabled: cob_id = self.cob_id | (RTR_NOT_ALLOWED if not self.rtr_allowed else 0x0) logger.info("Setting COB-ID 0x%X and re-enabling PDO", cob_id) - self.com_record[1].raw = cob_id + yield self.com_record[1], cob_id self.subscribe() + @ensure_not_async # NOTE: Safeguard for accidental async use + def save(self) -> None: + """Read PDO configuration for this map using SDO.""" + for sdo, value in self.save_generator(): + if value == '@@fillmap': + # NOTE: Blocking - protected in SdoClient + self._fill_map(sdo.raw) + else: + # NOTE: Blocking call + sdo.raw = value + + async def asave(self) -> None: + """Read PDO configuration for this map using SDO, async variant.""" + for sdo, value in self.save_generator(): + if value == '@@fillmap': + self._fill_map(await sdo.aget_raw()) + else: + await sdo.aset_raw(value) + def subscribe(self) -> None: """Register the PDO for reception on the network. @@ -555,17 +644,28 @@ def remote_request(self) -> None: if self.enabled and self.rtr_allowed: self.pdo_node.network.send_message(self.cob_id, bytes(), remote=True) + @ensure_not_async # NOTE: Safeguard for accidental async use def wait_for_reception(self, timeout: float = 10) -> float: """Wait for the next transmit PDO. :param float timeout: Max time to wait in seconds. :return: Timestamp of message received or None if timeout. """ + # NOTE: Blocking lock with self.receive_condition: self.is_received = False + # NOTE: Blocking call self.receive_condition.wait(timeout) return self.timestamp if self.is_received else None + async def await_for_reception(self, timeout: float = 10) -> float: + """Wait for the next transmit PDO. + + :param float timeout: Max time to wait in seconds. + :return: Timestamp of message received or None if timeout. + """ + return await asyncio.to_thread(self.wait_for_reception, timeout) + class PdoVariable(variable.Variable): """One object dictionary variable mapped to a PDO.""" @@ -605,6 +705,11 @@ def get_data(self) -> bytes: return data + async def aget_data(self) -> bytes: + # Since get_data() is not making any IO, it can be called + # directly with no special async variant + return self.get_data() + def set_data(self, data: bytes): """Set for the given variable the PDO data. @@ -638,6 +743,11 @@ def set_data(self, data: bytes): self.pdo_parent.update() + async def aset_data(self, data: bytes): + # Since get_data() is not making any IO, it can be called + # directly with no special async variant + return self.set_data(data) + # For compatibility Variable = PdoVariable diff --git a/canopen/profiles/p402.py b/canopen/profiles/p402.py index 88d86aad..4b6ee8a9 100644 --- a/canopen/profiles/p402.py +++ b/canopen/profiles/p402.py @@ -3,10 +3,19 @@ import time from typing import Dict +from canopen.async_guard import ensure_not_async from canopen.node import RemoteNode from canopen.pdo import PdoMap from canopen.sdo import SdoCommunicationError +""" +NOTE: Async compatibility +This file is not async compatible, as it contains numerous setters and getters +in many of its function. The BaseNode402 class should probably be refactored +and ported to a design which is async compatible. For now, "ensure_not_async" +guard is installed in its init function to warn the user not to use it. +""" + logger = logging.getLogger(__name__) @@ -212,6 +221,12 @@ class BaseNode402(RemoteNode): TIMEOUT_CHECK_TPDO = 0.2 # seconds TIMEOUT_HOMING_DEFAULT = 30 # seconds + # FIXME: Add async implementation of this class + + # NOTE: This safeguard is placed to prevent accidental async use of this + # class, as it is not async compatible. + + @ensure_not_async # NOTE: Safeguard for accidental async use def __init__(self, node_id, object_dictionary): super(BaseNode402, self).__init__(node_id, object_dictionary) self.tpdo_values = {} # { index: value from last received TPDO } @@ -250,6 +265,7 @@ def setup_pdos(self, upload=True): def _init_tpdo_values(self): for tpdo in self.tpdo.values(): if tpdo.enabled: + # NOTE: Adding blocking callback tpdo.add_callback(self.on_TPDOs_update_callback) for obj in tpdo: logger.debug('Configured TPDO: 0x%04X', obj.index) @@ -289,35 +305,47 @@ def _check_op_mode_configured(self): "Operation Mode Display not configured in node %s's PDOs. Using SDOs can cause slow performance.", self.id) + # NOTE: Blocking def reset_from_fault(self): """Reset node from fault and set it to Operation Enable state.""" + # NOTE: Blocking getter on errors if self.state == 'FAULT': # Resets the Fault Reset bit (rising edge 0 -> 1) + # NOTE: Blocking setter self.controlword = State402.CW_DISABLE_VOLTAGE # FIXME! The rising edge happens with the transitions toward OPERATION # ENABLED below, but until then the loop will always reach the timeout! timeout = time.monotonic() + self.TIMEOUT_RESET_FAULT + # NOTE: Blocking on errors while self.is_faulted(): if time.monotonic() > timeout: break + # NOTE: Blocking self.check_statusword() + # NOTE: Blocking setter self.state = 'OPERATION ENABLED' + # NOTE: Blocking on errors def is_faulted(self): bitmask, bits = State402.SW_MASK['FAULT'] + # NOTE: Blocking getter on errors return self.statusword & bitmask == bits + # NOTE Blocking def _homing_status(self): """Interpret the current Statusword bits as homing state string.""" # Wait to make sure a TPDO was received + # NOTE: Blocking self.check_statusword() status = None for key, value in Homing.STATES.items(): bitmask, bits = value + # NOTE: Blocking getter on errors if self.statusword & bitmask == bits: status = key return status + # NOTE: Blocking def is_homed(self, restore_op_mode=False): """Switch to homing mode and determine its status. @@ -325,15 +353,20 @@ def is_homed(self, restore_op_mode=False): :return: If the status indicates successful homing. :rtype: bool """ + # NOTE: Blocking getter previous_op_mode = self.op_mode if previous_op_mode != 'HOMING': logger.info('Switch to HOMING from %s', previous_op_mode) + # NOTE: Blocking setter self.op_mode = 'HOMING' # blocks until confirmed + # NOTE: Blocking homingstatus = self._homing_status() if restore_op_mode: + # NOTE: Blocking setter self.op_mode = previous_op_mode return homingstatus in ('TARGET REACHED', 'ATTAINED') + # NOTE: Blocking def homing(self, timeout=None, restore_op_mode=False): """Execute the configured Homing method on the node. @@ -346,17 +379,23 @@ def homing(self, timeout=None, restore_op_mode=False): if timeout is None: timeout = self.TIMEOUT_HOMING_DEFAULT if restore_op_mode: + # NOTE: Blocking getter previous_op_mode = self.op_mode + # NOTE: Blocking setter self.op_mode = 'HOMING' # The homing process will initialize at operation enabled + # NOTE: Blocking setter self.state = 'OPERATION ENABLED' homingstatus = 'UNKNOWN' + # NOTE: Blocking setter self.controlword = State402.CW_OPERATION_ENABLED | Homing.CW_START # does not block # Wait for one extra cycle, to make sure the controlword was received + # NOTE: Blocking self.check_statusword() t = time.monotonic() + timeout try: while homingstatus not in ('TARGET REACHED', 'ATTAINED'): + # NOTE: Blocking homingstatus = self._homing_status() if homingstatus in ('INTERRUPTED', 'ERROR VELOCITY IS NOT ZERO', 'ERROR VELOCITY IS ZERO'): @@ -369,9 +408,11 @@ def homing(self, timeout=None, restore_op_mode=False): logger.info(str(e)) finally: if restore_op_mode: + # NOTE: Blocking setter self.op_mode = previous_op_mode return False + # NOTE: Blocking getter @property def op_mode(self): """The node's Operation Mode stored in the object 0x6061. @@ -398,15 +439,18 @@ def op_mode(self): try: pdo = self.tpdo_pointers[0x6061].pdo_parent if pdo.is_periodic: + # NOTE: Call to blocking method timestamp = pdo.wait_for_reception(timeout=self.TIMEOUT_CHECK_TPDO) if timestamp is None: raise RuntimeError(f"Timeout getting node {self.id}'s mode of operation.") code = self.tpdo_values[0x6061] except KeyError: logger.warning('The object 0x6061 is not a configured TPDO, fallback to SDO') + # NOTE: Blocking - protected in SdoClient code = self.sdo[0x6061].raw return OperationMode.CODE2NAME[code] + # NOTE: Blocking setter @op_mode.setter def op_mode(self, mode): try: @@ -415,13 +459,16 @@ def op_mode(self, mode): f'Operation mode {mode} not suppported on node {self.id}.') # Update operation mode in RPDO if possible, fall back to SDO if 0x6060 in self.rpdo_pointers: + # NOTE: Blocking - protected in SdoClient self.rpdo_pointers[0x6060].raw = OperationMode.NAME2CODE[mode] pdo = self.rpdo_pointers[0x6060].pdo_parent if not pdo.is_periodic: pdo.transmit() else: + # NOTE: Blocking - protected in SdoClient self.sdo[0x6060].raw = OperationMode.NAME2CODE[mode] timeout = time.monotonic() + self.TIMEOUT_SWITCH_OP_MODE + # NOTE: Blocking getter while self.op_mode != mode: if time.monotonic() > timeout: raise RuntimeError( @@ -432,12 +479,15 @@ def op_mode(self, mode): except (RuntimeError, ValueError) as e: logger.warning(str(e)) + # NOTE: Blocking def _clear_target_values(self): # [target velocity, target position, target torque] for target_index in [0x60FF, 0x607A, 0x6071]: if target_index in self.sdo.keys(): + # NOTE: Blocking - protected in SdoClient self.sdo[target_index].raw = 0 + # NOTE: Blocking def is_op_mode_supported(self, mode): """Check if the operation mode is supported by the node. @@ -450,20 +500,26 @@ def is_op_mode_supported(self, mode): """ if not hasattr(self, '_op_mode_support'): # Cache value only on first lookup, this object should never change. + # NOTE: Blocking - protected in SdoClient self._op_mode_support = self.sdo[0x6502].raw logger.info('Caching node %s supported operation modes 0x%04X', self.id, self._op_mode_support) bits = OperationMode.SUPPORTED[mode] return self._op_mode_support & bits == bits + # NOTE: Blocking def on_TPDOs_update_callback(self, mapobject: PdoMap): """Cache updated values from a TPDO received from this node. :param mapobject: The received PDO message. """ + # NOTE: Callback. Called from another thread unless async for obj in mapobject: + # FIXME: Is this thread-safe? + # NOTE: Blocking - protected in SdoClient self.tpdo_values[obj.index] = obj.raw + # NOTE: Blocking getter on errors @property def statusword(self): """Return the last read value of the Statusword (0x6041) from the device. @@ -475,8 +531,10 @@ def statusword(self): return self.tpdo_values[0x6041] except KeyError: logger.warning('The object 0x6041 is not a configured TPDO, fallback to SDO') + # NOTE: Blocking - protected in SdoClient return self.sdo[0x6041].raw + # NOTE: Blocking, conditional def check_statusword(self, timeout=None): """Report an up-to-date reading of the Statusword (0x6041) from the device. @@ -496,7 +554,9 @@ def check_statusword(self, timeout=None): if timestamp is None: raise RuntimeError('Timeout waiting for updated statusword') else: + # NOTE: Blocking - protected in SdoClient return self.sdo[0x6041].raw + # NOTE: Blocking getter on errors return self.statusword @property @@ -508,16 +568,20 @@ def controlword(self): """ raise RuntimeError('The Controlword is write-only.') + # NOTE: Blocking setter @controlword.setter def controlword(self, value): if 0x6040 in self.rpdo_pointers: + # NOTE: Blocking - protected in SdoClient self.rpdo_pointers[0x6040].raw = value pdo = self.rpdo_pointers[0x6040].pdo_parent if not pdo.is_periodic: pdo.transmit() else: + # NOTE: Blocking - protected in SdoClient self.sdo[0x6040].raw = value + # NOTE: Blocking getter on errors @property def state(self): """Manipulate current state of the DS402 State Machine on the node. @@ -541,42 +605,55 @@ def state(self): """ for state, mask_val_pair in State402.SW_MASK.items(): bitmask, bits = mask_val_pair + # NOTE: Blocking getter on errors if self.statusword & bitmask == bits: return state return 'UNKNOWN' + # NOTE: Blocking setter @state.setter def state(self, target_state): timeout = time.monotonic() + self.TIMEOUT_SWITCH_STATE_FINAL + # NOTE: Blocking getter on errors while self.state != target_state: + # NOTE: Blocking next_state = self._next_state(target_state) + # NOTE: Blocking if self._change_state(next_state): continue if time.monotonic() > timeout: raise RuntimeError('Timeout when trying to change state') + # NOTE: Blocking self.check_statusword() + # NOTE: Blocking def _next_state(self, target_state): if target_state in ('NOT READY TO SWITCH ON', 'FAULT REACTION ACTIVE', 'FAULT'): raise ValueError( f'Target state {target_state} cannot be entered programmatically') + # NOTE: Blocking getter on errors from_state = self.state if (from_state, target_state) in State402.TRANSITIONTABLE: return target_state else: return State402.next_state_indirect(from_state) + # NOTE: Blocking def _change_state(self, target_state): try: + # NOTE: Blocking setter, getter on errors self.controlword = State402.TRANSITIONTABLE[(self.state, target_state)] except KeyError: + # NOTE: Blocking getter on errors raise ValueError( f'Illegal state transition from {self.state} to {target_state}') timeout = time.monotonic() + self.TIMEOUT_SWITCH_STATE_SINGLE + # NOTE: Blocking getter on errors while self.state != target_state: if time.monotonic() > timeout: return False + # NOTE: Blocking self.check_statusword() return True diff --git a/canopen/sdo/base.py b/canopen/sdo/base.py index ddc75ed9..4996c649 100644 --- a/canopen/sdo/base.py +++ b/canopen/sdo/base.py @@ -7,6 +7,7 @@ import canopen.network from canopen import objectdictionary from canopen import variable +from canopen.async_guard import ensure_not_async from canopen.utils import pretty_index @@ -83,6 +84,9 @@ def get_variable( def upload(self, index: int, subindex: int) -> bytes: raise NotImplementedError() + async def aupload(self, index: int, subindex: int) -> bytes: + raise NotImplementedError() + def download( self, index: int, @@ -92,6 +96,15 @@ def download( ) -> None: raise NotImplementedError() + async def adownload( + self, + index: int, + subindex: int, + data: bytes, + force_segment: bool = False, + ) -> None: + raise NotImplementedError() + class SdoRecord(Mapping): @@ -109,10 +122,20 @@ def __iter__(self) -> Iterator[int]: # Skip the "highest subindex" entry, which is not part of the data return filter(None, iter(self.od)) + async def aiter(self): + for i in iter(self.od): + yield i + + def __aiter__(self): + return self.aiter() + def __len__(self) -> int: # Skip the "highest subindex" entry, which is not part of the data return len(self.od) - int(0 in self.od) + async def alen(self) -> int: + return len(self.od) + def __contains__(self, subindex: Union[int, str]) -> bool: return subindex in self.od @@ -133,9 +156,20 @@ def __iter__(self) -> Iterator[int]: # Skip the "highest subindex" entry, which is not part of the data return iter(range(1, len(self) + 1)) + async def aiter(self): + for i in range(1, await self.alen() + 1): + yield i + + def __aiter__(self): + return self.aiter() + def __len__(self) -> int: + # NOTE: Blocking - protected in SdoClient return self[0].raw + async def alen(self) -> int: + return await self[0].aget_raw() # type: ignore[return-value] + def __contains__(self, subindex: int) -> bool: return 0 <= subindex <= len(self) @@ -147,13 +181,25 @@ def __init__(self, sdo_node: SdoBase, od: objectdictionary.ODVariable): self.sdo_node = sdo_node variable.Variable.__init__(self, od) + def __await__(self): + return self.aget_raw().__await__() + + @ensure_not_async # NOTE: Safeguard for accidental async use def get_data(self) -> bytes: return self.sdo_node.upload(self.od.index, self.od.subindex) + async def aget_data(self) -> bytes: + return await self.sdo_node.aupload(self.od.index, self.od.subindex) + + @ensure_not_async # NOTE: Safeguard for accidental async use def set_data(self, data: bytes): force_segment = self.od.data_type == objectdictionary.DOMAIN self.sdo_node.download(self.od.index, self.od.subindex, data, force_segment) + async def aset_data(self, data: bytes): + force_segment = self.od.data_type == objectdictionary.DOMAIN + await self.sdo_node.adownload(self.od.index, self.od.subindex, data, force_segment) + @property def writable(self) -> bool: return self.od.writable @@ -162,6 +208,7 @@ def writable(self) -> bool: def readable(self) -> bool: return self.od.readable + @ensure_not_async # NOTE: Safeguard for accidental async use def open(self, mode="rb", encoding="ascii", buffering=1024, size=None, block_transfer=False, request_crc_support=True): """Open the data stream as a file like object. @@ -196,6 +243,13 @@ def open(self, mode="rb", encoding="ascii", buffering=1024, size=None, return self.sdo_node.open(self.od.index, self.od.subindex, mode, encoding, buffering, size, block_transfer, request_crc_support=request_crc_support) + async def aopen(self, mode="rb", encoding="ascii", buffering=1024, size=None, + block_transfer=False, request_crc_support=True): + """Open the data stream as a file like object. See open()""" + return await self.sdo_node.aopen(self.od.index, self.od.subindex, mode, + encoding, buffering, size, block_transfer, + request_crc_support=request_crc_support) + # For compatibility Record = SdoRecord diff --git a/canopen/sdo/client.py b/canopen/sdo/client.py index 76fa6fbc..490cba05 100644 --- a/canopen/sdo/client.py +++ b/canopen/sdo/client.py @@ -1,3 +1,4 @@ +import asyncio import io import logging import queue @@ -7,6 +8,7 @@ from can import CanError from canopen import objectdictionary +from canopen.async_guard import ensure_not_async from canopen.sdo.base import SdoBase from canopen.sdo.constants import * from canopen.sdo.exceptions import * @@ -42,13 +44,17 @@ def __init__(self, rx_cobid, tx_cobid, od): """ SdoBase.__init__(self, rx_cobid, tx_cobid, od) self.responses = queue.Queue() + self.lock = asyncio.Lock() # For ensuring only one pending SDO request in async + # @callback # NOTE: called from another thread def on_response(self, can_id, data, timestamp): self.responses.put(bytes(data)) + @ensure_not_async # NOTE: Safeguard for accidental async use def send_request(self, request): retries_left = self.MAX_RETRIES if self.PAUSE_BEFORE_SEND: + # NOTE: Blocking time.sleep(self.PAUSE_BEFORE_SEND) while True: try: @@ -60,12 +66,14 @@ def send_request(self, request): raise logger.info(str(e)) if self.RETRY_DELAY: + # NOTE: Blocking time.sleep(self.RETRY_DELAY) else: break def read_response(self): try: + # NOTE: Blocking call response = self.responses.get( block=True, timeout=self.RESPONSE_TIMEOUT) except queue.Empty: @@ -79,7 +87,8 @@ def read_response(self): def request_response(self, sdo_request): retries_left = self.MAX_RETRIES if not self.responses.empty(): - # logger.warning("There were unexpected messages in the queue") + # FIXME: Recreating the queue + logger.warning("There were unexpected messages in the queue") self.responses = queue.Queue() while True: self.send_request(sdo_request) @@ -102,6 +111,11 @@ def abort(self, abort_code=0x08000000): self.send_request(request) logger.error("Transfer aborted by client with code 0x%08X", abort_code) + async def aabort(self, abort_code=0x08000000): + """Abort current transfer. Async version.""" + return await asyncio.to_thread(self.abort, abort_code) + + @ensure_not_async # NOTE: Safeguard for accidental async use def upload(self, index: int, subindex: int) -> bytes: """May be called to make a read operation without an Object Dictionary. @@ -120,7 +134,9 @@ def upload(self, index: int, subindex: int) -> bytes: with self.open(index, subindex, buffering=0) as fp: response_size = fp.size data = fp.read() + return self.truncate_data(index, subindex, data, response_size) + def truncate_data(self, index: int, subindex: int, data: bytes, size: int) -> bytes: # If size is available through variable in OD, then use the smaller of the two sizes. # Some devices send U32/I32 even if variable is smaller in OD var = self.od.get_variable(index, subindex) @@ -129,11 +145,31 @@ def upload(self, index: int, subindex: int) -> bytes: if var.fixed_size: # Get the size in bytes for this variable var_size = len(var) // 8 - if response_size is None or var_size < response_size: + if size is None or var_size < size: # Truncate the data to specified size data = data[0:var_size] return data + async def aupload(self, index: int, subindex: int) -> bytes: + """May be called to make a read operation without an Object Dictionary. + Async version. + """ + async with self.lock: # Ensure only one active SDO request per channel + # Deferring to thread because there are sleeps and queue waits in the call chain + # The call stack is typically: + # upload -> open -> ReadableStream -> request_reponse -> send_request -> network.send_message + # recv -> on_reponse -> queue.put + # request_reponse -> read_response -> queue.get + def _upload(): + with self.open(index, subindex, buffering=0) as fp: + response_size = fp.size + data = fp.read() + return data, response_size + + data, response_size = await asyncio.to_thread(_upload) + return self.truncate_data(index, subindex, data, response_size) + + @ensure_not_async # NOTE: Safeguard for accidental async use def download( self, index: int, @@ -161,6 +197,27 @@ def download( force_segment=force_segment) as fp: fp.write(data) + async def adownload( + self, + index: int, + subindex: int, + data: bytes, + force_segment: bool = False, + ) -> None: + """May be called to make a write operation without an Object Dictionary. + Async version. + """ + async with self.lock: # Ensure only one active SDO request per channel + # Deferring to thread because there are sleeps in the call chain + + def _download(): + with self.open(index, subindex, "wb", buffering=7, size=len(data), + force_segment=force_segment) as fp: + fp.write(data) + + return await asyncio.to_thread(_download) + + @ensure_not_async # NOTE: Safeguard for accidental async use def open(self, index, subindex=0, mode="rb", encoding="ascii", buffering=1024, size=None, block_transfer=False, force_segment=False, request_crc_support=True): """Open the data stream as a file like object. diff --git a/canopen/sdo/server.py b/canopen/sdo/server.py index 8d693633..0f6eec2c 100644 --- a/canopen/sdo/server.py +++ b/canopen/sdo/server.py @@ -1,5 +1,6 @@ import logging +from canopen.async_guard import ensure_not_async from canopen.sdo.base import SdoBase from canopen.sdo.constants import * from canopen.sdo.exceptions import * @@ -28,7 +29,9 @@ def __init__(self, rx_cobid, tx_cobid, node): self._subindex = None self.last_received_error = 0x00000000 + # @callback # NOTE: called from another thread def on_request(self, can_id, data, timestamp): + # FIXME: There is a lot of calls here, this must be checked for thread safe command, = struct.unpack_from("B", data, 0) ccs = command & 0xE0 @@ -190,6 +193,7 @@ def abort(self, abort_code=0x08000000): self.send_response(data) # logger.error("Transfer aborted with code 0x%08X", abort_code) + @ensure_not_async # NOTE: Safeguard for accidental async use def upload(self, index: int, subindex: int) -> bytes: """May be called to make a read operation without an Object Dictionary. @@ -205,6 +209,22 @@ def upload(self, index: int, subindex: int) -> bytes: """ return self._node.get_data(index, subindex) + async def aupload(self, index: int, subindex: int) -> bytes: + """May be called to make a read operation without an Object Dictionary. + + :param index: + Index of object to read. + :param subindex: + Sub-index of object to read. + + :return: A data object. + + :raises canopen.SdoAbortedError: + When node responds with an error. + """ + return self._node.get_data(index, subindex) + + @ensure_not_async # NOTE: Safeguard for accidental async use def download( self, index: int, @@ -225,3 +245,24 @@ def download( When node responds with an error. """ return self._node.set_data(index, subindex, data) + + async def adownload( + self, + index: int, + subindex: int, + data: bytes, + force_segment: bool = False, + ): + """May be called to make a write operation without an Object Dictionary. + + :param index: + Index of object to write. + :param subindex: + Sub-index of object to write. + :param data: + Data to be written. + + :raises canopen.SdoAbortedError: + When node responds with an error. + """ + return self._node.set_data(index, subindex, data) diff --git a/canopen/variable.py b/canopen/variable.py index d2538c3f..ff2c47ce 100644 --- a/canopen/variable.py +++ b/canopen/variable.py @@ -3,6 +3,7 @@ from typing import Union from canopen import objectdictionary +from canopen.async_guard import ensure_not_async from canopen.utils import pretty_index @@ -33,9 +34,15 @@ def __repr__(self) -> str: def get_data(self) -> bytes: raise NotImplementedError("Variable is not readable") + async def aget_data(self) -> bytes: + raise NotImplementedError("Variable is not readable") + def set_data(self, data: bytes): raise NotImplementedError("Variable is not writable") + async def aset_data(self, data: bytes): + raise NotImplementedError("Variable is not writable") + @property def data(self) -> bytes: """Byte representation of the object as :class:`bytes`.""" @@ -75,7 +82,14 @@ def raw(self) -> Union[int, bool, float, str, bytes]: Data types that this library does not handle yet must be read and written as :class:`bytes`. """ - value = self.od.decode_raw(self.data) + return self._get_raw(self.get_data()) + + async def aget_raw(self) -> Union[int, bool, float, str, bytes]: + """Raw representation of the object, async variant""" + return self._get_raw(await self.aget_data()) + + def _get_raw(self, data: bytes) -> Union[int, bool, float, str, bytes]: + value = self.od.decode_raw(data) text = f"Value of {self.name!r} ({pretty_index(self.index, self.subindex)}) is {value!r}" if value in self.od.value_descriptions: text += f" ({self.od.value_descriptions[value]})" @@ -84,10 +98,17 @@ def raw(self) -> Union[int, bool, float, str, bytes]: @raw.setter def raw(self, value: Union[int, bool, float, str, bytes]): + self.set_data(self._set_raw(value)) + + async def aset_raw(self, value: Union[int, bool, float, str, bytes]): + """Set the raw value of the object, async variant""" + await self.aset_data(self._set_raw(value)) + + def _set_raw(self, value: Union[int, bool, float, str, bytes]): logger.debug("Writing %r (0x%04X:%02X) = %r", self.name, self.index, self.subindex, value) - self.data = self.od.encode_raw(value) + return self.od.encode_raw(value) @property def phys(self) -> Union[int, bool, float, str, bytes]: @@ -97,7 +118,14 @@ def phys(self) -> Union[int, bool, float, str, bytes]: either a :class:`float` or an :class:`int`. Non integers will be passed as is. """ - value = self.od.decode_phys(self.raw) + return self._get_phys(self.raw) + + async def aget_phys(self) -> Union[int, bool, float, str, bytes]: + """Physical value scaled with some factor (defaults to 1), async variant.""" + return self._get_phys(await self.aget_raw()) + + def _get_phys(self, raw: Union[int, bool, float, str, bytes]): + value = self.od.decode_phys(raw) if self.od.unit: logger.debug("Physical value is %s %s", value, self.od.unit) return value @@ -106,6 +134,10 @@ def phys(self) -> Union[int, bool, float, str, bytes]: def phys(self, value: Union[int, bool, float, str, bytes]): self.raw = self.od.encode_phys(value) + async def aset_phys(self, value: Union[int, bool, float, str, bytes]): + """Set physical value scaled with some factor (defaults to 1). Async variant""" + await self.aset_raw(self.od.encode_phys(value)) + @property def desc(self) -> str: """Converts to and from a description of the value as a string.""" @@ -113,10 +145,20 @@ def desc(self) -> str: logger.debug("Description is '%s'", value) return value + async def aget_desc(self) -> str: + """Converts to and from a description of the value as a string, async variant.""" + value = self.od.decode_desc(await self.aget_raw()) + logger.debug("Description is '%s'", value) + return value + @desc.setter def desc(self, desc: str): self.raw = self.od.encode_desc(desc) + async def aset_desc(self, desc: str): + """Set variable description, async variant.""" + await self.aset_raw(self.od.encode_desc(desc)) + @property def bits(self) -> "Bits": """Access bits using integers, slices, or bit descriptions.""" @@ -143,6 +185,16 @@ def read(self, fmt: str = "raw") -> Union[int, bool, float, str, bytes]: elif fmt == "desc": return self.desc + async def aread(self, fmt: str = "raw") -> Union[int, bool, float, str, bytes]: + """Alternative way of reading using a function instead of attributes. Async variant.""" + if fmt == "raw": + return await self.aget_raw() + elif fmt == "phys": + return await self.aget_phys() + elif fmt == "desc": + return await self.aget_desc() + raise ValueError(f"Unknown format '{fmt}'") + def write( self, value: Union[int, bool, float, str, bytes], fmt: str = "raw" ) -> None: @@ -163,11 +215,24 @@ def write( elif fmt == "desc": self.desc = value + async def awrite( + self, value: Union[int, bool, float, str, bytes], fmt: str = "raw" + ) -> None: + """Alternative way of writing using a function instead of attributes. Async variant""" + if fmt == "raw": + await self.aset_raw(value) + elif fmt == "phys": + await self.aset_phys(value) + elif fmt == "desc": + await self.aset_desc(value) # type: ignore[arg-type] + class Bits(Mapping): + @ensure_not_async # NOTE: Safeguard for accidental async use def __init__(self, variable: Variable): self.variable = variable + # FIXME: This is not compatible with async self.read() @staticmethod @@ -199,3 +264,9 @@ def read(self): def write(self): self.variable.raw = self.raw + + async def aread(self): + self.raw = await self.variable.aget_raw() + + async def awrite(self): + await self.variable.aset_raw(self.raw) diff --git a/examples/canopen_async.py b/examples/canopen_async.py new file mode 100644 index 00000000..371c5047 --- /dev/null +++ b/examples/canopen_async.py @@ -0,0 +1,69 @@ +import asyncio +import logging +import canopen + +# Set logging output +logging.basicConfig(level=logging.INFO) +log = logging.getLogger(__name__) + + +async def do_loop(network: canopen.Network, nodeid): + + # Create the node object and load the OD + node: canopen.RemoteNode = await network.aadd_node(nodeid, 'eds/e35.eds') + + # Get the PDOs from the remote + await node.tpdo.aread(from_od=False) + await node.rpdo.aread(from_od=False) + + # Set the remote state + node.nmt.state = 'OPERATIONAL' + + # Set SDO + await node.sdo['something'].aset_raw(2) + + i = 0 + while True: + i += 1 + + # Wait for PDO + t = await node.tpdo[1].await_for_reception(1) + if not t: + continue + + # Get TPDO value + # PDO values are accessed non-synchronously using attributes + state = node.tpdo[1]['state'].raw + + # If state send RPDO to remote + if state == 5: + + await asyncio.sleep(0.2) + + # Set RPDO and transmit + node.rpdo[1]['count'].phys = i + node.rpdo[1].transmit() + + +async def amain(): + + # Create the canopen network and connect it to the CAN bus + loop = asyncio.get_running_loop() + async with canopen.Network(loop=loop).connect( + interface='virtual', bitrate=1000000, recieve_own_messages=True + ) as network: + + # Start two instances and run them concurrently + # NOTE: It is better to use asyncio.TaskGroup to manage tasks, but this + # is not available before Python 3.11. + await asyncio.gather( + asyncio.create_task(do_loop(network, 20)), + asyncio.create_task(do_loop(network, 21)), + ) + + +def main(): + asyncio.run(amain()) + +if __name__ == '__main__': + main() diff --git a/pyproject.toml b/pyproject.toml index e9f3b871..cf433b61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,10 +8,11 @@ authors = [ {name = "Christian Sandberg", email = "christiansandberg@me.com"}, {name = "André Colomb", email = "src@andre.colomb.de"}, {name = "André Filipe Silva", email = "afsilva.work@gmail.com"}, + {name = "Svein Seldal", email = "sveinse@seldal.com"}, ] description = "CANopen stack implementation" readme = "README.rst" -requires-python = ">=3.8" +requires-python = ">=3.9" license = {file = "LICENSE.txt"} classifiers = [ "Development Status :: 5 - Production/Stable", @@ -50,9 +51,17 @@ filterwarnings = [ ] [tool.mypy] -python_version = "3.8" +python_version = "3.9" exclude = [ "^examples*", "^test*", "^setup.py*", ] + +[tool.coverage.run] +branch = true + +[tool.coverage.report] +exclude_also = [ + 'if TYPE_CHECKING:', +] diff --git a/test/test_eds.py b/test/test_eds.py index 68f5ad3c..3c2218e9 100644 --- a/test/test_eds.py +++ b/test/test_eds.py @@ -213,8 +213,6 @@ def test_reading_factor(self): self.assertEqual(var2.factor, 1) self.assertEqual(var2.unit, '') - - def test_comments(self): self.assertEqual(self.od.comments, """ @@ -296,7 +294,6 @@ def test_export_eds_to_stdout(self): buf.name = "mock.eds" self.verify_od(buf, "eds") - def verify_od(self, source, doctype): exported_od = canopen.import_od(source) diff --git a/test/test_emcy.py b/test/test_emcy.py index d883e9c8..966c58c5 100644 --- a/test/test_emcy.py +++ b/test/test_emcy.py @@ -1,20 +1,36 @@ import logging import threading import unittest +import asyncio from contextlib import contextmanager import can import canopen -from canopen.emcy import EmcyError +from canopen.emcy import EmcyError, EmcyConsumer TIMEOUT = 0.1 -class TestEmcy(unittest.TestCase): +class TestEmcy(unittest.IsolatedAsyncioTestCase): + + __test__ = False # This is a base class, tests should not be run directly. + use_async: bool + def setUp(self): - self.emcy = canopen.emcy.EmcyConsumer() + loop = None + if self.use_async: + loop = asyncio.get_event_loop() + self.loop = loop + + self.net = canopen.Network(loop=loop) + self.net.connect(interface="virtual") + self.emcy = EmcyConsumer() + self.emcy.network = self.net + + def tearDown(self): + self.net.disconnect() def check_error(self, err, code, reg, data, ts): self.assertIsInstance(err, EmcyError) @@ -24,7 +40,16 @@ def check_error(self, err, code, reg, data, ts): self.assertEqual(err.data, data) self.assertAlmostEqual(err.timestamp, ts) - def test_emcy_consumer_on_emcy(self): + async def dispatch_emcy(self, can_id, data, ts): + # Dispatch an EMCY datagram. + if self.use_async: + await asyncio.to_thread( + self.emcy.on_emcy, can_id, data, ts + ) + else: + self.emcy.on_emcy(can_id, data, ts) + + async def test_emcy_consumer_on_emcy(self): # Make sure multiple callbacks receive the same information. acc1 = [] acc2 = [] @@ -32,7 +57,7 @@ def test_emcy_consumer_on_emcy(self): self.emcy.add_callback(lambda err: acc2.append(err)) # Dispatch an EMCY datagram. - self.emcy.on_emcy(0x81, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) + await self.dispatch_emcy(0x81, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) self.assertEqual(len(self.emcy.log), 1) self.assertEqual(len(self.emcy.active), 1) @@ -46,7 +71,7 @@ def test_emcy_consumer_on_emcy(self): ) # Dispatch a new EMCY datagram. - self.emcy.on_emcy(0x81, b'\x10\x90\x01\x04\x03\x02\x01\x00', 2000) + await self.dispatch_emcy(0x81, b'\x10\x90\x01\x04\x03\x02\x01\x00', 2000) self.assertEqual(len(self.emcy.log), 2) self.assertEqual(len(self.emcy.active), 2) @@ -59,13 +84,13 @@ def test_emcy_consumer_on_emcy(self): ) # Dispatch an EMCY reset. - self.emcy.on_emcy(0x81, b'\x00\x00\x00\x00\x00\x00\x00\x00', 2000) + await self.dispatch_emcy(0x81, b'\x00\x00\x00\x00\x00\x00\x00\x00', 2000) self.assertEqual(len(self.emcy.log), 3) self.assertEqual(len(self.emcy.active), 0) - def test_emcy_consumer_reset(self): - self.emcy.on_emcy(0x81, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) - self.emcy.on_emcy(0x81, b'\x10\x90\x01\x04\x03\x02\x01\x00', 2000) + async def test_emcy_consumer_reset(self): + await self.dispatch_emcy(0x81, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) + await self.dispatch_emcy(0x81, b'\x10\x90\x01\x04\x03\x02\x01\x00', 2000) self.assertEqual(len(self.emcy.log), 2) self.assertEqual(len(self.emcy.active), 2) @@ -73,7 +98,10 @@ def test_emcy_consumer_reset(self): self.assertEqual(len(self.emcy.log), 0) self.assertEqual(len(self.emcy.active), 0) - def test_emcy_consumer_wait(self): + async def test_emcy_consumer_wait(self): + if self.use_async: + self.skipTest("Not implemented for async") + PAUSE = TIMEOUT / 2 def push_err(): @@ -95,7 +123,10 @@ def timer(func): t.join(TIMEOUT) # Check unfiltered wait, on timeout. - self.assertIsNone(self.emcy.wait(timeout=TIMEOUT)) + if self.use_async: + self.assertIsNone(await self.emcy.async_wait(timeout=TIMEOUT)) + else: + self.assertIsNone(self.emcy.wait(timeout=TIMEOUT)) # Check unfiltered wait, on success. with timer(push_err) as t: @@ -124,6 +155,18 @@ def push_reset(): self.assertIsNone(self.emcy.wait(0x9000, TIMEOUT)) +class TestEmcySync(TestEmcy): + """ Run the tests in non-asynchronous mode. """ + __test__ = True + use_async = False + + +class TestEmcyAsync(TestEmcy): + """ Run the tests in asynchronous mode. """ + __test__ = True + use_async = True + + class TestEmcyError(unittest.TestCase): def test_emcy_error(self): error = EmcyError(0x2001, 0x02, b'\x00\x01\x02\x03\x04', 1000) @@ -181,11 +224,19 @@ def check(code, expected): check(0xffff, "Device Specific") -class TestEmcyProducer(unittest.TestCase): +class TestEmcyProducer(unittest.IsolatedAsyncioTestCase): + + __test__ = False # This is a base class, tests should not be run directly. + use_async: bool + def setUp(self): - self.txbus = can.Bus(interface="virtual") - self.rxbus = can.Bus(interface="virtual") - self.net = canopen.Network(self.txbus) + loop = None + if self.use_async: + loop = asyncio.get_event_loop() + + self.txbus = can.Bus(interface="virtual", loop=loop) + self.rxbus = can.Bus(interface="virtual", loop=loop) + self.net = canopen.Network(self.txbus, loop=loop) self.net.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 self.net.connect() self.emcy = canopen.emcy.EmcyProducer(0x80 + 1) @@ -202,7 +253,7 @@ def check_response(self, expected): actual = msg.data self.assertEqual(actual, expected) - def test_emcy_producer_send(self): + async def test_emcy_producer_send(self): def check(*args, res): self.emcy.send(*args) self.check_response(res) @@ -211,7 +262,7 @@ def check(*args, res): check(0x2001, 0x2, res=b'\x01\x20\x02\x00\x00\x00\x00\x00') check(0x2001, 0x2, b'\x2a', res=b'\x01\x20\x02\x2a\x00\x00\x00\x00') - def test_emcy_producer_reset(self): + async def test_emcy_producer_reset(self): def check(*args, res): self.emcy.reset(*args) self.check_response(res) @@ -221,5 +272,17 @@ def check(*args, res): check(3, b"\xaa\xbb", res=b'\x00\x00\x03\xaa\xbb\x00\x00\x00') +class TestEmcyProducerSync(TestEmcyProducer): + """ Run the tests in non-asynchronous mode. """ + __test__ = True + use_async = False + + +class TestEmcyProducerAsync(TestEmcyProducer): + """ Run the tests in asynchronous mode. """ + __test__ = True + use_async = True + + if __name__ == "__main__": unittest.main() diff --git a/test/test_local.py b/test/test_local.py index 31404bd6..14770133 100644 --- a/test/test_local.py +++ b/test/test_local.py @@ -1,43 +1,56 @@ import time import unittest +import asyncio import canopen +from canopen.async_guard import AllowBlocking from .util import SAMPLE_EDS -class TestSDO(unittest.TestCase): +class TestSDO(unittest.IsolatedAsyncioTestCase): """ Test SDO client and server against each other. """ - @classmethod - def setUpClass(cls): - cls.network1 = canopen.Network() - cls.network1.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 - cls.network1.connect("test", interface="virtual") - cls.remote_node = cls.network1.add_node(2, SAMPLE_EDS) - - cls.network2 = canopen.Network() - cls.network2.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 - cls.network2.connect("test", interface="virtual") - cls.local_node = cls.network2.create_node(2, SAMPLE_EDS) - - cls.remote_node2 = cls.network1.add_node(3, SAMPLE_EDS) - - cls.local_node2 = cls.network2.create_node(3, SAMPLE_EDS) - - @classmethod - def tearDownClass(cls): - cls.network1.disconnect() - cls.network2.disconnect() - - def test_expedited_upload(self): - self.local_node.sdo[0x1400][1].raw = 0x99 - vendor_id = self.remote_node.sdo[0x1400][1].raw + __test__ = False # This is a base class, tests should not be run directly. + use_async: bool + + def setUp(self): + loop = None + if self.use_async: + loop = asyncio.get_event_loop() + + self.network1 = canopen.Network(loop=loop) + self.network1.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 + self.network1.connect("test", interface="virtual") + with AllowBlocking(): + self.remote_node = self.network1.add_node(2, SAMPLE_EDS) + + self.network2 = canopen.Network(loop=loop) + self.network2.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 + self.network2.connect("test", interface="virtual") + self.local_node = self.network2.create_node(2, SAMPLE_EDS) + with AllowBlocking(): + self.remote_node2 = self.network1.add_node(3, SAMPLE_EDS) + self.local_node2 = self.network2.create_node(3, SAMPLE_EDS) + + def tearDown(self): + self.network1.disconnect() + self.network2.disconnect() + + async def test_expedited_upload(self): + if self.use_async: + await self.local_node.sdo[0x1400][1].aset_raw(0x99) + vendor_id = await self.remote_node.sdo[0x1400][1].aget_raw() + else: + self.local_node.sdo[0x1400][1].raw = 0x99 + vendor_id = self.remote_node.sdo[0x1400][1].raw self.assertEqual(vendor_id, 0x99) - def test_block_upload_switch_to_expedite_upload(self): + async def test_block_upload_switch_to_expedite_upload(self): + if self.use_async: + self.skipTest("Block upload not supported in async mode") with self.assertRaises(canopen.SdoCommunicationError) as context: with self.remote_node.sdo[0x1008].open('r', block_transfer=True) as fp: pass @@ -45,7 +58,9 @@ def test_block_upload_switch_to_expedite_upload(self): # from block upload to expedite upload self.assertEqual("Unexpected response 0x41", str(context.exception)) - def test_block_download_not_supported(self): + async def test_block_download_not_supported(self): + if self.use_async: + self.skipTest("Block download not supported in async mode") data = b"TEST DEVICE" with self.assertRaises(canopen.SdoAbortedError) as context: with self.remote_node.sdo[0x1008].open('wb', @@ -54,104 +69,174 @@ def test_block_download_not_supported(self): pass self.assertEqual(context.exception.code, 0x05040001) - def test_expedited_upload_default_value_visible_string(self): - device_name = self.remote_node.sdo["Manufacturer device name"].raw + async def test_expedited_upload_default_value_visible_string(self): + if self.use_async: + device_name = await self.remote_node.sdo["Manufacturer device name"].aget_raw() + else: + device_name = self.remote_node.sdo["Manufacturer device name"].raw self.assertEqual(device_name, "TEST DEVICE") - def test_expedited_upload_default_value_real(self): - sampling_rate = self.remote_node.sdo["Sensor Sampling Rate (Hz)"].raw + async def test_expedited_upload_default_value_real(self): + if self.use_async: + sampling_rate = await self.remote_node.sdo["Sensor Sampling Rate (Hz)"].aget_raw() + else: + sampling_rate = self.remote_node.sdo["Sensor Sampling Rate (Hz)"].raw self.assertAlmostEqual(sampling_rate, 5.2, places=2) - def test_upload_zero_length(self): - self.local_node.sdo["Manufacturer device name"].raw = b"" - with self.assertRaises(canopen.SdoAbortedError) as error: - self.remote_node.sdo["Manufacturer device name"].data + async def test_upload_zero_length(self): + if self.use_async: + await self.local_node.sdo["Manufacturer device name"].aset_raw(b"") + with self.assertRaises(canopen.SdoAbortedError) as error: + await self.remote_node.sdo["Manufacturer device name"].aget_data() + else: + self.local_node.sdo["Manufacturer device name"].raw = b"" + with self.assertRaises(canopen.SdoAbortedError) as error: + self.remote_node.sdo["Manufacturer device name"].data # Should be No data available self.assertEqual(error.exception.code, 0x0800_0024) - def test_segmented_upload(self): - self.local_node.sdo["Manufacturer device name"].raw = "Some cool device" - device_name = self.remote_node.sdo["Manufacturer device name"].data + async def test_segmented_upload(self): + if self.use_async: + await self.local_node.sdo["Manufacturer device name"].aset_raw("Some cool device") + device_name = await self.remote_node.sdo["Manufacturer device name"].aget_data() + else: + self.local_node.sdo["Manufacturer device name"].raw = "Some cool device" + device_name = self.remote_node.sdo["Manufacturer device name"].data self.assertEqual(device_name, b"Some cool device") - def test_expedited_download(self): - self.remote_node.sdo[0x2004].raw = 0xfeff - value = self.local_node.sdo[0x2004].raw + async def test_expedited_download(self): + if self.use_async: + await self.remote_node.sdo[0x2004].aset_raw(0xfeff) + value = await self.local_node.sdo[0x2004].aget_raw() + else: + self.remote_node.sdo[0x2004].raw = 0xfeff + value = self.local_node.sdo[0x2004].raw self.assertEqual(value, 0xfeff) - def test_expedited_download_wrong_datatype(self): + async def test_expedited_download_wrong_datatype(self): # Try to write 32 bit in integer16 type - with self.assertRaises(canopen.SdoAbortedError) as error: - self.remote_node.sdo.download(0x2001, 0x0, bytes([10, 10, 10, 10])) + if self.use_async: + with self.assertRaises(canopen.SdoAbortedError) as error: + await self.remote_node.sdo.adownload(0x2001, 0x0, bytes([10, 10, 10, 10])) + else: + with self.assertRaises(canopen.SdoAbortedError) as error: + self.remote_node.sdo.download(0x2001, 0x0, bytes([10, 10, 10, 10])) self.assertEqual(error.exception.code, 0x06070010) # Try to write normal 16 bit word, should be ok - self.remote_node.sdo.download(0x2001, 0x0, bytes([10, 10])) - value = self.remote_node.sdo.upload(0x2001, 0x0) + if self.use_async: + await self.remote_node.sdo.adownload(0x2001, 0x0, bytes([10, 10])) + value = await self.remote_node.sdo.aupload(0x2001, 0x0) + else: + self.remote_node.sdo.download(0x2001, 0x0, bytes([10, 10])) + value = self.remote_node.sdo.upload(0x2001, 0x0) self.assertEqual(value, bytes([10, 10])) - def test_segmented_download(self): - self.remote_node.sdo[0x2000].raw = "Another cool device" - value = self.local_node.sdo[0x2000].data + async def test_segmented_download(self): + if self.use_async: + await self.remote_node.sdo[0x2000].aset_raw("Another cool device") + value = await self.local_node.sdo[0x2000].aget_data() + else: + self.remote_node.sdo[0x2000].raw = "Another cool device" + value = self.local_node.sdo[0x2000].data self.assertEqual(value, b"Another cool device") - def test_slave_send_heartbeat(self): + async def test_slave_send_heartbeat(self): # Setting the heartbeat time should trigger heartbeating # to start - self.remote_node.sdo["Producer heartbeat time"].raw = 100 - state = self.remote_node.nmt.wait_for_heartbeat() + if self.use_async: + await self.remote_node.sdo["Producer heartbeat time"].aset_raw(100) + state = await self.remote_node.nmt.await_for_heartbeat() + else: + self.remote_node.sdo["Producer heartbeat time"].raw = 100 + state = self.remote_node.nmt.wait_for_heartbeat() self.local_node.nmt.stop_heartbeat() # The NMT master will change the state INITIALISING (0) # to PRE-OPERATIONAL (127) self.assertEqual(state, 'PRE-OPERATIONAL') - def test_nmt_state_initializing_to_preoper(self): + async def test_nmt_state_initializing_to_preoper(self): # Initialize the heartbeat timer - self.local_node.sdo["Producer heartbeat time"].raw = 100 + if self.use_async: + await self.local_node.sdo["Producer heartbeat time"].aset_raw(100) + else: + self.local_node.sdo["Producer heartbeat time"].raw = 100 self.local_node.nmt.stop_heartbeat() # This transition shall start the heartbeating self.local_node.nmt.state = 'INITIALISING' self.local_node.nmt.state = 'PRE-OPERATIONAL' - state = self.remote_node.nmt.wait_for_heartbeat() + if self.use_async: + state = await self.remote_node.nmt.await_for_heartbeat() + else: + state = self.remote_node.nmt.wait_for_heartbeat() self.local_node.nmt.stop_heartbeat() self.assertEqual(state, 'PRE-OPERATIONAL') - def test_receive_abort_request(self): - self.remote_node.sdo.abort(0x05040003) + async def test_receive_abort_request(self): + if self.use_async: + await self.remote_node.sdo.aabort(0x05040003) + else: + self.remote_node.sdo.abort(0x05040003) # Line below is just so that we are sure the client have received the abort # before we do the check - time.sleep(0.1) + if self.use_async: + await asyncio.sleep(0.1) + else: + time.sleep(0.1) self.assertEqual(self.local_node.sdo.last_received_error, 0x05040003) - def test_start_remote_node(self): + async def test_start_remote_node(self): self.remote_node.nmt.state = 'OPERATIONAL' # Line below is just so that we are sure the client have received the command # before we do the check - time.sleep(0.1) + if self.use_async: + await asyncio.sleep(0.1) + else: + time.sleep(0.1) slave_state = self.local_node.nmt.state self.assertEqual(slave_state, 'OPERATIONAL') - def test_two_nodes_on_the_bus(self): - self.local_node.sdo["Manufacturer device name"].raw = "Some cool device" - device_name = self.remote_node.sdo["Manufacturer device name"].data + async def test_two_nodes_on_the_bus(self): + if self.use_async: + await self.local_node.sdo["Manufacturer device name"].aset_raw("Some cool device") + device_name = await self.remote_node.sdo["Manufacturer device name"].aget_data() + else: + self.local_node.sdo["Manufacturer device name"].raw = "Some cool device" + device_name = self.remote_node.sdo["Manufacturer device name"].data self.assertEqual(device_name, b"Some cool device") - self.local_node2.sdo["Manufacturer device name"].raw = "Some cool device2" - device_name = self.remote_node2.sdo["Manufacturer device name"].data + if self.use_async: + await self.local_node2.sdo["Manufacturer device name"].aset_raw("Some cool device2") + device_name = await self.remote_node2.sdo["Manufacturer device name"].aget_data() + else: + self.local_node2.sdo["Manufacturer device name"].raw = "Some cool device2" + device_name = self.remote_node2.sdo["Manufacturer device name"].data self.assertEqual(device_name, b"Some cool device2") - def test_abort(self): - with self.assertRaises(canopen.SdoAbortedError) as cm: - _ = self.remote_node.sdo.upload(0x1234, 0) + async def test_abort(self): + if self.use_async: + with self.assertRaises(canopen.SdoAbortedError) as cm: + _ = await self.remote_node.sdo.aupload(0x1234, 0) + else: + with self.assertRaises(canopen.SdoAbortedError) as cm: + _ = self.remote_node.sdo.upload(0x1234, 0) # Should be Object does not exist self.assertEqual(cm.exception.code, 0x06020000) - with self.assertRaises(canopen.SdoAbortedError) as cm: - _ = self.remote_node.sdo.upload(0x1018, 100) + if self.use_async: + with self.assertRaises(canopen.SdoAbortedError) as cm: + _ = await self.remote_node.sdo.aupload(0x1018, 100) + else: + with self.assertRaises(canopen.SdoAbortedError) as cm: + _ = self.remote_node.sdo.upload(0x1018, 100) # Should be Subindex does not exist self.assertEqual(cm.exception.code, 0x06090011) - with self.assertRaises(canopen.SdoAbortedError) as cm: - _ = self.remote_node.sdo[0x1001].data + if self.use_async: + with self.assertRaises(canopen.SdoAbortedError) as cm: + _ = await self.remote_node.sdo[0x1001].aget_data() + else: + with self.assertRaises(canopen.SdoAbortedError) as cm: + _ = self.remote_node.sdo[0x1001].data # Should be Resource not available self.assertEqual(cm.exception.code, 0x060A0023) @@ -163,54 +248,98 @@ def _some_read_callback(self, **kwargs): def _some_write_callback(self, **kwargs): self._kwargs = kwargs - def test_callbacks(self): + async def test_callbacks(self): self.local_node.add_read_callback(self._some_read_callback) self.local_node.add_write_callback(self._some_write_callback) - data = self.remote_node.sdo.upload(0x1003, 5) + if self.use_async: + data = await self.remote_node.sdo.aupload(0x1003, 5) + else: + data = self.remote_node.sdo.upload(0x1003, 5) self.assertEqual(data, b"\x01\x02\x00\x00") self.assertEqual(self._kwargs["index"], 0x1003) self.assertEqual(self._kwargs["subindex"], 5) - self.remote_node.sdo.download(0x1017, 0, b"\x03\x04") + if self.use_async: + await self.remote_node.sdo.adownload(0x1017, 0, b"\x03\x04") + else: + self.remote_node.sdo.download(0x1017, 0, b"\x03\x04") self.assertEqual(self._kwargs["index"], 0x1017) self.assertEqual(self._kwargs["subindex"], 0) self.assertEqual(self._kwargs["data"], b"\x03\x04") -class TestPDO(unittest.TestCase): +class TestSDOSync(TestSDO): + """ Run the test in non-async mode. """ + __test__ = True + use_async = False + + +class TestSDOAsync(TestSDO): + """ Run the test in async mode. """ + __test__ = True + use_async = True + + +class TestPDO(unittest.IsolatedAsyncioTestCase): """ Test PDO slave. """ - @classmethod - def setUpClass(cls): - cls.network1 = canopen.Network() - cls.network1.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 - cls.network1.connect("test", interface="virtual") - cls.remote_node = cls.network1.add_node(2, SAMPLE_EDS) + __test__ = False # This is a base class, tests should not be run directly. + use_async: bool + + def setUp(self): + loop = None + if self.use_async: + loop = asyncio.get_event_loop() - cls.network2 = canopen.Network() - cls.network2.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 - cls.network2.connect("test", interface="virtual") - cls.local_node = cls.network2.create_node(2, SAMPLE_EDS) + self.network1 = canopen.Network(loop=loop) + self.network1.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 + self.network1.connect("test", interface="virtual") + with AllowBlocking(): + self.remote_node = self.network1.add_node(2, SAMPLE_EDS) - @classmethod - def tearDownClass(cls): - cls.network1.disconnect() - cls.network2.disconnect() + self.network2 = canopen.Network(loop=loop) + self.network2.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 + self.network2.connect("test", interface="virtual") + self.local_node = self.network2.create_node(2, SAMPLE_EDS) - def test_read(self): + def tearDown(self): + self.network1.disconnect() + self.network2.disconnect() + + async def test_read(self): # TODO: Do some more checks here. Currently it only tests that they # can be called without raising an error. - self.remote_node.pdo.read() - self.local_node.pdo.read() - - def test_save(self): + if self.use_async: + await self.remote_node.pdo.aread() + await self.local_node.pdo.aread() + else: + self.remote_node.pdo.read() + self.local_node.pdo.read() + + async def test_save(self): # TODO: Do some more checks here. Currently it only tests that they # can be called without raising an error. - self.remote_node.pdo.save() - self.local_node.pdo.save() + if self.use_async: + await self.remote_node.pdo.asave() + await self.local_node.pdo.asave() + else: + self.remote_node.pdo.save() + self.local_node.pdo.save() + + +class TestPDOSync(TestPDO): + """ Run the test in non-async mode. """ + __test__ = True + use_async = False + + +class TestPDOAsync(TestPDO): + """ Run the test in async mode. """ + __test__ = True + use_async = True if __name__ == "__main__": diff --git a/test/test_network.py b/test/test_network.py index cd65ea71..409d795d 100644 --- a/test/test_network.py +++ b/test/test_network.py @@ -1,6 +1,7 @@ import logging import time import unittest +import asyncio import can @@ -9,22 +10,39 @@ from .util import SAMPLE_EDS -class TestNetwork(unittest.TestCase): +class TestNetwork(unittest.IsolatedAsyncioTestCase): + + __test__ = False # This is a base class, tests should not be run directly. + use_async: bool def setUp(self): - self.network = canopen.Network() + self.loop = None + if self.use_async: + self.loop = asyncio.get_event_loop() + + self.network = canopen.Network(loop=self.loop) self.network.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 - def test_network_add_node(self): + def tearDown(self): + if self.network.bus is not None: + self.network.disconnect() + + async def test_network_add_node(self): # Add using str. with self.assertLogs(): - node = self.network.add_node(2, SAMPLE_EDS) + if self.use_async: + node = await self.network.aadd_node(2, SAMPLE_EDS) + else: + node = self.network.add_node(2, SAMPLE_EDS) self.assertEqual(self.network[2], node) self.assertEqual(node.id, 2) self.assertIsInstance(node, canopen.RemoteNode) # Add using OD. - node = self.network.add_node(3, self.network[2].object_dictionary) + if self.use_async: + node = await self.network.aadd_node(3, self.network[2].object_dictionary) + else: + node = self.network.add_node(3, self.network[2].object_dictionary) self.assertEqual(self.network[3], node) self.assertEqual(node.id, 3) self.assertIsInstance(node, canopen.RemoteNode) @@ -32,7 +50,10 @@ def test_network_add_node(self): # Add using RemoteNode. with self.assertLogs(): node = canopen.RemoteNode(4, SAMPLE_EDS) - self.network.add_node(node) + if self.use_async: + await self.network.aadd_node(node) + else: + self.network.add_node(node) self.assertEqual(self.network[4], node) self.assertEqual(node.id, 4) self.assertIsInstance(node, canopen.RemoteNode) @@ -40,7 +61,10 @@ def test_network_add_node(self): # Add using LocalNode. with self.assertLogs(): node = canopen.LocalNode(5, SAMPLE_EDS) - self.network.add_node(node) + if self.use_async: + await self.network.aadd_node(node) + else: + self.network.add_node(node) self.assertEqual(self.network[5], node) self.assertEqual(node.id, 5) self.assertIsInstance(node, canopen.LocalNode) @@ -48,12 +72,15 @@ def test_network_add_node(self): # Verify that we've got the correct number of nodes. self.assertEqual(len(self.network), 4) - def test_network_add_node_upload_eds(self): + async def test_network_add_node_upload_eds(self): # Will err because we're not connected to a real network. with self.assertLogs(level=logging.ERROR): - self.network.add_node(2, SAMPLE_EDS, upload_eds=True) + if self.use_async: + await self.network.aadd_node(2, SAMPLE_EDS, upload_eds=True) + else: + self.network.add_node(2, SAMPLE_EDS, upload_eds=True) - def test_network_create_node(self): + async def test_network_create_node(self): with self.assertLogs(): self.network.create_node(2, SAMPLE_EDS) self.network.create_node(3, SAMPLE_EDS) @@ -63,7 +90,7 @@ def test_network_create_node(self): self.assertIsInstance(self.network[3], canopen.LocalNode) self.assertIsInstance(self.network[4], canopen.RemoteNode) - def test_network_check(self): + async def test_network_check(self): self.network.connect(interface="virtual") def cleanup(): @@ -86,18 +113,29 @@ class Custom(Exception): with self.assertLogs(level=logging.ERROR): self.network.disconnect() - def test_network_notify(self): + async def test_network_notify(self): with self.assertLogs(): - self.network.add_node(2, SAMPLE_EDS) + if self.use_async: + await self.network.aadd_node(2, SAMPLE_EDS) + else: + self.network.add_node(2, SAMPLE_EDS) node = self.network[2] - self.network.notify(0x82, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1473418396.0) + async def notify(*args): + """Simulate a notification from the network.""" + if self.use_async: + # If we're using async, we must run the notify in a thread + # to avoid getting blocking call errors. + await asyncio.to_thread(self.network.notify, *args) + else: + self.network.notify(*args) + await notify(0x82, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1473418396.0) self.assertEqual(len(node.emcy.active), 1) - self.network.notify(0x702, b'\x05', 1473418396.0) + await notify(0x702, b'\x05', 1473418396.0) self.assertEqual(node.nmt.state, 'OPERATIONAL') self.assertListEqual(self.network.scanner.nodes, [2]) - def test_network_send_message(self): - bus = can.interface.Bus(interface="virtual") + async def test_network_send_message(self): + bus = can.interface.Bus(interface="virtual", loop=self.loop) self.addCleanup(bus.shutdown) self.network.connect(interface="virtual") @@ -118,7 +156,7 @@ def test_network_send_message(self): self.assertEqual(msg.arbitration_id, 0x12345) self.assertTrue(msg.is_extended_id) - def test_network_subscribe_unsubscribe(self): + async def test_network_subscribe_unsubscribe(self): N_HOOKS = 3 accumulators = [] * N_HOOKS @@ -148,7 +186,7 @@ def hook(*args, i=i): # Verify that no new data was added to the accumulator. self.assertEqual(accumulators[0], [(0, bytes([1, 2, 3]), 1000)]) - def test_network_subscribe_multiple(self): + async def test_network_subscribe_multiple(self): N_HOOKS = 3 self.network.connect(interface="virtual", receive_own_messages=True) self.addCleanup(self.network.disconnect) @@ -201,16 +239,20 @@ def hook(*args, i=i): self.assertEqual(accumulators[1], BATCH1) self.assertEqual(accumulators[2], BATCH1 + [BATCH2] + [BATCH3]) - def test_network_context_manager(self): + async def test_network_context_manager(self): with self.network.connect(interface="virtual"): pass with self.assertRaisesRegex(RuntimeError, "Not connected"): self.network.send_message(0, []) - def test_network_item_access(self): + async def test_network_item_access(self): with self.assertLogs(): - self.network.add_node(2, SAMPLE_EDS) - self.network.add_node(3, SAMPLE_EDS) + if self.use_async: + await self.network.aadd_node(2, SAMPLE_EDS) + await self.network.aadd_node(3, SAMPLE_EDS) + else: + self.network.add_node(2, SAMPLE_EDS) + self.network.add_node(3, SAMPLE_EDS) self.assertEqual([2, 3], [node for node in self.network]) # Check __delitem__. @@ -229,7 +271,7 @@ def test_network_item_access(self): self.assertNotEqual(self.network[3], old) self.assertEqual([3], [node for node in self.network]) - def test_network_send_periodic(self): + async def test_network_send_periodic(self): DATA1 = bytes([1, 2, 3]) DATA2 = bytes([4, 5, 6]) COB_ID = 0x123 @@ -238,7 +280,7 @@ def test_network_send_periodic(self): self.network.connect(interface="virtual") self.addCleanup(self.network.disconnect) - bus = can.Bus(interface="virtual") + bus = can.Bus(interface="virtual", loop=self.loop) self.addCleanup(bus.shutdown) acc = [] @@ -285,14 +327,90 @@ def wait_for_periodicity(): if msg is not None: self.assertIsNone(bus.recv(PERIOD)) + def test_dispatch_callbacks_sync(self): + + result1 = 0 + result2 = 0 + + def callback1(arg): + nonlocal result1 + result1 = arg + 1 + + def callback2(arg): + nonlocal result2 + result2 = arg * 2 + + # Check that the synchronous callbacks are called correctly + self.network.dispatch_callbacks([callback1, callback2], 5) + self.assertEqual([result1, result2], [6, 10]) + + async def async_callback(arg): + return arg + 1 + + # This is a workaround to create an async callback which we have the + # ability to clean up after the test. Logicallt its the same as calling + # async_callback directly. + coro = None + def _create_async_callback(arg): + nonlocal coro + coro = async_callback(arg) + return coro + + # Check that it's not possible to call async callbacks in a non-async context + with self.assertRaises(RuntimeError): + self.network.dispatch_callbacks([_create_async_callback], 5) + + # Cleanup + if coro is not None: + coro.close() # Close the coroutine to prevent warnings. + + async def test_dispatch_callbacks_async(self): + + result1 = 0 + result2 = 0 -class TestScanner(unittest.TestCase): + event = asyncio.Event() + + def callback(arg): + nonlocal result1 + result1 = arg + 1 + + async def async_callback(arg): + nonlocal result2 + result2 = arg * 2 + event.set() # Notify the test that the async callback is done + + # Check that both callbacks are called correctly in an async context + self.network.dispatch_callbacks([callback, async_callback], 5) + await event.wait() + self.assertEqual([result1, result2], [6, 10]) + + +class TestNetworkSync(TestNetwork): + """ Run tests in a synchronous context. """ + __test__ = True + use_async = False + + +class TestNetworkAsync(TestNetwork): + """ Run tests in an asynchronous context. """ + __test__ = True + use_async = True + + +class TestScanner(unittest.IsolatedAsyncioTestCase): TIMEOUT = 0.1 + __test__ = False # This is a base class, tests should not be run directly. + use_async: bool + def setUp(self): + self.loop = None + if self.use_async: + self.loop = asyncio.get_event_loop() self.scanner = canopen.network.NodeScanner() - def test_scanner_on_message_received(self): + async def test_scanner_on_message_received(self): # Emergency frames should be recognized. self.scanner.on_message_received(0x081) # Heartbeats should be recognized. @@ -312,23 +430,23 @@ def test_scanner_on_message_received(self): self.scanner.on_message_received(0x50e) self.assertListEqual(self.scanner.nodes, [1, 3, 5, 7, 9, 11, 13]) - def test_scanner_reset(self): + async def test_scanner_reset(self): self.scanner.nodes = [1, 2, 3] # Mock scan. self.scanner.reset() self.assertListEqual(self.scanner.nodes, []) - def test_scanner_search_no_network(self): + async def test_scanner_search_no_network(self): with self.assertRaisesRegex(RuntimeError, "No actual Network object was assigned"): self.scanner.search() - def test_scanner_search(self): - rxbus = can.Bus(interface="virtual") + async def test_scanner_search(self): + rxbus = can.Bus(interface="virtual", loop=self.loop) self.addCleanup(rxbus.shutdown) - txbus = can.Bus(interface="virtual") + txbus = can.Bus(interface="virtual", loop=self.loop) self.addCleanup(txbus.shutdown) - net = canopen.Network(txbus) + net = canopen.Network(txbus, loop=self.loop) net.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 net.connect() self.addCleanup(net.disconnect) @@ -346,9 +464,9 @@ def test_scanner_search(self): # Check that no spurious packets were sent. self.assertIsNone(rxbus.recv(self.TIMEOUT)) - def test_scanner_search_limit(self): - bus = can.Bus(interface="virtual", receive_own_messages=True) - net = canopen.Network(bus) + async def test_scanner_search_limit(self): + bus = can.Bus(interface="virtual", receive_own_messages=True, loop=self.loop) + net = canopen.Network(bus, loop=self.loop) net.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 net.connect() self.addCleanup(net.disconnect) @@ -363,5 +481,17 @@ def test_scanner_search_limit(self): self.assertIsNone(bus.recv(self.TIMEOUT)) +class TestScannerSync(TestScanner): + """ Run the tests in a synchronous context. """ + __test__ = True + use_async = False + + +class TestScannerAsync(TestScanner): + """ Run the tests in an asynchronous context. """ + __test__ = True + use_async = True + + if __name__ == "__main__": unittest.main() diff --git a/test/test_nmt.py b/test/test_nmt.py index 636126dc..5fc0caee 100644 --- a/test/test_nmt.py +++ b/test/test_nmt.py @@ -1,16 +1,19 @@ import threading import time import unittest +import asyncio import can import canopen +from canopen.async_guard import AllowBlocking from canopen.nmt import COMMAND_TO_STATE, NMT_COMMANDS, NMT_STATES, NmtError from .util import SAMPLE_EDS class TestNmtBase(unittest.TestCase): + def setUp(self): node_id = 2 self.node_id = node_id @@ -42,19 +45,27 @@ def test_state_set_invalid(self): self.nmt.state = "INVALID" -class TestNmtMaster(unittest.TestCase): +class TestNmtMaster(unittest.IsolatedAsyncioTestCase): NODE_ID = 2 PERIOD = 0.01 TIMEOUT = PERIOD * 10 + __test__ = False # This is a base class, tests should not be run directly. + use_async: bool + def setUp(self): - net = canopen.Network() + loop = None + if self.use_async: + loop = asyncio.get_event_loop() + + net = canopen.Network(loop=loop) net.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 net.connect(interface="virtual") with self.assertLogs(): - node = net.add_node(self.NODE_ID, SAMPLE_EDS) + with AllowBlocking(): + node = net.add_node(self.NODE_ID, SAMPLE_EDS) - self.bus = can.Bus(interface="virtual") + self.bus = can.Bus(interface="virtual", loop=loop) self.net = net self.node = node @@ -67,47 +78,65 @@ def dispatch_heartbeat(self, code): hb = can.Message(arbitration_id=cob_id, data=[code]) self.bus.send(hb) - def test_nmt_master_no_heartbeat(self): + async def test_nmt_master_no_heartbeat(self): with self.assertRaisesRegex(NmtError, "heartbeat"): - self.node.nmt.wait_for_heartbeat(self.TIMEOUT) + if self.use_async: + await self.node.nmt.await_for_heartbeat(self.TIMEOUT) + else: + self.node.nmt.wait_for_heartbeat(self.TIMEOUT) with self.assertRaisesRegex(NmtError, "boot-up"): - self.node.nmt.wait_for_bootup(self.TIMEOUT) + if self.use_async: + await self.node.nmt.await_for_bootup(self.TIMEOUT) + else: + self.node.nmt.wait_for_bootup(self.TIMEOUT) - def test_nmt_master_on_heartbeat(self): + async def test_nmt_master_on_heartbeat(self): # Skip the special INITIALISING case. for code in [st for st in NMT_STATES if st != 0]: with self.subTest(code=code): t = threading.Timer(0.01, self.dispatch_heartbeat, args=(code,)) t.start() self.addCleanup(t.join) - actual = self.node.nmt.wait_for_heartbeat(0.1) + if self.use_async: + actual = await self.node.nmt.await_for_heartbeat(0.1) + else: + actual = self.node.nmt.wait_for_heartbeat(0.1) expected = NMT_STATES[code] self.assertEqual(actual, expected) - def test_nmt_master_wait_for_bootup(self): + async def test_nmt_master_wait_for_bootup(self): t = threading.Timer(0.01, self.dispatch_heartbeat, args=(0x00,)) t.start() self.addCleanup(t.join) - self.node.nmt.wait_for_bootup(self.TIMEOUT) + if self.use_async: + await self.node.nmt.await_for_bootup(self.TIMEOUT) + else: + self.node.nmt.wait_for_bootup(self.TIMEOUT) self.assertEqual(self.node.nmt.state, "PRE-OPERATIONAL") - def test_nmt_master_on_heartbeat_initialising(self): + async def test_nmt_master_on_heartbeat_initialising(self): t = threading.Timer(0.01, self.dispatch_heartbeat, args=(0x00,)) t.start() self.addCleanup(t.join) - state = self.node.nmt.wait_for_heartbeat(self.TIMEOUT) + if self.use_async: + state = await self.node.nmt.await_for_heartbeat(self.TIMEOUT) + else: + state = self.node.nmt.wait_for_heartbeat(self.TIMEOUT) self.assertEqual(state, "PRE-OPERATIONAL") - def test_nmt_master_on_heartbeat_unknown_state(self): + async def test_nmt_master_on_heartbeat_unknown_state(self): t = threading.Timer(0.01, self.dispatch_heartbeat, args=(0xcb,)) t.start() self.addCleanup(t.join) - state = self.node.nmt.wait_for_heartbeat(self.TIMEOUT) + if self.use_async: + state = await self.node.nmt.await_for_heartbeat(self.TIMEOUT) + else: + state = self.node.nmt.wait_for_heartbeat(self.TIMEOUT) # Expect the high bit to be masked out, and a formatted string to # be returned. self.assertEqual(state, "UNKNOWN STATE '75'") - def test_nmt_master_add_heartbeat_callback(self): + async def test_nmt_master_add_heartbeat_callback(self): event = threading.Event() state = None def hook(st): @@ -117,10 +146,13 @@ def hook(st): self.node.nmt.add_heartbeat_callback(hook) self.dispatch_heartbeat(0x7f) - self.assertTrue(event.wait(self.TIMEOUT)) + if self.use_async: + await asyncio.to_thread(event.wait, self.TIMEOUT) + else: + self.assertTrue(event.wait(self.TIMEOUT)) self.assertEqual(state, 127) - def test_nmt_master_node_guarding(self): + async def test_nmt_master_node_guarding(self): self.node.nmt.start_node_guarding(self.PERIOD) msg = self.bus.recv(self.TIMEOUT) self.assertIsNotNone(msg) @@ -135,64 +167,111 @@ def test_nmt_master_node_guarding(self): self.assertIsNone(self.bus.recv(self.TIMEOUT)) -class TestNmtSlave(unittest.TestCase): +class TestNmtMasterSync(TestNmtMaster): + """ Run tests in non-asynchronous mode. """ + __test__ = True + use_async = False + + +class TestNmtMasterAsync(TestNmtMaster): + """ Run tests in asynchronous mode. """ + __test__ = True + use_async = True + + +class TestNmtSlave(unittest.IsolatedAsyncioTestCase): + + __test__ = False # This is a base class, tests should not be run directly. + use_async: bool + def setUp(self): - self.network1 = canopen.Network() + loop = None + if self.use_async: + loop = asyncio.get_event_loop() + + self.network1 = canopen.Network(loop=loop) self.network1.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 self.network1.connect("test", interface="virtual") with self.assertLogs(): - self.remote_node = self.network1.add_node(2, SAMPLE_EDS) + with AllowBlocking(): + self.remote_node = self.network1.add_node(2, SAMPLE_EDS) - self.network2 = canopen.Network() + self.network2 = canopen.Network(loop=loop) self.network2.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 self.network2.connect("test", interface="virtual") with self.assertLogs(): self.local_node = self.network2.create_node(2, SAMPLE_EDS) - self.remote_node2 = self.network1.add_node(3, SAMPLE_EDS) + with AllowBlocking(): + self.remote_node2 = self.network1.add_node(3, SAMPLE_EDS) self.local_node2 = self.network2.create_node(3, SAMPLE_EDS) def tearDown(self): self.network1.disconnect() self.network2.disconnect() - def test_start_two_remote_nodes(self): + async def test_start_two_remote_nodes(self): self.remote_node.nmt.state = "OPERATIONAL" # Line below is just so that we are sure the client have received the command # before we do the check - time.sleep(0.1) + if self.use_async: + await asyncio.sleep(0.1) + else: + time.sleep(0.1) slave_state = self.local_node.nmt.state self.assertEqual(slave_state, "OPERATIONAL") self.remote_node2.nmt.state = "OPERATIONAL" # Line below is just so that we are sure the client have received the command # before we do the check - time.sleep(0.1) + if self.use_async: + await asyncio.sleep(0.1) + else: + time.sleep(0.1) slave_state = self.local_node2.nmt.state self.assertEqual(slave_state, "OPERATIONAL") - def test_stop_two_remote_nodes_using_broadcast(self): + async def test_stop_two_remote_nodes_using_broadcast(self): # This is a NMT broadcast "Stop remote node" # ie. set the node in STOPPED state self.network1.send_message(0, [2, 0]) # Line below is just so that we are sure the slaves have received the command # before we do the check - time.sleep(0.1) + if self.use_async: + await asyncio.sleep(0.1) + else: + time.sleep(0.1) slave_state = self.local_node.nmt.state self.assertEqual(slave_state, "STOPPED") slave_state = self.local_node2.nmt.state self.assertEqual(slave_state, "STOPPED") - def test_heartbeat(self): + async def test_heartbeat(self): self.assertEqual(self.remote_node.nmt.state, "INITIALISING") self.assertEqual(self.local_node.nmt.state, "INITIALISING") self.local_node.nmt.state = "OPERATIONAL" - self.local_node.sdo[0x1017].raw = 100 - time.sleep(0.2) + if self.use_async: + await self.local_node.sdo[0x1017].aset_raw(100) + await asyncio.sleep(0.2) + else: + self.local_node.sdo[0x1017].raw = 100 + time.sleep(0.2) self.assertEqual(self.remote_node.nmt.state, "OPERATIONAL") self.local_node.nmt.stop_heartbeat() +class TestNmtSlaveSync(TestNmtSlave): + """ Run tests in non-asynchronous mode. """ + __test__ = True + use_async = False + + +class TestNmtSlaveAsync(TestNmtSlave): + """ Run tests in asynchronous mode. """ + __test__ = True + use_async = True + + if __name__ == "__main__": unittest.main() diff --git a/test/test_node.py b/test/test_node.py index 43373a2a..96a279c7 100644 --- a/test/test_node.py +++ b/test/test_node.py @@ -1,4 +1,5 @@ import unittest +import asyncio import canopen @@ -8,21 +9,26 @@ def count_subscribers(network: canopen.Network) -> int: return sum(len(n) for n in network.subscribers.values()) -class TestLocalNode(unittest.TestCase): +class TestLocalNode(unittest.IsolatedAsyncioTestCase): - @classmethod - def setUpClass(cls): - cls.network = canopen.Network() - cls.network.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 - cls.network.connect(interface="virtual") + __test__ = False # This is a base class, tests should not be run directly. + use_async: bool - cls.node = canopen.LocalNode(2, canopen.objectdictionary.ObjectDictionary()) + def setUp(self): + loop = None + if self.use_async: + loop = asyncio.get_event_loop() - @classmethod - def tearDownClass(cls): - cls.network.disconnect() + self.network = canopen.Network(loop=loop) + self.network.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 + self.network.connect(interface="virtual") - def test_associate_network(self): + self.node = canopen.LocalNode(2, canopen.objectdictionary.ObjectDictionary()) + + def tearDown(self): + self.network.disconnect() + + async def test_associate_network(self): # Need to store the number of subscribers before associating because the # network implementation automatically adds subscribers to the list n_subscribers = count_subscribers(self.network) @@ -57,21 +63,38 @@ def test_associate_network(self): self.node.remove_network() -class TestRemoteNode(unittest.TestCase): +class TestLocalNodeSync(TestLocalNode): + """ Run the tests in non-asynchronous mode. """ + __test__ = True + use_async = False + + +class TestLocalNodeAsync(TestLocalNode): + """ Run the tests in asynchronous mode. """ + __test__ = True + use_async = True + + +class TestRemoteNode(unittest.IsolatedAsyncioTestCase): + + __test__ = False # This is a base class, tests should not be run directly. + use_async: bool - @classmethod - def setUpClass(cls): - cls.network = canopen.Network() - cls.network.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 - cls.network.connect(interface="virtual") + def setUp(self): + loop = None + if self.use_async: + loop = asyncio.get_event_loop() - cls.node = canopen.RemoteNode(2, canopen.objectdictionary.ObjectDictionary()) + self.network = canopen.Network(loop=loop) + self.network.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 + self.network.connect(interface="virtual") - @classmethod - def tearDownClass(cls): - cls.network.disconnect() + self.node = canopen.RemoteNode(2, canopen.objectdictionary.ObjectDictionary()) - def test_associate_network(self): + def tearDown(self): + self.network.disconnect() + + async def test_associate_network(self): # Need to store the number of subscribers before associating because the # network implementation automatically adds subscribers to the list n_subscribers = count_subscribers(self.network) @@ -83,6 +106,7 @@ def test_associate_network(self): self.assertIs(self.node.tpdo.network, self.network) self.assertIs(self.node.rpdo.network, self.network) self.assertIs(self.node.nmt.network, self.network) + self.assertIs(self.node.emcy.network, self.network) # Test that its not possible to associate the network multiple times with self.assertRaises(RuntimeError) as cm: @@ -98,7 +122,20 @@ def test_associate_network(self): self.assertIs(self.node.tpdo.network, uninitalized) self.assertIs(self.node.rpdo.network, uninitalized) self.assertIs(self.node.nmt.network, uninitalized) + self.assertIs(self.node.emcy.network, uninitalized) self.assertEqual(count_subscribers(self.network), n_subscribers) # Test that its possible to deassociate the network multiple times self.node.remove_network() + + +class TestRemoteNodeSync(TestRemoteNode): + """ Run the tests in non-asynchronous mode. """ + __test__ = True + use_async = False + + +class TestRemoteNodeAsync(TestRemoteNode): + """ Run the tests in asynchronous mode. """ + __test__ = True + use_async = True diff --git a/test/test_pdo.py b/test/test_pdo.py index 1badc89d..f66672cd 100644 --- a/test/test_pdo.py +++ b/test/test_pdo.py @@ -5,7 +5,11 @@ from .util import SAMPLE_EDS, tmp_file -class TestPDO(unittest.TestCase): +class TestPDO(unittest.IsolatedAsyncioTestCase): + + __test__ = False # This is a base class, tests should not be run directly. + use_async: bool + def setUp(self): node = canopen.Node(1, SAMPLE_EDS) pdo = node.pdo.tx[1] @@ -16,59 +20,115 @@ def setUp(self): pdo.add_variable('BOOLEAN value', length=1) # 0x2005 pdo.add_variable('BOOLEAN value 2', length=1) # 0x2006 - # Write some values - pdo['INTEGER16 value'].raw = -3 - pdo['UNSIGNED8 value'].raw = 0xf - pdo['INTEGER8 value'].raw = -2 - pdo['INTEGER32 value'].raw = 0x01020304 - pdo['BOOLEAN value'].raw = False - pdo['BOOLEAN value 2'].raw = True - self.pdo = pdo self.node = node - def test_pdo_map_bit_mapping(self): - self.assertEqual(self.pdo.data, b'\xfd\xff\xef\x04\x03\x02\x01\x02') + async def set_values(self): + """Initialize the PDO with some valuues. + + Do this in a separate method in order to be abel to use the + async and sync versions of the tests. + """ + node = self.node + pdo = node.pdo.tx[1] + if self.use_async: + # Write some values (different from the synchronous values) + await pdo['INTEGER16 value'].aset_raw(12) + await pdo['UNSIGNED8 value'].aset_raw(0xe) + await pdo['INTEGER8 value'].aset_raw(-4) + await pdo['INTEGER32 value'].aset_raw(0x56789abc) + await pdo['BOOLEAN value'].aset_raw(True) + await pdo['BOOLEAN value 2'].aset_raw(False) + else: + # Write some values + pdo['INTEGER16 value'].raw = -3 + pdo['UNSIGNED8 value'].raw = 0xf + pdo['INTEGER8 value'].raw = -2 + pdo['INTEGER32 value'].raw = 0x01020304 + pdo['BOOLEAN value'].raw = False + pdo['BOOLEAN value 2'].raw = True - def test_pdo_map_getitem(self): + async def test_pdo_map_bit_mapping(self): + await self.set_values() + if self.use_async: + self.assertEqual(self.pdo.data, b'\x0c\x00\xce\xbc\x9a\x78\x56\x01') + else: + self.assertEqual(self.pdo.data, b'\xfd\xff\xef\x04\x03\x02\x01\x02') + + async def test_pdo_map_getitem(self): + await self.set_values() pdo = self.pdo - self.assertEqual(pdo['INTEGER16 value'].raw, -3) - self.assertEqual(pdo['UNSIGNED8 value'].raw, 0xf) - self.assertEqual(pdo['INTEGER8 value'].raw, -2) - self.assertEqual(pdo['INTEGER32 value'].raw, 0x01020304) - self.assertEqual(pdo['BOOLEAN value'].raw, False) - self.assertEqual(pdo['BOOLEAN value 2'].raw, True) - - def test_pdo_getitem(self): + if self.use_async: + self.assertEqual(await pdo['INTEGER16 value'].aget_raw(), 12) + self.assertEqual(await pdo['UNSIGNED8 value'].aget_raw(), 0xe) + self.assertEqual(await pdo['INTEGER8 value'].aget_raw(), -4) + self.assertEqual(await pdo['INTEGER32 value'].aget_raw(), 0x56789abc) + self.assertEqual(await pdo['BOOLEAN value'].aget_raw(), True) + self.assertEqual(await pdo['BOOLEAN value 2'].aget_raw(), False) + else: + self.assertEqual(pdo['INTEGER16 value'].raw, -3) + self.assertEqual(pdo['UNSIGNED8 value'].raw, 0xf) + self.assertEqual(pdo['INTEGER8 value'].raw, -2) + self.assertEqual(pdo['INTEGER32 value'].raw, 0x01020304) + self.assertEqual(pdo['BOOLEAN value'].raw, False) + self.assertEqual(pdo['BOOLEAN value 2'].raw, True) + + async def test_pdo_getitem(self): + await self.set_values() node = self.node - self.assertEqual(node.tpdo[1]['INTEGER16 value'].raw, -3) - self.assertEqual(node.tpdo[1]['UNSIGNED8 value'].raw, 0xf) - self.assertEqual(node.tpdo[1]['INTEGER8 value'].raw, -2) - self.assertEqual(node.tpdo[1]['INTEGER32 value'].raw, 0x01020304) - self.assertEqual(node.tpdo['INTEGER32 value'].raw, 0x01020304) - self.assertEqual(node.tpdo[1]['BOOLEAN value'].raw, False) - self.assertEqual(node.tpdo[1]['BOOLEAN value 2'].raw, True) - - # Test different types of access - self.assertEqual(node.pdo[0x1600]['INTEGER16 value'].raw, -3) - self.assertEqual(node.pdo['INTEGER16 value'].raw, -3) - self.assertEqual(node.pdo.tx[1]['INTEGER16 value'].raw, -3) - self.assertEqual(node.pdo[0x2001].raw, -3) - self.assertEqual(node.tpdo[0x2001].raw, -3) - self.assertEqual(node.pdo[0x2002].raw, 0xf) - self.assertEqual(node.pdo['0x2002'].raw, 0xf) - self.assertEqual(node.tpdo[0x2002].raw, 0xf) - self.assertEqual(node.pdo[0x1600][0x2002].raw, 0xf) - - def test_pdo_save(self): - self.node.tpdo.save() - self.node.rpdo.save() - - def test_pdo_export(self): + if self.use_async: + self.assertEqual(await node.tpdo[1]['INTEGER16 value'].aget_raw(), 12) + self.assertEqual(await node.tpdo[1]['UNSIGNED8 value'].aget_raw(), 0xe) + self.assertEqual(await node.tpdo[1]['INTEGER8 value'].aget_raw(), -4) + self.assertEqual(await node.tpdo[1]['INTEGER32 value'].aget_raw(), 0x56789abc) + self.assertEqual(await node.tpdo['INTEGER32 value'].aget_raw(), 0x56789abc) + self.assertEqual(await node.tpdo[1]['BOOLEAN value'].aget_raw(), True) + self.assertEqual(await node.tpdo[1]['BOOLEAN value 2'].aget_raw(), False) + + # Test different types of access + self.assertEqual(await node.pdo[0x1600]['INTEGER16 value'].aget_raw(), 12) + self.assertEqual(await node.pdo['INTEGER16 value'].aget_raw(), 12) + self.assertEqual(await node.pdo.tx[1]['INTEGER16 value'].aget_raw(), 12) + self.assertEqual(await node.pdo[0x2001].aget_raw(), 12) + self.assertEqual(await node.tpdo[0x2001].aget_raw(), 12) + self.assertEqual(await node.pdo[0x2002].aget_raw(), 0xe) + self.assertEqual(await node.pdo['0x2002'].aget_raw(), 0xe) + self.assertEqual(await node.tpdo[0x2002].aget_raw(), 0xe) + self.assertEqual(await node.pdo[0x1600][0x2002].aget_raw(), 0xe) + else: + self.assertEqual(node.tpdo[1]['INTEGER16 value'].raw, -3) + self.assertEqual(node.tpdo[1]['UNSIGNED8 value'].raw, 0xf) + self.assertEqual(node.tpdo[1]['INTEGER8 value'].raw, -2) + self.assertEqual(node.tpdo[1]['INTEGER32 value'].raw, 0x01020304) + self.assertEqual(node.tpdo['INTEGER32 value'].raw, 0x01020304) + self.assertEqual(node.tpdo[1]['BOOLEAN value'].raw, False) + self.assertEqual(node.tpdo[1]['BOOLEAN value 2'].raw, True) + + # Test different types of access + self.assertEqual(node.pdo[0x1600]['INTEGER16 value'].raw, -3) + self.assertEqual(node.pdo['INTEGER16 value'].raw, -3) + self.assertEqual(node.pdo.tx[1]['INTEGER16 value'].raw, -3) + self.assertEqual(node.pdo[0x2001].raw, -3) + self.assertEqual(node.tpdo[0x2001].raw, -3) + self.assertEqual(node.pdo[0x2002].raw, 0xf) + self.assertEqual(node.pdo['0x2002'].raw, 0xf) + self.assertEqual(node.tpdo[0x2002].raw, 0xf) + self.assertEqual(node.pdo[0x1600][0x2002].raw, 0xf) + + async def test_pdo_save(self): + await self.set_values() + if self.use_async: + await self.node.tpdo.asave() + await self.node.rpdo.asave() + else: + self.node.tpdo.save() + self.node.rpdo.save() + + async def test_pdo_export(self): try: import canmatrix except ImportError: - raise unittest.SkipTest("The PDO export API requires canmatrix") + self.skipTest("The PDO export API requires canmatrix") for pdo in "tpdo", "rpdo": with tmp_file(suffix=".csv") as tmp: @@ -81,5 +141,17 @@ def test_pdo_export(self): self.assertIn("Frame Name", header) +class TestPDOSync(TestPDO): + """ Test the functions in synchronous mode. """ + __test__ = True + use_async = False + + +class TestPDOAsync(TestPDO): + """ Test the functions in asynchronous mode. """ + __test__ = True + use_async = True + + if __name__ == "__main__": unittest.main() diff --git a/test/test_sdo.py b/test/test_sdo.py index 78012a30..3eeb9950 100644 --- a/test/test_sdo.py +++ b/test/test_sdo.py @@ -1,6 +1,8 @@ import unittest +import asyncio import canopen +from canopen.async_guard import AllowBlocking import canopen.objectdictionary.datatypes as dt from canopen.objectdictionary import ODVariable @@ -11,17 +13,20 @@ RX = 2 -class TestSDOVariables(unittest.TestCase): +class TestSDOVariables(unittest.IsolatedAsyncioTestCase): """Some basic assumptions on the behavior of SDO variable objects. Mostly what is stated in the API docs. """ + __test__ = False # This is a base class, tests should not be run directly. + use_async: bool + def setUp(self): node = canopen.LocalNode(1, SAMPLE_EDS) self.sdo_node = node.sdo - def test_record_iter_length(self): + async def test_record_iter_length(self): """Assume the "highest subindex supported" entry is not counted. Sub-objects without an OD entry should be skipped as well. @@ -31,7 +36,7 @@ def test_record_iter_length(self): self.assertEqual(len(record), 3) self.assertEqual(subs, 3) - def test_array_iter_length(self): + async def test_array_iter_length(self): """Assume the "highest subindex supported" entry is not counted.""" array = self.sdo_node[0x1003] subs = sum(1 for _ in iter(array)) @@ -42,19 +47,38 @@ def test_array_iter_length(self): subs = sum(1 for _ in iter(array)) self.assertEqual(subs, 8) - def test_array_members_dynamic(self): + async def test_array_members_dynamic(self): """Check if sub-objects missing from OD entry are generated dynamically.""" array = self.sdo_node[0x1003] - for var in array.values(): - self.assertIsInstance(var, canopen.sdo.SdoVariable) + if self.use_async: + async for i in array: + self.assertIsInstance(array[i], canopen.sdo.SdoVariable) + else: + for var in array.values(): + self.assertIsInstance(var, canopen.sdo.SdoVariable) + + +class TestSDOVariablesSync(TestSDOVariables): + """ Run tests in non-asynchronous mode. """ + __test__ = True + use_async = False + +class TestSDOVariablesAsync(TestSDOVariables): + """ Run tests in asynchronous mode. """ + __test__ = True + use_async = True -class TestSDO(unittest.TestCase): + +class TestSDO(unittest.IsolatedAsyncioTestCase): """ Test SDO traffic by example. Most are taken from http://www.canopensolutions.com/english/about_canopen/device_configuration_canopen.shtml """ + __test__ = False # This is a base class, tests should not be run directly. + use_async: bool + def _send_message(self, can_id, data, remote=False): """Will be used instead of the usual Network.send_message method. @@ -71,21 +95,30 @@ def _send_message(self, can_id, data, remote=False): self.message_sent = True def setUp(self): - network = canopen.Network() + loop = None + if self.use_async: + loop = asyncio.get_event_loop() + + network = canopen.Network(loop=loop) network.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 network.send_message = self._send_message - node = network.add_node(2, SAMPLE_EDS) + with AllowBlocking(): + node = network.add_node(2, SAMPLE_EDS) node.sdo.RESPONSE_TIMEOUT = 0.01 self.network = network - self.message_sent = False + def tearDown(self): + self.network.disconnect() - def test_expedited_upload(self): + async def test_expedited_upload(self): self.data = [ (TX, b'\x40\x18\x10\x01\x00\x00\x00\x00'), (RX, b'\x43\x18\x10\x01\x04\x00\x00\x00') ] - vendor_id = self.network[2].sdo[0x1018][1].raw + if self.use_async: + vendor_id = await self.network[2].sdo[0x1018][1].aget_raw() + else: + vendor_id = self.network[2].sdo[0x1018][1].raw self.assertEqual(vendor_id, 4) # UNSIGNED8 without padded data part (see issue #5) @@ -93,29 +126,38 @@ def test_expedited_upload(self): (TX, b'\x40\x00\x14\x02\x00\x00\x00\x00'), (RX, b'\x4f\x00\x14\x02\xfe') ] - trans_type = self.network[2].sdo[0x1400]['Transmission type RPDO 1'].raw + if self.use_async: + trans_type = await self.network[2].sdo[0x1400]['Transmission type RPDO 1'].aget_raw() + else: + trans_type = self.network[2].sdo[0x1400]['Transmission type RPDO 1'].raw self.assertEqual(trans_type, 254) self.assertTrue(self.message_sent) - def test_size_not_specified(self): + async def test_size_not_specified(self): self.data = [ (TX, b'\x40\x00\x14\x02\x00\x00\x00\x00'), (RX, b'\x42\x00\x14\x02\xfe\x00\x00\x00') ] # Make sure the size of the data is 1 byte - data = self.network[2].sdo.upload(0x1400, 2) + if self.use_async: + data = await self.network[2].sdo.aupload(0x1400, 2) + else: + data = self.network[2].sdo.upload(0x1400, 2) self.assertEqual(data, b'\xfe') self.assertTrue(self.message_sent) - def test_expedited_download(self): + async def test_expedited_download(self): self.data = [ (TX, b'\x2b\x17\x10\x00\xa0\x0f\x00\x00'), (RX, b'\x60\x17\x10\x00\x00\x00\x00\x00') ] - self.network[2].sdo[0x1017].raw = 4000 + if self.use_async: + await self.network[2].sdo[0x1017].aset_raw(4000) + else: + self.network[2].sdo[0x1017].raw = 4000 self.assertTrue(self.message_sent) - def test_segmented_upload(self): + async def test_segmented_upload(self): self.data = [ (TX, b'\x40\x08\x10\x00\x00\x00\x00\x00'), (RX, b'\x41\x08\x10\x00\x1A\x00\x00\x00'), @@ -128,10 +170,13 @@ def test_segmented_upload(self): (TX, b'\x70\x00\x00\x00\x00\x00\x00\x00'), (RX, b'\x15\x69\x6E\x73\x20\x21\x00\x00') ] - device_name = self.network[2].sdo[0x1008].raw + if self.use_async: + device_name = await self.network[2].sdo[0x1008].aget_raw() + else: + device_name = self.network[2].sdo[0x1008].raw self.assertEqual(device_name, "Tiny Node - Mega Domains !") - def test_segmented_download(self): + async def test_segmented_download(self): self.data = [ (TX, b'\x21\x00\x20\x00\x0d\x00\x00\x00'), (RX, b'\x60\x00\x20\x00\x00\x00\x00\x00'), @@ -140,9 +185,12 @@ def test_segmented_download(self): (TX, b'\x13\x73\x74\x72\x69\x6e\x67\x00'), (RX, b'\x30\x00\x20\x00\x00\x00\x00\x00') ] - self.network[2].sdo['Writable string'].raw = 'A long string' + if self.use_async: + await self.network[2].sdo['Writable string'].aset_raw('A long string') + else: + self.network[2].sdo['Writable string'].raw = 'A long string' - def test_block_download(self): + async def test_block_download(self): self.data = [ (TX, b'\xc6\x00\x20\x00\x1e\x00\x00\x00'), (RX, b'\xa4\x00\x20\x00\x7f\x00\x00\x00'), @@ -156,21 +204,27 @@ def test_block_download(self): (RX, b'\xa1\x00\x00\x00\x00\x00\x00\x00') ] data = b'A really really long string...' - with self.network[2].sdo['Writable string'].open( - 'wb', size=len(data), block_transfer=True) as fp: - fp.write(data) - - def test_segmented_download_zero_length(self): + if self.use_async: + self.skipTest("Async SDO block download not implemented yet") + else: + with self.network[2].sdo['Writable string'].open( + 'wb', size=len(data), block_transfer=True) as fp: + fp.write(data) + + async def test_segmented_download_zero_length(self): self.data = [ (TX, b'\x21\x00\x20\x00\x00\x00\x00\x00'), (RX, b'\x60\x00\x20\x00\x00\x00\x00\x00'), (TX, b'\x0F\x00\x00\x00\x00\x00\x00\x00'), (RX, b'\x20\x00\x00\x00\x00\x00\x00\x00'), ] - self.network[2].sdo[0x2000].raw = "" + if self.use_async: + await self.network[2].sdo[0x2000].aset_raw("") + else: + self.network[2].sdo[0x2000].raw = "" self.assertTrue(self.message_sent) - def test_block_upload(self): + async def test_block_upload(self): self.data = [ (TX, b'\xa4\x08\x10\x00\x7f\x00\x00\x00'), (RX, b'\xc6\x08\x10\x00\x1a\x00\x00\x00'), @@ -183,11 +237,14 @@ def test_block_upload(self): (RX, b'\xc9\x40\xe1\x00\x00\x00\x00\x00'), (TX, b'\xa1\x00\x00\x00\x00\x00\x00\x00') ] - with self.network[2].sdo[0x1008].open('r', block_transfer=True) as fp: - data = fp.read() + if self.use_async: + self.skipTest("Async SDO block upload not implemented yet") + else: + with self.network[2].sdo[0x1008].open('r', block_transfer=True) as fp: + data = fp.read() self.assertEqual(data, 'Tiny Node - Mega Domains !') - def test_sdo_block_upload_retransmit(self): + async def test_sdo_block_upload_retransmit(self): """Trigger a retransmit by only validating a block partially.""" self.data = [ (TX, b'\xa4\x08\x10\x00\x7f\x00\x00\x00'), @@ -488,11 +545,14 @@ def test_sdo_block_upload_retransmit(self): (RX, b'\xc9\x3b\x49\x00\x00\x00\x00\x00'), (TX, b'\xa1\x00\x00\x00\x00\x00\x00\x00'), # --> Transfer ends without issues ] - with self.network[2].sdo[0x1008].open('r', block_transfer=True) as fp: - data = fp.read() + if self.use_async: + self.skipTest("Async SDO block upload not implemented yet") + else: + with self.network[2].sdo[0x1008].open('r', block_transfer=True) as fp: + data = fp.read() self.assertEqual(data, 39 * 'the crazy fox jumps over the lazy dog\n') - def test_writable_file(self): + async def test_writable_file(self): self.data = [ (TX, b'\x20\x00\x20\x00\x00\x00\x00\x00'), (RX, b'\x60\x00\x20\x00\x00\x00\x00\x00'), @@ -503,31 +563,65 @@ def test_writable_file(self): (TX, b'\x0f\x00\x00\x00\x00\x00\x00\x00'), (RX, b'\x20\x00\x20\x00\x00\x00\x00\x00') ] - with self.network[2].sdo['Writable string'].open('wb') as fp: - fp.write(b'1234') - fp.write(b'56789') - self.assertTrue(fp.closed) - # Write on closed file - with self.assertRaises(ValueError): - fp.write(b'123') - - def test_abort(self): + if self.use_async: + self.skipTest("Async SDO writable file not implemented yet") + else: + with self.network[2].sdo['Writable string'].open('wb') as fp: + fp.write(b'1234') + fp.write(b'56789') + self.assertTrue(fp.closed) + # Write on closed file + with self.assertRaises(ValueError): + fp.write(b'123') + + async def test_abort(self): self.data = [ (TX, b'\x40\x18\x10\x01\x00\x00\x00\x00'), (RX, b'\x80\x18\x10\x01\x11\x00\x09\x06') ] - with self.assertRaises(canopen.SdoAbortedError) as cm: - _ = self.network[2].sdo[0x1018][1].raw + if self.use_async: + with self.assertRaises(canopen.SdoAbortedError) as cm: + _ = await self.network[2].sdo[0x1018][1].aget_raw() + else: + with self.assertRaises(canopen.SdoAbortedError) as cm: + _ = self.network[2].sdo[0x1018][1].raw self.assertEqual(cm.exception.code, 0x06090011) - def test_add_sdo_channel(self): + async def test_add_sdo_channel(self): client = self.network[2].add_sdo(0x123456, 0x234567) self.assertIn(client, self.network[2].sdo_channels) + async def test_async_protection(self): + self.data = [ + (TX, b'\x40\x18\x10\x01\x00\x00\x00\x00'), + (RX, b'\x43\x18\x10\x01\x04\x00\x00\x00') + ] + if self.use_async: + # Test that regular commands are not allowed in async mode + with self.assertRaises(RuntimeError): + _ = self.network[2].sdo[0x1018][1].raw + else: + self.skipTest("No async protection test needed in sync mode") + + +class TestSDOSync(TestSDO): + """ Run tests in synchronous mode. """ + __test__ = True + use_async = False + + +class TestSDOAsync(TestSDO): + """ Run tests in asynchronous mode. """ + __test__ = True + use_async = True -class TestSDOClientDatatypes(unittest.TestCase): + +class TestSDOClientDatatypes(unittest.IsolatedAsyncioTestCase): """Test the SDO client uploads with the different data types in CANopen.""" + __test__ = False # This is a base class, tests should not be run directly. + use_async: bool + def _send_message(self, can_id, data, remote=False): """Will be used instead of the usual Network.send_message method. @@ -542,85 +636,117 @@ def _send_message(self, can_id, data, remote=False): self.network.notify(0x582, self.data.pop(0)[1], 0.0) def setUp(self): - network = canopen.Network() + loop = None + if self.use_async: + loop = asyncio.get_event_loop() + + network = canopen.Network(loop=loop) network.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 network.send_message = self._send_message - node = network.add_node(2, DATATYPES_EDS) + with AllowBlocking(): + node = network.add_node(2, DATATYPES_EDS) node.sdo.RESPONSE_TIMEOUT = 0.01 self.node = node self.network = network - def test_boolean(self): + def tearDown(self): + self.network.disconnect() + + async def test_boolean(self): self.data = [ (TX, b'\x40\x01\x20\x00\x00\x00\x00\x00'), (RX, b'\x4f\x01\x20\x00\xfe\xfd\xfc\xfb') ] - data = self.network[2].sdo.upload(0x2000 + dt.BOOLEAN, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.BOOLEAN, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.BOOLEAN, 0) self.assertEqual(data, b'\xfe') - def test_unsigned8(self): + async def test_unsigned8(self): self.data = [ (TX, b'\x40\x05\x20\x00\x00\x00\x00\x00'), (RX, b'\x4f\x05\x20\x00\xfe\xfd\xfc\xfb') ] - data = self.network[2].sdo.upload(0x2000 + dt.UNSIGNED8, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.UNSIGNED8, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.UNSIGNED8, 0) self.assertEqual(data, b'\xfe') - def test_unsigned16(self): + async def test_unsigned16(self): self.data = [ (TX, b'\x40\x06\x20\x00\x00\x00\x00\x00'), (RX, b'\x4b\x06\x20\x00\xfe\xfd\xfc\xfb') ] - data = self.network[2].sdo.upload(0x2000 + dt.UNSIGNED16, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.UNSIGNED16, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.UNSIGNED16, 0) self.assertEqual(data, b'\xfe\xfd') - def test_unsigned24(self): + async def test_unsigned24(self): self.data = [ (TX, b'\x40\x16\x20\x00\x00\x00\x00\x00'), (RX, b'\x47\x16\x20\x00\xfe\xfd\xfc\xfb') ] - data = self.network[2].sdo.upload(0x2000 + dt.UNSIGNED24, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.UNSIGNED24, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.UNSIGNED24, 0) self.assertEqual(data, b'\xfe\xfd\xfc') - def test_unsigned32(self): + async def test_unsigned32(self): self.data = [ (TX, b'\x40\x07\x20\x00\x00\x00\x00\x00'), (RX, b'\x43\x07\x20\x00\xfe\xfd\xfc\xfb') ] - data = self.network[2].sdo.upload(0x2000 + dt.UNSIGNED32, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.UNSIGNED32, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.UNSIGNED32, 0) self.assertEqual(data, b'\xfe\xfd\xfc\xfb') - def test_unsigned40(self): + async def test_unsigned40(self): self.data = [ (TX, b'\x40\x18\x20\x00\x00\x00\x00\x00'), (RX, b'\x41\x18\x20\x00\xfe\xfd\xfc\xfb'), (TX, b'\x60\x00\x00\x00\x00\x00\x00\x00'), (RX, b'\x05\xb2\x01\x20\x02\x91\x12\x03'), ] - data = self.network[2].sdo.upload(0x2000 + dt.UNSIGNED40, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.UNSIGNED40, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.UNSIGNED40, 0) self.assertEqual(data, b'\xb2\x01\x20\x02\x91') - def test_unsigned48(self): + async def test_unsigned48(self): self.data = [ (TX, b'\x40\x19\x20\x00\x00\x00\x00\x00'), (RX, b'\x41\x19\x20\x00\xfe\xfd\xfc\xfb'), (TX, b'\x60\x00\x00\x00\x00\x00\x00\x00'), (RX, b'\x03\xb2\x01\x20\x02\x91\x12\x03'), ] - data = self.network[2].sdo.upload(0x2000 + dt.UNSIGNED48, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.UNSIGNED48, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.UNSIGNED48, 0) self.assertEqual(data, b'\xb2\x01\x20\x02\x91\x12') - def test_unsigned56(self): + async def test_unsigned56(self): self.data = [ (TX, b'\x40\x1a\x20\x00\x00\x00\x00\x00'), (RX, b'\x41\x1a\x20\x00\xfe\xfd\xfc\xfb'), (TX, b'\x60\x00\x00\x00\x00\x00\x00\x00'), (RX, b'\x01\xb2\x01\x20\x02\x91\x12\x03'), ] - data = self.network[2].sdo.upload(0x2000 + dt.UNSIGNED56, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.UNSIGNED56, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.UNSIGNED56, 0) self.assertEqual(data, b'\xb2\x01\x20\x02\x91\x12\x03') - def test_unsigned64(self): + async def test_unsigned64(self): self.data = [ (TX, b'\x40\x1b\x20\x00\x00\x00\x00\x00'), (RX, b'\x41\x1b\x20\x00\xfe\xfd\xfc\xfb'), @@ -629,72 +755,96 @@ def test_unsigned64(self): (TX, b'\x70\x00\x00\x00\x00\x00\x00\x00'), (RX, b'\x1d\x19\x21\x70\xfe\xfd\xfc\xfb'), ] - data = self.network[2].sdo.upload(0x2000 + dt.UNSIGNED64, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.UNSIGNED64, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.UNSIGNED64, 0) self.assertEqual(data, b'\xb2\x01\x20\x02\x91\x12\x03\x19') - def test_integer8(self): + async def test_integer8(self): self.data = [ (TX, b'\x40\x02\x20\x00\x00\x00\x00\x00'), (RX, b'\x4f\x02\x20\x00\xfe\xfd\xfc\xfb') ] - data = self.network[2].sdo.upload(0x2000 + dt.INTEGER8, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.INTEGER8, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.INTEGER8, 0) self.assertEqual(data, b'\xfe') - def test_integer16(self): + async def test_integer16(self): self.data = [ (TX, b'\x40\x03\x20\x00\x00\x00\x00\x00'), (RX, b'\x4b\x03\x20\x00\xfe\xfd\xfc\xfb') ] - data = self.network[2].sdo.upload(0x2000 + dt.INTEGER16, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.INTEGER16, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.INTEGER16, 0) self.assertEqual(data, b'\xfe\xfd') - def test_integer24(self): + async def test_integer24(self): self.data = [ (TX, b'\x40\x10\x20\x00\x00\x00\x00\x00'), (RX, b'\x47\x10\x20\x00\xfe\xfd\xfc\xfb') ] - data = self.network[2].sdo.upload(0x2000 + dt.INTEGER24, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.INTEGER24, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.INTEGER24, 0) self.assertEqual(data, b'\xfe\xfd\xfc') - def test_integer32(self): + async def test_integer32(self): self.data = [ (TX, b'\x40\x04\x20\x00\x00\x00\x00\x00'), (RX, b'\x43\x04\x20\x00\xfe\xfd\xfc\xfb') ] - data = self.network[2].sdo.upload(0x2000 + dt.INTEGER32, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.INTEGER32, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.INTEGER32, 0) self.assertEqual(data, b'\xfe\xfd\xfc\xfb') - def test_integer40(self): + async def test_integer40(self): self.data = [ (TX, b'\x40\x12\x20\x00\x00\x00\x00\x00'), (RX, b'\x41\x12\x20\x00\xfe\xfd\xfc\xfb'), (TX, b'\x60\x00\x00\x00\x00\x00\x00\x00'), (RX, b'\x05\xb2\x01\x20\x02\x91\x12\x03'), ] - data = self.network[2].sdo.upload(0x2000 + dt.INTEGER40, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.INTEGER40, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.INTEGER40, 0) self.assertEqual(data, b'\xb2\x01\x20\x02\x91') - def test_integer48(self): + async def test_integer48(self): self.data = [ (TX, b'\x40\x13\x20\x00\x00\x00\x00\x00'), (RX, b'\x41\x13\x20\x00\xfe\xfd\xfc\xfb'), (TX, b'\x60\x00\x00\x00\x00\x00\x00\x00'), (RX, b'\x03\xb2\x01\x20\x02\x91\x12\x03'), ] - data = self.network[2].sdo.upload(0x2000 + dt.INTEGER48, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.INTEGER48, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.INTEGER48, 0) self.assertEqual(data, b'\xb2\x01\x20\x02\x91\x12') - def test_integer56(self): + async def test_integer56(self): self.data = [ (TX, b'\x40\x14\x20\x00\x00\x00\x00\x00'), (RX, b'\x41\x14\x20\x00\xfe\xfd\xfc\xfb'), (TX, b'\x60\x00\x00\x00\x00\x00\x00\x00'), (RX, b'\x01\xb2\x01\x20\x02\x91\x12\x03'), ] - data = self.network[2].sdo.upload(0x2000 + dt.INTEGER56, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.INTEGER56, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.INTEGER56, 0) self.assertEqual(data, b'\xb2\x01\x20\x02\x91\x12\x03') - def test_integer64(self): + async def test_integer64(self): self.data = [ (TX, b'\x40\x15\x20\x00\x00\x00\x00\x00'), (RX, b'\x41\x15\x20\x00\xfe\xfd\xfc\xfb'), @@ -703,18 +853,24 @@ def test_integer64(self): (TX, b'\x70\x00\x00\x00\x00\x00\x00\x00'), (RX, b'\x1d\x19\x21\x70\xfe\xfd\xfc\xfb'), ] - data = self.network[2].sdo.upload(0x2000 + dt.INTEGER64, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.INTEGER64, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.INTEGER64, 0) self.assertEqual(data, b'\xb2\x01\x20\x02\x91\x12\x03\x19') - def test_real32(self): + async def test_real32(self): self.data = [ (TX, b'\x40\x08\x20\x00\x00\x00\x00\x00'), (RX, b'\x43\x08\x20\x00\xfe\xfd\xfc\xfb') ] - data = self.network[2].sdo.upload(0x2000 + dt.REAL32, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.REAL32, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.REAL32, 0) self.assertEqual(data, b'\xfe\xfd\xfc\xfb') - def test_real64(self): + async def test_real64(self): self.data = [ (TX, b'\x40\x11\x20\x00\x00\x00\x00\x00'), (RX, b'\x41\x11\x20\x00\xfe\xfd\xfc\xfb'), @@ -723,10 +879,13 @@ def test_real64(self): (TX, b'\x70\x00\x00\x00\x00\x00\x00\x00'), (RX, b'\x1d\x19\x21\x70\xfe\xfd\xfc\xfb'), ] - data = self.network[2].sdo.upload(0x2000 + dt.REAL64, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.REAL64, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.REAL64, 0) self.assertEqual(data, b'\xb2\x01\x20\x02\x91\x12\x03\x19') - def test_visible_string(self): + async def test_visible_string(self): self.data = [ (TX, b'\x40\x09\x20\x00\x00\x00\x00\x00'), (RX, b'\x41\x09\x20\x00\x1A\x00\x00\x00'), @@ -739,10 +898,13 @@ def test_visible_string(self): (TX, b'\x70\x00\x00\x00\x00\x00\x00\x00'), (RX, b'\x15\x69\x6E\x73\x20\x21\x00\x00') ] - data = self.network[2].sdo.upload(0x2000 + dt.VISIBLE_STRING, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.VISIBLE_STRING, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.VISIBLE_STRING, 0) self.assertEqual(data, b'Tiny Node - Mega Domains !') - def test_unicode_string(self): + async def test_unicode_string(self): self.data = [ (TX, b'\x40\x0b\x20\x00\x00\x00\x00\x00'), (RX, b'\x41\x0b\x20\x00\x1A\x00\x00\x00'), @@ -755,10 +917,13 @@ def test_unicode_string(self): (TX, b'\x70\x00\x00\x00\x00\x00\x00\x00'), (RX, b'\x15\x69\x6E\x73\x20\x21\x00\x00') ] - data = self.network[2].sdo.upload(0x2000 + dt.UNICODE_STRING, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.UNICODE_STRING, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.UNICODE_STRING, 0) self.assertEqual(data, b'Tiny Node - Mega Domains !') - def test_octet_string(self): + async def test_octet_string(self): self.data = [ (TX, b'\x40\x0a\x20\x00\x00\x00\x00\x00'), (RX, b'\x41\x0a\x20\x00\x1A\x00\x00\x00'), @@ -771,10 +936,13 @@ def test_octet_string(self): (TX, b'\x70\x00\x00\x00\x00\x00\x00\x00'), (RX, b'\x15\x69\x6E\x73\x20\x21\x00\x00') ] - data = self.network[2].sdo.upload(0x2000 + dt.OCTET_STRING, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.OCTET_STRING, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.OCTET_STRING, 0) self.assertEqual(data, b'Tiny Node - Mega Domains !') - def test_domain(self): + async def test_domain(self): self.data = [ (TX, b'\x40\x0f\x20\x00\x00\x00\x00\x00'), (RX, b'\x41\x0f\x20\x00\x1A\x00\x00\x00'), @@ -787,19 +955,25 @@ def test_domain(self): (TX, b'\x70\x00\x00\x00\x00\x00\x00\x00'), (RX, b'\x15\x69\x6E\x73\x20\x21\x00\x00') ] - data = self.network[2].sdo.upload(0x2000 + dt.DOMAIN, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2000 + dt.DOMAIN, 0) + else: + data = self.network[2].sdo.upload(0x2000 + dt.DOMAIN, 0) self.assertEqual(data, b'Tiny Node - Mega Domains !') - def test_unknown_od_32(self): + async def test_unknown_od_32(self): """Test an unknown OD entry of 32 bits (4 bytes).""" self.data = [ (TX, b'\x40\xFF\x20\x00\x00\x00\x00\x00'), (RX, b'\x43\xFF\x20\x00\xfe\xfd\xfc\xfb') ] - data = self.network[2].sdo.upload(0x20FF, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x20FF, 0) + else: + data = self.network[2].sdo.upload(0x20FF, 0) self.assertEqual(data, b'\xfe\xfd\xfc\xfb') - def test_unknown_od_112(self): + async def test_unknown_od_112(self): """Test an unknown OD entry of 112 bits (14 bytes).""" self.data = [ (TX, b'\x40\xFF\x20\x00\x00\x00\x00\x00'), @@ -809,10 +983,13 @@ def test_unknown_od_112(self): (TX, b'\x70\x00\x00\x00\x00\x00\x00\x00'), (RX, b'\x11\x19\x21\x70\xfe\xfd\xfc\xfb'), ] - data = self.network[2].sdo.upload(0x20FF, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x20FF, 0) + else: + data = self.network[2].sdo.upload(0x20FF, 0) self.assertEqual(data, b'\xb2\x01\x20\x02\x91\x12\x03\x19\x21\x70\xfe\xfd\xfc\xfb') - def test_unknown_datatype32(self): + async def test_unknown_datatype32(self): """Test an unknown datatype, but known OD, of 32 bits (4 bytes).""" # Add fake entry 0x2100 to OD, using fake datatype 0xFF if 0x2100 not in self.node.object_dictionary: @@ -823,10 +1000,13 @@ def test_unknown_datatype32(self): (TX, b'\x40\x00\x21\x00\x00\x00\x00\x00'), (RX, b'\x43\x00\x21\x00\xfe\xfd\xfc\xfb') ] - data = self.network[2].sdo.upload(0x2100, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2100, 0) + else: + data = self.network[2].sdo.upload(0x2100, 0) self.assertEqual(data, b'\xfe\xfd\xfc\xfb') - def test_unknown_datatype112(self): + async def test_unknown_datatype112(self): """Test an unknown datatype, but known OD, of 112 bits (14 bytes).""" # Add fake entry 0x2100 to OD, using fake datatype 0xFF if 0x2100 not in self.node.object_dictionary: @@ -841,8 +1021,24 @@ def test_unknown_datatype112(self): (TX, b'\x70\x00\x00\x00\x00\x00\x00\x00'), (RX, b'\x11\x19\x21\x70\xfe\xfd\xfc\xfb'), ] - data = self.network[2].sdo.upload(0x2100, 0) + if self.use_async: + data = await self.network[2].sdo.aupload(0x2100, 0) + else: + data = self.network[2].sdo.upload(0x2100, 0) self.assertEqual(data, b'\xb2\x01\x20\x02\x91\x12\x03\x19\x21\x70\xfe\xfd\xfc\xfb') + +class TestSDOClientDatatypesSync(TestSDOClientDatatypes): + """ Run tests in synchronous mode. """ + __test__ = True + use_async = False + + +class TestSDOClientDatatypesAsync(TestSDOClientDatatypes): + """ Run tests in asynchronous mode. """ + __test__ = True + use_async = True + + if __name__ == "__main__": unittest.main() diff --git a/test/test_sync.py b/test/test_sync.py index 93633538..8f7a76b2 100644 --- a/test/test_sync.py +++ b/test/test_sync.py @@ -1,5 +1,6 @@ import threading import unittest +import asyncio import can @@ -10,26 +11,34 @@ TIMEOUT = PERIOD * 10 -class TestSync(unittest.TestCase): +class TestSync(unittest.IsolatedAsyncioTestCase): + + __test__ = False # This is a base class, tests should not be run directly. + use_async: bool + def setUp(self): - self.net = canopen.Network() + loop = None + if self.use_async: + loop = asyncio.get_event_loop() + + self.net = canopen.Network(loop=loop) self.net.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 self.net.connect(interface="virtual") self.sync = canopen.sync.SyncProducer(self.net) - self.rxbus = can.Bus(interface="virtual") + self.rxbus = can.Bus(interface="virtual", loop=loop) def tearDown(self): self.net.disconnect() self.rxbus.shutdown() - def test_sync_producer_transmit(self): + async def test_sync_producer_transmit(self): self.sync.transmit() msg = self.rxbus.recv(TIMEOUT) self.assertIsNotNone(msg) self.assertEqual(msg.arbitration_id, 0x80) self.assertEqual(msg.dlc, 0) - def test_sync_producer_transmit_count(self): + async def test_sync_producer_transmit_count(self): self.sync.transmit(2) msg = self.rxbus.recv(TIMEOUT) self.assertIsNotNone(msg) @@ -37,11 +46,11 @@ def test_sync_producer_transmit_count(self): self.assertEqual(msg.dlc, 1) self.assertEqual(msg.data, b"\x02") - def test_sync_producer_start_invalid_period(self): + async def test_sync_producer_start_invalid_period(self): with self.assertRaises(ValueError): self.sync.start(0) - def test_sync_producer_start(self): + async def test_sync_producer_start(self): self.sync.start(PERIOD) self.addCleanup(self.sync.stop) @@ -75,5 +84,17 @@ def periodicity(): self.assertIsNone(self.net.bus.recv(TIMEOUT)) +class TestSyncSync(TestSync): + """ Test the functions in synchronous mode. """ + __test__ = True + use_async = False + + +class TestSyncAsync(TestSync): + """ Test the functions in asynchronous mode. """ + __test__ = True + use_async = True + + if __name__ == "__main__": unittest.main() diff --git a/test/test_time.py b/test/test_time.py index fa45a444..477e2efa 100644 --- a/test/test_time.py +++ b/test/test_time.py @@ -1,3 +1,4 @@ +import asyncio import struct import time import unittest @@ -8,17 +9,26 @@ import canopen.timestamp -class TestTime(unittest.TestCase): +class TestTime(unittest.IsolatedAsyncioTestCase): - def test_epoch(self): + __test__ = False # This is a base class, tests should not be run directly. + use_async: bool + + def setUp(self): + self.loop = None + if self.use_async: + self.loop = asyncio.get_event_loop() + + async def test_epoch(self): """Verify that the epoch matches the standard definition.""" epoch = datetime.strptime( "1984-01-01 00:00:00 +0000", "%Y-%m-%d %H:%M:%S %z" ).timestamp() self.assertEqual(int(epoch), canopen.timestamp.OFFSET) - def test_time_producer(self): - network = canopen.Network() + async def test_time_producer(self): + network = canopen.Network(loop=self.loop) + self.addCleanup(network.disconnect) network.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 network.connect(interface="virtual", receive_own_messages=True) producer = canopen.timestamp.TimeProducer(network) @@ -42,7 +52,17 @@ def test_time_producer(self): self.assertEqual(days, int(current_from_epoch) // 86400) self.assertEqual(ms, int(current_from_epoch % 86400 * 1000)) - network.disconnect() + +class TestTimeSync(TestTime): + """ Test time functions in synchronous mode. """ + __test__ = True + use_async = False + + +class TestTimeAsync(TestTime): + """ Test time functions in asynchronous mode. """ + __test__ = True + use_async = True if __name__ == "__main__": diff --git a/test/test_utils.py b/test/test_utils.py index a17cce92..18c60d2f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -3,7 +3,7 @@ from canopen.utils import pretty_index -class TestUtils(unittest.TestCase): +class TestUtils(unittest.IsolatedAsyncioTestCase): def test_pretty_index(self): self.assertEqual(pretty_index(0x12ab), "0x12AB")