We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9588098 commit 8201938Copy full SHA for 8201938
src/huggingface_inference_toolkit/diffusers_utils.py
@@ -73,7 +73,13 @@ def __call__(
73
logger.warning("The `output_type` cannot be modified, and PIL will be used by default instead.")
74
75
# Call pipeline with parameters
76
- out = self.pipeline(prompt, num_images_per_prompt=1, **kwargs)
+ if self.pipeline.device.type == "cuda":
77
+ model_dtype = next(self.pipeline.parameters()).dtype
78
+ with torch.autocast("cuda", dtype=model_dtype):
79
+ out = self.pipeline(prompt, num_images_per_prompt=1, **kwargs)
80
+ else:
81
82
+
83
return out.images[0]
84
85
0 commit comments