Skip to content

Commit 6a934d4

Browse files
committed
reorder state_dict
1 parent 1d541b5 commit 6a934d4

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

bitsandbytes/nn/modules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,15 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
229229
besides weight and bias,
230230
fill state_dict with components of quant_state
231231
"""
232+
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
233+
232234
if getattr(self.weight, "quant_state", None) is not None:
233235
quant_state_dict = self.weight.quant_state.as_dict()
234236
tensor_keys = [k for k, v in quant_state_dict.items() if isinstance(v, torch.Tensor)]
235237
for k in tensor_keys:
236238
destination[prefix + "weight." + k] = quant_state_dict.pop(k) if keep_vars else quant_state_dict.pop(k).detach()
237239
destination[prefix + "weight." + "quant_state_dict"] = quant_state_dict
238240
destination[prefix + "weight." + "quantization_method"] = "bitsandbytes." + quant_state_dict["quant_type"]
239-
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
240241

241242
def forward(self, x: torch.Tensor):
242243
# weights are cast automatically as Int8Params, but the bias has to be cast manually

0 commit comments

Comments
 (0)