Skip to content

Commit 8201938

Browse files
committed
maybe fix torch dtype issue
1 parent 9588098 commit 8201938

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

src/huggingface_inference_toolkit/diffusers_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,13 @@ def __call__(
7373
logger.warning("The `output_type` cannot be modified, and PIL will be used by default instead.")
7474

7575
# Call pipeline with parameters
76-
out = self.pipeline(prompt, num_images_per_prompt=1, **kwargs)
76+
if self.pipeline.device.type == "cuda":
77+
model_dtype = next(self.pipeline.parameters()).dtype
78+
with torch.autocast("cuda", dtype=model_dtype):
79+
out = self.pipeline(prompt, num_images_per_prompt=1, **kwargs)
80+
else:
81+
out = self.pipeline(prompt, num_images_per_prompt=1, **kwargs)
82+
7783
return out.images[0]
7884

7985

0 commit comments

Comments
 (0)