Skip to content

Conversation

@Viditnegi
Copy link

@Viditnegi Viditnegi commented Jan 6, 2025

By default, the transformer always gets transferred to the device with float32 whether we specify the dtype as bfloat16 or float16. This fixes it.

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

By default, the transformer always gets transferred to the device with float32 whether we specify the dtype as bfloat16 or float16. This fixes it.
@hlky hlky requested review from linoytsaban and sayakpaul January 6, 2025 09:02
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)

transformer.to(accelerator.device, dtype=weight_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not be doing it as we're fine-tuning the transformer. accelerator.prepare() handles this for us (plus any additional hook placements that might be required).

Copy link
Author

@Viditnegi Viditnegi Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transformer gets transferred to the gpu in float32 by default with accelerator.prepare(). Takes 40 gb memory at once!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well it will trained in mixed-precision through autocast. This is handled by accelerate. If the user wants to cast it to a lower precision, it should be done through an CLI argument, not by default IMO.

Copy link
Author

@Viditnegi Viditnegi Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay got it. Thanks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for context, we do have this line - transformer.to(accelerator.device, dtype=weight_dtype) in the LoRA training scripts, and the flag --upcast_before_saving (controls wether to upcast to fp32 before saving) that is False by default indeed due to the memory reqs

Copy link
Author

@Viditnegi Viditnegi Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay got it, thanks!

@DN6
Copy link
Collaborator

DN6 commented Jan 8, 2025

Can we close this since it seems we can resolve the issue with a CLI arg?

@Viditnegi Viditnegi closed this Jan 8, 2025
@Viditnegi
Copy link
Author

Can we close this since it seems we can resolve the issue with a CLI arg?

Yes, just closed the pr!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants