Skip to content

Commit cd73824

Browse files
authored
Ensure response is fully read to prevent premature connection closure in rest command (home-assistant#148532)
1 parent 32121a0 commit cd73824

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

homeassistant/components/rest_command/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,11 @@ async def async_service_handler(service: ServiceCall) -> ServiceResponse:
178178
)
179179

180180
if not service.return_response:
181+
# always read the response to avoid closing the connection
182+
# before the server has finished sending it, while avoiding excessive memory usage
183+
async for _ in response.content.iter_chunked(1024):
184+
pass
185+
181186
return None
182187

183188
_content = None

tests/components/rest_command/test_init.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ async def test_rest_command_get_response_malformed_json(
328328

329329
aioclient_mock.get(
330330
TEST_URL,
331-
content='{"status": "failure", 42',
331+
content=b'{"status": "failure", 42',
332332
headers={"content-type": "application/json"},
333333
)
334334

@@ -381,3 +381,27 @@ async def test_rest_command_get_response_none(
381381
)
382382

383383
assert not response
384+
385+
386+
async def test_rest_command_response_iter_chunked(
387+
hass: HomeAssistant,
388+
setup_component: ComponentSetup,
389+
aioclient_mock: AiohttpClientMocker,
390+
) -> None:
391+
"""Ensure response is consumed when return_response is False."""
392+
await setup_component()
393+
394+
png = base64.decodebytes(
395+
b"iVBORw0KGgoAAAANSUhEUgAAAAIAAAABCAIAAAB7QOjdAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQ"
396+
b"UAAAAJcEhZcwAAFiUAABYlAUlSJPAAAAAPSURBVBhXY/h/ku////8AECAE1JZPvDAAAAAASUVORK5CYII="
397+
)
398+
aioclient_mock.get(TEST_URL, content=png)
399+
400+
with patch("aiohttp.StreamReader.iter_chunked", autospec=True) as mock_iter_chunked:
401+
response = await hass.services.async_call(DOMAIN, "get_test", {}, blocking=True)
402+
403+
# Ensure the response is not returned
404+
assert response is None
405+
406+
# Verify iter_chunked was called with a chunk size
407+
assert mock_iter_chunked.called

0 commit comments

Comments
 (0)