diff --git a/canopen/sync.py b/canopen/sync.py index 44ea56c3..c9a1c679 100644 --- a/canopen/sync.py +++ b/canopen/sync.py @@ -15,15 +15,17 @@ class SyncProducer: def __init__(self, network: canopen.network.Network): self.network = network self.period: Optional[float] = None - self._task = None + self._task: Optional[canopen.network.PeriodicMessageTask] = None def transmit(self, count: Optional[int] = None): """Send out a SYNC message once. :param count: Counter to add in message. + :raises ValueError: + If the counter value does not fit in one byte. """ - data = [count] if count is not None else [] + data = bytes([count]) if count is not None else b"" self.network.send_message(self.cob_id, data) def start(self, period: Optional[float] = None): @@ -31,16 +33,24 @@ def start(self, period: Optional[float] = None): :param period: Period of SYNC message in seconds. + :raises RuntimeError: + If a periodic transmission is already started. + :raises ValueError: + If no period is set via argument nor the instance attribute. """ + if self._task is not None: + raise RuntimeError("Periodic SYNC transmission task already running") + if period is not None: self.period = period if not self.period: raise ValueError("A valid transmission period has not been given") - self._task = self.network.send_periodic(self.cob_id, [], self.period) + self._task = self.network.send_periodic(self.cob_id, b"", self.period) def stop(self): """Stop periodic transmission of SYNC message.""" if self._task is not None: self._task.stop() + self._task = None diff --git a/test/test_sync.py b/test/test_sync.py index 93633538..66c4867d 100644 --- a/test/test_sync.py +++ b/test/test_sync.py @@ -74,6 +74,16 @@ def periodicity(): if msg is not None: self.assertIsNone(self.net.bus.recv(TIMEOUT)) + def test_sync_producer_restart(self): + self.sync.start(PERIOD) + self.addCleanup(self.sync.stop) + # Cannot start again while running + with self.assertRaises(RuntimeError): + self.sync.start(PERIOD) + # Can restart after stopping + self.sync.stop() + self.sync.start(PERIOD) + if __name__ == "__main__": unittest.main()