Skip to content

Commit 2d69000

Browse files
committed
fix(CivitAI): correct parse URL with query params
1 parent c674bc8 commit 2d69000

File tree

2 files changed

+126
-29
lines changed

2 files changed

+126
-29
lines changed

comfy_cli/command/models/models.py

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import contextlib
12
import os
23
import pathlib
34
import sys
45
from typing import Annotated, Optional
5-
from urllib.parse import unquote, urlparse
6+
from urllib.parse import parse_qs, unquote, urlparse
67

78
import requests
89
import typer
@@ -74,39 +75,65 @@ def check_huggingface_url(url: str) -> tuple[bool, Optional[str], Optional[str],
7475
return True, repo_id, filename, folder_name, branch_name
7576

7677

77-
def check_civitai_url(url: str) -> tuple[bool, bool, int, int]:
78+
def check_civitai_url(url: str) -> tuple[bool, bool, Optional[int], Optional[int]]:
7879
"""
7980
Returns:
80-
is_civitai_model_url: True if the url is a civitai model url
81-
is_civitai_api_url: True if the url is a civitai api url
82-
model_id: The model id or None if it's api url
83-
version_id: The version id or None if it doesn't have version id info
81+
is_civitai_model_url: True if the url is a civitai *web* model url (e.g. /models/12345)
82+
is_civitai_api_url: True if the url is a civitai *api* url useful for resolving downloads
83+
model_id: The model id (for /models/*), else None
84+
version_id: The version id (for /api/download/models/* or ?modelVersionId=), else None
8485
"""
85-
prefix = "civitai.com"
8686
try:
87-
if prefix in url:
88-
# URL is civitai api download url: https://civitai.com/api/download/models/12345
89-
if "civitai.com/api/download" in url:
90-
# This is a direct download link
91-
version_id = url.strip("/").split("/")[-1]
92-
return False, True, None, int(version_id)
93-
94-
# URL is civitai web url (e.g.
95-
# - https://civitai.com/models/43331
96-
# - https://civitai.com/models/43331/majicmix-realistic
97-
subpath = url[url.find(prefix) + len(prefix) :].strip("/")
98-
url_parts = subpath.split("?")
99-
if len(url_parts) > 1:
100-
model_id = url_parts[0].split("/")[1]
101-
version_id = url_parts[1].split("=")[1]
102-
return True, False, int(model_id), int(version_id)
103-
else:
104-
model_id = subpath.split("/")[1]
105-
return True, False, int(model_id), None
106-
except (ValueError, IndexError):
87+
parsed = urlparse(url)
88+
host = (parsed.hostname or "").lower()
89+
if host != "civitai.com" and not host.endswith(".civitai.com"):
90+
return False, False, None, None
91+
p_parts = [p for p in parsed.path.split("/") if p]
92+
query = parse_qs(parsed.query)
93+
94+
if len(p_parts) >= 4 and p_parts[0] == "api":
95+
# Case 1: /api/download/models/<version_id>
96+
# e.g. https://civitai.com/api/download/models/1617665?type=Model&format=SafeTensor
97+
if p_parts[1] == "download" and p_parts[2] == "models":
98+
try:
99+
version_id = int(p_parts[3])
100+
return False, True, None, version_id
101+
except ValueError:
102+
return False, True, None, None
103+
104+
# Case 2: /api/v1/model-versions/<version_id>
105+
if p_parts[1] == "v1" and p_parts[2] in ("model-versions", "modelVersions"):
106+
try:
107+
version_id = int(p_parts[3])
108+
return False, True, None, version_id
109+
except ValueError:
110+
return False, True, None, None
111+
112+
# Case 3: /models/<model_id>[/*] with optional ?modelVersionId=<id>
113+
# e.g. https://civitai.com/models/43331
114+
# https://civitai.com/models/43331/majicmix-realistic?modelVersionId=485088
115+
if len(p_parts) >= 2 and p_parts[0] == "models":
116+
try:
117+
model_id = int(p_parts[1])
118+
except ValueError:
119+
return False, False, None, None
120+
version_id = None
121+
mv = query.get("modelVersionId")
122+
if mv and len(mv) > 0:
123+
with contextlib.suppress(ValueError):
124+
version_id = int(mv[0])
125+
if version_id is None:
126+
mv = query.get("version")
127+
if mv and len(mv) > 0:
128+
with contextlib.suppress(ValueError):
129+
version_id = int(mv[0])
130+
return True, False, model_id, version_id
131+
132+
return False, False, None, None
133+
134+
except Exception:
107135
print("Error parsing CivitAI model URL")
108-
109-
return False, False, None, None
136+
return False, False, None, None
110137

111138

112139
def request_civitai_model_version_api(version_id: int, headers: Optional[dict] = None):

tests/comfy_cli/command/models/test_models.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ def test_valid_model_url_with_version():
1111
assert check_civitai_url(url) == (True, False, 43331, None)
1212

1313

14+
def test_valid_model_url_with_version_and_additional_segments():
15+
url = "https://civitai.com/models/43331/majicmix-realistic/extra"
16+
assert check_civitai_url(url) == (True, False, 43331, None)
17+
18+
1419
def test_valid_model_url_with_query():
1520
url = "https://civitai.com/models/43331?version=12345"
1621
assert check_civitai_url(url) == (True, False, 43331, 12345)
@@ -31,8 +36,73 @@ def test_malformed_url():
3136
assert check_civitai_url(url) == (False, False, None, None)
3237

3338

39+
def test_invalid_model_id_url():
40+
url = "https://civitai.com/models/invalid_id"
41+
assert check_civitai_url(url) == (False, False, None, None)
42+
43+
3444
def test_malformed_query_url():
3545
url = "https://civitai.com/models/43331?version="
46+
assert check_civitai_url(url) == (True, False, 43331, None)
47+
48+
49+
def test_model_url_with_model_version_id_query():
50+
url = "https://civitai.com/models/43331?modelVersionId=485088"
51+
assert check_civitai_url(url) == (True, False, 43331, 485088)
52+
53+
54+
def test_model_url_with_model_version_id_invalid():
55+
url = "https://civitai.com/models/43331?modelVersionId=abc"
56+
assert check_civitai_url(url) == (True, False, 43331, None)
57+
58+
59+
def test_valid_api_v1_model_versions_url():
60+
url = "https://civitai.com/api/v1/model-versions/1617665"
61+
assert check_civitai_url(url) == (False, True, None, 1617665)
62+
63+
64+
def test_valid_api_v1_model_versions_camelcase_segment():
65+
url = "https://civitai.com/api/v1/modelVersions/1617665"
66+
assert check_civitai_url(url) == (False, True, None, 1617665)
67+
68+
69+
def test_valid_api_download_with_query_params():
70+
url = "https://civitai.com/api/download/models/1617665?type=Model&format=SafeTensor"
71+
assert check_civitai_url(url) == (False, True, None, 1617665)
72+
73+
74+
def test_api_download_trailing_slash_is_ok():
75+
url = "https://civitai.com/api/download/models/1617665/"
76+
assert check_civitai_url(url) == (False, True, None, 1617665)
77+
78+
79+
def test_api_download_non_numeric_id_models_version():
80+
url = "https://civitai.com/api/v1/modelVersions/notanumber"
81+
assert check_civitai_url(url) == (False, True, None, None)
82+
83+
84+
def test_api_download_non_numeric_id():
85+
url = "https://civitai.com/api/download/models/notanumber"
86+
assert check_civitai_url(url) == (False, True, None, None)
87+
88+
89+
def test_model_url_with_slug_and_query():
90+
url = "https://civitai.com/models/43331/majicmix-realistic?modelVersionId=485088"
91+
assert check_civitai_url(url) == (True, False, 43331, 485088)
92+
93+
94+
def test_www_subdomain_is_accepted():
95+
url = "https://www.civitai.com/models/43331?version=12345"
96+
assert check_civitai_url(url) == (True, False, 43331, 12345)
97+
98+
99+
def test_completly_mailformed_civitai_url():
100+
url = "https://civitai.com/"
101+
assert check_civitai_url(url) == (False, False, None, None)
102+
103+
104+
def test_non_evil_civitai_url():
105+
url = "https://evilcivitai.com/models/43331?version=12345"
36106
assert check_civitai_url(url) == (False, False, None, None)
37107

38108

0 commit comments

Comments
 (0)