Skip to content

Commit 5df8f99

Browse files
committed
Fixing the style of pipeline
1 parent efe5fba commit 5df8f99

File tree

1 file changed

+40
-41
lines changed

1 file changed

+40
-41
lines changed

examples/model_search/pipeline_easy.py

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,12 @@
1616
import os
1717
import re
1818
import requests
19-
from tqdm.auto import tqdm
20-
from typing import Union
2119
from collections import OrderedDict
2220
from dataclasses import (
2321
dataclass,
2422
asdict
2523
)
24+
from typing import Union
2625

2726
from huggingface_hub.file_download import http_get
2827
from huggingface_hub.utils import validate_hf_hub_args
@@ -31,11 +30,10 @@
3130
hf_hub_download,
3231
)
3332

34-
from diffusers.utils import logging
3533
from diffusers.loaders.single_file_utils import (
34+
_extract_repo_id_and_weights_name,
3635
infer_diffusers_model_type,
3736
load_single_file_checkpoint,
38-
_extract_repo_id_and_weights_name,
3937
VALID_URL_PREFIXES,
4038
)
4139
from diffusers.pipelines.auto_pipeline import (
@@ -48,6 +46,7 @@
4846
StableDiffusionControlNetInpaintPipeline,
4947
StableDiffusionControlNetPipeline,
5048
)
49+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
5150
from diffusers.pipelines.stable_diffusion import (
5251
StableDiffusionImg2ImgPipeline,
5352
StableDiffusionInpaintPipeline,
@@ -58,7 +57,7 @@
5857
StableDiffusionXLInpaintPipeline,
5958
StableDiffusionXLPipeline,
6059
)
61-
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
60+
from diffusers.utils import logging
6261

6362
logger = logging.get_logger(__name__)
6463

@@ -189,7 +188,7 @@ class SearchResult:
189188
The status of the model.
190189
"""
191190
model_path: str = ""
192-
loading_method: Union[str, None] = None
191+
loading_method: Union[str, None] = None
193192
checkpoint_format: Union[str, None] = None
194193
repo_status: RepoStatus = RepoStatus()
195194
model_status: ModelStatus = ModelStatus()
@@ -282,7 +281,7 @@ def get_keyword_types(keyword):
282281
`dict`: A dictionary containing the model format, loading method,
283282
and various types and extra types flags.
284283
"""
285-
284+
286285
# Initialize the status dictionary with default values
287286
status = {
288287
"checkpoint_format": None,
@@ -299,30 +298,30 @@ def get_keyword_types(keyword):
299298
"missing_model_index": None,
300299
},
301300
}
302-
301+
303302
# Check if the keyword is an HTTP or HTTPS URL
304303
status["extra_type"]["url"] = bool(re.search(r"^(https?)://", keyword))
305-
304+
306305
# Check if the keyword is a file
307306
if os.path.isfile(keyword):
308307
status["type"]["local"] = True
309308
status["checkpoint_format"] = "single_file"
310309
status["loading_method"] = "from_single_file"
311-
310+
312311
# Check if the keyword is a directory
313312
elif os.path.isdir(keyword):
314313
status["type"]["local"] = True
315314
status["checkpoint_format"] = "diffusers"
316315
status["loading_method"] = "from_pretrained"
317316
if not os.path.exists(os.path.join(keyword, "model_index.json")):
318317
status["extra_type"]["missing_model_index"] = True
319-
318+
320319
# Check if the keyword is a Civitai URL
321320
elif keyword.startswith("https://civitai.com/"):
322321
status["type"]["civitai_url"] = True
323322
status["checkpoint_format"] = "single_file"
324323
status["loading_method"] = None
325-
324+
326325
# Check if the keyword starts with any valid URL prefixes
327326
elif any(keyword.startswith(prefix) for prefix in VALID_URL_PREFIXES):
328327
repo_id, weights_name = _extract_repo_id_and_weights_name(keyword)
@@ -334,19 +333,19 @@ def get_keyword_types(keyword):
334333
status["type"]["hf_repo"] = True
335334
status["checkpoint_format"] = "diffusers"
336335
status["loading_method"] = "from_pretrained"
337-
336+
338337
# Check if the keyword matches a Hugging Face repository format
339338
elif re.match(r"^[^/]+/[^/]+$", keyword):
340339
status["type"]["hf_repo"] = True
341340
status["checkpoint_format"] = "diffusers"
342341
status["loading_method"] = "from_pretrained"
343-
342+
344343
# If none of the above apply
345344
else:
346345
status["type"]["other"] = True
347346
status["checkpoint_format"] = None
348347
status["loading_method"] = None
349-
348+
350349
return status
351350

352351

@@ -532,15 +531,15 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
532531
):
533532
diffusers_model_exists = True
534533
break
535-
534+
536535
elif (
537536
any(file_path.endswith(ext) for ext in EXTENSION)
538537
and not any(config in file_path for config in CONFIG_FILE_LIST)
539538
and not any(exc in file_path for exc in exclusion)
540539
and os.path.basename(os.path.dirname(file_path)) not in DIFFUSERS_CONFIG_DIR
541540
):
542541
file_list.append(file_path)
543-
542+
544543
# Exit from the loop if a multi-folder diffusers model or valid file is found
545544
if diffusers_model_exists or file_list:
546545
break
@@ -560,7 +559,7 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
560559
)
561560
else:
562561
model_path = repo_id
563-
562+
564563
elif file_list:
565564
# Sort and find the safest model
566565
file_name = next(
@@ -571,7 +570,7 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
571570
),
572571
file_list[0]
573572
)
574-
573+
575574

