Skip to content

Commit 73dc810

Browse files
RaHehlbdraco
andauthored
Implement reconfiguration flow for UniFi Protect integration (home-assistant#157532)
Co-authored-by: J. Nick Koston <[email protected]>
1 parent f306cde commit 73dc810

File tree

4 files changed

+1362
-100
lines changed

4 files changed

+1362
-100
lines changed

homeassistant/components/unifiprotect/config_flow.py

Lines changed: 214 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
CONF_VERIFY_SSL,
3232
)
3333
from homeassistant.core import HomeAssistant, callback
34+
from homeassistant.helpers import selector
3435
from homeassistant.helpers.aiohttp_client import (
3536
async_create_clientsession,
3637
async_get_clientsession,
@@ -56,15 +57,113 @@
5657
)
5758
from .data import UFPConfigEntry, async_last_update_was_successful
5859
from .discovery import async_start_discovery
59-
from .utils import _async_resolve, _async_short_mac, _async_unifi_mac_from_hass
60+
from .utils import (
61+
_async_resolve,
62+
_async_short_mac,
63+
_async_unifi_mac_from_hass,
64+
async_create_api_client,
65+
)
6066

6167
_LOGGER = logging.getLogger(__name__)
6268

69+
70+
def _filter_empty_credentials(user_input: dict[str, Any]) -> dict[str, Any]:
71+
"""Filter out empty credential fields to preserve existing values."""
72+
return {k: v for k, v in user_input.items() if v not in (None, "")}
73+
74+
75+
def _normalize_port(data: dict[str, Any]) -> dict[str, Any]:
76+
"""Ensure port is stored as int (NumberSelector returns float)."""
77+
return {**data, CONF_PORT: int(data.get(CONF_PORT, DEFAULT_PORT))}
78+
79+
80+
def _build_data_without_credentials(entry_data: Mapping[str, Any]) -> dict[str, Any]:
81+
"""Build form data from existing config entry, excluding sensitive credentials."""
82+
return {
83+
CONF_HOST: entry_data[CONF_HOST],
84+
CONF_PORT: entry_data[CONF_PORT],
85+
CONF_VERIFY_SSL: entry_data[CONF_VERIFY_SSL],
86+
CONF_USERNAME: entry_data[CONF_USERNAME],
87+
}
88+
89+
90+
async def _async_clear_session_if_credentials_changed(
91+
hass: HomeAssistant,
92+
entry: UFPConfigEntry,
93+
new_data: Mapping[str, Any],
94+
) -> None:
95+
"""Clear stored session if credentials have changed to force fresh authentication."""
96+
existing_data = entry.data
97+
if existing_data.get(CONF_USERNAME) != new_data.get(
98+
CONF_USERNAME
99+
) or existing_data.get(CONF_PASSWORD) != new_data.get(CONF_PASSWORD):
100+
_LOGGER.debug("Credentials changed, clearing stored session")
101+
protect = async_create_api_client(hass, entry)
102+
try:
103+
await protect.clear_session()
104+
except Exception as ex: # noqa: BLE001
105+
_LOGGER.debug("Failed to clear session, continuing anyway: %s", ex)
106+
107+
63108
ENTRY_FAILURE_STATES = (
64109
ConfigEntryState.SETUP_ERROR,
65110
ConfigEntryState.SETUP_RETRY,
66111
)
67112

