33import os
44import traceback
55import logging
6- from folder_paths import models_dir
6+ from folder_paths import folder_names_and_paths , get_folder_paths
77import re
88from typing import Callable , Any , Optional , Awaitable , Dict
99from enum import Enum
@@ -17,6 +17,7 @@ class DownloadStatusType(Enum):
1717 COMPLETED = "completed"
1818 ERROR = "error"
1919
20+
2021@dataclass
2122class DownloadModelStatus ():
2223 status : str
@@ -29,7 +30,7 @@ def __init__(self, status: DownloadStatusType, progress_percentage: float, messa
2930 self .progress_percentage = progress_percentage
3031 self .message = message
3132 self .already_existed = already_existed
32-
33+
3334 def to_dict (self ) -> Dict [str , Any ]:
3435 return {
3536 "status" : self .status ,
@@ -38,102 +39,112 @@ def to_dict(self) -> Dict[str, Any]:
3839 "already_existed" : self .already_existed
3940 }
4041
42+
4143async def download_model (model_download_request : Callable [[str ], Awaitable [aiohttp .ClientResponse ]],
42- model_name : str ,
43- model_url : str ,
44- model_sub_directory : str ,
44+ model_name : str ,
45+ model_url : str ,
46+ model_directory : str ,
47+ folder_path : str ,
4548 progress_callback : Callable [[str , DownloadModelStatus ], Awaitable [Any ]],
4649 progress_interval : float = 1.0 ) -> DownloadModelStatus :
4750 """
4851 Download a model file from a given URL into the models directory.
4952
5053 Args:
51- model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
54+ model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
5255 A function that makes an HTTP request. This makes it easier to mock in unit tests.
53- model_name (str):
56+ model_name (str):
5457 The name of the model file to be downloaded. This will be the filename on disk.
55- model_url (str):
58+ model_url (str):
5659 The URL from which to download the model.
57- model_sub_directory (str):
58- The subdirectory within the main models directory where the model
60+ model_directory (str):
61+ The subdirectory within the main models directory where the model
5962 should be saved (e.g., 'checkpoints', 'loras', etc.).
60- progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
63+ progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
6164 An asynchronous function to call with progress updates.
65+ folder_path (str);
66+ Path to which model folder should be used as the root.
6267
6368 Returns:
6469 DownloadModelStatus: The result of the download operation.
6570 """
66- if not validate_model_subdirectory ( model_sub_directory ):
71+ if not validate_filename ( model_name ):
6772 return DownloadModelStatus (
68- DownloadStatusType .ERROR ,
73+ DownloadStatusType .ERROR ,
6974 0 ,
70- "Invalid model subdirectory" ,
75+ "Invalid model name" ,
7176 False
7277 )
7378
74- if not validate_filename ( model_name ) :
79+ if not model_directory in folder_names_and_paths :
7580 return DownloadModelStatus (
76- DownloadStatusType .ERROR ,
81+ DownloadStatusType .ERROR ,
7782 0 ,
78- "Invalid model name" ,
83+ "Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working." ,
7984 False
8085 )
8186
82- file_path , relative_path = create_model_path (model_name , model_sub_directory , models_dir )
83- existing_file = await check_file_exists (file_path , model_name , progress_callback , relative_path )
87+ if not folder_path in get_folder_paths (model_directory ):
88+ return DownloadModelStatus (
89+ DownloadStatusType .ERROR ,
90+ 0 ,
91+ f"Invalid folder path '{ folder_path } ', does not match the list of known directories ({ get_folder_paths (model_directory )} ). If you're seeing this in the downloader UI, you may need to refresh the page." ,
92+ False
93+ )
94+
95+ file_path = create_model_path (model_name , folder_path )
96+ existing_file = await check_file_exists (file_path , model_name , progress_callback )
8497 if existing_file :
8598 return existing_file
8699
87100 try :
101+ logging .info (f"Downloading { model_name } from { model_url } " )
88102 status = DownloadModelStatus (DownloadStatusType .PENDING , 0 , f"Starting download of { model_name } " , False )
89- await progress_callback (relative_path , status )
103+ await progress_callback (model_name , status )
90104
91105 response = await model_download_request (model_url )
92106 if response .status != 200 :
93107 error_message = f"Failed to download { model_name } . Status code: { response .status } "
94108 logging .error (error_message )
95109 status = DownloadModelStatus (DownloadStatusType .ERROR , 0 , error_message , False )
96- await progress_callback (relative_path , status )
110+ await progress_callback (model_name , status )
97111 return DownloadModelStatus (DownloadStatusType .ERROR , 0 , error_message , False )
98112
99- return await track_download_progress (response , file_path , model_name , progress_callback , relative_path , progress_interval )
113+ return await track_download_progress (response , file_path , model_name , progress_callback , progress_interval )
100114
101115 except Exception as e :
102116 logging .error (f"Error in downloading model: { e } " )
103- return await handle_download_error (e , model_name , progress_callback , relative_path )
104-
117+ return await handle_download_error (e , model_name , progress_callback )
105118
106- def create_model_path ( model_name : str , model_directory : str , models_base_dir : str ) -> tuple [ str , str ]:
107- full_model_dir = os . path . join ( models_base_dir , model_directory )
108- os .makedirs (full_model_dir , exist_ok = True )
109- file_path = os .path .join (full_model_dir , model_name )
119+
120+ def create_model_path ( model_name : str , folder_path : str ) -> tuple [ str , str ]:
121+ os .makedirs (folder_path , exist_ok = True )
122+ file_path = os .path .join (folder_path , model_name )
110123
111124 # Ensure the resulting path is still within the base directory
112125 abs_file_path = os .path .abspath (file_path )
113- abs_base_dir = os .path .abspath (str ( models_base_dir ) )
126+ abs_base_dir = os .path .abspath (folder_path )
114127 if os .path .commonprefix ([abs_file_path , abs_base_dir ]) != abs_base_dir :
115- raise Exception (f"Invalid model directory: { model_directory } /{ model_name } " )
128+ raise Exception (f"Invalid model directory: { folder_path } /{ model_name } " )
116129
130+ return file_path
117131
118- relative_path = '/' .join ([model_directory , model_name ])
119- return file_path , relative_path
120132
121- async def check_file_exists (file_path : str ,
122- model_name : str ,
123- progress_callback : Callable [[str , DownloadModelStatus ], Awaitable [Any ]],
124- relative_path : str ) -> Optional [DownloadModelStatus ]:
133+ async def check_file_exists (file_path : str ,
134+ model_name : str ,
135+ progress_callback : Callable [[str , DownloadModelStatus ], Awaitable [Any ]]
136+ ) -> Optional [DownloadModelStatus ]:
125137 if os .path .exists (file_path ):
126138 status = DownloadModelStatus (DownloadStatusType .COMPLETED , 100 , f"{ model_name } already exists" , True )
127- await progress_callback (relative_path , status )
139+ await progress_callback (model_name , status )
128140 return status
129141 return None
130142
131143
132- async def track_download_progress (response : aiohttp .ClientResponse ,
133- file_path : str ,
134- model_name : str ,
135- progress_callback : Callable [[str , DownloadModelStatus ], Awaitable [Any ]],
136- relative_path : str ,
144+ async def track_download_progress (response : aiohttp .ClientResponse ,
145+ file_path : str ,
146+ model_name : str ,
147+ progress_callback : Callable [[str , DownloadModelStatus ], Awaitable [Any ]],
137148 interval : float = 1.0 ) -> DownloadModelStatus :
138149 try :
139150 total_size = int (response .headers .get ('Content-Length' , 0 ))
@@ -144,10 +155,11 @@ async def update_progress():
144155 nonlocal last_update_time
145156 progress = (downloaded / total_size ) * 100 if total_size > 0 else 0
146157 status = DownloadModelStatus (DownloadStatusType .IN_PROGRESS , progress , f"Downloading { model_name } " , False )
147- await progress_callback (relative_path , status )
158+ await progress_callback (model_name , status )
148159 last_update_time = time .time ()
149160
150- with open (file_path , 'wb' ) as f :
161+ temp_file_path = file_path + '.tmp'
162+ with open (temp_file_path , 'wb' ) as f :
151163 chunk_iterator = response .content .iter_chunked (8192 )
152164 while True :
153165 try :
@@ -156,58 +168,39 @@ async def update_progress():
156168 break
157169 f .write (chunk )
158170 downloaded += len (chunk )
159-
171+
160172 if time .time () - last_update_time >= interval :
161173 await update_progress ()
162174
175+ os .rename (temp_file_path , file_path )
176+
163177 await update_progress ()
164-
178+
165179 logging .info (f"Successfully downloaded { model_name } . Total downloaded: { downloaded } " )
166180 status = DownloadModelStatus (DownloadStatusType .COMPLETED , 100 , f"Successfully downloaded { model_name } " , False )
167- await progress_callback (relative_path , status )
181+ await progress_callback (model_name , status )
168182
169183 return status
170184 except Exception as e :
171185 logging .error (f"Error in track_download_progress: { e } " )
172186 logging .error (traceback .format_exc ())
173- return await handle_download_error (e , model_name , progress_callback , relative_path )
187+ return await handle_download_error (e , model_name , progress_callback )
188+
174189
175- async def handle_download_error (e : Exception ,
176- model_name : str ,
177- progress_callback : Callable [[str , DownloadModelStatus ], Any ],
178- relative_path : str ) -> DownloadModelStatus :
190+ async def handle_download_error (e : Exception ,
191+ model_name : str ,
192+ progress_callback : Callable [[str , DownloadModelStatus ], Any ]
193+ ) -> DownloadModelStatus :
179194 error_message = f"Error downloading { model_name } : { str (e )} "
180195 status = DownloadModelStatus (DownloadStatusType .ERROR , 0 , error_message , False )
181- await progress_callback (relative_path , status )
196+ await progress_callback (model_name , status )
182197 return status
183198
184- def validate_model_subdirectory (model_subdirectory : str ) -> bool :
185- """
186- Validate that the model subdirectory is safe to install into.
187- Must not contain relative paths, nested paths or special characters
188- other than underscores and hyphens.
189-
190- Args:
191- model_subdirectory (str): The subdirectory for the specific model type.
192-
193- Returns:
194- bool: True if the subdirectory is safe, False otherwise.
195- """
196- if len (model_subdirectory ) > 50 :
197- return False
198-
199- if '..' in model_subdirectory or '/' in model_subdirectory :
200- return False
201-
202- if not re .match (r'^[a-zA-Z0-9_-]+$' , model_subdirectory ):
203- return False
204-
205- return True
206199
207200def validate_filename (filename : str )-> bool :
208201 """
209202 Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
210-
203+
211204 Args:
212205 filename (str): The filename to validate
213206
0 commit comments