576575
if download:
577576
model_path = hf_hub_download(
@@ -581,12 +580,12 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
581580
token=token,
582581
force_download=force_download,
583582
)
584-
583+
585584
if file_name:
586585
download_url = f"https://huggingface.co/{repo_id}/blob/main/{file_name}"
587586
else:
588587
download_url = f"https://huggingface.co/{repo_id}"
589-
588+
590589
output_info = get_keyword_types(model_path)
591590

592591
if include_params:
@@ -606,10 +605,10 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
606605
local=download,
607606
)
608607
)
609-
608+
610609
else:
611610
return model_path
612-
611+
613612

614613
def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]:
615614
r"""
@@ -693,7 +692,7 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
693692
return None
694693
else:
695694
raise ValueError("Invalid JSON response")
696-
695+
697696
# Sort repositories by download count in descending order
698697
sorted_repos = sorted(data["items"], key=lambda x: x["stats"]["downloadCount"], reverse=True)
699698

@@ -737,14 +736,14 @@ def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]
737736
else:
738737
continue
739738
break
740-
739+
741740
# Exception handling when search candidates are not found
742741
if not selected_model:
743742
if skip_error:
744743
return None
745744
else:
746745
raise ValueError("No model found. Please try changing the word you are searching for.")
747-
746+
748747
# Define model file status
749748
file_name = selected_model["filename"]
750749
download_url = selected_model["download_url"]
@@ -814,7 +813,7 @@ def __init__(self, *args, **kwargs):
814813
# EnvironmentError is returned
815814
super().__init__()
816815

817-
816+
818817
@classmethod
819818
@validate_hf_hub_args
820819
def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
@@ -929,10 +928,10 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
929928
kwargs.update(_status)
930929

931930
# Search for the model on Hugging Face and get the model status
932-
hf_model_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
931+
hf_model_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
933932
logger.warning(f"checkpoint_path: {hf_model_status.model_status.download_url}")
934933
checkpoint_path = hf_model_status.model_path
935-
934+
936935
# Check the format of the model checkpoint
937936
if hf_model_status.checkpoint_format == "single_file":
938937
# Load the pipeline from a single file checkpoint
@@ -943,7 +942,7 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
943942
)
944943
else:
945944
return cls.from_pretrained(checkpoint_path, **kwargs)
946-
945+
947946
@classmethod
948947
def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
949948
r"""
@@ -1047,7 +1046,7 @@ def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
10471046
pipeline_mapping=SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING,
10481047
**kwargs
10491048
)
1050-
1049+
10511050

10521051

10531052
class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
@@ -1071,7 +1070,7 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
10711070
def __init__(self, *args, **kwargs):
10721071
# EnvironmentError is returned
10731072
super().__init__()
1074-
1073+
10751074
@classmethod
10761075
@validate_hf_hub_args
10771076
def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
@@ -1186,10 +1185,10 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
11861185
kwargs.update(_parmas)
11871186

11881187
# Search for the model on Hugging Face and get the model status
1189-
model_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
1188+
model_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
11901189
logger.warning(f"checkpoint_path: {model_status.model_status.download_url}")
11911190
checkpoint_path = model_status.model_path
1192-
1191+
11931192
# Check the format of the model checkpoint
11941193
if model_status.checkpoint_format == "single_file":
11951194
# Load the pipeline from a single file checkpoint
@@ -1200,7 +1199,7 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
12001199
)
12011200
else:
12021201
return cls.from_pretrained(checkpoint_path, **kwargs)
1203-
1202+
12041203
@classmethod
12051204
def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
12061205
r"""
@@ -1305,7 +1304,7 @@ def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
13051304
**kwargs
13061305
)
13071306

1308-
1307+
13091308

13101309
class EasyPipelineForInpainting(AutoPipelineForInpainting):
13111310
r"""
@@ -1328,7 +1327,7 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
13281327
def __init__(self, *args, **kwargs):
13291328
# EnvironmentError is returned
13301329
super().__init__()
1331-
1330+
13321331
@classmethod
13331332
@validate_hf_hub_args
13341333
def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
@@ -1443,10 +1442,10 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
14431442
kwargs.update(_status)
14441443

14451444
# Search for the model on Hugging Face and get the model status
1446-
model_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
1445+
model_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
14471446
logger.warning(f"checkpoint_path: {model_status.model_status.download_url}")
14481447
checkpoint_path = model_status.model_path
1449-
1448+
14501449
# Check the format of the model checkpoint
14511450
if model_status.checkpoint_format == "single_file":
14521451
# Load the pipeline from a single file checkpoint
@@ -1457,7 +1456,7 @@ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
14571456
)
14581457
else:
14591458
return cls.from_pretrained(checkpoint_path, **kwargs)
1460-
1459+
14611460
@classmethod
14621461
def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
14631462
r"""
@@ -1560,4 +1559,4 @@ def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
15601559
pretrained_model_or_path=checkpoint_path,
15611560
pipeline_mapping=SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING,
15621561
**kwargs
1563-
)
1562+
)

0 commit comments

Comments
 (0)