44import logging
55import time
66import uuid
7+ from collections .abc import Callable , Iterable
78from dataclasses import dataclass
89from enum import Enum
910from io import BytesIO
10- from typing import Any , Callable , Iterable , Literal , Optional , Type , TypeVar , Union
11+ from typing import Any , Literal , TypeVar
1112from urllib .parse import urljoin , urlparse
1213
1314import aiohttp
@@ -37,8 +38,8 @@ def __init__(
3738 path : str ,
3839 method : Literal ["GET" , "POST" , "PUT" , "DELETE" , "PATCH" ] = "GET" ,
3940 * ,
40- query_params : Optional [ dict [str , Any ]] = None ,
41- headers : Optional [ dict [str , str ]] = None ,
41+ query_params : dict [str , Any ] | None = None ,
42+ headers : dict [str , str ] | None = None ,
4243 ):
4344 self .path = path
4445 self .method = method
@@ -52,29 +53,29 @@ class _RequestConfig:
5253 endpoint : ApiEndpoint
5354 timeout : float
5455 content_type : str
55- data : Optional [ dict [str , Any ]]
56- files : Optional [ Union [ dict [str , Any ], list [tuple [str , Any ]]]]
57- multipart_parser : Optional [ Callable ]
56+ data : dict [str , Any ] | None
57+ files : dict [str , Any ] | list [tuple [str , Any ]] | None
58+ multipart_parser : Callable | None
5859 max_retries : int
5960 retry_delay : float
6061 retry_backoff : float
6162 wait_label : str = "Waiting"
6263 monitor_progress : bool = True
63- estimated_total : Optional [ int ] = None
64- final_label_on_success : Optional [ str ] = "Completed"
65- progress_origin_ts : Optional [ float ] = None
66- price_extractor : Optional [ Callable [[dict [str , Any ]], Optional [ float ]]] = None
64+ estimated_total : int | None = None
65+ final_label_on_success : str | None = "Completed"
66+ progress_origin_ts : float | None = None
67+ price_extractor : Callable [[dict [str , Any ]], float | None ] | None = None
6768
6869
6970@dataclass
7071class _PollUIState :
7172 started : float
7273 status_label : str = "Queued"
7374 is_queued : bool = True
74- price : Optional [ float ] = None
75- estimated_duration : Optional [ int ] = None
75+ price : float | None = None
76+ estimated_duration : int | None = None
7677 base_processing_elapsed : float = 0.0 # sum of completed active intervals
77- active_since : Optional [ float ] = None # start time of current active interval (None if queued)
78+ active_since : float | None = None # start time of current active interval (None if queued)
7879
7980
8081_RETRY_STATUS = {408 , 429 , 500 , 502 , 503 , 504 }
@@ -87,20 +88,20 @@ async def sync_op(
8788 cls : type [IO .ComfyNode ],
8889 endpoint : ApiEndpoint ,
8990 * ,
90- response_model : Type [M ],
91- price_extractor : Optional [ Callable [[M ], Optional [ float ]]] = None ,
92- data : Optional [ BaseModel ] = None ,
93- files : Optional [ Union [ dict [str , Any ], list [tuple [str , Any ]]]] = None ,
91+ response_model : type [M ],
92+ price_extractor : Callable [[M | Any ], float | None ] | None = None ,
93+ data : BaseModel | None = None ,
94+ files : dict [str , Any ] | list [tuple [str , Any ]] | None = None ,
9495 content_type : str = "application/json" ,
9596 timeout : float = 3600.0 ,
96- multipart_parser : Optional [ Callable ] = None ,
97+ multipart_parser : Callable | None = None ,
9798 max_retries : int = 3 ,
9899 retry_delay : float = 1.0 ,
99100 retry_backoff : float = 2.0 ,
100101 wait_label : str = "Waiting for server" ,
101- estimated_duration : Optional [ int ] = None ,
102- final_label_on_success : Optional [ str ] = "Completed" ,
103- progress_origin_ts : Optional [ float ] = None ,
102+ estimated_duration : int | None = None ,
103+ final_label_on_success : str | None = "Completed" ,
104+ progress_origin_ts : float | None = None ,
104105 monitor_progress : bool = True ,
105106) -> M :
106107 raw = await sync_op_raw (
@@ -131,22 +132,22 @@ async def poll_op(
131132 cls : type [IO .ComfyNode ],
132133 poll_endpoint : ApiEndpoint ,
133134 * ,
134- response_model : Type [M ],
135- status_extractor : Callable [[M ], Optional [ Union [ str , int ]] ],
136- progress_extractor : Optional [ Callable [[M ], Optional [ int ]]] = None ,
137- price_extractor : Optional [ Callable [[M ], Optional [ float ]]] = None ,
138- completed_statuses : Optional [ list [Union [ str , int ]]] = None ,
139- failed_statuses : Optional [ list [Union [ str , int ]]] = None ,
140- queued_statuses : Optional [ list [Union [ str , int ]]] = None ,
141- data : Optional [ BaseModel ] = None ,
135+ response_model : type [M ],
136+ status_extractor : Callable [[M | Any ], str | int | None ],
137+ progress_extractor : Callable [[M | Any ], int | None ] | None = None ,
138+ price_extractor : Callable [[M | Any ], float | None ] | None = None ,
139+ completed_statuses : list [str | int ] | None = None ,
140+ failed_statuses : list [str | int ] | None = None ,
141+ queued_statuses : list [str | int ] | None = None ,
142+ data : BaseModel | None = None ,
142143 poll_interval : float = 5.0 ,
143144 max_poll_attempts : int = 120 ,
144145 timeout_per_poll : float = 120.0 ,
145146 max_retries_per_poll : int = 3 ,
146147 retry_delay_per_poll : float = 1.0 ,
147148 retry_backoff_per_poll : float = 2.0 ,
148- estimated_duration : Optional [ int ] = None ,
149- cancel_endpoint : Optional [ ApiEndpoint ] = None ,
149+ estimated_duration : int | None = None ,
150+ cancel_endpoint : ApiEndpoint | None = None ,
150151 cancel_timeout : float = 10.0 ,
151152) -> M :
152153 raw = await poll_op_raw (
@@ -178,22 +179,22 @@ async def sync_op_raw(
178179 cls : type [IO .ComfyNode ],
179180 endpoint : ApiEndpoint ,
180181 * ,
181- price_extractor : Optional [ Callable [[dict [str , Any ]], Optional [ float ]]] = None ,
182- data : Optional [ Union [ dict [str , Any ], BaseModel ]] = None ,
183- files : Optional [ Union [ dict [str , Any ], list [tuple [str , Any ]]]] = None ,
182+ price_extractor : Callable [[dict [str , Any ]], float | None ] | None = None ,
183+ data : dict [str , Any ] | BaseModel | None = None ,
184+ files : dict [str , Any ] | list [tuple [str , Any ]] | None = None ,
184185 content_type : str = "application/json" ,
185186 timeout : float = 3600.0 ,
186- multipart_parser : Optional [ Callable ] = None ,
187+ multipart_parser : Callable | None = None ,
187188 max_retries : int = 3 ,
188189 retry_delay : float = 1.0 ,
189190 retry_backoff : float = 2.0 ,
190191 wait_label : str = "Waiting for server" ,
191- estimated_duration : Optional [ int ] = None ,
192+ estimated_duration : int | None = None ,
192193 as_binary : bool = False ,
193- final_label_on_success : Optional [ str ] = "Completed" ,
194- progress_origin_ts : Optional [ float ] = None ,
194+ final_label_on_success : str | None = "Completed" ,
195+ progress_origin_ts : float | None = None ,
195196 monitor_progress : bool = True ,
196- ) -> Union [ dict [str , Any ], bytes ] :
197+ ) -> dict [str , Any ] | bytes :
197198 """
198199 Make a single network request.
199200 - If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
@@ -229,21 +230,21 @@ async def poll_op_raw(
229230 cls : type [IO .ComfyNode ],
230231 poll_endpoint : ApiEndpoint ,
231232 * ,
232- status_extractor : Callable [[dict [str , Any ]], Optional [ Union [ str , int ]] ],
233- progress_extractor : Optional [ Callable [[dict [str , Any ]], Optional [ int ]]] = None ,
234- price_extractor : Optional [ Callable [[dict [str , Any ]], Optional [ float ]]] = None ,
235- completed_statuses : Optional [ list [Union [ str , int ]]] = None ,
236- failed_statuses : Optional [ list [Union [ str , int ]]] = None ,
237- queued_statuses : Optional [ list [Union [ str , int ]]] = None ,
238- data : Optional [ Union [ dict [str , Any ], BaseModel ]] = None ,
233+ status_extractor : Callable [[dict [str , Any ]], str | int | None ],
234+ progress_extractor : Callable [[dict [str , Any ]], int | None ] | None = None ,
235+ price_extractor : Callable [[dict [str , Any ]], float | None ] | None = None ,
236+ completed_statuses : list [str | int ] | None = None ,
237+ failed_statuses : list [str | int ] | None = None ,
238+ queued_statuses : list [str | int ] | None = None ,
239+ data : dict [str , Any ] | BaseModel | None = None ,
239240 poll_interval : float = 5.0 ,
240241 max_poll_attempts : int = 120 ,
241242 timeout_per_poll : float = 120.0 ,
242243 max_retries_per_poll : int = 3 ,
243244 retry_delay_per_poll : float = 1.0 ,
244245 retry_backoff_per_poll : float = 2.0 ,
245- estimated_duration : Optional [ int ] = None ,
246- cancel_endpoint : Optional [ ApiEndpoint ] = None ,
246+ estimated_duration : int | None = None ,
247+ cancel_endpoint : ApiEndpoint | None = None ,
247248 cancel_timeout : float = 10.0 ,
248249) -> dict [str , Any ]:
249250 """
@@ -261,7 +262,7 @@ async def poll_op_raw(
261262 consumed_attempts = 0 # counts only non-queued polls
262263
263264 progress_bar = utils .ProgressBar (100 ) if progress_extractor else None
264- last_progress : Optional [ int ] = None
265+ last_progress : int | None = None
265266
266267 state = _PollUIState (started = started , estimated_duration = estimated_duration )
267268 stop_ticker = asyncio .Event ()
@@ -420,10 +421,10 @@ async def _ticker():
420421
421422def _display_text (
422423 node_cls : type [IO .ComfyNode ],
423- text : Optional [ str ] ,
424+ text : str | None ,
424425 * ,
425- status : Optional [ Union [ str , int ]] = None ,
426- price : Optional [ float ] = None ,
426+ status : str | int | None = None ,
427+ price : float | None = None ,
427428) -> None :
428429 display_lines : list [str ] = []
429430 if status :
@@ -440,13 +441,13 @@ def _display_text(
440441
441442def _display_time_progress (
442443 node_cls : type [IO .ComfyNode ],
443- status : Optional [ Union [ str , int ]] ,
444+ status : str | int | None ,
444445 elapsed_seconds : int ,
445- estimated_total : Optional [ int ] = None ,
446+ estimated_total : int | None = None ,
446447 * ,
447- price : Optional [ float ] = None ,
448- is_queued : Optional [ bool ] = None ,
449- processing_elapsed_seconds : Optional [ int ] = None ,
448+ price : float | None = None ,
449+ is_queued : bool | None = None ,
450+ processing_elapsed_seconds : int | None = None ,
450451) -> None :
451452 if estimated_total is not None and estimated_total > 0 and is_queued is False :
452453 pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
@@ -488,7 +489,7 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
488489 raise ValueError ("files tuple must be (filename, file[, content_type])" )
489490
490491
491- def _merge_params (endpoint_params : dict [str , Any ], method : str , data : Optional [ dict [str , Any ]] ) -> dict [str , Any ]:
492+ def _merge_params (endpoint_params : dict [str , Any ], method : str , data : dict [str , Any ] | None ) -> dict [str , Any ]:
492493 params = dict (endpoint_params or {})
493494 if method .upper () == "GET" and data :
494495 for k , v in data .items ():
@@ -534,9 +535,9 @@ def _generate_operation_id(method: str, path: str, attempt: int) -> str:
534535def _snapshot_request_body_for_logging (
535536 content_type : str ,
536537 method : str ,
537- data : Optional [ dict [str , Any ]] ,
538- files : Optional [ Union [ dict [str , Any ], list [tuple [str , Any ]]]] ,
539- ) -> Optional [ Union [ dict [str , Any ], str ]] :
538+ data : dict [str , Any ] | None ,
539+ files : dict [str , Any ] | list [tuple [str , Any ]] | None ,
540+ ) -> dict [str , Any ] | str | None :
540541 if method .upper () == "GET" :
541542 return None
542543 if content_type == "multipart/form-data" :
@@ -586,13 +587,13 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float):
586587 attempt = 0
587588 delay = cfg .retry_delay
588589 operation_succeeded : bool = False
589- final_elapsed_seconds : Optional [ int ] = None
590- extracted_price : Optional [ float ] = None
590+ final_elapsed_seconds : int | None = None
591+ extracted_price : float | None = None
591592 while True :
592593 attempt += 1
593594 stop_event = asyncio .Event ()
594- monitor_task : Optional [ asyncio .Task ] = None
595- sess : Optional [ aiohttp .ClientSession ] = None
595+ monitor_task : asyncio .Task | None = None
596+ sess : aiohttp .ClientSession | None = None
596597
597598 operation_id = _generate_operation_id (method , cfg .endpoint .path , attempt )
598599 logging .debug ("[DEBUG] HTTP %s %s (attempt %d)" , method , url , attempt )
@@ -887,7 +888,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float):
887888 )
888889
889890
890- def _validate_or_raise (response_model : Type [M ], payload : Any ) -> M :
891+ def _validate_or_raise (response_model : type [M ], payload : Any ) -> M :
891892 try :
892893 return response_model .model_validate (payload )
893894 except Exception as e :
@@ -902,9 +903,9 @@ def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
902903
903904
904905def _wrap_model_extractor (
905- response_model : Type [M ],
906- extractor : Optional [ Callable [[M ], Any ]] ,
907- ) -> Optional [ Callable [[dict [str , Any ]], Any ]] :
906+ response_model : type [M ],
907+ extractor : Callable [[M ], Any ] | None ,
908+ ) -> Callable [[dict [str , Any ]], Any ] | None :
908909 """Wrap a typed extractor so it can be used by the dict-based poller.
909910 Validates the dict into `response_model` before invoking `extractor`.
910911 Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
@@ -929,18 +930,18 @@ def _wrapped(d: dict[str, Any]) -> Any:
929930 return _wrapped
930931
931932
932- def _normalize_statuses (values : Optional [ Iterable [Union [ str , int ]]] ) -> set [Union [ str , int ] ]:
933+ def _normalize_statuses (values : Iterable [str | int ] | None ) -> set [str | int ]:
933934 if not values :
934935 return set ()
935- out : set [Union [ str , int ] ] = set ()
936+ out : set [str | int ] = set ()
936937 for v in values :
937938 nv = _normalize_status_value (v )
938939 if nv is not None :
939940 out .add (nv )
940941 return out
941942
942943
943- def _normalize_status_value (val : Union [ str , int , None ] ) -> Union [ str , int , None ] :
944+ def _normalize_status_value (val : str | int | None ) -> str | int | None :
944945 if isinstance (val , str ):
945946 return val .strip ().lower ()
946947 return val
0 commit comments