1919
2020
2121if is_torch_available () and is_gguf_available ():
22- import gguf
2322 import torch
2423
2524 from .utils import (
25+ GGML_QUANT_SIZES ,
2626 GGUFParameter ,
2727 _quant_shape_from_byte_shape ,
2828 _replace_with_gguf_linear ,
3333
3434
3535class GGUFQuantizer (DiffusersQuantizer ):
36+ use_keep_in_fp32_modules = True
37+
3638 def __init__ (self , quantization_config , ** kwargs ):
3739 super ().__init__ (quantization_config , ** kwargs )
3840
3941 self .compute_dtype = quantization_config .compute_dtype
4042 self .pre_quantized = quantization_config .pre_quantized
43+ self .modules_to_not_convert = quantization_config .modules_to_not_convert
44+
45+ if not isinstance (self .modules_to_not_convert , list ):
46+ self .modules_to_not_convert = [self .modules_to_not_convert ]
4147
4248 def validate_environment (self , * args , ** kwargs ):
4349 if not is_accelerate_available () or is_accelerate_version ("<" , "0.26.0" ):
@@ -70,7 +76,7 @@ def check_quantized_param_shape(self, param_name, current_param, loaded_param):
7076 current_param_shape = current_param .shape
7177 quant_type = loaded_param .quant_type
7278
73- block_size , type_size = gguf . GGML_QUANT_SIZES [quant_type ]
79+ block_size , type_size = GGML_QUANT_SIZES [quant_type ]
7480
7581 inferred_shape = _quant_shape_from_byte_shape (loaded_param_shape , type_size , block_size )
7682 if inferred_shape != current_param_shape :
@@ -96,7 +102,7 @@ def check_if_quantized_param(
96102 def create_quantized_param (
97103 self ,
98104 model : "ModelMixin" ,
99- param_value : " torch.Tensor" ,
105+ param_value : Union [ "GGUFParameter" , " torch.Tensor"] ,
100106 param_name : str ,
101107 target_device : "torch.device" ,
102108 state_dict : Dict [str , Any ],
@@ -119,7 +125,13 @@ def _process_model_before_weight_loading(
119125 ** kwargs ,
120126 ):
121127 state_dict = kwargs .get ("state_dict" , None )
122- _replace_with_gguf_linear (model , self .compute_dtype , state_dict )
128+
129+ self .modules_to_not_convert .extend (keep_in_fp32_modules )
130+ self .modules_to_not_convert = [module for module in self .modules_to_not_convert if module is not None ]
131+
132+ _replace_with_gguf_linear (
133+ model , self .compute_dtype , state_dict , modules_to_not_convert = self .modules_to_not_convert
134+ )
123135
124136 def _process_model_after_weight_loading (self , model : "ModelMixin" , ** kwargs ):
125137 return model
0 commit comments