1- from  typing  import  TYPE_CHECKING , Any , Dict , List , Optional 
1+ from  typing  import  TYPE_CHECKING , Any , Dict , List , Optional ,  Union 
22
33from  ..base  import  DiffusersQuantizer 
44
1212    is_accelerate_available ,
1313    is_accelerate_version ,
1414    is_gguf_available ,
15+     is_gguf_version ,
1516    is_torch_available ,
1617    logging ,
1718)
2122    import  gguf 
2223    import  torch 
2324
24-     from  .utils  import  GGUFParameter , _quant_shape_from_byte_shape , _replace_with_gguf_linear 
25+     from  .utils  import  (
26+         GGUFParameter ,
27+         _quant_shape_from_byte_shape ,
28+         _replace_with_gguf_linear ,
29+     )
2530
2631
2732logger  =  logging .get_logger (__name__ )
@@ -39,11 +44,26 @@ def validate_environment(self, *args, **kwargs):
3944            raise  ImportError (
4045                "Loading GGUF Parameters requires `accelerate` installed in your enviroment: `pip install 'accelerate>=0.26.0'`" 
4146            )
42-         if  not  is_gguf_available ():
47+         if  not  is_gguf_available ()  or   is_gguf_version ( "<" ,  "0.10.0" ) :
4348            raise  ImportError (
44-                 "To load GGUF format files you must have `gguf` installed in your environment: `pip install gguf`" 
49+                 "To load GGUF format files you must have `gguf` installed in your environment: `pip install gguf>=0.10.0 `" 
4550            )
4651
52+     def  adjust_max_memory (self , max_memory : Dict [str , Union [int , str ]]) ->  Dict [str , Union [int , str ]]:
53+         # need more space for buffers that are created during quantization 
54+         max_memory  =  {key : val  *  0.90  for  key , val  in  max_memory .items ()}
55+         return  max_memory 
56+ 
57+     def  adjust_target_dtype (self , target_dtype : "torch.dtype" ) ->  "torch.dtype" :
58+         if  target_dtype  !=  torch .uint8 :
59+             logger .info (f"target_dtype { target_dtype }  )
60+         return  torch .uint8 
61+ 
62+     def  update_torch_dtype (self , torch_dtype : "torch.dtype" ) ->  "torch.dtype" :
63+         if  torch_dtype  is  None :
64+             torch_dtype  =  self .compute_dtype 
65+         return  torch_dtype 
66+ 
4767    def  check_quantized_param_shape (self , param_name , current_param , loaded_param ):
4868        loaded_param_shape  =  loaded_param .shape 
4969        current_param_shape  =  current_param .shape 
@@ -62,7 +82,7 @@ def check_quantized_param_shape(self, param_name, current_param, loaded_param):
6282    def  check_if_quantized_param (
6383        self ,
6484        model : "ModelMixin" ,
65-         param_value : " torch.Tensor"
85+         param_value : Union [ "GGUFParameter" ,  " torch.Tensor"] ,
6686        param_name : str ,
6787        state_dict : Dict [str , Any ],
6888        ** kwargs ,
@@ -82,10 +102,13 @@ def create_quantized_param(
82102        unexpected_keys : Optional [List [str ]] =  None ,
83103    ):
84104        module , tensor_name  =  get_module_from_name (model , param_name )
85-         if  tensor_name  not  in module ._parameters :
105+         if  tensor_name  not  in module ._parameters   and   tensor_name   not   in   module . _buffers :
86106            raise  ValueError (f"{ module } { tensor_name }  )
87107
88-         module ._parameters [tensor_name ] =  param_value 
108+         if  tensor_name  in  module ._parameters :
109+             module ._parameters [tensor_name ] =  param_value .to (target_device )
110+         if  tensor_name  in  module ._buffers :
111+             module ._buffers [tensor_name ] =  param_value .to (target_device )
89112
90113    def  _process_model_before_weight_loading (
91114        self ,
0 commit comments