Skip to content

Commit abd0ee7

Browse files
Fix progress step recursion (home-assistant#153906)
1 parent 9e3eb20 commit abd0ee7

File tree

2 files changed

+459
-36
lines changed

2 files changed

+459
-36
lines changed

homeassistant/data_entry_flow.py

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -645,12 +645,24 @@ class FlowHandler(Generic[_FlowContextT, _FlowResultT, _HandlerT]):
645645
__progress_task: asyncio.Task[Any] | None = None
646646
__no_progress_task_reported = False
647647
deprecated_show_progress = False
648-
_progress_step_data: ProgressStepData[_FlowResultT] = {
649-
"tasks": {},
650-
"abort_reason": "",
651-
"abort_description_placeholders": MappingProxyType({}),
652-
"next_step_result": None,
653-
}
648+
__progress_step_data: ProgressStepData[_FlowResultT] | None = None
649+
650+
@property
651+
def _progress_step_data(self) -> ProgressStepData[_FlowResultT]:
652+
"""Return progress step data.
653+
654+
A property is used instead of a simple attribute as derived classes
655+
do not call super().__init__.
656+
The property makes sure that the dict is initialized if needed.
657+
"""
658+
if not self.__progress_step_data:
659+
self.__progress_step_data = {
660+
"tasks": {},
661+
"abort_reason": "",
662+
"abort_description_placeholders": MappingProxyType({}),
663+
"next_step_result": None,
664+
}
665+
return self.__progress_step_data
654666

655667
@property
656668
def source(self) -> str | None:
@@ -777,9 +789,10 @@ async def async_step__progress_step_abort(
777789
self, user_input: dict[str, Any] | None = None
778790
) -> _FlowResultT:
779791
"""Abort the flow."""
792+
progress_step_data = self._progress_step_data
780793
return self.async_abort(
781-
reason=self._progress_step_data["abort_reason"],
782-
description_placeholders=self._progress_step_data[
794+
reason=progress_step_data["abort_reason"],
795+
description_placeholders=progress_step_data[
783796
"abort_description_placeholders"
784797
],
785798
)
@@ -795,14 +808,15 @@ async def async_step__progress_step_progress_done(
795808
without using async_show_progress_done.
796809
If no next step is set, abort the flow.
797810
"""
798-
if self._progress_step_data["next_step_result"] is None:
811+
progress_step_data = self._progress_step_data
812+
if (next_step_result := progress_step_data["next_step_result"]) is None:
799813
return self.async_abort(
800-
reason=self._progress_step_data["abort_reason"],
801-
description_placeholders=self._progress_step_data[
814+
reason=progress_step_data["abort_reason"],
815+
description_placeholders=progress_step_data[
802816
"abort_description_placeholders"
803817
],
804818
)
805-
return self._progress_step_data["next_step_result"]
819+
return next_step_result
806820

807821
@callback
808822
def async_external_step(
@@ -1021,48 +1035,55 @@ async def wrapper(
10211035
self: FlowHandler[Any, ResultT], *args: P.args, **kwargs: P.kwargs
10221036
) -> ResultT:
10231037
step_id = func.__name__.replace("async_step_", "")
1024-
1038+
progress_step_data = self._progress_step_data
10251039
# Check if we have a progress task running
1026-
progress_task = self._progress_step_data["tasks"].get(step_id)
1040+
progress_task = progress_step_data["tasks"].get(step_id)
10271041

10281042
if progress_task is None:
10291043
# First call - create and start the progress task
10301044
progress_task = self.hass.async_create_task(
10311045
func(self, *args, **kwargs), # type: ignore[arg-type]
10321046
f"Progress step {step_id}",
10331047
)
1034-
self._progress_step_data["tasks"][step_id] = progress_task
1035-
1036-
if not progress_task.done():
1037-
# Handle description placeholders
1038-
placeholders = None
1039-
if description_placeholders is not None:
1040-
if callable(description_placeholders):
1041-
placeholders = description_placeholders(self)
1042-
else:
1043-
placeholders = description_placeholders
1044-
1045-
return self.async_show_progress(
1046-
step_id=step_id,
1047-
progress_action=step_id,
1048-
progress_task=progress_task,
1049-
description_placeholders=placeholders,
1050-
)
1048+
progress_step_data["tasks"][step_id] = progress_task
1049+
1050+
if not progress_task.done():
1051+
# Handle description placeholders
1052+
placeholders = None
1053+
if description_placeholders is not None:
1054+
if callable(description_placeholders):
1055+
placeholders = description_placeholders(self)
1056+
else:
1057+
placeholders = description_placeholders
1058+
1059+
return self.async_show_progress(
1060+
step_id=step_id,
1061+
progress_action=step_id,
1062+
progress_task=progress_task,
1063+
description_placeholders=placeholders,
1064+
)
10511065

10521066
# Task is done or this is a subsequent call
10531067
try:
1054-
self._progress_step_data["next_step_result"] = await progress_task
1068+
progress_task_result = await progress_task
10551069
except AbortFlow as err:
1056-
self._progress_step_data["abort_reason"] = err.reason
1057-
self._progress_step_data["abort_description_placeholders"] = (
1070+
progress_step_data["abort_reason"] = err.reason
1071+
progress_step_data["abort_description_placeholders"] = (
10581072
err.description_placeholders or {}
10591073
)
10601074
return self.async_show_progress_done(
10611075
next_step_id="_progress_step_abort"
10621076
)
10631077
finally:
10641078
# Clean up task reference
1065-
self._progress_step_data["tasks"].pop(step_id, None)
1079+
progress_step_data["tasks"].pop(step_id, None)
1080+
1081+
# If the result type is FlowResultType.SHOW_PROGRESS_DONE
1082+
# an earlier show progress step has already been run and stored its result.
1083+
# In this case we should not overwrite the result,
1084+
# but just use the stored one.
1085+
if progress_task_result["type"] != FlowResultType.SHOW_PROGRESS_DONE:
1086+
progress_step_data["next_step_result"] = progress_task_result
10661087

10671088
return self.async_show_progress_done(
10681089
next_step_id="_progress_step_progress_done"

0 commit comments

Comments
 (0)