Skip to content

Commit c1f3f56

Browse files
committed
Rearrange the weights directly in state dict before loading
1 parent f734076 commit c1f3f56

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

bitsandbytes/nn/modules.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,15 @@ def to(self, *args, **kwargs):
306306
return new_param
307307

308308

309+
def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
310+
weight = state_dict[f"{prefix}weight"]
311+
weight_format = state_dict.pop(f"{prefix}weight_format", "row")
312+
313+
if weight_format != "row":
314+
tile_indices = get_tile_inds(weight_format, weight.device)
315+
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
316+
317+
309318
class Linear8bitLt(nn.Linear):
310319
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
311320
memory_efficient_backward=False, threshold=0.0, index=None):
@@ -321,6 +330,7 @@ def __init__(self, input_features, output_features, bias=True, has_fp16_weights=
321330
self.state.use_pool = True
322331

323332
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
333+
self._register_load_state_dict_pre_hook(maybe_rearrange_weight)
324334

325335
def _save_to_state_dict(self, destination, prefix, keep_vars):
326336
super()._save_to_state_dict(destination, prefix, keep_vars)
@@ -370,12 +380,6 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
370380
self.state.SCB = self.weight.SCB
371381

372382
unexpected_keys.remove(key)
373-
if input_name == "weight_format":
374-
input_param = state_dict[key]
375-
if input_param != "row":
376-
tile_indices = get_tile_inds(input_param, self.weight.device)
377-
self.weight.data = self.weight.CB = undo_layout(self.weight.data, tile_indices)
378-
unexpected_keys.remove(key)
379383

380384
def init_8bit_state(self):
381385
self.state.CB = self.weight.CB

0 commit comments

Comments
 (0)