1+ from __future__ import annotations
2+
13import os
24import time
35import logging
4- from typing import Set , List , Dict , Tuple
6+ from collections . abc import Collection
57
6- supported_pt_extensions : Set [str ] = set ([ '.ckpt' , '.pt' , '.bin' , '.pth' , '.safetensors' , '.pkl' , '.sft' ])
8+ supported_pt_extensions : set [str ] = { '.ckpt' , '.pt' , '.bin' , '.pth' , '.safetensors' , '.pkl' , '.sft' }
79
8- SupportedFileExtensionsType = Set [str ]
9- ScanPathType = List [str ]
10- folder_names_and_paths : Dict [str , Tuple [ScanPathType , SupportedFileExtensionsType ]] = {}
10+ folder_names_and_paths : dict [str , tuple [list [str ], set [str ]]] = {}
1111
1212base_path = os .path .dirname (os .path .realpath (__file__ ))
1313models_dir = os .path .join (base_path , "models" )
4242input_directory = os .path .join (os .path .dirname (os .path .realpath (__file__ )), "input" )
4343user_directory = os .path .join (os .path .dirname (os .path .realpath (__file__ )), "user" )
4444
45- filename_list_cache = {}
45+ filename_list_cache : dict [ str , tuple [ list [ str ], dict [ str , float ], float ]] = {}
4646
4747if not os .path .exists (input_directory ):
4848 try :
4949 os .makedirs (input_directory )
5050 except :
5151 logging .error ("Failed to create input directory" )
5252
53- def set_output_directory (output_dir ) :
53+ def set_output_directory (output_dir : str ) -> None :
5454 global output_directory
5555 output_directory = output_dir
5656
57- def set_temp_directory (temp_dir ) :
57+ def set_temp_directory (temp_dir : str ) -> None :
5858 global temp_directory
5959 temp_directory = temp_dir
6060
61- def set_input_directory (input_dir ) :
61+ def set_input_directory (input_dir : str ) -> None :
6262 global input_directory
6363 input_directory = input_dir
6464
65- def get_output_directory ():
65+ def get_output_directory () -> str :
6666 global output_directory
6767 return output_directory
6868
69- def get_temp_directory ():
69+ def get_temp_directory () -> str :
7070 global temp_directory
7171 return temp_directory
7272
73- def get_input_directory ():
73+ def get_input_directory () -> str :
7474 global input_directory
7575 return input_directory
7676
7777
7878#NOTE: used in http server so don't put folders that should not be accessed remotely
79- def get_directory_by_type (type_name ) :
79+ def get_directory_by_type (type_name : str ) -> str | None :
8080 if type_name == "output" :
8181 return get_output_directory ()
8282 if type_name == "temp" :
@@ -88,7 +88,7 @@ def get_directory_by_type(type_name):
8888
8989# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
9090# otherwise use default_path as base_dir
91- def annotated_filepath (name ) :
91+ def annotated_filepath (name : str ) -> tuple [ str , str | None ] :
9292 if name .endswith ("[output]" ):
9393 base_dir = get_output_directory ()
9494 name = name [:- 9 ]
@@ -104,7 +104,7 @@ def annotated_filepath(name):
104104 return name , base_dir
105105
106106
107- def get_annotated_filepath (name , default_dir = None ):
107+ def get_annotated_filepath (name : str , default_dir : str | None = None ) -> str :
108108 name , base_dir = annotated_filepath (name )
109109
110110 if base_dir is None :
@@ -116,7 +116,7 @@ def get_annotated_filepath(name, default_dir=None):
116116 return os .path .join (base_dir , name )
117117
118118
119- def exists_annotated_filepath (name ):
119+ def exists_annotated_filepath (name ) -> bool :
120120 name , base_dir = annotated_filepath (name )
121121
122122 if base_dir is None :
@@ -126,17 +126,17 @@ def exists_annotated_filepath(name):
126126 return os .path .exists (filepath )
127127
128128
129- def add_model_folder_path (folder_name , full_folder_path ) :
129+ def add_model_folder_path (folder_name : str , full_folder_path : str ) -> None :
130130 global folder_names_and_paths
131131 if folder_name in folder_names_and_paths :
132132 folder_names_and_paths [folder_name ][0 ].append (full_folder_path )
133133 else :
134134 folder_names_and_paths [folder_name ] = ([full_folder_path ], set ())
135135
136- def get_folder_paths (folder_name ) :
136+ def get_folder_paths (folder_name : str ) -> list [ str ] :
137137 return folder_names_and_paths [folder_name ][0 ][:]
138138
139- def recursive_search (directory , excluded_dir_names = None ):
139+ def recursive_search (directory : str , excluded_dir_names : list [ str ] | None = None ) -> tuple [ list [ str ], dict [ str , float ]] :
140140 if not os .path .isdir (directory ):
141141 return [], {}
142142
@@ -153,14 +153,18 @@ def recursive_search(directory, excluded_dir_names=None):
153153 logging .warning (f"Warning: Unable to access { directory } . Skipping this path." )
154154
155155 logging .debug ("recursive file list on directory {}" .format (directory ))
156+ dirpath : str
157+ subdirs : list [str ]
158+ filenames : list [str ]
159+
156160 for dirpath , subdirs , filenames in os .walk (directory , followlinks = True , topdown = True ):
157161 subdirs [:] = [d for d in subdirs if d not in excluded_dir_names ]
158162 for file_name in filenames :
159163 relative_path = os .path .relpath (os .path .join (dirpath , file_name ), directory )
160164 result .append (relative_path )
161165
162166 for d in subdirs :
163- path = os .path .join (dirpath , d )
167+ path : str = os .path .join (dirpath , d )
164168 try :
165169 dirs [path ] = os .path .getmtime (path )
166170 except FileNotFoundError :
@@ -169,12 +173,12 @@ def recursive_search(directory, excluded_dir_names=None):
169173 logging .debug ("found {} files" .format (len (result )))
170174 return result , dirs
171175
172- def filter_files_extensions (files , extensions ) :
176+ def filter_files_extensions (files : Collection [ str ] , extensions : Collection [ str ]) -> list [ str ] :
173177 return sorted (list (filter (lambda a : os .path .splitext (a )[- 1 ].lower () in extensions or len (extensions ) == 0 , files )))
174178
175179
176180
177- def get_full_path (folder_name , filename ) :
181+ def get_full_path (folder_name : str , filename : str ) -> str | None :
178182 global folder_names_and_paths
179183 if folder_name not in folder_names_and_paths :
180184 return None
@@ -189,7 +193,7 @@ def get_full_path(folder_name, filename):
189193
190194 return None
191195
192- def get_filename_list_ (folder_name ) :
196+ def get_filename_list_ (folder_name : str ) -> tuple [ list [ str ], dict [ str , float ], float ] :
193197 global folder_names_and_paths
194198 output_list = set ()
195199 folders = folder_names_and_paths [folder_name ]
@@ -199,9 +203,9 @@ def get_filename_list_(folder_name):
199203 output_list .update (filter_files_extensions (files , folders [1 ]))
200204 output_folders = {** output_folders , ** folders_all }
201205
202- return ( sorted (list (output_list )), output_folders , time .perf_counter () )
206+ return sorted (list (output_list )), output_folders , time .perf_counter ()
203207
204- def cached_filename_list_ (folder_name ) :
208+ def cached_filename_list_ (folder_name : str ) -> tuple [ list [ str ], dict [ str , float ], float ] | None :
205209 global filename_list_cache
206210 global folder_names_and_paths
207211 if folder_name not in filename_list_cache :
@@ -222,25 +226,25 @@ def cached_filename_list_(folder_name):
222226
223227 return out
224228
225- def get_filename_list (folder_name ) :
229+ def get_filename_list (folder_name : str ) -> list [ str ] :
226230 out = cached_filename_list_ (folder_name )
227231 if out is None :
228232 out = get_filename_list_ (folder_name )
229233 global filename_list_cache
230234 filename_list_cache [folder_name ] = out
231235 return list (out [0 ])
232236
233- def get_save_image_path (filename_prefix , output_dir , image_width = 0 , image_height = 0 ):
234- def map_filename (filename ) :
237+ def get_save_image_path (filename_prefix : str , output_dir : str , image_width = 0 , image_height = 0 ) -> tuple [ str , str , int , str , str ] :
238+ def map_filename (filename : str ) -> tuple [ int , str ] :
235239 prefix_len = len (os .path .basename (filename_prefix ))
236240 prefix = filename [:prefix_len + 1 ]
237241 try :
238242 digits = int (filename [prefix_len + 1 :].split ('_' )[0 ])
239243 except :
240244 digits = 0
241- return ( digits , prefix )
245+ return digits , prefix
242246
243- def compute_vars (input , image_width , image_height ) :
247+ def compute_vars (input : str , image_width : int , image_height : int ) -> str :
244248 input = input .replace ("%width%" , str (image_width ))
245249 input = input .replace ("%height%" , str (image_height ))
246250 return input
0 commit comments