Skip to content

Commit 08c8968

Browse files
authored
Internal download API: Add proper validated directory input (Comfy-Org#4981)
* add internal /folder_paths route returns a json maps of folder paths * (minor) format download_models.py * initial folder path input on download api * actually, require folder_path and clean up some code * partial tests update * fix & logging * also download to a tmp file not the live file to avoid compounding errors from network failure * update tests again * test tweaks * workaround the first tests blocker * fix file handling in tests * rewrite test for create_model_path * minor doc fix * avoid 'mock_directory' use temp dir to avoid accidental fs pollution from tests
1 parent 479a427 commit 08c8968

File tree

4 files changed

+184
-174
lines changed

4 files changed

+184
-174
lines changed

model_filemanager/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# model_manager/__init__.py
2-
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename
2+
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename

model_filemanager/download_models.py

Lines changed: 69 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import traceback
55
import logging
6-
from folder_paths import models_dir
6+
from folder_paths import folder_names_and_paths, get_folder_paths
77
import re
88
from typing import Callable, Any, Optional, Awaitable, Dict
99
from enum import Enum
@@ -17,6 +17,7 @@ class DownloadStatusType(Enum):
1717
COMPLETED = "completed"
1818
ERROR = "error"
1919

20+
2021
@dataclass
2122
class DownloadModelStatus():
2223
status: str
@@ -29,7 +30,7 @@ def __init__(self, status: DownloadStatusType, progress_percentage: float, messa
2930
self.progress_percentage = progress_percentage
3031
self.message = message
3132
self.already_existed = already_existed
32-
33+
3334
def to_dict(self) -> Dict[str, Any]:
3435
return {
3536
"status": self.status,
@@ -38,102 +39,112 @@ def to_dict(self) -> Dict[str, Any]:
3839
"already_existed": self.already_existed
3940
}
4041

42+
4143
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
42-
model_name: str,
43-
model_url: str,
44-
model_sub_directory: str,
44+
model_name: str,
45+
model_url: str,
46+
model_directory: str,
47+
folder_path: str,
4548
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
4649
progress_interval: float = 1.0) -> DownloadModelStatus:
4750
"""
4851
Download a model file from a given URL into the models directory.
4952
5053
Args:
51-
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
54+
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
5255
A function that makes an HTTP request. This makes it easier to mock in unit tests.
53-
model_name (str):
56+
model_name (str):
5457
The name of the model file to be downloaded. This will be the filename on disk.
55-
model_url (str):
58+
model_url (str):
5659
The URL from which to download the model.
57-
model_sub_directory (str):
58-
The subdirectory within the main models directory where the model
60+
model_directory (str):
61+
The subdirectory within the main models directory where the model
5962
should be saved (e.g., 'checkpoints', 'loras', etc.).
60-
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
63+
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
6164
An asynchronous function to call with progress updates.
65+
folder_path (str);
66+
Path to which model folder should be used as the root.
6267
6368
Returns:
6469
DownloadModelStatus: The result of the download operation.
6570
"""
66-
if not validate_model_subdirectory(model_sub_directory):
71+
if not validate_filename(model_name):
6772
return DownloadModelStatus(
68-
DownloadStatusType.ERROR,
73+
DownloadStatusType.ERROR,
6974
0,
70-
"Invalid model subdirectory",
75+
"Invalid model name",
7176
False
7277
)
7378

74-
if not validate_filename(model_name):
79+
if not model_directory in folder_names_and_paths:
7580
return DownloadModelStatus(
76-
DownloadStatusType.ERROR,
81+
DownloadStatusType.ERROR,
7782
0,
78-
"Invalid model name",
83+
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
7984
False
8085
)
8186

82-
file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
83-
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
87+
if not folder_path in get_folder_paths(model_directory):
88+
return DownloadModelStatus(
89+
DownloadStatusType.ERROR,
90+
0,
91+
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
92+
False
93+
)
94+
95+
file_path = create_model_path(model_name, folder_path)
96+
existing_file = await check_file_exists(file_path, model_name, progress_callback)
8497
if existing_file:
8598
return existing_file
8699

87100
try:
101+
logging.info(f"Downloading {model_name} from {model_url}")
88102
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
89-
await progress_callback(relative_path, status)
103+
await progress_callback(model_name, status)
90104

91105
response = await model_download_request(model_url)
92106
if response.status != 200:
93107
error_message = f"Failed to download {model_name}. Status code: {response.status}"
94108
logging.error(error_message)
95109
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
96-
await progress_callback(relative_path, status)
110+
await progress_callback(model_name, status)
97111
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
98112

99-
return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
113+
return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
100114

101115
except Exception as e:
102116
logging.error(f"Error in downloading model: {e}")
103-
return await handle_download_error(e, model_name, progress_callback, relative_path)
104-
117+
return await handle_download_error(e, model_name, progress_callback)
105118

106-
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
107-
full_model_dir = os.path.join(models_base_dir, model_directory)
108-
os.makedirs(full_model_dir, exist_ok=True)
109-
file_path = os.path.join(full_model_dir, model_name)
119+
120+
def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
121+
os.makedirs(folder_path, exist_ok=True)
122+
file_path = os.path.join(folder_path, model_name)
110123

111124
# Ensure the resulting path is still within the base directory
112125
abs_file_path = os.path.abspath(file_path)
113-
abs_base_dir = os.path.abspath(str(models_base_dir))
126+
abs_base_dir = os.path.abspath(folder_path)
114127
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
115-
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
128+
raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
116129

130+
return file_path
117131

118-
relative_path = '/'.join([model_directory, model_name])
119-
return file_path, relative_path
120132

121-
async def check_file_exists(file_path: str,
122-
model_name: str,
123-
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
124-
relative_path: str) -> Optional[DownloadModelStatus]:
133+
async def check_file_exists(file_path: str,
134+
model_name: str,
135+
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
136+
) -> Optional[DownloadModelStatus]:
125137
if os.path.exists(file_path):
126138
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
127-
await progress_callback(relative_path, status)
139+
await progress_callback(model_name, status)
128140
return status
129141
return None
130142

131143

132-
async def track_download_progress(response: aiohttp.ClientResponse,
133-
file_path: str,
134-
model_name: str,
135-
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
136-
relative_path: str,
144+
async def track_download_progress(response: aiohttp.ClientResponse,
145+
file_path: str,
146+
model_name: str,
147+
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
137148
interval: float = 1.0) -> DownloadModelStatus:
138149
try:
139150
total_size = int(response.headers.get('Content-Length', 0))
@@ -144,10 +155,11 @@ async def update_progress():
144155
nonlocal last_update_time
145156
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
146157
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
147-
await progress_callback(relative_path, status)
158+
await progress_callback(model_name, status)
148159
last_update_time = time.time()
149160

150-
with open(file_path, 'wb') as f:
161+
temp_file_path = file_path + '.tmp'
162+
with open(temp_file_path, 'wb') as f:
151163
chunk_iterator = response.content.iter_chunked(8192)
152164
while True:
153165
try:
@@ -156,58 +168,39 @@ async def update_progress():
156168
break
157169
f.write(chunk)
158170
downloaded += len(chunk)
159-
171+
160172
if time.time() - last_update_time >= interval:
161173
await update_progress()
162174

175+
os.rename(temp_file_path, file_path)
176+
163177
await update_progress()
164-
178+
165179
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
166180
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
167-
await progress_callback(relative_path, status)
181+
await progress_callback(model_name, status)
168182

169183
return status
170184
except Exception as e:
171185
logging.error(f"Error in track_download_progress: {e}")
172186
logging.error(traceback.format_exc())
173-
return await handle_download_error(e, model_name, progress_callback, relative_path)
187+
return await handle_download_error(e, model_name, progress_callback)
188+
174189

175-
async def handle_download_error(e: Exception,
176-
model_name: str,
177-
progress_callback: Callable[[str, DownloadModelStatus], Any],
178-
relative_path: str) -> DownloadModelStatus:
190+
async def handle_download_error(e: Exception,
191+
model_name: str,
192+
progress_callback: Callable[[str, DownloadModelStatus], Any]
193+
) -> DownloadModelStatus:
179194
error_message = f"Error downloading {model_name}: {str(e)}"
180195
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
181-
await progress_callback(relative_path, status)
196+
await progress_callback(model_name, status)
182197
return status
183198

184-
def validate_model_subdirectory(model_subdirectory: str) -> bool:
185-
"""
186-
Validate that the model subdirectory is safe to install into.
187-
Must not contain relative paths, nested paths or special characters
188-
other than underscores and hyphens.
189-
190-
Args:
191-
model_subdirectory (str): The subdirectory for the specific model type.
192-
193-
Returns:
194-
bool: True if the subdirectory is safe, False otherwise.
195-
"""
196-
if len(model_subdirectory) > 50:
197-
return False
198-
199-
if '..' in model_subdirectory or '/' in model_subdirectory:
200-
return False
201-
202-
if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
203-
return False
204-
205-
return True
206199

207200
def validate_filename(filename: str)-> bool:
208201
"""
209202
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
210-
203+
211204
Args:
212205
filename (str): The filename to validate
213206

server.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -689,18 +689,19 @@ async def report_progress(filename: str, status: DownloadModelStatus):
689689
data = await request.json()
690690
url = data.get('url')
691691
model_directory = data.get('model_directory')
692+
folder_path = data.get('folder_path')
692693
model_filename = data.get('model_filename')
693694
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
694695

695-
if not url or not model_directory or not model_filename:
696+
if not url or not model_directory or not model_filename or not folder_path:
696697
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
697698

698699
session = self.client_session
699700
if session is None:
700701
logging.error("Client session is not initialized")
701702
return web.Response(status=500)
702703

703-
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
704+
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
704705
await task
705706

706707
return web.json_response(task.result().to_dict())

0 commit comments

Comments
 (0)