- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.5k
Update transformer to accelerator.device with it's weight_dtype #10466
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
By default, the transformer always gets transferred to the device with float32 whether we specify the dtype as bfloat16 or float16. This fixes it.
| The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. | 
| "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." | ||
| ) | ||
|  | ||
| transformer.to(accelerator.device, dtype=weight_dtype) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not be doing it as we're fine-tuning the transformer. accelerator.prepare() handles this for us (plus any additional hook placements that might be required).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The transformer gets transferred to the gpu in float32 by default with accelerator.prepare(). Takes 40 gb memory at once!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well it will trained in mixed-precision through autocast. This is handled by accelerate. If the user wants to cast it to a lower precision, it should be done through an CLI argument, not by default IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay got it. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for context, we do have this line - transformer.to(accelerator.device, dtype=weight_dtype) in the LoRA training scripts, and the flag --upcast_before_saving (controls wether to upcast to fp32 before saving) that is False by default indeed due to the memory reqs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay got it, thanks!
| Can we close this since it seems we can resolve the issue with a CLI arg? | 
| 
 Yes, just closed the pr! | 
By default, the transformer always gets transferred to the device with float32 whether we specify the dtype as bfloat16 or float16. This fixes it.
What does this PR do?
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.