|
5 | 5 | import abc |
6 | 6 | import asyncio |
7 | 7 | from collections import defaultdict |
8 | | -from collections.abc import Callable, Container, Hashable, Iterable, Mapping |
| 8 | +from collections.abc import Callable, Container, Coroutine, Hashable, Iterable, Mapping |
9 | 9 | from contextlib import suppress |
10 | 10 | import copy |
11 | 11 | from dataclasses import dataclass |
12 | 12 | from enum import StrEnum |
| 13 | +import functools |
13 | 14 | import logging |
14 | 15 | from types import MappingProxyType |
15 | | -from typing import Any, Generic, Required, TypedDict, TypeVar, cast |
| 16 | +from typing import Any, Concatenate, Generic, Required, TypedDict, TypeVar, cast |
16 | 17 |
|
17 | 18 | import voluptuous as vol |
18 | 19 |
|
@@ -150,6 +151,15 @@ class FlowResult(TypedDict, Generic[_FlowContextT, _HandlerT], total=False): |
150 | 151 | url: str |
151 | 152 |
|
152 | 153 |
|
| 154 | +class ProgressStepData[_FlowResultT](TypedDict): |
| 155 | + """Typed data for progress step tracking.""" |
| 156 | + |
| 157 | + tasks: dict[str, asyncio.Task[Any]] |
| 158 | + abort_reason: str |
| 159 | + abort_description_placeholders: Mapping[str, str] |
| 160 | + next_step_result: _FlowResultT | None |
| 161 | + |
| 162 | + |
153 | 163 | def _map_error_to_schema_errors( |
154 | 164 | schema_errors: dict[str, Any], |
155 | 165 | error: vol.Invalid, |
@@ -639,6 +649,12 @@ class FlowHandler(Generic[_FlowContextT, _FlowResultT, _HandlerT]): |
639 | 649 | __progress_task: asyncio.Task[Any] | None = None |
640 | 650 | __no_progress_task_reported = False |
641 | 651 | deprecated_show_progress = False |
| 652 | + _progress_step_data: ProgressStepData[_FlowResultT] = { |
| 653 | + "tasks": {}, |
| 654 | + "abort_reason": "", |
| 655 | + "abort_description_placeholders": MappingProxyType({}), |
| 656 | + "next_step_result": None, |
| 657 | + } |
642 | 658 |
|
643 | 659 | @property |
644 | 660 | def source(self) -> str | None: |
@@ -761,6 +777,37 @@ def async_abort( |
761 | 777 | description_placeholders=description_placeholders, |
762 | 778 | ) |
763 | 779 |
|
| 780 | + async def async_step__progress_step_abort( |
| 781 | + self, user_input: dict[str, Any] | None = None |
| 782 | + ) -> _FlowResultT: |
| 783 | + """Abort the flow.""" |
| 784 | + return self.async_abort( |
| 785 | + reason=self._progress_step_data["abort_reason"], |
| 786 | + description_placeholders=self._progress_step_data[ |
| 787 | + "abort_description_placeholders" |
| 788 | + ], |
| 789 | + ) |
| 790 | + |
| 791 | + async def async_step__progress_step_progress_done( |
| 792 | + self, user_input: dict[str, Any] | None = None |
| 793 | + ) -> _FlowResultT: |
| 794 | + """Progress done. Return the next step. |
| 795 | +
|
| 796 | + Used by the progress_step decorator |
| 797 | + to allow decorated step methods |
| 798 | + to call the next step method, to change step, |
| 799 | + without using async_show_progress_done. |
| 800 | + If no next step is set, abort the flow. |
| 801 | + """ |
| 802 | + if self._progress_step_data["next_step_result"] is None: |
| 803 | + return self.async_abort( |
| 804 | + reason=self._progress_step_data["abort_reason"], |
| 805 | + description_placeholders=self._progress_step_data[ |
| 806 | + "abort_description_placeholders" |
| 807 | + ], |
| 808 | + ) |
| 809 | + return self._progress_step_data["next_step_result"] |
| 810 | + |
764 | 811 | @callback |
765 | 812 | def async_external_step( |
766 | 813 | self, |
@@ -930,3 +977,90 @@ def __init__( |
930 | 977 | def __call__(self, value: Any) -> Any: |
931 | 978 | """Validate input.""" |
932 | 979 | return self.schema(value) |
| 980 | + |
| 981 | + |
| 982 | +type _FuncType[_T: FlowHandler[Any, Any, Any], _R: FlowResult[Any, Any], **_P] = ( |
| 983 | + Callable[Concatenate[_T, _P], Coroutine[Any, Any, _R]] |
| 984 | +) |
| 985 | + |
| 986 | + |
| 987 | +def progress_step[ |
| 988 | + HandlerT: FlowHandler[Any, Any, Any], |
| 989 | + ResultT: FlowResult[Any, Any], |
| 990 | + **P, |
| 991 | +]( |
| 992 | + description_placeholders: ( |
| 993 | + dict[str, str] | Callable[[Any], dict[str, str]] | None |
| 994 | + ) = None, |
| 995 | +) -> Callable[[_FuncType[HandlerT, ResultT, P]], _FuncType[HandlerT, ResultT, P]]: |
| 996 | + """Decorator to create a progress step from an async function. |
| 997 | +
|
| 998 | + The decorated method should be a step method |
| 999 | + which needs to show progress. |
| 1000 | + The method should accept dict[str, Any] as user_input |
| 1001 | + and should return a FlowResult or raise AbortFlow. |
| 1002 | + The method can call self.async_update_progress(progress) |
| 1003 | + to update progress. |
| 1004 | +
|
| 1005 | + Args: |
| 1006 | + description_placeholders: Static dict or callable that returns dict for progress UI placeholders. |
| 1007 | + """ |
| 1008 | + |
| 1009 | + def decorator( |
| 1010 | + func: _FuncType[HandlerT, ResultT, P], |
| 1011 | + ) -> _FuncType[HandlerT, ResultT, P]: |
| 1012 | + @functools.wraps(func) |
| 1013 | + async def wrapper( |
| 1014 | + self: FlowHandler[Any, ResultT], *args: P.args, **kwargs: P.kwargs |
| 1015 | + ) -> ResultT: |
| 1016 | + step_id = func.__name__.replace("async_step_", "") |
| 1017 | + |
| 1018 | + # Check if we have a progress task running |
| 1019 | + progress_task = self._progress_step_data["tasks"].get(step_id) |
| 1020 | + |
| 1021 | + if progress_task is None: |
| 1022 | + # First call - create and start the progress task |
| 1023 | + progress_task = self.hass.async_create_task( |
| 1024 | + func(self, *args, **kwargs), # type: ignore[arg-type] |
| 1025 | + f"Progress step {step_id}", |
| 1026 | + ) |
| 1027 | + self._progress_step_data["tasks"][step_id] = progress_task |
| 1028 | + |
| 1029 | + if not progress_task.done(): |
| 1030 | + # Handle description placeholders |
| 1031 | + placeholders = None |
| 1032 | + if description_placeholders is not None: |
| 1033 | + if callable(description_placeholders): |
| 1034 | + placeholders = description_placeholders(self) |
| 1035 | + else: |
| 1036 | + placeholders = description_placeholders |
| 1037 | + |
| 1038 | + return self.async_show_progress( |
| 1039 | + step_id=step_id, |
| 1040 | + progress_action=step_id, |
| 1041 | + progress_task=progress_task, |
| 1042 | + description_placeholders=placeholders, |
| 1043 | + ) |
| 1044 | + |
| 1045 | + # Task is done or this is a subsequent call |
| 1046 | + try: |
| 1047 | + self._progress_step_data["next_step_result"] = await progress_task |
| 1048 | + except AbortFlow as err: |
| 1049 | + self._progress_step_data["abort_reason"] = err.reason |
| 1050 | + self._progress_step_data["abort_description_placeholders"] = ( |
| 1051 | + err.description_placeholders or {} |
| 1052 | + ) |
| 1053 | + return self.async_show_progress_done( |
| 1054 | + next_step_id="_progress_step_abort" |
| 1055 | + ) |
| 1056 | + finally: |
| 1057 | + # Clean up task reference |
| 1058 | + self._progress_step_data["tasks"].pop(step_id, None) |
| 1059 | + |
| 1060 | + return self.async_show_progress_done( |
| 1061 | + next_step_id="_progress_step_progress_done" |
| 1062 | + ) |
| 1063 | + |
| 1064 | + return wrapper |
| 1065 | + |
| 1066 | + return decorator |
0 commit comments