File tree Expand file tree Collapse file tree 3 files changed +9
-10
lines changed
src/huggingface_inference_toolkit Expand file tree Collapse file tree 3 files changed +9
-10
lines changed Original file line number Diff line number Diff line change 11import importlib .util
22import logging
33
4+ logger = logging .getLogger (__name__ )
5+ logging .basicConfig (format = "%(asctime)s | %(levelname)s | %(message)s" , level = logging .INFO )
6+
47_diffusers = importlib .util .find_spec ("diffusers" ) is not None
58
69
@@ -12,9 +15,6 @@ def is_diffusers_available():
1215 import torch
1316 from diffusers import AutoPipelineForText2Image , DPMSolverMultistepScheduler , StableDiffusionPipeline
1417
15- logger = logging .getLogger (__name__ )
16- logging .basicConfig (format = "%(asctime)s | %(levelname)s | %(message)s" , level = logging .INFO )
17-
1818
1919class IEAutoPipelineForText2Image :
2020 def __init__ (self , model_dir : str , device : str = None ): # needs "cuda" for GPU
Original file line number Diff line number Diff line change @@ -59,7 +59,7 @@ def create_artifact_filter(framework):
5959 """
6060 Returns a list of regex pattern based on the DL Framework. which will be to used to ignore files when downloading
6161 """
62- ignore_regex_list = list (framework2weight .values ())
62+ ignore_regex_list = list (set ( framework2weight .values () ))
6363
6464 pattern = framework2weight .get (framework , None )
6565 if pattern in ignore_regex_list :
@@ -153,7 +153,6 @@ def _load_repository_from_hf(
153153 if any (f .rfilename .endswith ("safetensors" ) for f in files ):
154154 framework = "safetensors"
155155
156-
157156 # create regex to only include the framework specific weights
158157 ignore_regex = create_artifact_filter (framework )
159158 logger .info (f"Ignore regex pattern for files, which are not downloaded: { ', ' .join (ignore_regex ) } " )
@@ -284,8 +283,8 @@ def convert_params_to_int_or_bool(params):
284283 for k , v in params .items ():
285284 if v .isnumeric ():
286285 params [k ] = int (v )
287- if v == ' false' :
286+ if v == " false" :
288287 params [k ] = False
289- if v == ' true' :
288+ if v == " true" :
290289 params [k ] = True
291290 return params
Original file line number Diff line number Diff line change 33from PIL import Image
44from transformers .testing_utils import require_torch , slow
55
6- from huggingface_inference_toolkit . handler import get_inference_handler_either_custom_or_default_handler
7- from huggingface_inference_toolkit .diffusers_utils import get_diffusers_pipeline , DiffusersPipelineImageToText
6+
7+ from huggingface_inference_toolkit .diffusers_utils import get_diffusers_pipeline , IEAutoPipelineForText2Image
88from huggingface_inference_toolkit .utils import _load_repository_from_hf , get_pipeline
99
1010
@@ -15,7 +15,7 @@ def test_get_diffusers_pipeline():
1515 "hf-internal-testing/tiny-stable-diffusion-torch" , tmpdirname , framework = "pytorch"
1616 )
1717 pipe = get_pipeline ("text-to-image" , storage_dir .as_posix ())
18- assert isinstance (pipe , DiffusersPipelineImageToText )
18+ assert isinstance (pipe , IEAutoPipelineForText2Image )
1919
2020
2121@slow
You can’t perform that action at this time.
0 commit comments