-
Notifications
You must be signed in to change notification settings - Fork 752
serialize scales as bf16 and serialize in Named Data Map #11031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/11031
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ebd5d87 with merge base 1bc36c7 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@mcr229 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@mcr229 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
| num_scales = scale.numel() | ||
|
|
||
| if quant_params.is_per_channel_group: | ||
| scale = scale.to(torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be a flag (default=bf16), which lines up with QB4W xnnpack flag to use bf16. We can error out at AoT for fp32 given we can't run that yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure what you mean, won't almost all scales that come to use be fp32?
| channel_dim:int; | ||
| group_size:int; | ||
| scale_bf16:[ushort]; | ||
| scale_bf16:[ushort] (deprecated); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious, why mark this as deprecated but not float if we are moving to ndm for evreythig?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because this actually was never used, since we never added the export path to serialize into this field. scale[float] is still used by older versions
Summary: XNNPACK Currently uses BF16 scales for running GEMMS with groupwise quantized weights. Currently we serialize scales as FP32, and then convert them to BF16 before passing to XNNPACK. We can save both memory and file size by serializing the scales as BF16 first. As an additional step here, we move the serialization of scales both for channelwise and groupwise quantized weights into the named data map. In the future, if we want to swap data that could be a potential feature because scales are no longer tied to the XNNPACK payload but can be swappable through the ptd file. cc lucylq for the scale serialization ### Llama Experiments ``` -rw-r--r-- 1 maxren staff 1746392320 May 20 16:49 llama3_fp32_scales.pte -rw-r--r-- 1 maxren staff 1707798912 May 20 18:47 llama3_bf16_scales.pte ``` we see ~40 mb reduction in model size. Reviewed By: kirklandsign Differential Revision: D75151974 Pulled By: mcr229
|
This pull request was exported from Phabricator. Differential Revision: D75151974 |
XNNPACK Currently uses BF16 scales for running GEMMS with groupwise quantized weights. Currently we serialize scales as FP32, and then convert them to BF16 before passing to XNNPACK. We can save both memory and file size by serializing the scales as BF16 first.
As an additional step here, we move the serialization of scales both for channelwise and groupwise quantized weights into the named data map. In the future, if we want to swap data that could be a potential feature because scales are no longer tied to the XNNPACK payload but can be swappable through the ptd file.
cc @lucylq for the scale serialization
Llama Experiments
we see ~40 mb reduction in model size.