113+
# Selectors for config flow form fields
114+
_TEXT_SELECTOR = selector.TextSelector()
115+
_PASSWORD_SELECTOR = selector.TextSelector(
116+
selector.TextSelectorConfig(type=selector.TextSelectorType.PASSWORD)
117+
)
118+
_PORT_SELECTOR = selector.NumberSelector(
119+
selector.NumberSelectorConfig(
120+
mode=selector.NumberSelectorMode.BOX, min=1, max=65535
121+
)
122+
)
123+
_BOOL_SELECTOR = selector.BooleanSelector()
124+
125+
126+
def _build_schema(
127+
*,
128+
include_host: bool = True,
129+
include_connection: bool = True,
130+
credentials_optional: bool = False,
131+
) -> vol.Schema:
132+
"""Build a config flow schema.
133+
134+
Args:
135+
include_host: Include host field (False when host comes from discovery).
136+
include_connection: Include port/verify_ssl fields.
137+
credentials_optional: Credentials optional (True to keep existing values).
138+
139+
"""
140+
req, opt = vol.Required, vol.Optional
141+
cred_key = opt if credentials_optional else req
142+
143+
schema: dict[vol.Marker, selector.Selector] = {}
144+
if include_host:
145+
schema[req(CONF_HOST)] = _TEXT_SELECTOR
146+
if include_connection:
147+
schema[req(CONF_PORT, default=DEFAULT_PORT)] = _PORT_SELECTOR
148+
schema[req(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL)] = _BOOL_SELECTOR
149+
schema[req(CONF_USERNAME)] = _TEXT_SELECTOR
150+
schema[cred_key(CONF_PASSWORD)] = _PASSWORD_SELECTOR
151+
schema[cred_key(CONF_API_KEY)] = _PASSWORD_SELECTOR
152+
return vol.Schema(schema)
153+
154+
155+
# Schemas for different flow contexts
156+
# User flow: all fields required
157+
CONFIG_SCHEMA = _build_schema()
158+
# Reconfigure flow: keep existing credentials if not provided
159+
RECONFIGURE_SCHEMA = _build_schema(credentials_optional=True)
160+
# Discovery flow: host comes from discovery, user sets port/ssl
161+
DISCOVERY_SCHEMA = _build_schema(include_host=False)
162+
# Reauth flow: only credentials, connection settings preserved
163+
REAUTH_SCHEMA = _build_schema(
164+
include_host=False, include_connection=False, credentials_optional=True
165+
)
166+
68167

