Skip to content

Commit e213691

Browse files
committed
Use contextlib.ExitStack to make the nested context managers cleaner.
1 parent e8a4d80 commit e213691

File tree

1 file changed

+101
-74
lines changed

1 file changed

+101
-74
lines changed

tests/test_cli.py

Lines changed: 101 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Tests for the pybricksdev CLI commands."""
22

33
import argparse
4+
import contextlib
45
import io
56
import os
67
import tempfile
7-
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
8+
from unittest.mock import AsyncMock, mock_open, patch
89

910
import pytest
1011

@@ -79,20 +80,27 @@ async def test_download_ble(self):
7980
name="MyHub",
8081
)
8182

82-
# Mock the hub creation
83-
with patch(
84-
"pybricksdev.connections.pybricks.PybricksHubBLE", return_value=mock_hub
85-
) as mock_hub_class:
86-
with patch("pybricksdev.ble.find_device", return_value="mock_device"):
87-
# Run the command
88-
download = Download()
89-
await download.run(args)
83+
# Set up mocks using ExitStack
84+
with contextlib.ExitStack() as stack:
85+
mock_hub_class = stack.enter_context(
86+
patch(
87+
"pybricksdev.connections.pybricks.PybricksHubBLE",
88+
return_value=mock_hub,
89+
)
90+
)
91+
stack.enter_context(
92+
patch("pybricksdev.ble.find_device", return_value="mock_device")
93+
)
94+
95+
# Run the command
96+
download = Download()
97+
await download.run(args)
9098

91-
# Verify the hub was created and used correctly
92-
mock_hub_class.assert_called_once_with("mock_device")
93-
mock_hub.connect.assert_called_once()
94-
mock_hub.download.assert_called_once()
95-
mock_hub.disconnect.assert_called_once()
99+
# Verify the hub was created and used correctly
100+
mock_hub_class.assert_called_once_with("mock_device")
101+
mock_hub.connect.assert_called_once()
102+
mock_hub.download.assert_called_once()
103+
mock_hub.disconnect.assert_called_once()
96104
finally:
97105
os.unlink(temp_path)
98106

@@ -117,20 +125,25 @@ async def test_download_usb(self):
117125
name=None,
118126
)
119127

120-
# Mock the hub creation
121-
with patch(
122-
"pybricksdev.connections.pybricks.PybricksHubUSB", return_value=mock_hub
123-
) as mock_hub_class:
124-
with patch("usb.core.find", return_value="mock_device"):
125-
# Run the command
126-
download = Download()
127-
await download.run(args)
128+
# Set up mocks using ExitStack
129+
with contextlib.ExitStack() as stack:
130+
mock_hub_class = stack.enter_context(
131+
patch(
132+
"pybricksdev.connections.pybricks.PybricksHubUSB",
133+
return_value=mock_hub,
134+
)
135+
)
136+
stack.enter_context(patch("usb.core.find", return_value="mock_device"))
137+
138+
# Run the command
139+
download = Download()
140+
await download.run(args)
128141

129-
# Verify the hub was created and used correctly
130-
mock_hub_class.assert_called_once_with("mock_device")
131-
mock_hub.connect.assert_called_once()
132-
mock_hub.download.assert_called_once()
133-
mock_hub.disconnect.assert_called_once()
142+
# Verify the hub was created and used correctly
143+
mock_hub_class.assert_called_once_with("mock_device")
144+
mock_hub.connect.assert_called_once()
145+
mock_hub.download.assert_called_once()
146+
mock_hub.disconnect.assert_called_once()
134147
finally:
135148
os.unlink(temp_path)
136149

@@ -154,20 +167,27 @@ async def test_download_ssh(self):
154167
name="ev3dev.local",
155168
)
156169

157-
# Mock the hub creation
158-
with patch(
159-
"pybricksdev.connections.ev3dev.EV3Connection", return_value=mock_hub
160-
) as mock_hub_class:
161-
with patch("socket.gethostbyname", return_value="192.168.1.1"):
162-
# Run the command
163-
download = Download()
164-
await download.run(args)
170+
# Set up mocks using ExitStack
171+
with contextlib.ExitStack() as stack:
172+
mock_hub_class = stack.enter_context(
173+
patch(
174+
"pybricksdev.connections.ev3dev.EV3Connection",
175+
return_value=mock_hub,
176+
)
177+
)
178+
stack.enter_context(
179+
patch("socket.gethostbyname", return_value="192.168.1.1")
180+
)
181+
182+
# Run the command
183+
download = Download()
184+
await download.run(args)
165185

166-
# Verify the hub was created and used correctly
167-
mock_hub_class.assert_called_once_with("192.168.1.1")
168-
mock_hub.connect.assert_called_once()
169-
mock_hub.download.assert_called_once()
170-
mock_hub.disconnect.assert_called_once()
186+
# Verify the hub was created and used correctly
187+
mock_hub_class.assert_called_once_with("192.168.1.1")
188+
mock_hub.connect.assert_called_once()
189+
mock_hub.download.assert_called_once()
190+
mock_hub.disconnect.assert_called_once()
171191
finally:
172192
os.unlink(temp_path)
173193

@@ -214,22 +234,29 @@ async def test_download_stdin(self):
214234
name="MyHub",
215235
)
216236

217-
# Mock the hub creation and file handling
218-
with patch(
219-
"pybricksdev.connections.pybricks.PybricksHubBLE", return_value=mock_hub
220-
) as mock_hub_class:
221-
with patch("pybricksdev.ble.find_device", return_value="mock_device"):
222-
with patch("tempfile.NamedTemporaryFile") as mock_temp:
223-
mock_temp.return_value.__enter__.return_value.name = "/tmp/test.py"
224-
# Run the command
225-
download = Download()
226-
await download.run(args)
237+
# Set up mocks using ExitStack
238+
with contextlib.ExitStack() as stack:
239+
mock_hub_class = stack.enter_context(
240+
patch(
241+
"pybricksdev.connections.pybricks.PybricksHubBLE",
242+
return_value=mock_hub,
243+
)
244+
)
245+
stack.enter_context(
246+
patch("pybricksdev.ble.find_device", return_value="mock_device")
247+
)
248+
mock_temp = stack.enter_context(patch("tempfile.NamedTemporaryFile"))
249+
mock_temp.return_value.__enter__.return_value.name = "/tmp/test.py"
250+
251+
# Run the command
252+
download = Download()
253+
await download.run(args)
227254

228-
# Verify the hub was created and used correctly
229-
mock_hub_class.assert_called_once_with("mock_device")
230-
mock_hub.connect.assert_called_once()
231-
mock_hub.download.assert_called_once()
232-
mock_hub.disconnect.assert_called_once()
255+
# Verify the hub was created and used correctly
256+
mock_hub_class.assert_called_once_with("mock_device")
257+
mock_hub.connect.assert_called_once()
258+
mock_hub.download.assert_called_once()
259+
mock_hub.disconnect.assert_called_once()
233260

234261
@pytest.mark.asyncio
235262
async def test_download_connection_error(self):
@@ -251,24 +278,24 @@ async def test_download_connection_error(self):
251278
name="MyHub",
252279
)
253280

254-
# Mock the compilation
255-
mock_mpy = MagicMock()
256-
mock_mpy.__bytes__ = lambda self: b"compiled code"
257-
258-
# Mock the hub creation
259-
with patch(
260-
"pybricksdev.connections.pybricks.PybricksHubBLE", return_value=mock_hub
261-
):
262-
with patch("pybricksdev.ble.find_device", return_value="mock_device"):
263-
with patch(
264-
"pybricksdev.compile.compile_multi_file", return_value=mock_mpy
265-
):
266-
# Run the command and verify it raises the error
267-
download = Download()
268-
with pytest.raises(RuntimeError, match="Connection failed"):
269-
await download.run(args)
270-
271-
# Verify disconnect was not called since connection failed
272-
mock_hub.disconnect.assert_not_called()
281+
# Set up mocks using ExitStack
282+
with contextlib.ExitStack() as stack:
283+
stack.enter_context(
284+
patch(
285+
"pybricksdev.connections.pybricks.PybricksHubBLE",
286+
return_value=mock_hub,
287+
)
288+
)
289+
stack.enter_context(
290+
patch("pybricksdev.ble.find_device", return_value="mock_device")
291+
)
292+
293+
# Run the command and verify it raises the error
294+
download = Download()
295+
with pytest.raises(RuntimeError, match="Connection failed"):
296+
await download.run(args)
297+
298+
# Verify disconnect was not called since connection failed
299+
mock_hub.disconnect.assert_not_called()
273300
finally:
274301
os.unlink(temp_path)

0 commit comments

Comments
 (0)