Skip to content

Commit 9ebd7d8

Browse files
authored
Feat: Add CivitAI API token support for download (#62)
1 parent f74cedc commit 9ebd7d8

File tree

3 files changed

+75
-11
lines changed

3 files changed

+75
-11
lines changed

comfy_cli/command/models/models.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77
from typing_extensions import Annotated
88

99
from comfy_cli import tracking, ui
10+
from comfy_cli import constants
11+
from comfy_cli.config_manager import ConfigManager
1012
from comfy_cli.constants import DEFAULT_COMFY_MODEL_PATH
1113
from comfy_cli.file_utils import download_file, DownloadException
1214
from comfy_cli.workspace_manager import WorkspaceManager
1315

1416
app = typer.Typer()
1517

1618
workspace_manager = WorkspaceManager()
19+
config_manager = ConfigManager()
1720

1821

1922
def get_workspace() -> pathlib.Path:
@@ -66,10 +69,12 @@ def check_civitai_url(url: str) -> Tuple[bool, bool, int, int]:
6669
return False, False, None, None
6770

6871

69-
def request_civitai_model_version_api(version_id: int):
72+
def request_civitai_model_version_api(version_id: int, headers: Optional[dict] = None):
7073
# Make a request to the Civitai API to get the model information
7174
response = requests.get(
72-
f"https://civitai.com/api/v1/model-versions/{version_id}", timeout=10
75+
f"https://civitai.com/api/v1/model-versions/{version_id}",
76+
headers=headers,
77+
timeout=10,
7378
)
7479
response.raise_for_status() # Raise an error for bad status codes
7580

@@ -81,9 +86,13 @@ def request_civitai_model_version_api(version_id: int):
8186
return model_name, download_url
8287

8388

84-
def request_civitai_model_api(model_id: int, version_id: int = None):
89+
def request_civitai_model_api(
90+
model_id: int, version_id: int = None, headers: Optional[dict] = None
91+
):
8592
# Make a request to the Civitai API to get the model information
86-
response = requests.get(f"https://civitai.com/api/v1/models/{model_id}", timeout=10)
93+
response = requests.get(
94+
f"https://civitai.com/api/v1/models/{model_id}", headers=headers, timeout=10
95+
)
8796
response.raise_for_status() # Raise an error for bad status codes
8897

8998
model_data = response.json()
@@ -123,18 +132,42 @@ def download(
123132
show_default=True,
124133
),
125134
] = DEFAULT_COMFY_MODEL_PATH,
135+
set_civitai_api_token: Annotated[
136+
Optional[str],
137+
typer.Option(
138+
"--set-civitai-api-token",
139+
help="Set the CivitAI API token to use for model listing.",
140+
show_default=False,
141+
),
142+
] = None,
126143
):
127144

128145
local_filename = None
146+
headers = None
147+
civitai_api_token = None
148+
149+
if set_civitai_api_token is not None:
150+
config_manager.set(constants.CIVITAI_API_TOKEN_KEY, set_civitai_api_token)
151+
civitai_api_token = set_civitai_api_token
152+
153+
else:
154+
civitai_api_token = config_manager.get(constants.CIVITAI_API_TOKEN_KEY)
155+
156+
if civitai_api_token is not None:
157+
headers = {
158+
"Content-Type": "application/json",
159+
"Authorization": f"Bearer {civitai_api_token}",
160+
}
129161

