Skip to content

Commit 0a7a591

Browse files
authored
Split connection process to enable earlier zeroconf listening (#1205)
1 parent 569370a commit 0a7a591

File tree

9 files changed

+311
-87
lines changed

9 files changed

+311
-87
lines changed

aioesphomeapi/client.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ async def connect(
206206
login: bool = False,
207207
) -> None:
208208
"""Connect to the device."""
209-
await self.start_connection(on_stop)
209+
await self.start_resolve_host(on_stop)
210+
await self.start_connection()
210211
await self.finish_connection(login)
211212

212213
def _on_stop(
@@ -219,11 +220,11 @@ def _on_stop(
219220
if on_stop:
220221
self._create_background_task(on_stop(expected_disconnect))
221222

222-
async def start_connection(
223+
async def start_resolve_host(
223224
self,
224225
on_stop: Callable[[bool], Coroutine[Any, Any, None]] | None = None,
225226
) -> None:
226-
"""Start connecting to the device."""
227+
"""Start resolving the host."""
227228
if self._connection is not None:
228229
raise APIConnectionError(f"Already connected to {self.log_name}!")
229230
self._connection = APIConnection(
@@ -232,6 +233,12 @@ async def start_connection(
232233
self._debug_enabled,
233234
self.log_name,
234235
)
236+
await self._execute_connection_coro(self._connection.start_resolve_host())
237+
238+
async def start_connection(self) -> None:
239+
"""Start connecting to the device."""
240+
if TYPE_CHECKING:
241+
assert self._connection is not None
235242
await self._execute_connection_coro(self._connection.start_connection())
236243
# If we connected, we should set the log name now
237244
if self._connection.connected_address:

aioesphomeapi/connection.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ cdef object SocketClosedAPIError
5353
cdef object astuple
5454

5555
cdef object CONNECTION_STATE_INITIALIZED
56+
cdef object CONNECTION_STATE_HOST_RESOLVED
5657
cdef object CONNECTION_STATE_SOCKET_OPENED
5758
cdef object CONNECTION_STATE_HANDSHAKE_COMPLETE
5859
cdef object CONNECTION_STATE_CONNECTED
@@ -104,6 +105,7 @@ cdef class APIConnection:
104105
cdef object _pong_timer
105106
cdef float _keep_alive_interval
106107
cdef float _keep_alive_timeout
108+
cdef object _resolve_host_future
107109
cdef object _start_connect_future
108110
cdef object _finish_connect_future
109111
cdef public Exception _fatal_exception
@@ -115,6 +117,7 @@ cdef class APIConnection:
115117
cdef bint _debug_enabled
116118
cdef public str received_name
117119
cdef public str connected_address
120+
cdef list _addrs_info
118121

119122
cpdef void send_message(self, object msg) except *
120123

aioesphomeapi/connection.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,19 @@ class ConnectionParams:
128128
class ConnectionState(enum.Enum):
129129
# The connection is initialized, but connect() wasn't called yet
130130
INITIALIZED = 0
131+
# The host has been resolved, but the socket hasn't been opened yet
132+
HOST_RESOLVED = 1
131133
# The socket has been opened, but the handshake and login haven't been completed
132-
SOCKET_OPENED = 1
134+
SOCKET_OPENED = 2
133135
# The handshake has been completed, messages can be exchanged
134-
HANDSHAKE_COMPLETE = 2
136+
HANDSHAKE_COMPLETE = 3
135137
# The connection has been established, authenticated data can be exchanged
136-
CONNECTED = 3
137-
CLOSED = 4
138+
CONNECTED = 4
139+
CLOSED = 5
138140

139141

140142
CONNECTION_STATE_INITIALIZED = ConnectionState.INITIALIZED
143+
CONNECTION_STATE_HOST_RESOLVED = ConnectionState.HOST_RESOLVED
141144
CONNECTION_STATE_SOCKET_OPENED = ConnectionState.SOCKET_OPENED
142145
CONNECTION_STATE_HANDSHAKE_COMPLETE = ConnectionState.HANDSHAKE_COMPLETE
143146
CONNECTION_STATE_CONNECTED = ConnectionState.CONNECTED
@@ -192,6 +195,7 @@ class APIConnection:
192195
"""
193196

194197
__slots__ = (
198+
"_addrs_info",
195199
"_debug_enabled",
196200
"_expected_disconnect",
197201
"_fatal_exception",
@@ -206,6 +210,7 @@ class APIConnection:
206210
"_ping_timer",
207211
"_pong_timer",
208212
"_read_exception_futures",
213+
"_resolve_host_future",
209214
"_send_pending_ping",
210215
"_socket",
211216
"_start_connect_future",
@@ -247,6 +252,7 @@ def __init__(
247252
self._keep_alive_interval = keepalive
248253
self._keep_alive_timeout = keepalive * KEEP_ALIVE_TIMEOUT_RATIO
249254

255+
self._resolve_host_future: asyncio.Future[None] | None = None
250256
self._start_connect_future: asyncio.Future[None] | None = None
251257
self._finish_connect_future: asyncio.Future[None] | None = None
252258
self._fatal_exception: Exception | None = None
@@ -258,6 +264,7 @@ def __init__(
258264
self._debug_enabled = debug_enabled
259265
self.received_name: str = ""
260266
self.connected_address: str | None = None
267+
self._addrs_info: list[hr.AddrInfo] = []
261268

262269
def set_log_name(self, name: str) -> None:
263270
"""Set the friendly log name for this connection."""
@@ -286,6 +293,7 @@ def _cleanup(self) -> None:
286293
fut.set_exception(new_exc)
287294
self._read_exception_futures.clear()
288295

296+
self._set_resolve_host_future()
289297
self._set_start_connect_future()
290298
self._set_finish_connect_future()
291299

@@ -572,32 +580,61 @@ def _async_pong_not_received(self) -> None:
572580
)
573581
)
574582

575-
async def _do_connect(self) -> None:
576-
"""Do the actual connect process."""
577-
addrs_info = await hr.async_resolve_host(
578-
self._params.addresses,
579-
self._params.port,
580-
self._params.zeroconf_manager,
581-
)
582-
await self._connect_socket_connect(addrs_info)
583+
async def start_resolve_host(self) -> None:
584+
"""Start the host resolution process.
585+
586+
This part of the process resolves the hostnames to IP addresses
587+
and prepares the connection for the next step.
588+
"""
589+
if self.connection_state is not CONNECTION_STATE_INITIALIZED:
590+
raise RuntimeError(
591+
"Connection can only be used once, connection is not in init state"
592+
)
593+
594+
self._resolve_host_future = self._loop.create_future()
595+
try:
596+
async with interrupt(
597+
self._resolve_host_future, ConnectionInterruptedError, None
598+
):
599+
self._addrs_info = await hr.async_resolve_host(
600+
self._params.addresses,
601+
self._params.port,
602+
self._params.zeroconf_manager,
603+
)
604+
except (Exception, CancelledError) as ex:
605+
# If the task was cancelled, we need to clean up the connection
606+
# and raise the CancelledError as APIConnectionError
607+
self._cleanup()
608+
raise self._wrap_fatal_connection_exception("resolving", ex)
609+
finally:
610+
self._set_resolve_host_future()
611+
self._set_connection_state(CONNECTION_STATE_HOST_RESOLVED)
612+
613+
def _set_resolve_host_future(self) -> None:
614+
if (
615+
self._resolve_host_future is not None
616+
and not self._resolve_host_future.done()
617+
):
618+
self._resolve_host_future.set_result(None)
619+
self._resolve_host_future = None
583620

584621
async def start_connection(self) -> None:
585622
"""Start the connection process.
586623
587624
This part of the process establishes the socket connection but
588625
does not initialize the frame helper or send the hello message.
589626
"""
590-
if self.connection_state is not CONNECTION_STATE_INITIALIZED:
627+
if self.connection_state is not CONNECTION_STATE_HOST_RESOLVED:
591628
raise RuntimeError(
592-
"Connection can only be used once, connection is not in init state"
629+
"Connection must be in HOST_RESOLVED state to start connection"
593630
)
594631

595632
self._start_connect_future = self._loop.create_future()
596633
try:
597634
async with interrupt(
598635
self._start_connect_future, ConnectionInterruptedError, None
599636
):
600-
await self._do_connect()
637+
await self._connect_socket_connect(self._addrs_info)
601638
except (Exception, CancelledError) as ex:
602639
# If the task was cancelled, we need to clean up the connection
603640
# and raise the CancelledError as APIConnectionError

aioesphomeapi/reconnect_logic.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,17 @@
3434

3535

3636
class ReconnectLogicState(Enum):
37-
CONNECTING = 0
38-
HANDSHAKING = 1
39-
READY = 2
40-
DISCONNECTED = 3
37+
RESOLVING = 0
38+
CONNECTING = 1
39+
HANDSHAKING = 2
40+
READY = 3
41+
DISCONNECTED = 4
4142

4243

4344
NOT_YET_CONNECTED_STATES = {
4445
ReconnectLogicState.DISCONNECTED,
4546
ReconnectLogicState.CONNECTING,
47+
ReconnectLogicState.RESOLVING,
4648
}
4749

4850

@@ -187,10 +189,21 @@ def _async_log_connection_error(self, err: Exception) -> None:
187189

188190
async def _try_connect(self) -> bool:
189191
"""Try connecting to the API client."""
192+
self._async_set_connection_state_while_locked(ReconnectLogicState.RESOLVING)
193+
start_resolve_time = time.perf_counter()
194+
try:
195+
await self._cli.start_resolve_host(on_stop=self._on_disconnect)
196+
except Exception as err: # pylint: disable=broad-except
197+
await self._handle_connection_failure(err)
198+
return False
190199
self._async_set_connection_state_while_locked(ReconnectLogicState.CONNECTING)
191200
start_connect_time = time.perf_counter()
201+
resolve_time = start_connect_time - start_resolve_time
202+
_LOGGER.info(
203+
"Successfully resolved %s in %0.3fs", self._cli.log_name, resolve_time
204+
)
192205
try:
193-
await self._cli.start_connection(on_stop=self._on_disconnect)
206+
await self._cli.start_connection()
194207
except Exception as err: # pylint: disable=broad-except
195208
await self._handle_connection_failure(err)
196209
return False
@@ -248,7 +261,13 @@ def _call_connect_once(self) -> None:
248261
if self._connect_task and not self._connect_task.done():
249262
if self._connection_state != ReconnectLogicState.CONNECTING:
250263
# Connection state is far enough along that we should
251-
# not restart the connect task
264+
# not restart the connect task.
265+
#
266+
# Zeroconf triggering scenarios:
267+
# - RESOLVING state: Don't cancel, the resolve task will complete immediately
268+
# since it's waiting for the same records zeroconf is delivering
269+
# - CONNECTING state: Cancel and restart to use potentially updated connection info
270+
# - HANDSHAKING state or later: Don't cancel, too far along in the process
252271
_LOGGER.debug(
253272
"%s: Not cancelling existing connect task as its already %s!",
254273
self._cli.log_name,
@@ -300,6 +319,7 @@ async def _connect_once_or_reschedule(self) -> None:
300319
or self._is_stopped
301320
):
302321
return
322+
self._start_zc_listen()
303323
if await self._try_connect():
304324
return
305325
tries = min(self._tries, 10) # prevent OverflowError
@@ -309,9 +329,6 @@ async def _connect_once_or_reschedule(self) -> None:
309329
"Trying to connect to %s in the background", self._cli.log_name
310330
)
311331
_LOGGER.debug("Retrying %s in %.2f seconds", self._cli.log_name, wait_time)
312-
if wait_time:
313-
# If we are waiting, start listening for mDNS records
314-
self._start_zc_listen()
315332
self._schedule_connect(wait_time)
316333

317334
def _remove_stop_task(self, _fut: asyncio.Future[None]) -> None:

tests/common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def async_fire_time_changed(
146146

147147
async def connect(conn: APIConnection, login: bool = True):
148148
"""Wrapper for connection logic to do both parts."""
149+
await conn.start_resolve_host()
149150
await conn.start_connection()
150151
await conn.finish_connection(login=login)
151152

@@ -156,7 +157,8 @@ async def connect_client(
156157
on_stop: Callable[[bool], Awaitable[None]] | None = None,
157158
) -> None:
158159
"""Wrapper for connection logic to do both parts."""
159-
await client.start_connection(on_stop=on_stop)
160+
await client.start_resolve_host(on_stop=on_stop)
161+
await client.start_connection()
160162
await client.finish_connection(login=login)
161163

162164

tests/test_client.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,16 +196,18 @@ async def test_expected_name(auth_client: APIClient) -> None:
196196

197197

198198
async def test_connect_backwards_compat() -> None:
199-
"""Verify connect is a thin wrapper around start_connection and finish_connection."""
199+
"""Verify connect is a thin wrapper around start_resolve_host, start_connection and finish_connection."""
200200

201201
cli = PatchableAPIClient("host", 1234, None)
202202
with (
203+
patch.object(cli, "start_resolve_host") as mock_start_resolve_host,
203204
patch.object(cli, "start_connection") as mock_start_connection,
204205
patch.object(cli, "finish_connection") as mock_finish_connection,
205206
):
206207
await cli.connect()
207208

208-
assert mock_start_connection.mock_calls == [call(None)]
209+
assert mock_start_resolve_host.mock_calls == [call(None)]
210+
assert mock_start_connection.mock_calls == [call()]
209211
assert mock_finish_connection.mock_calls == [call(False)]
210212

211213

@@ -216,6 +218,7 @@ async def test_finish_connection_wraps_exceptions_as_unhandled_api_error(
216218

217219
cli = APIClient("127.0.0.1", 1234, None)
218220
with patch("aioesphomeapi.client.APIConnection", PatchableAPIConnection):
221+
await cli.start_resolve_host()
219222
await cli.start_connection()
220223

221224
with (
@@ -244,6 +247,7 @@ async def _start_connection_with_delay(*args, **kwargs):
244247
"aioesphomeapi.connection.aiohappyeyeballs.start_connection",
245248
_start_connection_with_delay,
246249
):
250+
await cli.start_resolve_host()
247251
start_task = asyncio.create_task(cli.start_connection())
248252
await asyncio.sleep(0)
249253
assert cli._connection is not None
@@ -265,6 +269,7 @@ async def _start_connection_without_delay(*args, **kwargs):
265269
_start_connection_without_delay,
266270
),
267271
):
272+
await cli.start_resolve_host()
268273
await cli.start_connection()
269274
await asyncio.sleep(0)
270275

@@ -303,7 +308,7 @@ async def test_request_while_handshaking() -> None:
303308
async def test_connect_while_already_connected(auth_client: APIClient) -> None:
304309
"""Test connecting while already connected raises."""
305310
with pytest.raises(APIConnectionError):
306-
await auth_client.start_connection()
311+
await auth_client.start_resolve_host()
307312

308313

309314
@pytest.mark.parametrize(
@@ -1066,7 +1071,11 @@ async def test_noise_psk_handles_subclassed_string():
10661071
)
10671072
assert rl._connection_state is ReconnectLogicState.DISCONNECTED
10681073

1069-
with patch.object(cli, "start_connection"), patch.object(cli, "finish_connection"):
1074+
with (
1075+
patch.object(cli, "start_resolve_host"),
1076+
patch.object(cli, "start_connection"),
1077+
patch.object(cli, "finish_connection"),
1078+
):
10701079
await rl.start()
10711080
for _ in range(3):
10721081
await asyncio.sleep(0)

tests/test_connection.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,7 @@ async def test_connection_cannot_be_reused(
10371037
send_plaintext_connect_response(protocol, False)
10381038
await connect_task
10391039
with pytest.raises(RuntimeError):
1040-
await conn.start_connection()
1040+
await conn.start_resolve_host()
10411041

10421042

10431043
async def test_attempting_to_finish_unstarted_connection(
@@ -1048,6 +1048,17 @@ async def test_attempting_to_finish_unstarted_connection(
10481048
await conn.finish_connection(login=False)
10491049

10501050

1051+
async def test_start_connection_wrong_state(
1052+
conn: APIConnection,
1053+
) -> None:
1054+
"""Test that we raise when trying to start connection in wrong state."""
1055+
with pytest.raises(
1056+
RuntimeError,
1057+
match="Connection must be in HOST_RESOLVED state to start connection",
1058+
):
1059+
await conn.start_connection()
1060+
1061+
10511062
async def test_internal_message_received_immediately_after_connection(
10521063
conn: APIConnection,
10531064
resolve_host: AsyncMock,

0 commit comments

Comments
 (0)