2020if  is_torch_available ():
2121    import  torch 
2222
23- if  is_accelerate_available ():
24-     pass 
25- 
26- if  is_nunchaku_available ():
27-     from  .utils  import  replace_with_nunchaku_linear 
2823
2924logger  =  logging .get_logger (__name__ )
3025
@@ -79,13 +74,14 @@ def check_if_quantized_param(
7974        state_dict : Dict [str , Any ],
8075        ** kwargs ,
8176    ):
82-         from  nunchaku .models .linear  import  SVDQW4A4Linear 
83- 
84-         module , tensor_name  =  get_module_from_name (model , param_name )
85-         if  self .pre_quantized  and  isinstance (module , SVDQW4A4Linear ):
86-             return  True 
87- 
88-         return  False 
77+         # TODO: revisit 
78+         # Check if the param_name is not in self.modules_to_not_convert 
79+         if  any ((key  +  "."  in  param_name ) or  (key  ==  param_name ) for  key  in  self .modules_to_not_convert ):
80+             return  False 
81+         else :
82+             # We only quantize the weight of nn.Linear 
83+             module , _  =  get_module_from_name (model , param_name )
84+             return  isinstance (module , torch .nn .Linear )
8985
9086    def  create_quantized_param (
9187        self ,
@@ -112,13 +108,32 @@ def create_quantized_param(
112108                module ._buffers [tensor_name ] =  torch .nn .Parameter (param_value .to (target_device ))
113109
114110        elif  isinstance (module , torch .nn .Linear ):
115-             if  tensor_name  in  module ._parameters :
116-                 module ._parameters [tensor_name ] =  torch .nn .Parameter (param_value ).to (device = target_device )
117-             if  tensor_name  in  module ._buffers :
118-                 module ._buffers [tensor_name ] =  torch .nn .Parameter (param_value ).to (target_device )
119- 
120-             new_module  =  SVDQW4A4Linear .from_linear (module )
121-             setattr (model , param_name , new_module )
111+             # TODO: this returns an `SVDQW4A4Linear` layer initialized from the corresponding `linear` module. 
112+             # But we need to have a utility that can take a pretrained param value and quantize it. Not sure 
113+             # how to do that yet. 
114+             # Essentially, we need something like `bnb.nn.Params4bit.from_prequantized`. Or is there a better 
115+             # way to do it? 
116+             is_param  =  tensor_name  in  module ._parameters 
117+             is_buffer  =  tensor_name  in  module ._buffers 
118+             new_module  =  SVDQW4A4Linear .from_linear (
119+                 module , precision = self .quantization_config .precision , rank = self .quantization_config .rank 
120+             )
121+             module_name  =  "." .join (param_name .split ("." )[:- 1 ])
122+             if  "."  in  module_name :
123+                 parent_name , leaf  =  module_name .rsplit ("." , 1 )
124+                 parent  =  model .get_submodule (parent_name )
125+             else :
126+                 parent , leaf  =  model , module_name 
127+ 
128+             # rebind 
129+             # this will result into 
130+             # AttributeError: 'SVDQW4A4Linear' object has no attribute 'weight'. Did you mean: 'qweight'. 
131+             if  is_param :
132+                 new_module ._parameters [tensor_name ] =  torch .nn .Parameter (param_value ).to (device = target_device )
133+             elif  is_buffer :
134+                 new_module ._buffers [tensor_name ] =  torch .nn .Parameter (param_value ).to (device = target_device )
135+ 
136+             setattr (parent , leaf , new_module )
122137
123138    def  adjust_max_memory (self , max_memory : Dict [str , Union [int , str ]]) ->  Dict [str , Union [int , str ]]:
124139        max_memory  =  {key : val  *  0.90  for  key , val  in  max_memory .items ()}
@@ -157,24 +172,25 @@ def _process_model_before_weight_loading(
157172        keep_in_fp32_modules : List [str ] =  [],
158173        ** kwargs ,
159174    ):
160-         # TODO: deal with `device_map` 
161175        self .modules_to_not_convert  =  self .quantization_config .modules_to_not_convert 
162176
163177        if  not  isinstance (self .modules_to_not_convert , list ):
164178            self .modules_to_not_convert  =  [self .modules_to_not_convert ]
165179
166180        self .modules_to_not_convert .extend (keep_in_fp32_modules )
181+ 
182+         # TODO: revisit 
183+         # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` 
184+         # if isinstance(device_map, dict) and len(device_map.keys()) > 1: 
185+         #     keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] 
186+         #     self.modules_to_not_convert.extend(keys_on_cpu) 
187+ 
167188        # Purge `None`. 
168189        # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 
169190        # in case of diffusion transformer models. For language models and others alike, `lm_head` 
170191        # and tied modules are usually kept in FP32. 
171192        self .modules_to_not_convert  =  [module  for  module  in  self .modules_to_not_convert  if  module  is  not None ]
172193
173-         model  =  replace_with_nunchaku_linear (
174-             model ,
175-             modules_to_not_convert = self .modules_to_not_convert ,
176-             quantization_config = self .quantization_config ,
177-         )
178194        model .config .quantization_config  =  self .quantization_config 
179195
180196    def  _process_model_after_weight_loading (self , model , ** kwargs ):
0 commit comments