-
Couldn't load subscription status.
- Fork 6.4k
[Quantization] Add quantization support for bitsandbytes
#9213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 45 commits
e634ff2
02a6dff
c385a2b
0355875
e41b494
dfb33eb
e492655
6e86cc0
58a3d15
1d477f9
bd7f46d
d5d7bb6
44c8a75
6a0fcdc
e4590fa
77a1438
335ab6b
d44ef85
210fa1e
f4feee1
e8c1722
7f86a71
ba671b6
c1a9f13
4489c54
f2ca5e2
d6b8954
45029e2
4eb468a
939965d
8557166
d098d07
c4a0074
ee45612
b24c0a7
473505c
c795c82
c1d5b96
af7caca
80967f5
3bdf25a
27415cc
51cac09
15f3032
77c9fdb
ddc9f29
44c4109
27666a8
3464d83
b106124
330fa0a
abc8607
31725aa
e5938a6
444588f
d3360ce
d8b35f4
859f2d7
3b2d6e1
5799954
8e4bd08
835d4ad
27075fe
5c00c1c
5d633a0
c381fe0
3c92878
acdeb25
aa295b7
7f7c9ce
55f96d8
b28cc65
8328e86
9758942
b1a9878
971305b
f41adf1
0bcb88b
55b3696
4cb3a6d
8a03eae
53f0a92
6aab47c
9b9a610
510d57a
555a5ae
da10365
71316a6
12f5c59
5e722cd
c78dd0c
af3ecea
a473d28
870d74f
3e6cfeb
673993c
0d5f2f7
3cb20fe
10940a9
c0a88ae
dcc5bc5
5e0b4eb
569dd96
8bdc846
ff8ddef
de6394a
81bb48a
c5e62ae
d023b40
a3d2655
700b0f3
0ae70fe
ecdf1d0
aea3398
3a91974
5d8e844
501a6ba
1a931cb
2fa8fb9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ | |
| import torch | ||
| from huggingface_hub.utils import EntryNotFoundError | ||
|
|
||
| from ..quantizers.quantization_config import QuantizationMethod | ||
| from ..utils import ( | ||
| SAFE_WEIGHTS_INDEX_NAME, | ||
| SAFETENSORS_FILE_EXTENSION, | ||
|
|
@@ -53,11 +54,36 @@ | |
|
|
||
|
|
||
| # Adapted from `transformers` (see modeling_utils.py) | ||
| def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype): | ||
| def _determine_device_map( | ||
| model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None | ||
| ): | ||
| if isinstance(device_map, str): | ||
| special_dtypes = {} | ||
| if hf_quantizer is not None: | ||
| special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype)) | ||
| special_dtypes.update( | ||
| { | ||
| name: torch.float32 | ||
| for name, _ in model.named_parameters() | ||
| if any(m in name for m in keep_in_fp32_modules) | ||
| } | ||
| ) | ||
|
|
||
| target_dtype = torch_dtype | ||
| if hf_quantizer is not None: | ||
| target_dtype = hf_quantizer.adjust_target_dtype(target_dtype) | ||
|
|
||
| no_split_modules = model._get_no_split_modules(device_map) | ||
| device_map_kwargs = {"no_split_module_classes": no_split_modules} | ||
|
|
||
| if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters: | ||
| device_map_kwargs["special_dtypes"] = special_dtypes | ||
| elif len(special_dtypes) > 0: | ||
| logger.warning( | ||
| "This model has some weights that should be kept in higher precision, you need to upgrade " | ||
| "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)." | ||
| ) | ||
|
|
||
| if device_map != "sequential": | ||
| max_memory = get_balanced_memory( | ||
| model, | ||
|
|
@@ -69,8 +95,14 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_ | |
| else: | ||
| max_memory = get_max_memory(max_memory) | ||
|
|
||
| if hf_quantizer is not None: | ||
| max_memory = hf_quantizer.adjust_max_memory(max_memory) | ||
|
|
||
| device_map_kwargs["max_memory"] = max_memory | ||
| device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs) | ||
| device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs) | ||
|
|
||
| if hf_quantizer is not None: | ||
| hf_quantizer.validate_environment(device_map=device_map) | ||
|
|
||
| return device_map | ||
|
|
||
|
|
@@ -99,6 +131,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ | |
| """ | ||
| Reads a checkpoint file, returning properly formatted errors if they arise. | ||
| """ | ||
| if isinstance(checkpoint_file, dict): | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return checkpoint_file | ||
| try: | ||
| file_extension = os.path.basename(checkpoint_file).split(".")[-1] | ||
| if file_extension == SAFETENSORS_FILE_EXTENSION: | ||
|
|
@@ -136,29 +170,57 @@ def load_model_dict_into_meta( | |
| device: Optional[Union[str, torch.device]] = None, | ||
| dtype: Optional[Union[str, torch.dtype]] = None, | ||
| model_name_or_path: Optional[str] = None, | ||
| hf_quantizer=None, | ||
| keep_in_fp32_modules=None, | ||
| ) -> List[str]: | ||
| device = device or torch.device("cpu") | ||
| device = device or torch.device("cpu") if hf_quantizer is None else device | ||
|
||
| dtype = dtype or torch.float32 | ||
| is_quantized = hf_quantizer is not None | ||
|
|
||
| accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) | ||
|
|
||
| unexpected_keys = [] | ||
| empty_state_dict = model.state_dict() | ||
| unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict] | ||
| is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") | ||
|
|
||
| for param_name, param in state_dict.items(): | ||
| if param_name not in empty_state_dict: | ||
| unexpected_keys.append(param_name) | ||
| continue | ||
|
|
||
| if empty_state_dict[param_name].shape != param.shape: | ||
| # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params | ||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # in int/uint/bool and not cast them. | ||
| is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn | ||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn: | ||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if ( | ||
| keep_in_fp32_modules is not None | ||
| and any( | ||
| module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules | ||
| ) | ||
| and dtype == torch.float16 | ||
| ): | ||
| param = param.to(torch.float32) | ||
| else: | ||
| param = param.to(dtype) | ||
|
|
||
| is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES | ||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if not is_quantized and not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape: | ||
|
||
| model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" | ||
| raise ValueError( | ||
| f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." | ||
| ) | ||
|
|
||
| if accepts_dtype: | ||
| set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) | ||
| if ( | ||
| not is_quantized | ||
| or (not hf_quantizer.requires_parameters_quantization) | ||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| or (not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device)) | ||
| ): | ||
| if accepts_dtype: | ||
| set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) | ||
| else: | ||
| set_module_tensor_to_device(model, param_name, device, value=param) | ||
| else: | ||
| set_module_tensor_to_device(model, param_name, device, value=param) | ||
| hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) | ||
|
|
||
| return unexpected_keys | ||
|
|
||
|
|
||
|
|
@@ -228,3 +290,32 @@ def _fetch_index_file( | |
| index_file = None | ||
|
|
||
| return index_file | ||
|
|
||
|
|
||
| # Adapted from | ||
| # https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64 | ||
| def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata): | ||
| weight_map = sharded_metadata.get("weight_map", None) | ||
| if weight_map is None: | ||
| raise KeyError("'weight_map' key not found in the shard index file.") | ||
|
|
||
| # Collect all unique safetensors files from weight_map | ||
| files_to_load = set(weight_map.values()) | ||
| is_safetensors = all(f.endswith(".safetensors") for f in files_to_load) | ||
| merged_state_dict = {} | ||
|
|
||
| # Load tensors from each unique file | ||
| for file_name in files_to_load: | ||
| part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name) | ||
| if not os.path.exists(part_file_path): | ||
| raise FileNotFoundError(f"Part file {file_name} not found.") | ||
|
|
||
| if is_safetensors: | ||
| with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: | ||
| for tensor_key in f.keys(): | ||
| if tensor_key in weight_map: | ||
| merged_state_dict[tensor_key] = f.get_tensor(tensor_key) | ||
| else: | ||
| merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu")) | ||
|
|
||
| return merged_state_dict | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because
quantization_configisn't a part of any model's__init__().There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is better to not add to cofig_dict if it is not going into
__init__, i.e. at line 511There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We cannot remove
quantization_configfrom the config of a model as that would prevent loading of the quantized models viafrom_pretrained().quantization_configisn't used for initializing a model, it's used to determine what kind of quantization configuration to inject inside the given model. This is why it's only used infrom_pretrained()ofModelMixin.LMK if you have a better idea to handle it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we do not remove them from the config, just not adding to the
config_dictinside thisextract_init_dictmethod: basically, thecofig_dictin this function goes through these steps:init_dict: the quantisation config will not go there, so it is not affected if we do not add it toconfig_dictinit_dict, if the quantisation configs were not there, we do not need to throw a warning for itunused_kwargs- so I think this is the only difference it would make, do we need the quantisation config to be inunused_kwargsreturned byextract_init_dict? I thinkunused_kwargsis only used to send additional warnings for unexpected stuff, but since quantisation config is expected, and we have already decided not to send a warning here insideextract_init_dict- I think it does not need to go to theunused_kwargshere?@classmethod def extract_init_dict(cls, config_dict, **kwargs): ... config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values" # remove private attributes config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} + # remove quantization_config + config_dict = {k: v for k, v in config_dict.items() if k != "quantization_config")} ## here we use config_dict to create `init_dict` which will be passed to `__init__` method init_dict = {} for key in expected_keys: ... init_dict[key] = config_dict.pop(key) - only_quant_config_remaining = len(config_dict) == 1 and "quantization_config" in config_dict - if len(config_dict) > 0 and not only_quant_config_remaining: + if len(config_dict) > 0: logger.warning( f"The config attributes {config_dict} were passed to {cls.__name__}, " "but are not expected and will be ignored. Please verify your " f"{cls.config_name} configuration file." ) .... # 6. Define unused keyword arguments unused_kwargs = {**config_dict, **kwargs} return init_dict, unused_kwargs, hidden_config_dictThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. Resolved in 555a5ae.