Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/dreambooth/train_dreambooth_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,7 @@ def main(args):
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)

transformer.to(accelerator.device, dtype=weight_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not be doing it as we're fine-tuning the transformer. accelerator.prepare() handles this for us (plus any additional hook placements that might be required).

Copy link
Author

@Viditnegi Viditnegi Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transformer gets transferred to the gpu in float32 by default with accelerator.prepare(). Takes 40 gb memory at once!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well it will trained in mixed-precision through autocast. This is handled by accelerate. If the user wants to cast it to a lower precision, it should be done through an CLI argument, not by default IMO.

Copy link
Author

@Viditnegi Viditnegi Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay got it. Thanks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for context, we do have this line - transformer.to(accelerator.device, dtype=weight_dtype) in the LoRA training scripts, and the flag --upcast_before_saving (controls wether to upcast to fp32 before saving) that is False by default indeed due to the memory reqs

Copy link
Author

@Viditnegi Viditnegi Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay got it, thanks!

vae.to(accelerator.device, dtype=weight_dtype)
if not args.train_text_encoder:
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
Expand Down
Loading