Skip to content

Commit 5b20f0f

Browse files
committed
Refactor the code to move the temporary file creation and cleanup into the ExitStack context manager.
1 parent 555a3d1 commit 5b20f0f

File tree

1 file changed

+95
-93
lines changed

1 file changed

+95
-93
lines changed

tests/test_cli.py

Lines changed: 95 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -67,42 +67,42 @@ async def test_download_ble(self):
6767
mock_hub._mpy_abi_version = 6
6868
mock_hub.download = AsyncMock()
6969

70-
# Create a temporary file
71-
with tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False) as temp:
70+
# Set up mocks using ExitStack
71+
with contextlib.ExitStack() as stack:
72+
# Create and manage temporary file
73+
temp = stack.enter_context(
74+
tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False)
75+
)
7276
temp.write("print('test')")
7377
temp_path = temp.name
78+
stack.callback(os.unlink, temp_path)
7479

75-
try:
7680
# Create args
7781
args = argparse.Namespace(
7882
conntype="ble",
7983
file=open(temp_path, "r"),
8084
name="MyHub",
8185
)
8286

83-
# Set up mocks using ExitStack
84-
with contextlib.ExitStack() as stack:
85-
mock_hub_class = stack.enter_context(
86-
patch(
87-
"pybricksdev.connections.pybricks.PybricksHubBLE",
88-
return_value=mock_hub,
89-
)
90-
)
91-
stack.enter_context(
92-
patch("pybricksdev.ble.find_device", return_value="mock_device")
87+
mock_hub_class = stack.enter_context(
88+
patch(
89+
"pybricksdev.connections.pybricks.PybricksHubBLE",
90+
return_value=mock_hub,
9391
)
92+
)
93+
stack.enter_context(
94+
patch("pybricksdev.ble.find_device", return_value="mock_device")
95+
)
9496

95-
# Run the command
96-
download = Download()
97-
await download.run(args)
97+
# Run the command
98+
download = Download()
99+
await download.run(args)
98100

99-
# Verify the hub was created and used correctly
100-
mock_hub_class.assert_called_once_with("mock_device")
101-
mock_hub.connect.assert_called_once()
102-
mock_hub.download.assert_called_once()
103-
mock_hub.disconnect.assert_called_once()
104-
finally:
105-
os.unlink(temp_path)
101+
# Verify the hub was created and used correctly
102+
mock_hub_class.assert_called_once_with("mock_device")
103+
mock_hub.connect.assert_called_once()
104+
mock_hub.download.assert_called_once()
105+
mock_hub.disconnect.assert_called_once()
106106

107107
@pytest.mark.asyncio
108108
async def test_download_usb(self):
@@ -112,40 +112,40 @@ async def test_download_usb(self):
112112
mock_hub._mpy_abi_version = 6
113113
mock_hub.download = AsyncMock()
114114

115-
# Create a temporary file
116-
with tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False) as temp:
115+
# Set up mocks using ExitStack
116+
with contextlib.ExitStack() as stack:
117+
# Create and manage temporary file
118+
temp = stack.enter_context(
119+
tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False)
120+
)
117121
temp.write("print('test')")
118122
temp_path = temp.name
123+
stack.callback(os.unlink, temp_path)
119124

120-
try:
121125
# Create args
122126
args = argparse.Namespace(
123127
conntype="usb",
124128
file=open(temp_path, "r"),
125129
name=None,
126130
)
127131

128-
# Set up mocks using ExitStack
129-
with contextlib.ExitStack() as stack:
130-
mock_hub_class = stack.enter_context(
131-
patch(
132-
"pybricksdev.connections.pybricks.PybricksHubUSB",
133-
return_value=mock_hub,
134-
)
132+
mock_hub_class = stack.enter_context(
133+
patch(
134+
"pybricksdev.connections.pybricks.PybricksHubUSB",
135+
return_value=mock_hub,
135136
)
136-
stack.enter_context(patch("usb.core.find", return_value="mock_device"))
137+
)
138+
stack.enter_context(patch("usb.core.find", return_value="mock_device"))
137139

138-
# Run the command
139-
download = Download()
140-
await download.run(args)
140+
# Run the command
141+
download = Download()
142+
await download.run(args)
141143

142-
# Verify the hub was created and used correctly
143-
mock_hub_class.assert_called_once_with("mock_device")
144-
mock_hub.connect.assert_called_once()
145-
mock_hub.download.assert_called_once()
146-
mock_hub.disconnect.assert_called_once()
147-
finally:
148-
os.unlink(temp_path)
144+
# Verify the hub was created and used correctly
145+
mock_hub_class.assert_called_once_with("mock_device")
146+
mock_hub.connect.assert_called_once()
147+
mock_hub.download.assert_called_once()
148+
mock_hub.disconnect.assert_called_once()
149149

150150
@pytest.mark.asyncio
151151
async def test_download_ssh(self):
@@ -154,52 +154,56 @@ async def test_download_ssh(self):
154154
mock_hub = AsyncMock()
155155
mock_hub.download = AsyncMock()
156156

157-
# Create a temporary file
158-
with tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False) as temp:
157+
# Set up mocks using ExitStack
158+
with contextlib.ExitStack() as stack:
159+
# Create and manage temporary file
160+
temp = stack.enter_context(
161+
tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False)
162+
)
159163
temp.write("print('test')")
160164
temp_path = temp.name
165+
stack.callback(os.unlink, temp_path)
161166

