diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 7d5d02ee..84f73a64 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -625,6 +625,14 @@ async def bluetooth_device_connect( # pylint: disable=too-many-locals, too-many if self._debug_enabled: _LOGGER.debug("%s: Using connection version %s", address, request_type) + def on_bluetooth_connection_state_with_notify_cleanup( + connected: bool, mtu: int, error: int + ) -> None: + """Wrap connection state callback to clean up notify callbacks on disconnect.""" + if not connected: + self.bluetooth_gatt_stop_notify_for_address(address) + on_bluetooth_connection_state(connected, mtu, error) + unsub = self._get_connection().send_message_callback_response( BluetoothDeviceRequest( address=address, @@ -636,7 +644,7 @@ async def bluetooth_device_connect( # pylint: disable=too-many-locals, too-many on_bluetooth_device_connection_response, connect_future, address, - on_bluetooth_connection_state, + on_bluetooth_connection_state_with_notify_cleanup, ), (BluetoothDeviceConnectionResponse,), ) @@ -987,17 +995,44 @@ async def bluetooth_gatt_start_notify( remove_callback() raise + key = (address, handle) + self._notify_callbacks[key] = remove_callback + async def stop_notify() -> None: - if self._connection is None: - return + self.bluetooth_gatt_stop_notify(address, handle) + + def wrapped_remove_callback() -> None: + self._notify_callbacks.pop(key, None) + remove_callback() + + return stop_notify, wrapped_remove_callback + def bluetooth_gatt_stop_notify(self, address: int, handle: int) -> None: + """Stop a notify session for a GATT characteristic. + + This is a synchronous method that can be safely called from + exception handlers without awaiting. + """ + key = (address, handle) + if remove_callback := self._notify_callbacks.pop(key, None): remove_callback() + if self._connection is not None: self._connection.send_message( BluetoothGATTNotifyRequest(address=address, handle=handle, enable=False) ) - return stop_notify, remove_callback + def bluetooth_gatt_stop_notify_for_address(self, address: int) -> None: + """Stop all notify sessions for a Bluetooth device. + + This is a synchronous method that removes all notify callbacks + for a given address. It does not send disable messages since + this is typically called when the device has disconnected. + """ + keys_to_remove = [key for key in self._notify_callbacks if key[0] == address] + for key in keys_to_remove: + if remove_callback := self._notify_callbacks.pop(key, None): + remove_callback() def subscribe_home_assistant_states( self, diff --git a/aioesphomeapi/client_base.pxd b/aioesphomeapi/client_base.pxd index da1028d6..f73566d9 100644 --- a/aioesphomeapi/client_base.pxd +++ b/aioesphomeapi/client_base.pxd @@ -31,6 +31,7 @@ cdef class APIClientBase: cdef public APIConnection _connection cdef public bint _debug_enabled cdef public object _loop + cdef public dict _notify_callbacks cdef public ConnectionParams _params cdef public str cached_name cdef public str log_name diff --git a/aioesphomeapi/client_base.py b/aioesphomeapi/client_base.py index 4e3186dc..905bff3b 100644 --- a/aioesphomeapi/client_base.py +++ b/aioesphomeapi/client_base.py @@ -234,6 +234,7 @@ class APIClientBase: "_connection", "_debug_enabled", "_loop", + "_notify_callbacks", "_params", "cached_name", "log_name", @@ -299,6 +300,7 @@ def __init__( self._cached_device_info: DeviceInfo | None = None self.cached_name: str | None = None self._background_tasks: set[asyncio.Task[Any]] = set() + self._notify_callbacks: dict[tuple[int, int], Callable[[], None]] = {} self._loop = asyncio.get_running_loop() self._call_id_counter = itertools.count(1) self._set_log_name() diff --git a/tests/test_client.py b/tests/test_client.py index 345f29ff..d9077930 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2131,6 +2131,227 @@ def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None: assert connection.is_connected +async def test_bluetooth_gatt_stop_notify( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test bluetooth_gatt_stop_notify stops notify and removes callback.""" + client, connection, _transport, protocol = api_client + notifies = [] + + handlers_before = len(list(itertools.chain(*connection._message_handlers.values()))) + + def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None: + notifies.append((handle, data)) + + notify_task = asyncio.create_task( + client.bluetooth_gatt_start_notify(1234, 1, on_bluetooth_gatt_notify) + ) + await asyncio.sleep(0) + notify_response: message.Message = BluetoothGATTNotifyResponse( + address=1234, handle=1 + ) + mock_data_received(protocol, generate_plaintext_packet(notify_response)) + + await notify_task + + # Verify the callback is registered + assert (1234, 1) in client._notify_callbacks + + # Stop notify using the sync method + client.bluetooth_gatt_stop_notify(1234, 1) + + # Verify callback is removed + assert (1234, 1) not in client._notify_callbacks + + # Verify handlers are cleaned up + assert ( + len(list(itertools.chain(*connection._message_handlers.values()))) + == handlers_before + ) + + +async def test_bluetooth_gatt_start_notify_abort_callback_cleans_up( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test that the abort callback (second return value) cleans up _notify_callbacks.""" + client, connection, _transport, protocol = api_client + + handlers_before = len(list(itertools.chain(*connection._message_handlers.values()))) + + def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None: + pass + + notify_task = asyncio.create_task( + client.bluetooth_gatt_start_notify(1234, 1, on_bluetooth_gatt_notify) + ) + await asyncio.sleep(0) + mock_data_received( + protocol, + generate_plaintext_packet(BluetoothGATTNotifyResponse(address=1234, handle=1)), + ) + + _cancel_cb, abort_cb = await notify_task + + # Verify the callback is registered + assert (1234, 1) in client._notify_callbacks + + # Call abort callback directly (simulates connection lost scenario) + abort_cb() + + # Verify _notify_callbacks is cleaned up + assert (1234, 1) not in client._notify_callbacks + + # Verify handlers are cleaned up + assert ( + len(list(itertools.chain(*connection._message_handlers.values()))) + == handlers_before + ) + + +async def test_bluetooth_gatt_stop_notify_for_address( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test bluetooth_gatt_stop_notify_for_address stops all notifies for an address.""" + client, connection, _transport, protocol = api_client + + handlers_before = len(list(itertools.chain(*connection._message_handlers.values()))) + + def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None: + pass + + # Start multiple notifies for the same address + notify_task1 = asyncio.create_task( + client.bluetooth_gatt_start_notify(1234, 1, on_bluetooth_gatt_notify) + ) + await asyncio.sleep(0) + mock_data_received( + protocol, + generate_plaintext_packet(BluetoothGATTNotifyResponse(address=1234, handle=1)), + ) + await notify_task1 + + notify_task2 = asyncio.create_task( + client.bluetooth_gatt_start_notify(1234, 2, on_bluetooth_gatt_notify) + ) + await asyncio.sleep(0) + mock_data_received( + protocol, + generate_plaintext_packet(BluetoothGATTNotifyResponse(address=1234, handle=2)), + ) + await notify_task2 + + # Also start a notify for a different address + notify_task3 = asyncio.create_task( + client.bluetooth_gatt_start_notify(5678, 1, on_bluetooth_gatt_notify) + ) + await asyncio.sleep(0) + mock_data_received( + protocol, + generate_plaintext_packet(BluetoothGATTNotifyResponse(address=5678, handle=1)), + ) + await notify_task3 + + # Verify all callbacks are registered + assert (1234, 1) in client._notify_callbacks + assert (1234, 2) in client._notify_callbacks + assert (5678, 1) in client._notify_callbacks + + # Stop all notifies for address 1234 + client.bluetooth_gatt_stop_notify_for_address(1234) + + # Verify callbacks for 1234 are removed but 5678 remains + assert (1234, 1) not in client._notify_callbacks + assert (1234, 2) not in client._notify_callbacks + assert (5678, 1) in client._notify_callbacks + + # Clean up + client.bluetooth_gatt_stop_notify(5678, 1) + + # Verify all handlers are cleaned up + assert ( + len(list(itertools.chain(*connection._message_handlers.values()))) + == handlers_before + ) + + +async def test_bluetooth_device_connect_cleans_up_notify_on_disconnect( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], +) -> None: + """Test that notify callbacks are cleaned up when device disconnects.""" + client, _connection, _transport, protocol = api_client + + def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None: + pass + + # Start a notify + notify_task = asyncio.create_task( + client.bluetooth_gatt_start_notify(1234, 1, on_bluetooth_gatt_notify) + ) + await asyncio.sleep(0) + mock_data_received( + protocol, + generate_plaintext_packet(BluetoothGATTNotifyResponse(address=1234, handle=1)), + ) + await notify_task + + # Verify the callback is registered + assert (1234, 1) in client._notify_callbacks + + # Simulate device connecting (sets up the disconnect handler) + connection_states = [] + + def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None: + connection_states.append((connected, mtu, error)) + + connect_task = asyncio.create_task( + client.bluetooth_device_connect( + 1234, + on_bluetooth_connection_state, + feature_flags=BluetoothProxyFeature.REMOTE_CACHING, + address_type=0, + ) + ) + await asyncio.sleep(0) + + # Send connection response + mock_data_received( + protocol, + generate_plaintext_packet( + BluetoothDeviceConnectionResponse(address=1234, connected=True, mtu=500) + ), + ) + cancel = await connect_task + + # Verify connected state was received + assert connection_states == [(True, 500, 0)] + + # Now simulate disconnect + mock_data_received( + protocol, + generate_plaintext_packet( + BluetoothDeviceConnectionResponse(address=1234, connected=False, error=0) + ), + ) + await asyncio.sleep(0) + + # Verify disconnect state was received + assert connection_states == [(True, 500, 0), (False, 0, 0)] + + # Verify notify callback was cleaned up + assert (1234, 1) not in client._notify_callbacks + + # Clean up + cancel() + + async def test_subscribe_bluetooth_le_advertisements( api_client: tuple[ APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper