Skip to content

Commit 63b21fd

Browse files
jpbedefrenck
authored andcommitted
Ensure response is fully read to prevent premature connection closure in rest command (home-assistant#148532)
1 parent d87379d commit 63b21fd

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

homeassistant/components/rest_command/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,11 @@ async def async_service_handler(service: ServiceCall) -> ServiceResponse:
178178
)
179179

180180
if not service.return_response:
181+
# always read the response to avoid closing the connection
182+
# before the server has finished sending it, while avoiding excessive memory usage
183+
async for _ in response.content.iter_chunked(1024):
184+
pass
185+
181186
return None
182187

183188
_content = None

tests/components/rest_command/test_init.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ async def test_rest_command_get_response_malformed_json(
326326

327327
aioclient_mock.get(
328328
TEST_URL,
329-
content='{"status": "failure", 42',
329+
content=b'{"status": "failure", 42',
330330
headers={"content-type": "application/json"},
331331
)
332332

@@ -379,3 +379,27 @@ async def test_rest_command_get_response_none(
379379
)
380380

381381
assert not response
382+
383+
384+
async def test_rest_command_response_iter_chunked(
385+
hass: HomeAssistant,
386+
setup_component: ComponentSetup,
387+
aioclient_mock: AiohttpClientMocker,
388+
) -> None:
389+
"""Ensure response is consumed when return_response is False."""
390+
await setup_component()
391+
392+
png = base64.decodebytes(
393+
b"iVBORw0KGgoAAAANSUhEUgAAAAIAAAABCAIAAAB7QOjdAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQ"
394+
b"UAAAAJcEhZcwAAFiUAABYlAUlSJPAAAAAPSURBVBhXY/h/ku////8AECAE1JZPvDAAAAAASUVORK5CYII="
395+
)
396+
aioclient_mock.get(TEST_URL, content=png)
397+
398+
with patch("aiohttp.StreamReader.iter_chunked", autospec=True) as mock_iter_chunked:
399+
response = await hass.services.async_call(DOMAIN, "get_test", {}, blocking=True)
400+
401+
# Ensure the response is not returned
402+
assert response is None
403+
404+
# Verify iter_chunked was called with a chunk size
405+
assert mock_iter_chunked.called

0 commit comments

Comments
 (0)