Skip to content

Commit 285b435

Browse files
authored
Merge pull request #45 from huggingface/sd-xl-fix
Fix Stable diffusion pipeline
2 parents 8b8d3aa + 1597c1f commit 285b435

File tree

4 files changed

+6
-9
lines changed

4 files changed

+6
-9
lines changed

dockerfiles/pytorch/cpu/environment.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ dependencies:
88
- transformers[sklearn,sentencepiece,audio,vision]==4.31.0
99
- sentence_transformers==2.2.2
1010
- torchvision==0.14.1
11-
- diffusers==0.19.3
11+
- diffusers==0.20.0
1212
- accelerate==0.21.0
1313
- safetensors

dockerfiles/pytorch/gpu/environment.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ dependencies:
99
- transformers[sklearn,sentencepiece,audio,vision]==4.31.0
1010
- sentence_transformers==2.2.2
1111
- torchvision==0.14.1
12-
- diffusers==0.19.3
12+
- diffusers==0.20.0
1313
- accelerate==0.21.0
1414
- safetensors

src/huggingface_inference_toolkit/diffusers_utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import importlib.util
22
import logging
33

4+
from transformers.utils.import_utils import is_torch_bf16_gpu_available
5+
46
logger = logging.getLogger(__name__)
57
logging.basicConfig(format="%(asctime)s | %(levelname)s | %(message)s", level=logging.INFO)
68

@@ -20,7 +22,7 @@ class IEAutoPipelineForText2Image:
2022
def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU
2123
dtype = torch.float32
2224
if device == "cuda":
23-
dtype = torch.float16
25+
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16
2426
device_map = "auto" if device == "cuda" else None
2527

2628
self.pipeline = AutoPipelineForText2Image.from_pretrained(model_dir, torch_dtype=dtype, device_map=device_map)
@@ -43,11 +45,7 @@ def __call__(
4345
logger.warning("Sending num_images_per_prompt > 1 to pipeline is not supported. Using default value 1.")
4446

4547
# Call pipeline with parameters
46-
if self.pipeline.device.type == "cuda":
47-
with torch.autocast("cuda"):
48-
out = self.pipeline(prompt, num_images_per_prompt=1)
49-
else:
50-
out = self.pipeline(prompt, num_images_per_prompt=1)
48+
out = self.pipeline(prompt, num_images_per_prompt=1, **kwargs)
5149
return out.images[0]
5250

5351

src/huggingface_inference_toolkit/webservice_starlette.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ async def predict(request):
7575
# check for query parameter and add them to the body
7676
if request.query_params and "parameters" not in deserialized_body:
7777
deserialized_body["parameters"] = convert_params_to_int_or_bool(dict(request.query_params))
78-
print(deserialized_body)
7978

8079
# tracks request time
8180
start_time = perf_counter()

0 commit comments

Comments
 (0)