130162
is_civitai_model_url, is_civitai_api_url, model_id, version_id = check_civitai_url(
131163
url
132164
)
165+
133166
is_huggingface = False
134167
if is_civitai_model_url:
135-
local_filename, url = request_civitai_model_api(model_id, version_id)
168+
local_filename, url = request_civitai_model_api(model_id, version_id, headers)
136169
elif is_civitai_api_url:
137-
local_filename, url = request_civitai_model_version_api(version_id)
170+
local_filename, url = request_civitai_model_version_api(version_id, headers)
138171
elif check_huggingface_url(url):
139172
is_huggingface = True
140173
local_filename = potentially_strip_param_url(url.split("/")[-1])
@@ -157,7 +190,7 @@ def download(
157190

158191
# File does not exist, proceed with download
159192
print(f"Start downloading URL: {url} into {local_filepath}")
160-
download_file(url, local_filepath)
193+
download_file(url, local_filepath, headers)
161194

162195

163196
@app.command()
@@ -236,6 +269,7 @@ def list(
236269
show_default=True,
237270
),
238271
):
272+
239273
"""Display a list of all models currently downloaded in a table format."""
240274
model_dir = get_workspace() / relative_path
241275
models = list_models(model_dir)

comfy_cli/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class OS(Enum):
3838
CONFIG_KEY_INSTALL_EVENT_TRIGGERED = "install_event_triggered"
3939
CONFIG_KEY_BACKGROUND = "background"
4040

41+
CIVITAI_API_TOKEN_KEY = "civitai_api_token"
42+
4143
DEFAULT_TRACKING_VALUE = True
4244

4345
COMFY_LOCK_YAML_FILE = "comfy.lock.yaml"

comfy_cli/file_utils.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import pathlib
3+
from typing import Optional
34
import zipfile
45

56
import requests
@@ -12,8 +13,30 @@ class DownloadException(Exception):
1213
pass
1314

1415

15-
def guess_status_code_reason(status_code: int) -> str:
16+
def guess_status_code_reason(status_code: int, message: str) -> str:
1617
if status_code == 401:
18+
import json
19+
20+
def parse_json(input_data):
21+
try:
22+
# Check if the input is a byte string
23+
if isinstance(input_data, bytes):
24+
# Decode the byte string to a regular string
25+
input_data = input_data.decode("utf-8")
26+
27+
# Parse the string as JSON
28+
json_object = json.loads(input_data)
29+
30+
return json_object
31+
32+
except json.JSONDecodeError as e:
33+
# Handle JSON decoding error
34+
print(f"JSON decoding error: {e}")
35+
36+
msg_json = parse_json(message)
37+
if msg_json is not None:
38+
if "message" in msg_json:
39+
return f"Unauthorized download ({status_code}).\n{msg_json['message']}\nor you can set civitai api token using `comfy model download --set-civitai-api-token <token>`"
1740
return f"Unauthorized download ({status_code}), you might need to manually log into browser to download one"
1841
elif status_code == 403:
1942
return f"Forbidden url ({status_code}), you might need to manually log into browser to download one"
@@ -22,7 +45,10 @@ def guess_status_code_reason(status_code: int) -> str:
2245
return f"Unknown error occurred (status code: {status_code})"
2346

2447

25-
def download_file(url: str, local_filepath: pathlib.Path):
48+
def download_file(
49+
url: str, local_filepath: pathlib.Path, headers: Optional[dict] = None
50+
):
51+
2652
"""Helper function to download a file."""
2753

2854
import httpx
@@ -31,7 +57,7 @@ def download_file(url: str, local_filepath: pathlib.Path):
3157
parents=True, exist_ok=True
3258
) # Ensure the directory exists
3359

34-
with httpx.stream("GET", url, follow_redirects=True) as response:
60+
with httpx.stream("GET", url, follow_redirects=True, headers=headers) as response:
3561
if response.status_code == 200:
3662
total = int(response.headers["Content-Length"])
3763
try:
@@ -49,7 +75,9 @@ def download_file(url: str, local_filepath: pathlib.Path):
4975
if delete_eh:
5076
local_filepath.unlink()
5177
else:
52-
status_reason = guess_status_code_reason(response.status_code)
78+
status_reason = guess_status_code_reason(
79+
response.status_code, response.read()
80+
)
5381
raise DownloadException(f"Failed to download file.\n{status_reason}")
5482

5583

0 commit comments

Comments
 (0)