Skip to content

Commit e524659

Browse files
committed
Fix and make style
1 parent 6545ead commit e524659

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

examples/model_search/pipeline_easy.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# coding=utf-8
2-
# Copyright 2024 suzukimain
2+
# Copyright 2025 suzukimain
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@
1717
import re
1818
import types
1919
from collections import OrderedDict
20-
from dataclasses import asdict, dataclass
20+
from dataclasses import asdict, dataclass, field
2121
from typing import Dict, List, Optional, Union
2222

2323
import requests
@@ -321,9 +321,9 @@ class SearchResult:
321321
model_path: str = ""
322322
loading_method: Union[str, None] = None
323323
checkpoint_format: Union[str, None] = None
324-
repo_status: RepoStatus = RepoStatus()
325-
model_status: ModelStatus = ModelStatus()
326-
extra_status: ExtraStatus = ExtraStatus()
324+
repo_status: RepoStatus = field(default_factory=RepoStatus)
325+
model_status: ModelStatus = field(default_factory=ModelStatus)
326+
extra_status: ExtraStatus = field(default_factory=ExtraStatus)
327327

328328

329329
@validate_hf_hub_args
@@ -589,10 +589,18 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
589589
gated = kwargs.pop("gated", False)
590590
skip_error = kwargs.pop("skip_error", False)
591591

592+
file_list = []
593+
hf_repo_info = {}
594+
hf_security_info = {}
595+
model_path = ""
596+
repo_id, file_name = "", ""
597+
diffusers_model_exists = False
598+
592599
# Get the type and loading method for the keyword
593600
search_word_status = get_keyword_types(search_word)
594601

595602
if search_word_status["type"]["hf_repo"]:
603+
hf_repo_info = hf_api.model_info(repo_id=search_word, securityStatus=True)
596604
if download:
597605
model_path = DiffusionPipeline.download(
598606
search_word,
@@ -635,13 +643,6 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
635643
)
636644
model_dicts = [asdict(value) for value in list(hf_models)]
637645

638-
file_list = []
639-
hf_repo_info = {}
640-
hf_security_info = {}
641-
model_path = ""
642-
repo_id, file_name = "", ""
643-
diffusers_model_exists = False
644-
645646
# Loop through models to find a suitable candidate
646647
for repo_info in model_dicts:
647648
repo_id = repo_info["id"]
@@ -706,6 +707,10 @@ def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, N
706707
force_download=force_download,
707708
)
708709

710+
# `pathlib.PosixPath` may be returned
711+
if model_path:
712+
model_path = str(model_path)
713+
709714
if file_name:
710715
download_url = f"https://huggingface.co/{repo_id}/blob/main/{file_name}"
711716
else:

0 commit comments

Comments
 (0)