1818import torch
1919from huggingface_hub .utils import validate_hf_hub_args
2020
21+ from ..quantizers .bitsandbytes import dequantize_bnb_weight
2122from ..utils import (
2223 USE_PEFT_BACKEND ,
2324 deprecate ,
@@ -1905,7 +1906,6 @@ def unload_lora_weights(self, reset_to_overwritten_params=False):
19051906
19061907 for name , module in transformer .named_modules ():
19071908 if isinstance (module , torch .nn .Linear ) and name in module_names :
1908- module_weight = module .weight .data
19091909 module_bias = module .bias .data if module .bias is not None else None
19101910 bias = module_bias is not None
19111911
@@ -1919,7 +1919,6 @@ def unload_lora_weights(self, reset_to_overwritten_params=False):
19191919 in_features ,
19201920 out_features ,
19211921 bias = bias ,
1922- dtype = module_weight .dtype ,
19231922 )
19241923
19251924 tmp_state_dict = {"weight" : current_param_weight }
@@ -1970,7 +1969,11 @@ def _maybe_expand_transformer_param_shape_or_error_(
19701969 is_peft_loaded = getattr (transformer , "peft_config" , None ) is not None
19711970 for name , module in transformer .named_modules ():
19721971 if isinstance (module , torch .nn .Linear ):
1973- module_weight = module .weight .data
1972+ module_weight = (
1973+ dequantize_bnb_weight (module .weight , state = module .weight .quant_state ).data
1974+ if module .weight .__class__ .__name__ == "Params4bit"
1975+ else module .weight .data
1976+ )
19741977 module_bias = module .bias .data if module .bias is not None else None
19751978 bias = module_bias is not None
19761979
@@ -1994,7 +1997,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
19941997
19951998 # TODO (sayakpaul): We still need to consider if the module we're expanding is
19961999 # quantized and handle it accordingly if that is the case.
1997- module_out_features , module_in_features = module_weight . shape
2000+ module_out_features , module_in_features = module_weight_shape
19982001 debug_message = ""
19992002 if in_features > module_in_features :
20002003 debug_message += (
@@ -2018,17 +2021,13 @@ def _maybe_expand_transformer_param_shape_or_error_(
20182021 parent_module = transformer .get_submodule (parent_module_name )
20192022
20202023 with torch .device ("meta" ):
2021- expanded_module = torch .nn .Linear (
2022- in_features , out_features , bias = bias , dtype = module_weight .dtype
2023- )
2024+ expanded_module = torch .nn .Linear (in_features , out_features , bias = bias )
20242025 # Only weights are expanded and biases are not. This is because only the input dimensions
20252026 # are changed while the output dimensions remain the same. The shape of the weight tensor
20262027 # is (out_features, in_features), while the shape of bias tensor is (out_features,), which
20272028 # explains the reason why only weights are expanded.
2028- new_weight = torch .zeros_like (
2029- expanded_module .weight .data , device = module_weight .device , dtype = module_weight .dtype
2030- )
2031- slices = tuple (slice (0 , dim ) for dim in module_weight .shape )
2029+ new_weight = torch .zeros_like (expanded_module .weight .data )
2030+ slices = tuple (slice (0 , dim ) for dim in module_weight_shape )
20322031 new_weight [slices ] = module_weight
20332032 tmp_state_dict = {"weight" : new_weight }
20342033 if module_bias is not None :
0 commit comments