@@ -1982,9 +1982,19 @@ def _maybe_expand_transformer_param_shape_or_error_(
19821982                out_features  =  state_dict [lora_B_weight_name ].shape [0 ]
19831983
19841984                # This means there's no need for an expansion in the params, so we simply skip. 
1985-                 if  tuple (module_weight .shape ) ==  (out_features , in_features ):
1985+                 module_weight_shape  =  module_weight .shape 
1986+                 expansion_shape  =  (out_features , in_features )
1987+                 quantization_config  =  getattr (transformer , "quantization_config" , None )
1988+                 if  quantization_config  and  quantization_config .quant_method  ==  "bitsandbytes" :
1989+                     if  quantization_config .load_in_4bit :
1990+                         expansion_shape  =  torch .Size (expansion_shape ).numel ()
1991+                         expansion_shape  =  ((expansion_shape  +  1 ) //  2 , 1 )
1992+ 
1993+                 if  tuple (module_weight_shape ) ==  expansion_shape :
19861994                    continue 
19871995
1996+                 # TODO (sayakpaul): We still need to consider if the module we're expanding is 
1997+                 # quantized and handle it accordingly if that is the case. 
19881998                module_out_features , module_in_features  =  module_weight .shape 
19891999                debug_message  =  "" 
19902000                if  in_features  >  module_in_features :
@@ -2080,13 +2090,22 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
20802090            base_weight_param  =  transformer_state_dict [base_param_name ]
20812091            lora_A_param  =  lora_state_dict [f"{ prefix } { k }  ]
20822092
2083-             if  base_weight_param .shape [1 ] >  lora_A_param .shape [1 ]:
2093+             # TODO (sayakpaul): Handle the cases when we actually need to expand. 
2094+             base_out_feature_shape  =  base_weight_param .shape [1 ]
2095+             lora_A_out_feature_shape  =  lora_A_param .shape [1 ]
2096+             quantization_config  =  getattr (transformer , "quantization_config" , None )
2097+             if  quantization_config  and  quantization_config .quant_method  ==  "bitsandbytes" :
2098+                 if  quantization_config .load_in_4bit :
2099+                     lora_A_out_feature_shape  =  lora_A_param .shape .numel ()
2100+                     lora_A_out_feature_shape  =  ((lora_A_out_feature_shape  +  1 ) //  2 , 1 )[1 ]
2101+ 
2102+             if  base_out_feature_shape  >  lora_A_out_feature_shape :
20842103                shape  =  (lora_A_param .shape [0 ], base_weight_param .shape [1 ])
20852104                expanded_state_dict_weight  =  torch .zeros (shape , device = base_weight_param .device )
20862105                expanded_state_dict_weight [:, : lora_A_param .shape [1 ]].copy_ (lora_A_param )
20872106                lora_state_dict [f"{ prefix } { k }  ] =  expanded_state_dict_weight 
20882107                expanded_module_names .add (k )
2089-             elif  base_weight_param . shape [ 1 ]  <  lora_A_param . shape [ 1 ] :
2108+             elif  lora_A_out_feature_shape  <  lora_A_out_feature_shape :
20902109                raise  NotImplementedError (
20912110                    f"This LoRA param ({ k } { lora_A_param .shape }  
20922111                )
0 commit comments