From afa668f2480f034376c6899fba696844206c2a5b Mon Sep 17 00:00:00 2001 From: yuvraj Kolkar Date: Tue, 22 Jul 2025 02:08:47 +0530 Subject: [PATCH] Add test_emcy.py for EMCY module coverage --- test/test_emcy.py | 309 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 301 insertions(+), 8 deletions(-) diff --git a/test/test_emcy.py b/test/test_emcy.py index d883e9c8..f14b1f98 100644 --- a/test/test_emcy.py +++ b/test/test_emcy.py @@ -2,6 +2,7 @@ import threading import unittest from contextlib import contextmanager +from unittest.mock import Mock, patch import can @@ -25,13 +26,11 @@ def check_error(self, err, code, reg, data, ts): self.assertAlmostEqual(err.timestamp, ts) def test_emcy_consumer_on_emcy(self): - # Make sure multiple callbacks receive the same information. acc1 = [] acc2 = [] self.emcy.add_callback(lambda err: acc1.append(err)) self.emcy.add_callback(lambda err: acc2.append(err)) - # Dispatch an EMCY datagram. self.emcy.on_emcy(0x81, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) self.assertEqual(len(self.emcy.log), 1) @@ -45,7 +44,6 @@ def test_emcy_consumer_on_emcy(self): data=bytes([0, 1, 2, 3, 4]), ts=1000, ) - # Dispatch a new EMCY datagram. self.emcy.on_emcy(0x81, b'\x10\x90\x01\x04\x03\x02\x01\x00', 2000) self.assertEqual(len(self.emcy.log), 2) self.assertEqual(len(self.emcy.active), 2) @@ -58,7 +56,6 @@ def test_emcy_consumer_on_emcy(self): data=bytes([4, 3, 2, 1, 0]), ts=2000, ) - # Dispatch an EMCY reset. self.emcy.on_emcy(0x81, b'\x00\x00\x00\x00\x00\x00\x00\x00', 2000) self.assertEqual(len(self.emcy.log), 3) self.assertEqual(len(self.emcy.active), 0) @@ -94,24 +91,20 @@ def timer(func): finally: t.join(TIMEOUT) - # Check unfiltered wait, on timeout. self.assertIsNone(self.emcy.wait(timeout=TIMEOUT)) - # Check unfiltered wait, on success. with timer(push_err) as t: with self.assertLogs(level=logging.INFO): t.start() err = self.emcy.wait(timeout=TIMEOUT) check_err(err) - # Check filtered wait, on success. with timer(push_err) as t: with self.assertLogs(level=logging.INFO): t.start() err = self.emcy.wait(0x2001, TIMEOUT) check_err(err) - # Check filtered wait, on timeout. with timer(push_err) as t: t.start() self.assertIsNone(self.emcy.wait(0x9000, TIMEOUT)) @@ -123,6 +116,110 @@ def push_reset(): t.start() self.assertIsNone(self.emcy.wait(0x9000, TIMEOUT)) + def test_emcy_consumer_initialization(self): + """Test EmcyConsumer initialization state.""" + consumer = canopen.emcy.EmcyConsumer() + self.assertEqual(consumer.log, []) + self.assertEqual(consumer.active, []) + self.assertEqual(consumer.callbacks, []) + self.assertIsInstance(consumer.emcy_received, threading.Condition) + + def test_emcy_consumer_multiple_callbacks(self): + """Test adding multiple callbacks and their execution order.""" + call_order = [] + + def callback1(err): + call_order.append('callback1') + + def callback2(err): + call_order.append('callback2') + + def callback3(err): + call_order.append('callback3') + + self.emcy.add_callback(callback1) + self.emcy.add_callback(callback2) + self.emcy.add_callback(callback3) + + self.emcy.on_emcy(0x81, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) + + self.assertEqual(call_order, ['callback1', 'callback2', 'callback3']) + self.assertEqual(len(self.emcy.callbacks), 3) + + def test_emcy_consumer_callback_exception_handling(self): + """Test that callback exceptions don't break other callbacks or the system.""" + successful_callbacks = [] + + def failing_callback(err): + raise ValueError("Test exception in callback") + + def successful_callback1(err): + successful_callbacks.append('success1') + + def successful_callback2(err): + successful_callbacks.append('success2') + + self.emcy.add_callback(successful_callback1) + self.emcy.add_callback(failing_callback) + self.emcy.add_callback(successful_callback2) + + with self.assertRaises(ValueError): + self.emcy.on_emcy(0x81, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) + + def test_emcy_consumer_error_reset_variants(self): + """Test different error reset code patterns.""" + self.emcy.on_emcy(0x81, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) + self.emcy.on_emcy(0x81, b'\x10\x90\x01\x04\x03\x02\x01\x00', 2000) + self.assertEqual(len(self.emcy.active), 2) + + self.emcy.on_emcy(0x81, b'\x00\x00\x00\x00\x00\x00\x00\x00', 3000) + self.assertEqual(len(self.emcy.active), 0) + + self.emcy.on_emcy(0x81, b'\x01\x30\x02\x00\x01\x02\x03\x04', 4000) + self.assertEqual(len(self.emcy.active), 1) + + self.emcy.on_emcy(0x81, b'\x99\x00\x01\x00\x00\x00\x00\x00', 5000) + self.assertEqual(len(self.emcy.active), 0) + + def test_emcy_consumer_wait_timeout_edge_cases(self): + """Test wait method with various timeout scenarios.""" + result = self.emcy.wait(timeout=0) + self.assertIsNone(result) + + result = self.emcy.wait(timeout=0.001) + self.assertIsNone(result) + + def test_emcy_consumer_wait_concurrent_errors(self): + """Test wait method when multiple errors arrive concurrently.""" + def push_multiple_errors(): + self.emcy.on_emcy(0x81, b'\x01\x20\x01\x01\x02\x03\x04\x05', 100) + self.emcy.on_emcy(0x81, b'\x02\x20\x01\x01\x02\x03\x04\x05', 101) + self.emcy.on_emcy(0x81, b'\x03\x20\x01\x01\x02\x03\x04\x05', 102) + + t = threading.Timer(TIMEOUT / 2, push_multiple_errors) + with self.assertLogs(level=logging.INFO): + t.start() + err = self.emcy.wait(0x2003, timeout=TIMEOUT) + t.join(TIMEOUT) + + self.assertIsNotNone(err) + self.assertEqual(err.code, 0x2003) + + def test_emcy_consumer_wait_time_expiry_during_execution(self): + """Test wait method when time expires while processing.""" + def push_err_with_delay(): + import time + time.sleep(TIMEOUT * 1.5) + self.emcy.on_emcy(0x81, b'\x01\x20\x01\x01\x02\x03\x04\x05', 100) + + t = threading.Timer(TIMEOUT / 4, push_err_with_delay) + t.start() + + result = self.emcy.wait(timeout=TIMEOUT) + t.join(TIMEOUT * 2) + + self.assertIsNone(result) + class TestEmcyError(unittest.TestCase): def test_emcy_error(self): @@ -180,6 +277,75 @@ def check(code, expected): check(0xff00, "Device Specific") check(0xffff, "Device Specific") + def test_emcy_error_initialization_types(self): + """Test EmcyError initialization with various data types.""" + error = EmcyError(0x1000, 0, b'', 123.456) + self.assertEqual(error.code, 0x1000) + self.assertEqual(error.register, 0) + self.assertEqual(error.data, b'') + self.assertEqual(error.timestamp, 123.456) + + error = EmcyError(0xFFFF, 0xFF, b'\xFF' * 5, float('inf')) + self.assertEqual(error.code, 0xFFFF) + self.assertEqual(error.register, 0xFF) + self.assertEqual(error.data, b'\xFF' * 5) + self.assertEqual(error.timestamp, float('inf')) + + def test_emcy_error_str_edge_cases(self): + """Test string representation with edge cases.""" + error = EmcyError(0x0000, 0, b'', 1000) + self.assertEqual(str(error), "Code 0x0000, Error Reset / No Error") + + error = EmcyError(0x0001, 0, b'', 1000) + self.assertEqual(str(error), "Code 0x0001, Error Reset / No Error") + + error = EmcyError(0x0100, 0, b'', 1000) + self.assertEqual(str(error), "Code 0x0100") + + error = EmcyError(0xFFFF, 0, b'', 1000) + self.assertEqual(str(error), "Code 0xFFFF, Device Specific") + + def test_emcy_error_get_desc_boundary_conditions(self): + """Test get_desc method with boundary conditions.""" + def check(code, expected): + err = EmcyError(code, 1, b'', 1000) + actual = err.get_desc() + self.assertEqual(actual, expected) + + check(0x0000, "Error Reset / No Error") + check(0x00FF, "Error Reset / No Error") + check(0x0100, "") + + check(0x0FFF, "") + check(0x1000, "Generic Error") + check(0x10FF, "Generic Error") + check(0x1100, "") + + check(0x1FFF, "") + check(0x2000, "Current") + check(0x2FFF, "Current") + check(0x3000, "Voltage") + + check(0x4FFF, "Temperature") + check(0x5000, "Device Hardware") + check(0x50FF, "Device Hardware") + check(0x5100, "") + + def test_emcy_error_inheritance(self): + """Test that EmcyError properly inherits from Exception.""" + error = EmcyError(0x1000, 0, b'', 1000) + + self.assertIsInstance(error, Exception) + + with self.assertRaises(EmcyError): + raise error + + try: + raise error + except Exception as e: + self.assertIsInstance(e, EmcyError) + self.assertEqual(e.code, 0x1000) + class TestEmcyProducer(unittest.TestCase): def setUp(self): @@ -220,6 +386,133 @@ def check(*args, res): check(3, res=b'\x00\x00\x03\x00\x00\x00\x00\x00') check(3, b"\xaa\xbb", res=b'\x00\x00\x03\xaa\xbb\x00\x00\x00') + def test_emcy_producer_initialization(self): + """Test EmcyProducer initialization.""" + producer = canopen.emcy.EmcyProducer(0x123) + self.assertEqual(producer.cob_id, 0x123) + network = producer.network + self.assertIsNotNone(network) + + def test_emcy_producer_send_edge_cases(self): + """Test EmcyProducer send method with edge cases.""" + def check(*args, res): + self.emcy.send(*args) + self.check_response(res) + + check(0xFFFF, 0xFF, b'\xFF\xFF\xFF\xFF\xFF', + res=b'\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF') + + check(0x0000, 0x00, b'', + res=b'\x00\x00\x00\x00\x00\x00\x00\x00') + + check(0x1234, 0x56, b'\xAB\xCD', + res=b'\x34\x12\x56\xAB\xCD\x00\x00\x00') + + check(0x1234, 0x56, b'\xAB\xCD\xEF\x12\x34', + res=b'\x34\x12\x56\xAB\xCD\xEF\x12\x34') + + def test_emcy_producer_reset_edge_cases(self): + """Test EmcyProducer reset method with edge cases.""" + def check(*args, res): + self.emcy.reset(*args) + self.check_response(res) + + check(0xFF, res=b'\x00\x00\xFF\x00\x00\x00\x00\x00') + + check(0xFF, b'\xFF\xFF\xFF\xFF\xFF', + res=b'\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF') + + check(0x12, b'\xAB\xCD', + res=b'\x00\x00\x12\xAB\xCD\x00\x00\x00') + + def test_emcy_producer_network_assignment(self): + """Test EmcyProducer network assignment and usage.""" + producer = canopen.emcy.EmcyProducer(0x100) + initial_network = producer.network + + producer.network = self.net + self.assertEqual(producer.network, self.net) + + producer.send(0x1000) + msg = self.rxbus.recv(TIMEOUT) + self.assertIsNotNone(msg) + self.assertEqual(msg.arbitration_id, 0x100) + + def test_emcy_producer_struct_packing(self): + """Test that the EMCY_STRUCT packing works correctly.""" + from canopen.emcy import EMCY_STRUCT + + packed = EMCY_STRUCT.pack(0x1234, 0x56, b'\xAB\xCD\xEF\x12\x34') + expected = b'\x34\x12\x56\xAB\xCD\xEF\x12\x34' + self.assertEqual(packed, expected) + + code, register, data = EMCY_STRUCT.unpack(expected) + self.assertEqual(code, 0x1234) + self.assertEqual(register, 0x56) + self.assertEqual(data, b'\xAB\xCD\xEF\x12\x34') + + packed = EMCY_STRUCT.pack(0x1234, 0x56, b'\xAB') + expected = b'\x34\x12\x56\xAB\x00\x00\x00\x00' + self.assertEqual(packed, expected) + + +class TestEmcyIntegration(unittest.TestCase): + """Integration tests for EMCY producer and consumer.""" + + def setUp(self): + self.txbus = can.Bus(interface="virtual") + self.rxbus = can.Bus(interface="virtual") + self.net = canopen.Network(self.txbus) + self.net.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 + self.net.connect() + + self.producer = canopen.emcy.EmcyProducer(0x081) + self.producer.network = self.net + + self.consumer = canopen.emcy.EmcyConsumer() + + def tearDown(self): + self.net.disconnect() + self.txbus.shutdown() + self.rxbus.shutdown() + + def test_producer_consumer_integration(self): + """Test that producer and consumer work together.""" + received_errors = [] + self.consumer.add_callback(lambda err: received_errors.append(err)) + + self.producer.send(0x2001, 0x02, b'\x01\x02\x03\x04\x05') + + msg = self.rxbus.recv(TIMEOUT) + self.assertIsNotNone(msg) + + self.consumer.on_emcy(msg.arbitration_id, msg.data, msg.timestamp) + + self.assertEqual(len(received_errors), 1) + self.assertEqual(len(self.consumer.log), 1) + self.assertEqual(len(self.consumer.active), 1) + + error = received_errors[0] + self.assertEqual(error.code, 0x2001) + self.assertEqual(error.register, 0x02) + self.assertEqual(error.data, b'\x01\x02\x03\x04\x05') + + def test_producer_reset_consumer_integration(self): + """Test producer reset clears consumer active errors.""" + self.consumer.on_emcy(0x081, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) + self.assertEqual(len(self.consumer.active), 1) + + self.producer.reset() + + msg = self.rxbus.recv(TIMEOUT) + self.assertIsNotNone(msg) + + self.consumer.on_emcy(msg.arbitration_id, msg.data, msg.timestamp) + + self.assertEqual(len(self.consumer.active), 0) + self.assertEqual(len(self.consumer.log), 2) + if __name__ == "__main__": unittest.main() + \ No newline at end of file