-
Notifications
You must be signed in to change notification settings - Fork 6.2k
support hf_quantizer
in cache warmup.
#12043
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
48209da
223bb6a
5cbf07e
1e4cb0b
53240c6
f1c5093
68431f9
39c3849
631165d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
|
||
import importlib | ||
import types | ||
from fnmatch import fnmatch | ||
from typing import TYPE_CHECKING, Any, Dict, List, Union | ||
|
||
from packaging import version | ||
|
@@ -278,6 +279,29 @@ def create_quantized_param( | |
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) | ||
quantize_(module, self.quantization_config.get_apply_tensor_subclass()) | ||
|
||
def get_cuda_warm_up_factor(self): | ||
""" | ||
This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup. | ||
- A factor of 2 means we pre-allocate the full memory footprint of the model. | ||
- A factor of 4 means we pre-allocate half of that, and so on | ||
|
||
However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give | ||
the correct size for quantized weights (like int4 or int8) That's because TorchAO internally represents | ||
quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the | ||
torch_dtype not the actual bit-width of the quantized data. | ||
|
||
To correct for this: | ||
- Use a division factor of 8 for int4 weights | ||
- Use a division factor of 4 for int8 weights | ||
""" | ||
# Original mapping for non-AOBaseConfig types | ||
map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "float8*": 4} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to handle more of these exhaustively: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Took a best guess of "8" for the unsigned int types. I think we can tackle more of these nuanced / lesser-used types as they become a bit more used. I think the int8 and fp8 types are far more common for now 👀 So, I have added a comment as well. |
||
quant_type = self.quantization_config.quant_type | ||
for pattern, target_dtype in map_to_target_dtype.items(): | ||
if fnmatch(quant_type, pattern): | ||
return target_dtype | ||
raise ValueError(f"Unsupported quant_type: {quant_type!r}") | ||
|
||
def _process_model_before_weight_loading( | ||
self, | ||
model: "ModelMixin", | ||
|
Uh oh!
There was an error while loading. Please reload this page.