162-
try:
163167
# Create args
164168
args = argparse.Namespace(
165169
conntype="ssh",
166170
file=open(temp_path, "r"),
167171
name="ev3dev.local",
168172
)
169173

170-
# Set up mocks using ExitStack
171-
with contextlib.ExitStack() as stack:
172-
mock_hub_class = stack.enter_context(
173-
patch(
174-
"pybricksdev.connections.ev3dev.EV3Connection",
175-
return_value=mock_hub,
176-
)
177-
)
178-
stack.enter_context(
179-
patch("socket.gethostbyname", return_value="192.168.1.1")
174+
mock_hub_class = stack.enter_context(
175+
patch(
176+
"pybricksdev.connections.ev3dev.EV3Connection",
177+
return_value=mock_hub,
180178
)
179+
)
180+
stack.enter_context(
181+
patch("socket.gethostbyname", return_value="192.168.1.1")
182+
)
181183

182-
# Run the command
183-
download = Download()
184-
await download.run(args)
184+
# Run the command
185+
download = Download()
186+
await download.run(args)
185187

186-
# Verify the hub was created and used correctly
187-
mock_hub_class.assert_called_once_with("192.168.1.1")
188-
mock_hub.connect.assert_called_once()
189-
mock_hub.download.assert_called_once()
190-
mock_hub.disconnect.assert_called_once()
191-
finally:
192-
os.unlink(temp_path)
188+
# Verify the hub was created and used correctly
189+
mock_hub_class.assert_called_once_with("192.168.1.1")
190+
mock_hub.connect.assert_called_once()
191+
mock_hub.download.assert_called_once()
192+
mock_hub.disconnect.assert_called_once()
193193

194194
@pytest.mark.asyncio
195195
async def test_download_ssh_no_name(self):
196196
"""Test that SSH connection requires a name."""
197-
# Create a temporary file
198-
with tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False) as temp:
197+
# Set up mocks using ExitStack
198+
with contextlib.ExitStack() as stack:
199+
# Create and manage temporary file
200+
temp = stack.enter_context(
201+
tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False)
202+
)
199203
temp.write("print('test')")
200204
temp_path = temp.name
205+
stack.callback(os.unlink, temp_path)
201206

202-
try:
203207
# Create args without name
204208
args = argparse.Namespace(
205209
conntype="ssh",
@@ -211,8 +215,6 @@ async def test_download_ssh_no_name(self):
211215
download = Download()
212216
with pytest.raises(SystemExit, match="1"):
213217
await download.run(args)
214-
finally:
215-
os.unlink(temp_path)
216218

217219
@pytest.mark.asyncio
218220
async def test_download_stdin(self):
@@ -265,37 +267,37 @@ async def test_download_connection_error(self):
265267
mock_hub = AsyncMock()
266268
mock_hub.connect.side_effect = RuntimeError("Connection failed")
267269

268-
# Create a temporary file
269-
with tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False) as temp:
270+
# Set up mocks using ExitStack
271+
with contextlib.ExitStack() as stack:
272+
# Create and manage temporary file
273+
temp = stack.enter_context(
274+
tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False)
275+
)
270276
temp.write("print('test')")
271277
temp_path = temp.name
278+
stack.callback(os.unlink, temp_path)
272279

273-
try:
274280
# Create args
275281
args = argparse.Namespace(
276282
conntype="ble",
277283
file=open(temp_path, "r"),
278284
name="MyHub",
279285
)
280286

281-
# Set up mocks using ExitStack
282-
with contextlib.ExitStack() as stack:
283-
stack.enter_context(
284-
patch(
285-
"pybricksdev.connections.pybricks.PybricksHubBLE",
286-
return_value=mock_hub,
287-
)
288-
)
289-
stack.enter_context(
290-
patch("pybricksdev.ble.find_device", return_value="mock_device")
287+
stack.enter_context(
288+
patch(
289+
"pybricksdev.connections.pybricks.PybricksHubBLE",
290+
return_value=mock_hub,
291291
)
292+
)
293+
stack.enter_context(
294+
patch("pybricksdev.ble.find_device", return_value="mock_device")
295+
)
292296

293-
# Run the command and verify it raises the error
294-
download = Download()
295-
with pytest.raises(RuntimeError, match="Connection failed"):
296-
await download.run(args)
297+
# Run the command and verify it raises the error
298+
download = Download()
299+
with pytest.raises(RuntimeError, match="Connection failed"):
300+
await download.run(args)
297301

298-
# Verify disconnect was not called since connection failed
299-
mock_hub.disconnect.assert_not_called()
300-
finally:
301-
os.unlink(temp_path)
302+
# Verify disconnect was not called since connection failed
303+
mock_hub.disconnect.assert_not_called()

0 commit comments

Comments
 (0)