69168
async def async_local_user_documentation_url(hass: HomeAssistant) -> str:
70169
"""Get the documentation url for creating a local user."""
@@ -178,19 +277,40 @@ async def async_step_discovery_confirm(
178277
"""Confirm discovery."""
179278
errors: dict[str, str] = {}
180279
discovery_info = self._discovered_device
280+
281+
form_data = {
282+
CONF_HOST: discovery_info["direct_connect_domain"]
283+
or discovery_info["source_ip"],
284+
CONF_PORT: DEFAULT_PORT,
285+
CONF_VERIFY_SSL: bool(discovery_info["direct_connect_domain"]),
286+
CONF_USERNAME: "",
287+
CONF_PASSWORD: "",
288+
}
289+
181290
if user_input is not None:
182-
user_input[CONF_PORT] = DEFAULT_PORT
291+
# Merge user input with discovery info
292+
merged_input = {**form_data, **user_input}
183293
nvr_data = None
184294
if discovery_info["direct_connect_domain"]:
185-
user_input[CONF_HOST] = discovery_info["direct_connect_domain"]
186-
user_input[CONF_VERIFY_SSL] = True
187-
nvr_data, errors = await self._async_get_nvr_data(user_input)
295+
merged_input[CONF_HOST] = discovery_info["direct_connect_domain"]
296+
merged_input[CONF_VERIFY_SSL] = True
297+
nvr_data, errors = await self._async_get_nvr_data(merged_input)
188298
if not nvr_data or errors:
189-
user_input[CONF_HOST] = discovery_info["source_ip"]
190-
user_input[CONF_VERIFY_SSL] = False
191-
nvr_data, errors = await self._async_get_nvr_data(user_input)
299+
merged_input[CONF_HOST] = discovery_info["source_ip"]
300+
merged_input[CONF_VERIFY_SSL] = False
301+
nvr_data, errors = await self._async_get_nvr_data(merged_input)
192302
if nvr_data and not errors:
193-
return self._async_create_entry(nvr_data.display_name, user_input)
303+
return self._async_create_entry(nvr_data.display_name, merged_input)
304+
# Preserve user input for form re-display, but keep discovery info
305+
form_data = {
306+
CONF_HOST: merged_input[CONF_HOST],
307+
CONF_PORT: merged_input[CONF_PORT],
308+
CONF_VERIFY_SSL: merged_input[CONF_VERIFY_SSL],
309+
CONF_USERNAME: user_input.get(CONF_USERNAME, ""),
310+
CONF_PASSWORD: user_input.get(CONF_PASSWORD, ""),
311+
}
312+
if CONF_API_KEY in user_input:
313+
form_data[CONF_API_KEY] = user_input[CONF_API_KEY]
194314

195315
placeholders = {
196316
"name": discovery_info["hostname"]
@@ -199,7 +319,6 @@ async def async_step_discovery_confirm(
199319
"ip_address": discovery_info["source_ip"],
200320
}
201321
self.context["title_placeholders"] = placeholders
202-
user_input = user_input or {}
203322
return self.async_show_form(
204323
step_id="discovery_confirm",
205324
description_placeholders={
@@ -208,14 +327,8 @@ async def async_step_discovery_confirm(
208327
self.hass
209328
),
210329
},
211-
data_schema=vol.Schema(
212-
{
213-
vol.Required(
214-
CONF_USERNAME, default=user_input.get(CONF_USERNAME)
215-
): str,
216-
vol.Required(CONF_PASSWORD): str,
217-
vol.Required(CONF_API_KEY): str,
218-
}
330+
data_schema=self.add_suggested_values_to_schema(
331+
DISCOVERY_SCHEMA, form_data
219332
),
220333
errors=errors,
221334
)
@@ -232,7 +345,7 @@ def async_get_options_flow(
232345
def _async_create_entry(self, title: str, data: dict[str, Any]) -> ConfigFlowResult:
233346
return self.async_create_entry(
234347
title=title,
235-
data={**data, CONF_ID: title},
348+
data={**_normalize_port(data), CONF_ID: title},
236349
options={
237350
CONF_DISABLE_RTSP: False,
238351
CONF_ALL_UPDATES: False,
@@ -251,7 +364,7 @@ async def _async_get_nvr_data(
251364
public_api_session = async_get_clientsession(self.hass)
252365

253366
host = user_input[CONF_HOST]
254-
port = user_input.get(CONF_PORT, DEFAULT_PORT)
367+
port = int(user_input.get(CONF_PORT, DEFAULT_PORT))
255368
verify_ssl = user_input.get(CONF_VERIFY_SSL, DEFAULT_VERIFY_SSL)
256369

257370
protect = ProtectApiClient(
@@ -261,7 +374,7 @@ async def _async_get_nvr_data(
261374
port=port,
262375
username=user_input[CONF_USERNAME],
263376
password=user_input[CONF_PASSWORD],
264-
api_key=user_input[CONF_API_KEY],
377+
api_key=user_input.get(CONF_API_KEY, ""),
265378
verify_ssl=verify_ssl,
266379
cache_dir=Path(self.hass.config.path(STORAGE_DIR, "unifiprotect")),
267380
config_dir=Path(self.hass.config.path(STORAGE_DIR, "unifiprotect")),
@@ -290,14 +403,17 @@ async def _async_get_nvr_data(
290403
auth_user = bootstrap.users.get(bootstrap.auth_user_id)
291404
if auth_user and auth_user.cloud_account:
292405
errors["base"] = "cloud_user"
293-
try:
294-
await protect.get_meta_info()
295-
except NotAuthorized as ex:
296-
_LOGGER.debug(ex)
297-
errors[CONF_API_KEY] = "invalid_auth"
298-
except ClientError as ex:
299-
_LOGGER.error(ex)
300-
errors["base"] = "cannot_connect"
406+
407+
# Only validate API key if bootstrap succeeded
408+
if nvr_data and not errors:
409+
try:
410+
await protect.get_meta_info()
411+
except NotAuthorized as ex:
412+
_LOGGER.debug(ex)
413+
errors[CONF_API_KEY] = "invalid_auth"
414+
except ClientError as ex:
415+
_LOGGER.error(ex)
416+
errors["base"] = "cannot_connect"
301417

302418
return nvr_data, errors
303419

@@ -313,16 +429,27 @@ async def async_step_reauth_confirm(
313429
"""Confirm reauth."""
314430
errors: dict[str, str] = {}
315431

316-
# prepopulate fields
317432
reauth_entry = self._get_reauth_entry()
318-
form_data = {**reauth_entry.data}
433+
form_data = _build_data_without_credentials(reauth_entry.data)
434+
319435
if user_input is not None:
320-
form_data.update(user_input)
436+
# Merge with existing config - empty credentials keep existing values
437+
merged_input = {
438+
**reauth_entry.data,
439+
**_filter_empty_credentials(user_input),
440+
}
441+
442+
# Clear stored session if credentials changed to force fresh authentication
443+
await _async_clear_session_if_credentials_changed(
444+
self.hass, reauth_entry, merged_input
445+
)
321446

322447
# validate login data
323-
_, errors = await self._async_get_nvr_data(form_data)
448+
_, errors = await self._async_get_nvr_data(merged_input)
324449
if not errors:
325-
return self.async_update_reload_and_abort(reauth_entry, data=form_data)
450+
return self.async_update_reload_and_abort(
451+
reauth_entry, data=_normalize_port(merged_input)
452+
)
326453

327454
self.context["title_placeholders"] = {
328455
"name": reauth_entry.title,
@@ -335,14 +462,58 @@ async def async_step_reauth_confirm(
335462
self.hass
336463
),
337464
},
338-
data_schema=vol.Schema(
339-
{
340-
vol.Required(
341-
CONF_USERNAME, default=form_data.get(CONF_USERNAME)
342-
): str,
343-
vol.Required(CONF_PASSWORD): str,
344-
vol.Required(CONF_API_KEY): str,
345-
}
465+
data_schema=self.add_suggested_values_to_schema(REAUTH_SCHEMA, form_data),
466+
errors=errors,
467+
)
468+
469+
async def async_step_reconfigure(
470+
self, user_input: dict[str, Any] | None = None
471+
) -> ConfigFlowResult:
472+
"""Handle reconfiguration of the integration."""
473+
errors: dict[str, str] = {}
474+
475+
reconfigure_entry = self._get_reconfigure_entry()
476+
form_data = _build_data_without_credentials(reconfigure_entry.data)
477+
478+
if user_input is not None:
479+
# Merge with existing config - empty credentials keep existing values
480+
merged_input = {
481+
**reconfigure_entry.data,
482+
**_filter_empty_credentials(user_input),
483+
}
484+
485+
# Clear stored session if credentials changed to force fresh authentication
486+
await _async_clear_session_if_credentials_changed(
487+
self.hass, reconfigure_entry, merged_input
488+
)
489+
490+
# validate login data
491+
nvr_data, errors = await self._async_get_nvr_data(merged_input)
492+
if nvr_data and not errors:
493+
new_unique_id = _async_unifi_mac_from_hass(nvr_data.mac)
494+
_LOGGER.debug(
495+
"Reconfigure: Current unique_id=%s, NVR MAC=%s, formatted=%s",
496+
reconfigure_entry.unique_id,
497+
nvr_data.mac,
498+
new_unique_id,
499+
)
500+
await self.async_set_unique_id(new_unique_id)
501+
self._abort_if_unique_id_mismatch(reason="wrong_nvr")
502+
503+
return self.async_update_reload_and_abort(
504+
reconfigure_entry,
505+
data=_normalize_port(merged_input),
506+
)
507+
508+
return self.async_show_form(
509+
step_id="reconfigure",
510+
description_placeholders={
511+
"local_user_documentation_url": await async_local_user_documentation_url(
512+
self.hass
513+
),
514+
},
515+
data_schema=self.add_suggested_values_to_schema(
516+
RECONFIGURE_SCHEMA, form_data
346517
),
347518
errors=errors,
348519
)
@@ -362,31 +533,14 @@ async def async_step_user(
362533

363534
return self._async_create_entry(nvr_data.display_name, user_input)
364535

365-
user_input = user_input or {}
366536
return self.async_show_form(
367537
step_id="user",
368538
description_placeholders={
369539
"local_user_documentation_url": await async_local_user_documentation_url(
370540
self.hass
371541
)
372542
},
373-
data_schema=vol.Schema(
374-
{
375-
vol.Required(CONF_HOST, default=user_input.get(CONF_HOST)): str,
376-
vol.Required(
377-
CONF_PORT, default=user_input.get(CONF_PORT, DEFAULT_PORT)
378-
): int,
379-
vol.Required(
380-
CONF_VERIFY_SSL,
381-
default=user_input.get(CONF_VERIFY_SSL, DEFAULT_VERIFY_SSL),
382-
): bool,
383-
vol.Required(
384-
CONF_USERNAME, default=user_input.get(CONF_USERNAME)
385-
): str,
386-
vol.Required(CONF_PASSWORD): str,
387-
vol.Required(CONF_API_KEY): str,
388-
}
389-
),
543+
data_schema=self.add_suggested_values_to_schema(CONFIG_SCHEMA, user_input),
390544
errors=errors,
391545
)
392546

0 commit comments

Comments
 (0)