- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.5k
          support hf_quantizer in cache warmup.
          #12043
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
48209da
              223bb6a
              5cbf07e
              1e4cb0b
              53240c6
              f1c5093
              68431f9
              39c3849
              631165d
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -19,6 +19,7 @@ | |
|  | ||
| import importlib | ||
| import types | ||
| from fnmatch import fnmatch | ||
| from typing import TYPE_CHECKING, Any, Dict, List, Union | ||
|  | ||
| from packaging import version | ||
|  | @@ -278,6 +279,29 @@ def create_quantized_param( | |
| module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) | ||
| quantize_(module, self.quantization_config.get_apply_tensor_subclass()) | ||
|  | ||
| def get_cuda_warm_up_factor(self): | ||
| """ | ||
| This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup. | ||
| - A factor of 2 means we pre-allocate the full memory footprint of the model. | ||
| - A factor of 4 means we pre-allocate half of that, and so on | ||
|  | ||
| However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give | ||
| the correct size for quantized weights (like int4 or int8) That's because TorchAO internally represents | ||
| quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the | ||
| torch_dtype not the actual bit-width of the quantized data. | ||
|  | ||
| To correct for this: | ||
| - Use a division factor of 8 for int4 weights | ||
| - Use a division factor of 4 for int8 weights | ||
| """ | ||
| # Original mapping for non-AOBaseConfig types | ||
| map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "float8*": 4} | ||
|          | ||
| quant_type = self.quantization_config.quant_type | ||
| for pattern, target_dtype in map_to_target_dtype.items(): | ||
| if fnmatch(quant_type, pattern): | ||
| return target_dtype | ||
| raise ValueError(f"Unsupported quant_type: {quant_type!r}") | ||
|  | ||
| def _process_model_before_weight_loading( | ||
| self, | ||
| model: "ModelMixin", | ||
|  | ||
Uh oh!
There was an error while loading. Please reload this page.