diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd index 20a703af..636857ce 100644 --- a/aioesphomeapi/connection.pxd +++ b/aioesphomeapi/connection.pxd @@ -113,6 +113,7 @@ cdef class APIConnection: cdef bint _send_pending_ping cdef public bint is_connected cdef bint _handshake_complete + cdef bint _initial_time_sent cdef bint _debug_enabled cdef public str received_name cdef public str connected_address diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 4d34c047..ea3b87e4 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -152,6 +152,13 @@ def _make_hello_request(client_info: str) -> HelloRequest: ) +def _make_time_response() -> GetTimeResponse: + """Create a GetTimeResponse.""" + resp = GetTimeResponse() + resp.epoch_seconds = int(time.time()) + return resp + + _cached_make_hello_request = lru_cache(maxsize=16)(_make_hello_request) make_hello_request = _cached_make_hello_request @@ -200,6 +207,7 @@ class APIConnection: "_finish_connect_future", "_frame_helper", "_handshake_complete", + "_initial_time_sent", "_keep_alive_interval", "_keep_alive_timeout", "_log_errors", @@ -261,6 +269,7 @@ def __init__( self._loop = asyncio.get_running_loop() self.is_connected = False self._handshake_complete = False + self._initial_time_sent = False self._debug_enabled = debug_enabled self.received_name: str = "" self.connected_address: str | None = None @@ -476,6 +485,13 @@ async def _connect_hello_login(self, login: bool) -> None: # the device has a password but we don't expect it msg_types.append(ConnectResponse) + # Send a GetTimeResponse proactively to reduce latency during reconnect. + # This avoids an additional round-trip for the GetTimeRequest. + # If the device doesn't have Home Assistant time enabled, it will + # simply ignore this response, but since it's included in the same + # packet, it's nearly free to send and reduces pressure during reconnect. + messages.append(_make_time_response()) + self._initial_time_sent = True responses = await self.send_messages_await_response_complex( tuple(messages), None, @@ -1061,9 +1077,11 @@ def _handle_get_time_request_internal( # pylint: disable=unused-argument self, _msg: GetTimeRequest ) -> None: """Handle a GetTimeRequest.""" - resp = GetTimeResponse() - resp.epoch_seconds = int(time.time()) - self.send_messages((resp,)) + if self._initial_time_sent: + # Ignore the first time request since we already sent it proactively + self._initial_time_sent = False + return + self.send_messages((_make_time_response(),)) async def disconnect(self) -> None: """Disconnect from the API.""" diff --git a/tests/common.py b/tests/common.py index fb66bb0c..fc10bde6 100644 --- a/tests/common.py +++ b/tests/common.py @@ -23,6 +23,7 @@ from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper from aioesphomeapi.api_pb2 import ( ConnectResponse, + GetTimeRequest, HelloResponse, PingRequest, PingResponse, @@ -192,6 +193,11 @@ def send_ping_request(protocol: APIPlaintextFrameHelper) -> None: protocol.data_received(generate_plaintext_packet(ping_request)) +def send_time_request(protocol: APIPlaintextFrameHelper) -> None: + time_request: message.Message = GetTimeRequest() + protocol.data_received(generate_plaintext_packet(time_request)) + + def get_mock_protocol(conn: APIConnection): protocol = APIPlaintextFrameHelper( connection=conn, diff --git a/tests/test_connection.py b/tests/test_connection.py index 84557ec5..4ef4d89d 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -52,6 +52,7 @@ send_ping_response, send_plaintext_connect_response, send_plaintext_hello, + send_time_request, utcnow, ) @@ -1291,3 +1292,66 @@ async def test_report_fatal_error_with_log_errors_false( # Verify the error is still stored internally assert conn._fatal_exception is regular_error + + +async def test_time_request_response( + plaintext_connect_task_with_login: tuple[ + APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task + ], +) -> None: + """Test that GetTimeResponse is sent proactively and first request is ignored.""" + conn, transport, protocol, connect_task = plaintext_connect_task_with_login + + # Verify that GetTimeResponse is sent proactively during initial handshake + # This happens before we even receive HelloResponse/ConnectResponse + initial_calls = transport.writelines.call_args_list + # Find the initial handshake packet that should contain Hello, Connect, and GetTimeResponse + handshake_found = False + for call_args in initial_calls: + full_data = b"".join(call_args[0][0]) + # Check if this packet contains GetTimeResponse (message type 0x25) + if b"\x25" in full_data: + handshake_found = True + break + assert handshake_found, "GetTimeResponse was not sent proactively during handshake" + + send_plaintext_hello(protocol) + send_plaintext_connect_response(protocol, False) + + await connect_task + assert conn.is_connected + + # Reset transport mock to check what gets sent after connection + transport.reset_mock() + + # Send first GetTimeRequest - this should be ignored since we sent time proactively + send_time_request(protocol) + await asyncio.sleep(0) + + # Verify no response was sent for the first request + assert transport.writelines.call_count == 0 + + # Send second GetTimeRequest - this should be answered + send_time_request(protocol) + await asyncio.sleep(0) + + # Verify GetTimeResponse was sent + assert transport.writelines.call_count == 1 + # GetTimeResponse message type is 37 (0x25) + # writelines is called with a list of bytes, check that we have the right message type + call_args = transport.writelines.call_args_list[0][0][0] + # Join all the bytes together to check + full_data = b"".join(call_args) + # Message type 37 is 0x25 + assert b"\x25" in full_data + + # Send third GetTimeRequest - this should also be answered + transport.reset_mock() + send_time_request(protocol) + await asyncio.sleep(0) + + # Verify another GetTimeResponse was sent + assert transport.writelines.call_count == 1 + call_args = transport.writelines.call_args_list[0][0][0] + full_data = b"".join(call_args) + assert b"\x25" in full_data