Skip to content

Commit e83b152

Browse files
authored
Add additional Discover flow tests (#217)
1 parent 1da1885 commit e83b152

File tree

1 file changed

+177
-16
lines changed

1 file changed

+177
-16
lines changed

msmart/tests/test_discover.py

Lines changed: 177 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,37 @@
1+
import asyncio
12
import unittest
3+
import unittest.mock as mock
4+
from unittest.mock import patch
25

3-
from msmart.const import DeviceType
6+
from msmart.const import DISCOVERY_MSG, DeviceType
47
from msmart.device import AirConditioner as AC
5-
from msmart.discover import Discover
8+
from msmart.discover import _IPV4_BROADCAST, Discover
9+
10+
_DISCOVER_RESPONSES = [
11+
(("10.100.1.140", 6445), bytes.fromhex("5a5a011178007a8000000000000000000000000060ca0000000e0000000000000000000001000000c08651cb1b88a167bdcf7d37534ef81312d39429bf9b2673f200b635fae369a560fa9655eab8344be22b1e3b024ef5dfd392dc3db64dbffb6a66fb9cd5ec87a78000cd9043833b9f76991e8af29f3496")),
12+
(("10.100.1.239", 6445), bytes.fromhex("837000c8200f00005a5a0111b8007a800000000061433702060817143daa00000086000000000000000001800000000041c7129527bc03ee009284a90c2fbd2f179764ac35b55e7fb0e4ab0de9298fa1a5ca328046c603fb1ab60079d550d03546b605180127fdb5bb33a105f5206b5f008bffba2bae272aa0c96d56b45c4afa33f826a0a4215d1dd87956a267d2dbd34bdfb3e16e33d88768cc4c3d0658937d0bb19369bf0317b24d3a4de9e6a13106f7ceb5acc6651ce53d684a32ce34dc3a4fbe0d4139de99cc88a0285e14657045")),
13+
]
614

715

816
class TestDiscover(unittest.IsolatedAsyncioTestCase):
917
# pylint: disable=protected-access
1018

1119
async def test_discover_v2(self) -> None:
1220
"""Test that we can parse a V2 discovery response."""
13-
DISCOVER_RESPONSE_V2 = bytes.fromhex(
14-
"5a5a011178007a8000000000000000000000000060ca0000000e0000000000000000000001000000c08651cb1b88a167bdcf7d37534ef81312d39429bf9b2673f200b635fae369a560fa9655eab8344be22b1e3b024ef5dfd392dc3db64dbffb6a66fb9cd5ec87a78000cd9043833b9f76991e8af29f3496")
15-
IP_ADDRESS = "10.100.1.140"
21+
HOST, RESPONSE_V2 = _DISCOVER_RESPONSES[0]
1622

1723
# Check version
18-
version = Discover._get_device_version(DISCOVER_RESPONSE_V2)
24+
version = Discover._get_device_version(RESPONSE_V2)
1925
self.assertEqual(version, 2)
2026

2127
# Check info matches
22-
info = await Discover._get_device_info(IP_ADDRESS, version, DISCOVER_RESPONSE_V2)
28+
info = await Discover._get_device_info(HOST[0], version, RESPONSE_V2)
2329
self.assertIsNotNone(info)
2430

25-
# Stop type errors
31+
# Suppress type errors
2632
assert info is not None
2733

28-
self.assertEqual(info["ip"], IP_ADDRESS)
34+
self.assertEqual(info["ip"], HOST[0])
2935
self.assertEqual(info["port"], 6444)
3036

3137
self.assertEqual(info["device_id"], 15393162840672)
@@ -44,22 +50,20 @@ async def test_discover_v2(self) -> None:
4450

4551
async def test_discover_v3(self) -> None:
4652
"""Test that we can parse a V3 discovery response."""
47-
DISCOVER_RESPONSE_V3 = bytes.fromhex(
48-
"837000c8200f00005a5a0111b8007a800000000061433702060817143daa00000086000000000000000001800000000041c7129527bc03ee009284a90c2fbd2f179764ac35b55e7fb0e4ab0de9298fa1a5ca328046c603fb1ab60079d550d03546b605180127fdb5bb33a105f5206b5f008bffba2bae272aa0c96d56b45c4afa33f826a0a4215d1dd87956a267d2dbd34bdfb3e16e33d88768cc4c3d0658937d0bb19369bf0317b24d3a4de9e6a13106f7ceb5acc6651ce53d684a32ce34dc3a4fbe0d4139de99cc88a0285e14657045")
49-
IP_ADDRESS = "10.100.1.239"
53+
HOST, RESPONSE_V3 = _DISCOVER_RESPONSES[1]
5054

5155
# Check version
52-
version = Discover._get_device_version(DISCOVER_RESPONSE_V3)
56+
version = Discover._get_device_version(RESPONSE_V3)
5357
self.assertEqual(version, 3)
5458

5559
# Check info matches
56-
info = await Discover._get_device_info(IP_ADDRESS, version, DISCOVER_RESPONSE_V3)
60+
info = await Discover._get_device_info(HOST[0], version, RESPONSE_V3)
5761
self.assertIsNotNone(info)
5862

59-
# Stop type errors
63+
# Suppress type errors
6064
assert info is not None
6165

62-
self.assertEqual(info["ip"], IP_ADDRESS)
66+
self.assertEqual(info["ip"], HOST[0])
6367
self.assertEqual(info["port"], 6444)
6468

6569
self.assertEqual(info["device_id"], 147334558165565)
@@ -77,5 +81,162 @@ async def test_discover_v3(self) -> None:
7781
self.assertIsNotNone(device)
7882

7983

84+
class TestDiscoverProtocol(unittest.IsolatedAsyncioTestCase):
85+
# pylint: disable=protected-access
86+
87+
async def _discover(self, *args, method=Discover.discover, **kwargs):
88+
"""Run the msmart-ng discover flow with necessary mocking."""
89+
90+
# Mock the underlying transport
91+
mock_transport = mock.MagicMock()
92+
protocol = None
93+
94+
# Define the side effect method for our mock create_datagram_endpoint which creates the real protocol
95+
def mock_create_datagram_endpoint_side_effect(protocol_factory, *args, **kwargs):
96+
nonlocal protocol, mock_transport
97+
98+
# Build the protocol from the factory
99+
protocol = protocol_factory()
100+
101+
# "Make" a connection
102+
protocol.connection_made(mock_transport)
103+
104+
return (mock_transport, protocol)
105+
106+
# Patch the create_datagram_endpoint to use our side effect method
107+
with patch('asyncio.BaseEventLoop.create_datagram_endpoint', side_effect=mock_create_datagram_endpoint_side_effect) as mock_create_datagram_endpoint:
108+
# Create a task to run discover concurrently
109+
task = asyncio.create_task(method(*args, **kwargs))
110+
111+
# Sleep a little to let the discover task start
112+
await asyncio.sleep(0.1)
113+
114+
# Assert the mocked create_datagram_endpoint was called
115+
mock_create_datagram_endpoint.assert_called_once()
116+
117+
# Suppress type errors
118+
assert protocol is not None
119+
120+
# Assert protocol and transport are assigned
121+
self.assertIsNotNone(protocol)
122+
self.assertEqual(protocol._transport, mock_transport)
123+
124+
return mock_transport, protocol, task
125+
126+
async def test_discover_broadcast(self) -> None:
127+
"""Test that Discover.discover sends broadcast packets."""
128+
# Start discovery
129+
mock_transport, protocol, discover_task = await self._discover(method=Discover.discover, discovery_packets=1, timeout=1)
130+
131+
# Wait for discovery to finish
132+
devices = await discover_task
133+
134+
# Assert that we tried to send discovery broadcasts
135+
mock_transport.sendto.assert_has_calls([
136+
mock.call(DISCOVERY_MSG, (_IPV4_BROADCAST, 6445)),
137+
mock.call(DISCOVERY_MSG, (_IPV4_BROADCAST, 20086))
138+
])
139+
140+
# Check that transport is closed
141+
mock_transport.close.assert_called_once()
142+
143+
# Assert no devices discovered
144+
self.assertEqual(devices, [])
145+
146+
async def test_discover_single(self) -> None:
147+
"""Test that Discover.discover_single sends packets to a particular host."""
148+
TARGET_HOST = "1.1.1.1"
149+
150+
# Start discovery
151+
mock_transport, protocol, discover_task = await self._discover(TARGET_HOST, method=Discover.discover_single, discovery_packets=1, timeout=1)
152+
153+
# Wait for discovery to finish
154+
device = await discover_task
155+
156+
# Assert that we tried to send discovery broadcasts
157+
mock_transport.sendto.assert_has_calls([
158+
mock.call(DISCOVERY_MSG, (TARGET_HOST, 6445)),
159+
mock.call(DISCOVERY_MSG, (TARGET_HOST, 20086))
160+
])
161+
162+
# Check that transport is closed
163+
mock_transport.close.assert_called_once()
164+
165+
# Assert no devices discovered
166+
self.assertEqual(device, None)
167+
168+
async def test_discover_devices(self):
169+
"""Test that discover processes device responses and returns a list of devices."""
170+
# Start discovery
171+
mock_transport, protocol, discover_task = await self._discover(
172+
discovery_packets=1,
173+
timeout=1,
174+
auto_connect=False # Disable auto connect for this test
175+
)
176+
177+
# Suppress type errors
178+
assert protocol is not None
179+
180+
# Mock responses from devices
181+
for host, response in _DISCOVER_RESPONSES:
182+
protocol.datagram_received(response, host)
183+
184+
# Wait for discovery to complete
185+
devices = await discover_task
186+
187+
# Check that transport is closed
188+
mock_transport.close.assert_called_once()
189+
190+
# Assert expected devices were found
191+
self.assertIsNotNone(devices)
192+
self.assertEqual(len(devices), len(_DISCOVER_RESPONSES))
193+
194+
self.assertIsInstance(devices[0], AC)
195+
self.assertIsInstance(devices[1], AC)
196+
197+
async def test_discover_device_with_connect(self):
198+
"""Test that discover attempts to automatically connect to discovered device."""
199+
# Start discovery
200+
mock_transport, protocol, discover_task = await self._discover(
201+
discovery_packets=1,
202+
timeout=1,
203+
auto_connect=True # Enable auto connect for this test
204+
)
205+
206+
# Suppress type errors
207+
assert protocol is not None
208+
209+
# Mock responses from a device
210+
host, response = _DISCOVER_RESPONSES[0]
211+
protocol.datagram_received(response, host)
212+
213+
# Define the side effect method for our mock connect to force the device online
214+
def mock_connect_side_effect(dev):
215+
# Force device online and supported
216+
dev._online = True
217+
dev._supported = True
218+
return True
219+
220+
# Patch the Discover.connect method to fake a device connection
221+
with patch("msmart.discover.Discover.connect", side_effect=mock_connect_side_effect) as mock_connect:
222+
# Wait for discovery to complete
223+
devices = await discover_task
224+
225+
# Assert expected device was found
226+
self.assertIsNotNone(devices)
227+
self.assertEqual(len(devices), 1)
228+
229+
# Assert connection attempt was made
230+
device = devices[0]
231+
mock_connect.assert_called_once_with(device)
232+
233+
# Check that transport is closed
234+
mock_transport.close.assert_called_once()
235+
236+
# Assert device connected and is supported
237+
self.assertTrue(device.online)
238+
self.assertTrue(device.supported)
239+
240+
80241
if __name__ == "__main__":
81242
unittest.main()

0 commit comments

Comments
 (0)