Skip to content

Commit 4430fe9

Browse files
committed
Update BriaPipeline example to use bfloat16 for precision sensitivity for better result
1 parent 4331603 commit 4430fe9

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/diffusers/pipelines/bria/pipeline_bria.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@
4949
>>> import torch
5050
>>> from diffusers import BriaPipeline
5151
52-
>>> pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2", torch_dtype=torch.float16)
52+
>>> pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2", torch_dtype=torch.bfloat16)
5353
>>> pipe.to("cuda")
54-
# BRIA's T5 text encoder is sensitive to precision. We need to cast it to float16 and keep the final layer in float32.
54+
# BRIA's T5 text encoder is sensitive to precision. We need to cast it to bfloat16 and keep the final layer in float32.
5555
56-
>>> pipe.text_encoder = pipe.text_encoder.to(dtype=torch.float16)
56+
>>> pipe.text_encoder = pipe.text_encoder.to(dtype=torch.bfloat16)
5757
>>> for block in pipe.text_encoder.encoder.block:
5858
... block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
5959
# BRIA's VAE is not supported in mixed precision, so we use float32.
@@ -267,6 +267,7 @@ def num_timesteps(self):
267267
def interrupt(self):
268268
return self._interrupt
269269

270+
270271
def check_inputs(
271272
self,
272273
prompt,

0 commit comments

Comments
 (0)