diff --git a/llmc/compression/quantization/module_utils.py b/llmc/compression/quantization/module_utils.py index 7fab3798d..a0cb5edf1 100644 --- a/llmc/compression/quantization/module_utils.py +++ b/llmc/compression/quantization/module_utils.py @@ -448,6 +448,10 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): keep_vars=keep_vars) if f'{prefix}fc.weight' in state_dict: state_dict[f'{prefix}weight'] = state_dict.pop(f'{prefix}fc.weight') + if f'{prefix}fc.weight_scale' in state_dict: + state_dict[f'{prefix}weight_scale'] = state_dict.pop(f'{prefix}fc.weight_scale') + if f'{prefix}fc.input_scale' in state_dict: + state_dict[f'{prefix}input_scale'] = state_dict.pop(f'{prefix}fc.input_scale') return state_dict def _fp32_forward(self, hidden_states):