Skip to content

Commit 1c0dd02

Browse files
authored
Abort USB discovery flows on device unplug (home-assistant#156303)
1 parent c414938 commit 1c0dd02

File tree

2 files changed

+204
-189
lines changed

2 files changed

+204
-189
lines changed

homeassistant/components/usb/__init__.py

Lines changed: 26 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import asyncio
66
from collections.abc import Callable, Coroutine, Sequence
7-
import dataclasses
87
from datetime import datetime, timedelta
98
from functools import partial
109
import logging
@@ -45,7 +44,7 @@
4544
usb_device_from_path, # noqa: F401
4645
usb_device_from_port, # noqa: F401
4746
usb_device_matches_matcher,
48-
usb_service_info_from_device, # noqa: F401
47+
usb_service_info_from_device,
4948
usb_unique_id_from_service_info, # noqa: F401
5049
)
5150

@@ -59,7 +58,6 @@
5958

6059
__all__ = [
6160
"USBCallbackMatcher",
62-
"async_is_plugged_in",
6361
"async_register_port_event_callback",
6462
"async_register_scan_request_callback",
6563
]
@@ -101,51 +99,6 @@ def async_register_port_event_callback(
10199
return discovery.async_register_port_event_callback(callback)
102100

103101

104-
@hass_callback
105-
def async_is_plugged_in(hass: HomeAssistant, matcher: USBCallbackMatcher) -> bool:
106-
"""Return True is a USB device is present."""
107-
108-
vid = matcher.get("vid", "")
109-
pid = matcher.get("pid", "")
110-
serial_number = matcher.get("serial_number", "")
111-
manufacturer = matcher.get("manufacturer", "")
112-
description = matcher.get("description", "")
113-
114-
if (
115-
vid != vid.upper()
116-
or pid != pid.upper()
117-
or serial_number != serial_number.lower()
118-
or manufacturer != manufacturer.lower()
119-
or description != description.lower()
120-
):
121-
raise ValueError(
122-
f"vid and pid must be uppercase, the rest lowercase in matcher {matcher!r}"
123-
)
124-
125-
usb_discovery: USBDiscovery = hass.data[DOMAIN]
126-
return any(
127-
usb_device_matches_matcher(
128-
USBDevice(
129-
device=device,
130-
vid=vid,
131-
pid=pid,
132-
serial_number=serial_number,
133-
manufacturer=manufacturer,
134-
description=description,
135-
),
136-
matcher,
137-
)
138-
for (
139-
device,
140-
vid,
141-
pid,
142-
serial_number,
143-
manufacturer,
144-
description,
145-
) in usb_discovery.seen
146-
)
147-
148-
149102
@hass_callback
150103
def async_get_usb_matchers_for_device(
151104
hass: HomeAssistant, device: USBDevice
@@ -244,7 +197,6 @@ def __init__(
244197
"""Init USB Discovery."""
245198
self.hass = hass
246199
self.usb = usb
247-
self.seen: set[tuple[str, ...]] = set()
248200
self.observer_active = False
249201
self._request_debouncer: Debouncer[Coroutine[Any, Any, None]] | None = None
250202
self._add_remove_debouncer: Debouncer[Coroutine[Any, Any, None]] | None = None
@@ -393,37 +345,40 @@ def async_get_usb_matchers_for_device(self, device: USBDevice) -> list[USBMatche
393345
async def _async_process_discovered_usb_device(self, device: USBDevice) -> None:
394346
"""Process a USB discovery."""
395347
_LOGGER.debug("Discovered USB Device: %s", device)
396-
device_tuple = dataclasses.astuple(device)
397-
if device_tuple in self.seen:
398-
return
399-
self.seen.add(device_tuple)
400-
401348
matched = self.async_get_usb_matchers_for_device(device)
402349
if not matched:
403350
return
404351

405-
service_info: _UsbServiceInfo | None = None
352+
service_info = usb_service_info_from_device(device)
406353

407354
for matcher in matched:
408-
if service_info is None:
409-
service_info = _UsbServiceInfo(
410-
device=await self.hass.async_add_executor_job(
411-
get_serial_by_id, device.device
412-
),
413-
vid=device.vid,
414-
pid=device.pid,
415-
serial_number=device.serial_number,
416-
manufacturer=device.manufacturer,
417-
description=device.description,
418-
)
419-
420355
discovery_flow.async_create_flow(
421356
self.hass,
422357
matcher["domain"],
423358
{"source": config_entries.SOURCE_USB},
424359
service_info,
425360
)
426361

362+
async def _async_process_removed_usb_device(self, device: USBDevice) -> None:
363+
"""Process a USB removal."""
364+
_LOGGER.debug("Removed USB Device: %s", device)
365+
matched = self.async_get_usb_matchers_for_device(device)
366+
if not matched:
367+
return
368+
369+
service_info = usb_service_info_from_device(device)
370+
371+
for matcher in matched:
372+
for flow in self.hass.config_entries.flow.async_progress_by_init_data_type(
373+
_UsbServiceInfo,
374+
lambda flow_service_info: flow_service_info == service_info,
375+
):
376+
if matcher["domain"] != flow["handler"]:
377+
continue
378+
379+
_LOGGER.debug("Aborting existing flow %s", flow["flow_id"])
380+
self.hass.config_entries.flow.async_abort(flow["flow_id"])
381+
427382
async def _async_process_ports(self, usb_devices: Sequence[USBDevice]) -> None:
428383
"""Process each discovered port."""
429384
_LOGGER.debug("USB devices: %r", usb_devices)
@@ -464,7 +419,10 @@ async def _async_process_ports(self, usb_devices: Sequence[USBDevice]) -> None:
464419
except Exception:
465420
_LOGGER.exception("Error in USB port event callback")
466421

467-
for usb_device in filtered_usb_devices:
422+
for usb_device in removed_devices:
423+
await self._async_process_removed_usb_device(usb_device)
424+
425+
for usb_device in added_devices:
468426
await self._async_process_discovered_usb_device(usb_device)
469427

470428
@hass_callback

0 commit comments

Comments
 (0)