1616import  os 
1717import  re 
1818import  requests 
19- from  tqdm .auto  import  tqdm 
20- from  typing  import  Union 
2119from  collections  import  OrderedDict 
2220from  dataclasses  import  (
2321    dataclass ,
2422    asdict 
2523)
24+ from  typing  import  Union 
2625
2726from  huggingface_hub .file_download  import  http_get 
2827from  huggingface_hub .utils  import  validate_hf_hub_args 
3130    hf_hub_download ,
3231)
3332
34- from  diffusers .utils  import  logging 
3533from  diffusers .loaders .single_file_utils  import  (
34+     _extract_repo_id_and_weights_name ,
3635    infer_diffusers_model_type ,
3736    load_single_file_checkpoint ,
38-     _extract_repo_id_and_weights_name ,
3937    VALID_URL_PREFIXES ,
4038)
4139from  diffusers .pipelines .auto_pipeline  import  (
4846    StableDiffusionControlNetInpaintPipeline ,
4947    StableDiffusionControlNetPipeline ,
5048)
49+ from  diffusers .pipelines .pipeline_utils  import  DiffusionPipeline 
5150from  diffusers .pipelines .stable_diffusion  import  (
5251    StableDiffusionImg2ImgPipeline ,
5352    StableDiffusionInpaintPipeline ,
5857    StableDiffusionXLInpaintPipeline ,
5958    StableDiffusionXLPipeline ,
6059)
61- from  diffusers .pipelines . pipeline_utils  import  DiffusionPipeline 
60+ from  diffusers .utils  import  logging 
6261
6362logger  =  logging .get_logger (__name__ )
6463
@@ -189,7 +188,7 @@ class SearchResult:
189188            The status of the model. 
190189    """ 
191190    model_path : str  =  "" 
192-     loading_method : Union [str , None ] =  None    
191+     loading_method : Union [str , None ] =  None 
193192    checkpoint_format : Union [str , None ] =  None 
194193    repo_status : RepoStatus  =  RepoStatus ()
195194    model_status : ModelStatus  =  ModelStatus ()
@@ -282,7 +281,7 @@ def get_keyword_types(keyword):
282281        `dict`: A dictionary containing the model format, loading method, 
283282                and various types and extra types flags. 
284283    """ 
285-      
284+ 
286285    # Initialize the status dictionary with default values 
287286    status  =  {
288287        "checkpoint_format" : None ,
@@ -299,30 +298,30 @@ def get_keyword_types(keyword):
299298            "missing_model_index" : None ,
300299        },
301300    }
302-      
301+ 
303302    # Check if the keyword is an HTTP or HTTPS URL 
304303    status ["extra_type" ]["url" ] =  bool (re .search (r"^(https?)://" , keyword ))
305-      
304+ 
306305    # Check if the keyword is a file 
307306    if  os .path .isfile (keyword ):
308307        status ["type" ]["local" ] =  True 
309308        status ["checkpoint_format" ] =  "single_file" 
310309        status ["loading_method" ] =  "from_single_file" 
311-      
310+ 
312311    # Check if the keyword is a directory 
313312    elif  os .path .isdir (keyword ):
314313        status ["type" ]["local" ] =  True 
315314        status ["checkpoint_format" ] =  "diffusers" 
316315        status ["loading_method" ] =  "from_pretrained" 
317316        if  not  os .path .exists (os .path .join (keyword , "model_index.json" )):
318317            status ["extra_type" ]["missing_model_index" ] =  True 
319-      
318+ 
320319    # Check if the keyword is a Civitai URL 
321320    elif  keyword .startswith ("https://civitai.com/" ):
322321        status ["type" ]["civitai_url" ] =  True 
323322        status ["checkpoint_format" ] =  "single_file" 
324323        status ["loading_method" ] =  None 
325-      
324+ 
326325    # Check if the keyword starts with any valid URL prefixes 
327326    elif  any (keyword .startswith (prefix ) for  prefix  in  VALID_URL_PREFIXES ):
328327        repo_id , weights_name  =  _extract_repo_id_and_weights_name (keyword )
@@ -334,19 +333,19 @@ def get_keyword_types(keyword):
334333            status ["type" ]["hf_repo" ] =  True 
335334            status ["checkpoint_format" ] =  "diffusers" 
336335            status ["loading_method" ] =  "from_pretrained" 
337-      
336+ 
338337    # Check if the keyword matches a Hugging Face repository format 
339338    elif  re .match (r"^[^/]+/[^/]+$" , keyword ):
340339        status ["type" ]["hf_repo" ] =  True 
341340        status ["checkpoint_format" ] =  "diffusers" 
342341        status ["loading_method" ] =  "from_pretrained" 
343-      
342+ 
344343    # If none of the above apply 
345344    else :
346345        status ["type" ]["other" ] =  True 
347346        status ["checkpoint_format" ] =  None 
348347        status ["loading_method" ] =  None 
349-      
348+ 
350349    return  status 
351350
352351
@@ -532,15 +531,15 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
532531                    ):
533532                        diffusers_model_exists  =  True 
534533                        break 
535-                      
534+ 
536535                    elif  (
537536                        any (file_path .endswith (ext ) for  ext  in  EXTENSION )
538537                        and  not  any (config  in  file_path  for  config  in  CONFIG_FILE_LIST )
539538                        and  not  any (exc  in  file_path  for  exc  in  exclusion )
540539                        and  os .path .basename (os .path .dirname (file_path )) not  in DIFFUSERS_CONFIG_DIR 
541540                    ):
542541                        file_list .append (file_path )
543-              
542+ 
544543            # Exit from the loop if a multi-folder diffusers model or valid file is found 
545544            if  diffusers_model_exists  or  file_list :
546545                break 
@@ -560,7 +559,7 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
560559                )
561560            else :
562561                model_path  =  repo_id 
563-                  
562+ 
564563        elif  file_list :
565564            # Sort and find the safest model 
566565            file_name  =  next (
@@ -571,7 +570,7 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
571570                ),
572571                file_list [0 ]
573572            )
574-              
573+ 
575574
576575            if  download :
577576                model_path  =  hf_hub_download (
@@ -581,12 +580,12 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
581580                    token = token ,
582581                    force_download = force_download ,
583582                )
584-              
583+ 
585584    if  file_name :
586585        download_url  =  f"https://huggingface.co/{ repo_id } { file_name }  
587586    else :
588587        download_url  =  f"https://huggingface.co/{ repo_id }  
589-      
588+ 
590589    output_info  =  get_keyword_types (model_path )
591590
592591    if  include_params :
@@ -606,10 +605,10 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
606605                local = download ,
607606            )
608607        )
609-      
608+ 
610609    else :
611610        return  model_path 
612-      
611+ 
613612
614613def  search_civitai (search_word : str , ** kwargs ) ->  Union [str , SearchResult , None ]:
615614    r""" 
@@ -693,7 +692,7 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
693692                return  None 
694693            else :
695694                raise  ValueError ("Invalid JSON response" )
696-      
695+ 
697696    # Sort repositories by download count in descending order 
698697    sorted_repos  =  sorted (data ["items" ], key = lambda  x : x ["stats" ]["downloadCount" ], reverse = True )
699698
@@ -737,14 +736,14 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
737736        else :
738737            continue 
739738        break 
740-      
739+ 
741740    # Exception handling when search candidates are not found 
742741    if  not  selected_model :
743742        if  skip_error :
744743            return  None 
745744        else :
746745            raise  ValueError ("No model found. Please try changing the word you are searching for." )
747-      
746+ 
748747    # Define model file status 
749748    file_name  =  selected_model ["filename" ]
750749    download_url  =  selected_model ["download_url" ]
@@ -814,7 +813,7 @@ def __init__(self, *args, **kwargs):
814813        # EnvironmentError is returned 
815814        super ().__init__ ()
816815
817-      
816+ 
818817    @classmethod  
819818    @validate_hf_hub_args  
820819    def  from_huggingface (cls , pretrained_model_link_or_path , ** kwargs ):
@@ -929,10 +928,10 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
929928        kwargs .update (_status )
930929
931930        # Search for the model on Hugging Face and get the model status 
932-         hf_model_status  =  search_huggingface (pretrained_model_link_or_path , ** kwargs )    
931+         hf_model_status  =  search_huggingface (pretrained_model_link_or_path , ** kwargs )
933932        logger .warning (f"checkpoint_path: { hf_model_status .model_status .download_url }  )
934933        checkpoint_path  =  hf_model_status .model_path 
935-          
934+ 
936935        # Check the format of the model checkpoint 
937936        if  hf_model_status .checkpoint_format  ==  "single_file" :
938937            # Load the pipeline from a single file checkpoint 
@@ -943,7 +942,7 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
943942            )
944943        else :
945944            return  cls .from_pretrained (checkpoint_path , ** kwargs )
946-      
945+ 
947946    @classmethod  
948947    def  from_civitai (cls , pretrained_model_link_or_path , ** kwargs ):
949948        r""" 
@@ -1047,7 +1046,7 @@ def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
10471046            pipeline_mapping = SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING ,
10481047            ** kwargs 
10491048        )
1050-      
1049+ 
10511050
10521051
10531052class  EasyPipelineForImage2Image (AutoPipelineForImage2Image ):
@@ -1071,7 +1070,7 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
10711070    def  __init__ (self , * args , ** kwargs ):
10721071        # EnvironmentError is returned 
10731072        super ().__init__ ()
1074-      
1073+ 
10751074    @classmethod  
10761075    @validate_hf_hub_args  
10771076    def  from_huggingface (cls , pretrained_model_link_or_path , ** kwargs ):
@@ -1186,10 +1185,10 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
11861185        kwargs .update (_parmas )
11871186
11881187        # Search for the model on Hugging Face and get the model status 
1189-         model_status  =  search_huggingface (pretrained_model_link_or_path , ** kwargs )    
1188+         model_status  =  search_huggingface (pretrained_model_link_or_path , ** kwargs )
11901189        logger .warning (f"checkpoint_path: { model_status .model_status .download_url }  )
11911190        checkpoint_path  =  model_status .model_path 
1192-          
1191+ 
11931192        # Check the format of the model checkpoint 
11941193        if  model_status .checkpoint_format  ==  "single_file" :
11951194            # Load the pipeline from a single file checkpoint 
@@ -1200,7 +1199,7 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
12001199            )
12011200        else :
12021201            return  cls .from_pretrained (checkpoint_path , ** kwargs )
1203-      
1202+ 
12041203    @classmethod  
12051204    def  from_civitai (cls , pretrained_model_link_or_path , ** kwargs ):
12061205        r""" 
@@ -1305,7 +1304,7 @@ def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
13051304            ** kwargs 
13061305        )
13071306
1308-      
1307+ 
13091308
13101309class  EasyPipelineForInpainting (AutoPipelineForInpainting ):
13111310    r""" 
@@ -1328,7 +1327,7 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
13281327    def  __init__ (self , * args , ** kwargs ):
13291328        # EnvironmentError is returned 
13301329        super ().__init__ ()
1331-      
1330+ 
13321331    @classmethod  
13331332    @validate_hf_hub_args  
13341333    def  from_huggingface (cls , pretrained_model_link_or_path , ** kwargs ):
@@ -1443,10 +1442,10 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
14431442        kwargs .update (_status )
14441443
14451444        # Search for the model on Hugging Face and get the model status 
1446-         model_status  =  search_huggingface (pretrained_model_link_or_path , ** kwargs )    
1445+         model_status  =  search_huggingface (pretrained_model_link_or_path , ** kwargs )
14471446        logger .warning (f"checkpoint_path: { model_status .model_status .download_url }  )
14481447        checkpoint_path  =  model_status .model_path 
1449-          
1448+ 
14501449        # Check the format of the model checkpoint 
14511450        if  model_status .checkpoint_format  ==  "single_file" :
14521451            # Load the pipeline from a single file checkpoint 
@@ -1457,7 +1456,7 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
14571456            )
14581457        else :
14591458            return  cls .from_pretrained (checkpoint_path , ** kwargs )
1460-      
1459+ 
14611460    @classmethod  
14621461    def  from_civitai (cls , pretrained_model_link_or_path , ** kwargs ):
14631462        r""" 
@@ -1560,4 +1559,4 @@ def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
15601559            pretrained_model_or_path = checkpoint_path ,
15611560            pipeline_mapping = SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING ,
15621561            ** kwargs 
1563-         )
1562+         )
0 commit comments