Skip to content

Commit 3c84562

Browse files
authored
[API Nodes]: fixes and refactor (#11104)
* chore(api-nodes): applied ruff's pyupgrade(python3.10) to api-nodes client's to folder * chore(api-nodes): add validate_video_frame_count function from LTX PR * chore(api-nodes): replace deprecated V1 imports * fix(api-nodes): the types returned by the "poll_op" function are now correct.
1 parent 9bc893c commit 3c84562

File tree

8 files changed

+146
-135
lines changed

8 files changed

+146
-135
lines changed

comfy_api_nodes/util/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
validate_string,
4848
validate_video_dimensions,
4949
validate_video_duration,
50+
validate_video_frame_count,
5051
)
5152

5253
__all__ = [
@@ -94,6 +95,7 @@
9495
"validate_string",
9596
"validate_video_dimensions",
9697
"validate_video_duration",
98+
"validate_video_frame_count",
9799
# Misc functions
98100
"get_fs_object_size",
99101
]

comfy_api_nodes/util/_helpers.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import contextlib
33
import os
44
import time
5+
from collections.abc import Callable
56
from io import BytesIO
6-
from typing import Callable, Optional, Union
77

88
from comfy.cli_args import args
99
from comfy.model_management import processing_interrupted
@@ -35,12 +35,12 @@ def default_base_url() -> str:
3535

