Skip to content

Commit 9d53cc2

Browse files
committed
fix unit tests
1 parent b49ddb7 commit 9d53cc2

File tree

3 files changed

+9
-10
lines changed

3 files changed

+9
-10
lines changed

src/huggingface_inference_toolkit/diffusers_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import importlib.util
22
import logging
33

4+
logger = logging.getLogger(__name__)
5+
logging.basicConfig(format="%(asctime)s | %(levelname)s | %(message)s", level=logging.INFO)
6+
47
_diffusers = importlib.util.find_spec("diffusers") is not None
58

69

@@ -12,9 +15,6 @@ def is_diffusers_available():
1215
import torch
1316
from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler, StableDiffusionPipeline
1417

15-
logger = logging.getLogger(__name__)
16-
logging.basicConfig(format="%(asctime)s | %(levelname)s | %(message)s", level=logging.INFO)
17-
1818

1919
class IEAutoPipelineForText2Image:
2020
def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU

src/huggingface_inference_toolkit/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def create_artifact_filter(framework):
5959
"""
6060
Returns a list of regex pattern based on the DL Framework. which will be to used to ignore files when downloading
6161
"""
62-
ignore_regex_list = list(framework2weight.values())
62+
ignore_regex_list = list(set(framework2weight.values()))
6363

6464
pattern = framework2weight.get(framework, None)
6565
if pattern in ignore_regex_list:
@@ -153,7 +153,6 @@ def _load_repository_from_hf(
153153
if any(f.rfilename.endswith("safetensors") for f in files):
154154
framework = "safetensors"
155155

156-
157156
# create regex to only include the framework specific weights
158157
ignore_regex = create_artifact_filter(framework)
159158
logger.info(f"Ignore regex pattern for files, which are not downloaded: { ', '.join(ignore_regex) }")
@@ -284,8 +283,8 @@ def convert_params_to_int_or_bool(params):
284283
for k, v in params.items():
285284
if v.isnumeric():
286285
params[k] = int(v)
287-
if v == 'false':
286+
if v == "false":
288287
params[k] = False
289-
if v == 'true':
288+
if v == "true":
290289
params[k] = True
291290
return params

tests/unit/test_diffusers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from PIL import Image
44
from transformers.testing_utils import require_torch, slow
55

6-
from huggingface_inference_toolkit.handler import get_inference_handler_either_custom_or_default_handler
7-
from huggingface_inference_toolkit.diffusers_utils import get_diffusers_pipeline, DiffusersPipelineImageToText
6+
7+
from huggingface_inference_toolkit.diffusers_utils import get_diffusers_pipeline, IEAutoPipelineForText2Image
88
from huggingface_inference_toolkit.utils import _load_repository_from_hf, get_pipeline
99

1010

@@ -15,7 +15,7 @@ def test_get_diffusers_pipeline():
1515
"hf-internal-testing/tiny-stable-diffusion-torch", tmpdirname, framework="pytorch"
1616
)
1717
pipe = get_pipeline("text-to-image", storage_dir.as_posix())
18-
assert isinstance(pipe, DiffusersPipelineImageToText)
18+
assert isinstance(pipe, IEAutoPipelineForText2Image)
1919

2020

2121
@slow

0 commit comments

Comments
 (0)