1818import torch
1919from huggingface_hub .utils import validate_hf_hub_args
2020
21- from ..quantizers .bitsandbytes import dequantize_bnb_weight
2221from ..utils import (
2322 USE_PEFT_BACKEND ,
2423 deprecate ,
2524 get_submodule_by_name ,
25+ is_bitsandbytes_available ,
2626 is_peft_available ,
2727 is_peft_version ,
2828 is_torch_version ,
4848)
4949
5050
51+ if is_bitsandbytes_available ():
52+ from ..quantizers .bitsandbytes import dequantize_bnb_weight
53+
5154_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
5255if is_torch_version (">=" , "1.9.0" ):
5356 if (
@@ -1971,11 +1974,13 @@ def _maybe_expand_transformer_param_shape_or_error_(
19711974 is_peft_loaded = getattr (transformer , "peft_config" , None ) is not None
19721975 for name , module in transformer .named_modules ():
19731976 if isinstance (module , torch .nn .Linear ):
1974- module_weight = (
1975- dequantize_bnb_weight (module .weight , state = module .weight .quant_state ).data
1976- if module .weight .__class__ .__name__ == "Params4bit"
1977- else module .weight .data
1978- )
1977+ is_quantized = module .weight .__class__ .__name__ == "Params4bit"
1978+ if is_quantized and not is_bitsandbytes_available ():
1979+ raise ValueError ("Install `bitsandbytes` to load quantized checkpoints." )
1980+ elif is_quantized :
1981+ module_weight = dequantize_bnb_weight (module .weight , state = module .weight .quant_state ).data
1982+ else :
1983+ module_weight = module .weight .data
19791984 module_bias = module .bias .data if module .bias is not None else None
19801985 bias = module_bias is not None
19811986
@@ -1997,8 +2002,6 @@ def _maybe_expand_transformer_param_shape_or_error_(
19972002 if tuple (module_weight_shape ) == (out_features , in_features ):
19982003 continue
19992004
2000- # TODO (sayakpaul): We still need to consider if the module we're expanding is
2001- # quantized and handle it accordingly if that is the case.
20022005 module_out_features , module_in_features = module_weight_shape
20032006 debug_message = ""
20042007 if in_features > module_in_features :
0 commit comments