3636
async def sleep_with_interrupt(
3737
seconds: float,
38-
node_cls: Optional[type[IO.ComfyNode]],
39-
label: Optional[str] = None,
40-
start_ts: Optional[float] = None,
41-
estimated_total: Optional[int] = None,
38+
node_cls: type[IO.ComfyNode] | None,
39+
label: str | None = None,
40+
start_ts: float | None = None,
41+
estimated_total: int | None = None,
4242
*,
43-
display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None,
43+
display_callback: Callable[[type[IO.ComfyNode], str, int, int | None], None] | None = None,
4444
):
4545
"""
4646
Sleep in 1s slices while:
@@ -65,7 +65,7 @@ def mimetype_to_extension(mime_type: str) -> str:
6565
return mime_type.split("/")[-1].lower()
6666

6767

68-
def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int:
68+
def get_fs_object_size(path_or_object: str | BytesIO) -> int:
6969
if isinstance(path_or_object, str):
7070
return os.path.getsize(path_or_object)
7171
return len(path_or_object.getvalue())

comfy_api_nodes/util/client.py

Lines changed: 73 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import logging
55
import time
66
import uuid
7+
from collections.abc import Callable, Iterable
78
from dataclasses import dataclass
89
from enum import Enum
910
from io import BytesIO
10-
from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
11+
from typing import Any, Literal, TypeVar
1112
from urllib.parse import urljoin, urlparse
1213

1314
import aiohttp
@@ -37,8 +38,8 @@ def __init__(
3738
path: str,
3839
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
3940
*,
40-
query_params: Optional[dict[str, Any]] = None,
41-
headers: Optional[dict[str, str]] = None,
41+
query_params: dict[str, Any] | None = None,
42+
headers: dict[str, str] | None = None,
4243
):
4344
self.path = path
4445
self.method = method
@@ -52,29 +53,29 @@ class _RequestConfig:
5253
endpoint: ApiEndpoint
5354
timeout: float
5455
content_type: str
55-
data: Optional[dict[str, Any]]
56-
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
57-
multipart_parser: Optional[Callable]
56+
data: dict[str, Any] | None
57+
files: dict[str, Any] | list[tuple[str, Any]] | None
58+
multipart_parser: Callable | None
5859
max_retries: int
5960
retry_delay: float
6061
retry_backoff: float
6162
wait_label: str = "Waiting"
6263
monitor_progress: bool = True
63-
estimated_total: Optional[int] = None
64-
final_label_on_success: Optional[str] = "Completed"
65-
progress_origin_ts: Optional[float] = None
66-
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None
64+
estimated_total: int | None = None
65+
final_label_on_success: str | None = "Completed"
66+
progress_origin_ts: float | None = None
67+
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
6768

6869

6970
@dataclass
7071
class _PollUIState:
7172
started: float
7273
status_label: str = "Queued"
7374
is_queued: bool = True
74-
price: Optional[float] = None
75-
estimated_duration: Optional[int] = None
75+
price: float | None = None
76+
estimated_duration: int | None = None
7677
base_processing_elapsed: float = 0.0 # sum of completed active intervals
77-
active_since: Optional[float] = None # start time of current active interval (None if queued)
78+
active_since: float | None = None # start time of current active interval (None if queued)
7879

7980

8081
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
@@ -87,20 +88,20 @@ async def sync_op(
8788
cls: type[IO.ComfyNode],
8889
endpoint: ApiEndpoint,
8990
*,
90-
response_model: Type[M],
91-
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
92-
data: Optional[BaseModel] = None,
93-
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
91+
response_model: type[M],
92+
price_extractor: Callable[[M | Any], float | None] | None = None,
93+
data: BaseModel | None = None,
94+
files: dict[str, Any] | list[tuple[str, Any]] | None = None,
9495
content_type: str = "application/json",
9596
timeout: float = 3600.0,
96-
multipart_parser: Optional[Callable] = None,
97+
multipart_parser: Callable | None = None,
9798
max_retries: int = 3,
9899
retry_delay: float = 1.0,
99100
retry_backoff: float = 2.0,
100101
wait_label: str = "Waiting for server",
101-
estimated_duration: Optional[int] = None,
102-
final_label_on_success: Optional[str] = "Completed",
103-
progress_origin_ts: Optional[float] = None,
102+
estimated_duration: int | None = None,
103+
final_label_on_success: str | None = "Completed",
104+
progress_origin_ts: float | None = None,
104105
monitor_progress: bool = True,
105106
) -> M:
106107
raw = await sync_op_raw(
@@ -131,22 +132,22 @@ async def poll_op(
131132
cls: type[IO.ComfyNode],
132133
poll_endpoint: ApiEndpoint,
133134
*,
134-
response_model: Type[M],
135-
status_extractor: Callable[[M], Optional[Union[str, int]]],
136-
progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
137-
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
138-
completed_statuses: Optional[list[Union[str, int]]] = None,
139-
failed_statuses: Optional[list[Union[str, int]]] = None,
140-
queued_statuses: Optional[list[Union[str, int]]] = None,
141-
data: Optional[BaseModel] = None,
135+
response_model: type[M],
136+
status_extractor: Callable[[M | Any], str | int | None],
137+
progress_extractor: Callable[[M | Any], int | None] | None = None,
138+
price_extractor: Callable[[M | Any], float | None] | None = None,
139+
completed_statuses: list[str | int] | None = None,
140+
failed_statuses: list[str | int] | None = None,
141+
queued_statuses: list[str | int] | None = None,
142+
data: BaseModel | None = None,
142143
poll_interval: float = 5.0,
143144
max_poll_attempts: int = 120,
144145
timeout_per_poll: float = 120.0,
145146
max_retries_per_poll: int = 3,
146147
retry_delay_per_poll: float = 1.0,
147148
retry_backoff_per_poll: float = 2.0,
148-
estimated_duration: Optional[int] = None,
149-
cancel_endpoint: Optional[ApiEndpoint] = None,
149+
estimated_duration: int | None = None,
150+
cancel_endpoint: ApiEndpoint | None = None,
150151
cancel_timeout: float = 10.0,
151152
) -> M:
152153
raw = await poll_op_raw(
@@ -178,22 +179,22 @@ async def sync_op_raw(
178179
cls: type[IO.ComfyNode],
179180
endpoint: ApiEndpoint,
180181
*,
181-
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
182-
data: Optional[Union[dict[str, Any], BaseModel]] = None,
183-
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
182+
price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
183+
data: dict[str, Any] | BaseModel | None = None,
184+
files: dict[str, Any] | list[tuple[str, Any]] | None = None,
184185
content_type: str = "application/json",
185186
timeout: float = 3600.0,
186-
multipart_parser: Optional[Callable] = None,
187+
multipart_parser: Callable | None = None,
187188
max_retries: int = 3,
188189
retry_delay: float = 1.0,
189190
retry_backoff: float = 2.0,
190191
wait_label: str = "Waiting for server",
191-
estimated_duration: Optional[int] = None,
192+
estimated_duration: int | None = None,
192193
as_binary: bool = False,
193-
final_label_on_success: Optional[str] = "Completed",
194-
progress_origin_ts: Optional[float] = None,
194+
final_label_on_success: str | None = "Completed",
195+
progress_origin_ts: float | None = None,
195196
monitor_progress: bool = True,
196-
) -> Union[dict[str, Any], bytes]:
197+
) -> dict[str, Any] | bytes:
197198
"""
198199
Make a single network request.
199200
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
@@ -229,21 +230,21 @@ async def poll_op_raw(
229230
cls: type[IO.ComfyNode],
230231
poll_endpoint: ApiEndpoint,
231232
*,
232-
status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
233-
progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
234-
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
235-
completed_statuses: Optional[list[Union[str, int]]] = None,
236-
failed_statuses: Optional[list[Union[str, int]]] = None,
237-
queued_statuses: Optional[list[Union[str, int]]] = None,
238-
data: Optional[Union[dict[str, Any], BaseModel]] = None,
233+
status_extractor: Callable[[dict[str, Any]], str | int | None],
234+
progress_extractor: Callable[[dict[str, Any]], int | None] | None = None,
235+
price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
236+
completed_statuses: list[str | int] | None = None,
237+
failed_statuses: list[str | int] | None = None,
238+
queued_statuses: list[str | int] | None = None,
239+
data: dict[str, Any] | BaseModel | None = None,
239240
poll_interval: float = 5.0,
240241
max_poll_attempts: int = 120,
241242
timeout_per_poll: float = 120.0,
242243
max_retries_per_poll: int = 3,
243244
retry_delay_per_poll: float = 1.0,
244245
retry_backoff_per_poll: float = 2.0,
245-
estimated_duration: Optional[int] = None,
246-
cancel_endpoint: Optional[ApiEndpoint] = None,
246+
estimated_duration: int | None = None,
247+
cancel_endpoint: ApiEndpoint | None = None,
247248
cancel_timeout: float = 10.0,
248249
) -> dict[str, Any]:
249250
"""
@@ -261,7 +262,7 @@ async def poll_op_raw(
261262
consumed_attempts = 0 # counts only non-queued polls
262263

263264
progress_bar = utils.ProgressBar(100) if progress_extractor else None
264-
last_progress: Optional[int] = None
265+
last_progress: int | None = None
265266

266267
state = _PollUIState(started=started, estimated_duration=estimated_duration)
267268
stop_ticker = asyncio.Event()
@@ -420,10 +421,10 @@ async def _ticker():
420421

421422
def _display_text(
422423
node_cls: type[IO.ComfyNode],
423-
text: Optional[str],
424+
text: str | None,
424425
*,
425-
status: Optional[Union[str, int]] = None,
426-
price: Optional[float] = None,
426+
status: str | int | None = None,
427+
price: float | None = None,
427428
) -> None:
428429
display_lines: list[str] = []
429430
if status:
@@ -440,13 +441,13 @@ def _display_text(
440441

441442
def _display_time_progress(
442443
node_cls: type[IO.ComfyNode],
443-
status: Optional[Union[str, int]],
444+
status: str | int | None,
444445
elapsed_seconds: int,
445-
estimated_total: Optional[int] = None,
446+
estimated_total: int | None = None,
446447
*,
447-
price: Optional[float] = None,
448-
is_queued: Optional[bool] = None,
449-
processing_elapsed_seconds: Optional[int] = None,
448+
price: float | None = None,
449+
is_queued: bool | None = None,
450+
processing_elapsed_seconds: int | None = None,
450451
) -> None:
451452
if estimated_total is not None and estimated_total > 0 and is_queued is False:
452453
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
@@ -488,7 +489,7 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
488489
raise ValueError("files tuple must be (filename, file[, content_type])")
489490

490491

491-
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
492+
def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
492493
params = dict(endpoint_params or {})
493494
if method.upper() == "GET" and data:
494495
for k, v in data.items():
@@ -534,9 +535,9 @@ def _generate_operation_id(method: str, path: str, attempt: int) -> str:
534535
def _snapshot_request_body_for_logging(
535536
content_type: str,
536537
method: str,
537-
data: Optional[dict[str, Any]],
538-
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
539-
) -> Optional[Union[dict[str, Any], str]]:
538+
data: dict[str, Any] | None,
539+
files: dict[str, Any] | list[tuple[str, Any]] | None,
540+
) -> dict[str, Any] | str | None:
540541
if method.upper() == "GET":
541542
return None
542543
if content_type == "multipart/form-data":
@@ -586,13 +587,13 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float):
586587
attempt = 0
587588
delay = cfg.retry_delay
588589
operation_succeeded: bool = False
589-
final_elapsed_seconds: Optional[int] = None
590-
extracted_price: Optional[float] = None
590+
final_elapsed_seconds: int | None = None
591+
extracted_price: float | None = None
591592
while True:
592593
attempt += 1
593594
stop_event = asyncio.Event()
594-
monitor_task: Optional[asyncio.Task] = None
595-
sess: Optional[aiohttp.ClientSession] = None
595+
monitor_task: asyncio.Task | None = None
596+
sess: aiohttp.ClientSession | None = None
596597

597598
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
598599
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
@@ -887,7 +888,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float):
887888
)
888889

889890

890-
def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
891+
def _validate_or_raise(response_model: type[M], payload: Any) -> M:
891892
try:
892893
return response_model.model_validate(payload)
893894
except Exception as e:
@@ -902,9 +903,9 @@ def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
902903

903904

904905
def _wrap_model_extractor(
905-
response_model: Type[M],
906-
extractor: Optional[Callable[[M], Any]],
907-
) -> Optional[Callable[[dict[str, Any]], Any]]:
906+
response_model: type[M],
907+
extractor: Callable[[M], Any] | None,
908+
) -> Callable[[dict[str, Any]], Any] | None:
908909
"""Wrap a typed extractor so it can be used by the dict-based poller.
909910
Validates the dict into `response_model` before invoking `extractor`.
910911
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
@@ -929,18 +930,18 @@ def _wrapped(d: dict[str, Any]) -> Any:
929930
return _wrapped
930931

931932

932-
def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
933+
def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]:
933934
if not values:
934935
return set()
935-
out: set[Union[str, int]] = set()
936+
out: set[str | int] = set()
936937
for v in values:
937938
nv = _normalize_status_value(v)
938939
if nv is not None:
939940
out.add(nv)
940941
return out
941942

942943

943-
def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
944+
def _normalize_status_value(val: str | int | None) -> str | int | None:
944945
if isinstance(val, str):
945946
return val.strip().lower()
946947
return val

0 commit comments

Comments
 (0)