Skip to content

Commit fab58dd

Browse files
committed
fix(auth): do not leak authentification for the absolute urls
1 parent b34bc79 commit fab58dd

File tree

2 files changed

+15
-17
lines changed

2 files changed

+15
-17
lines changed

comfy_api_nodes/util/client.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -482,18 +482,6 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
482482
raise ValueError("files tuple must be (filename, file[, content_type])")
483483

484484

485-
def _join_url(base_url: str, path: str) -> str:
486-
return urljoin(base_url.rstrip("/") + "/", path.lstrip("/"))
487-
488-
489-
def _merge_headers(node_cls: type[IO.ComfyNode], endpoint_headers: dict[str, str]) -> dict[str, str]:
490-
headers = {"Accept": "*/*"}
491-
headers.update(get_auth_header(node_cls))
492-
if endpoint_headers:
493-
headers.update(endpoint_headers)
494-
return headers
495-
496-
497485
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
498486
params = dict(endpoint_params or {})
499487
if method.upper() == "GET" and data:
@@ -566,7 +554,11 @@ def _snapshot_request_body_for_logging(
566554

567555
async def _request_base(cfg: _RequestConfig, expect_binary: bool):
568556
"""Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors."""
569-
url = _join_url(default_base_url(), cfg.endpoint.path)
557+
url = cfg.endpoint.path
558+
parsed_url = urlparse(url)
559+
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
560+
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
561+
570562
method = cfg.endpoint.method
571563
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
572564

@@ -598,7 +590,12 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float):
598590
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
599591
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
600592

601-
payload_headers = _merge_headers(cfg.node_cls, cfg.endpoint.headers)
593+
payload_headers = {"Accept": "*/*"}
594+
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
595+
payload_headers.update(get_auth_header(cfg.node_cls))
596+
if cfg.endpoint.headers:
597+
payload_headers.update(cfg.endpoint.headers)
598+
602599
payload_kw: dict[str, Any] = {"headers": payload_headers}
603600
if method == "GET":
604601
payload_headers.pop("Content-Type", None)

comfy_api_nodes/util/download_helpers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from io import BytesIO
55
from pathlib import Path
66
from typing import IO, Optional, Union
7-
from urllib.parse import urlparse
7+
from urllib.parse import urljoin, urlparse
88

99
import aiohttp
1010
import torch
@@ -57,10 +57,11 @@ async def download_url_to_bytesio(
5757
delay = retry_delay
5858
headers: dict[str, str] = {}
5959

60-
if url.startswith("/proxy/"):
60+
parsed_url = urlparse(url)
61+
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
6162
if cls is None:
6263
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
63-
url = default_base_url().rstrip("/") + url
64+
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
6465
headers = get_auth_header(cls)
6566

6667
while True:

0 commit comments

Comments
 (0)