Skip to content

Commit da5a587

Browse files
committed
maybe we hardcode bfloat16
1 parent 8201938 commit da5a587

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

src/huggingface_inference_toolkit/diffusers_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ def __call__(
7474

7575
# Call pipeline with parameters
7676
if self.pipeline.device.type == "cuda":
77-
model_dtype = next(self.pipeline.parameters()).dtype
78-
with torch.autocast("cuda", dtype=model_dtype):
77+
with torch.autocast("cuda", dtype=torch.bfloat16):
7978
out = self.pipeline(prompt, num_images_per_prompt=1, **kwargs)
8079
else:
8180
out = self.pipeline(prompt, num_images_per_prompt=1, **kwargs)

0 commit comments

Comments
 (0)