Skip to content

Commit 2d321a7

Browse files
authored
Merge pull request #503 from TimDettmers/efficient_8bit_serialize
Make 8-bit serialization more memory-efficient (v2)
2 parents ac5550a + b599fdb commit 2d321a7

File tree

2 files changed

+53
-36
lines changed

2 files changed

+53
-36
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,19 @@ def supports_igemmlt(device: torch.device) -> bool:
232232
return True
233233

234234

235+
def _get_tile_size(format):
236+
assert format in (
237+
"col_turing",
238+
"col_ampere",
239+
), f"please find this assert and manually enter tile size for {format}"
240+
return (8, 32) if format == "col_turing" else (32, 32)
241+
242+
243+
def get_tile_inds(format, device):
244+
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device)
245+
with torch.no_grad():
246+
return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device)
247+
235248
@dataclass
236249
class MatmulLtState:
237250
_tile_indices: Optional[torch.Tensor] = None
@@ -267,20 +280,10 @@ def reset_grads(self):
267280
self.SBt = None
268281
self.CBt = None
269282

270-
def get_tile_size(self):
271-
assert self.formatB in (
272-
"col_turing",
273-
"col_ampere",
274-
), f"please find this assert and manually enter tile size for {self.formatB}"
275-
return (8, 32) if self.formatB == "col_turing" else (32, 32)
276-
277283
@property
278284
def tile_indices(self):
279285
if self._tile_indices is None:
280-
device = self.CxB.device
281-
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device)
282-
with torch.no_grad():
283-
self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device)
286+
self._tile_indices = get_tile_inds(self.formatB, self.CxB.device)
284287
return self._tile_indices
285288

286289

bitsandbytes/nn/modules.py

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import bitsandbytes as bnb
1212
import bitsandbytes.functional
13-
from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout
13+
from bitsandbytes.autograd._functions import undo_layout, get_tile_inds
1414
from bitsandbytes.optim import GlobalOptimManager
1515
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
1616

@@ -306,6 +306,17 @@ def to(self, *args, **kwargs):
306306
return new_param
307307

308308

309+
def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
310+
weight = state_dict.get(f"{prefix}weight")
311+
if weight is None:
312+
# if the state dict has no weights for this layer (e.g., LoRA finetuning), do nothing
313+
return
314+
weight_format = state_dict.pop(f"{prefix}weight_format", "row")
315+
316+
if weight_format != "row":
317+
tile_indices = get_tile_inds(weight_format, weight.device)
318+
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
319+
309320

310321
class Linear8bitLt(nn.Linear):
311322
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
@@ -322,52 +333,55 @@ def __init__(self, input_features, output_features, bias=True, has_fp16_weights=
322333
self.state.use_pool = True
323334

324335
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
336+
self._register_load_state_dict_pre_hook(maybe_rearrange_weight)
325337

326338
def _save_to_state_dict(self, destination, prefix, keep_vars):
327-
if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None:
328-
# reorder weight layout back from ampere/turing to row
329-
reorder_layout = True
330-
weight_clone = self.weight.data.clone()
331-
else:
332-
reorder_layout = False
339+
super()._save_to_state_dict(destination, prefix, keep_vars)
333340

334-
try:
335-
if reorder_layout:
336-
self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices)
341+
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
342+
scb_name = "SCB"
337343

338-
super()._save_to_state_dict(destination, prefix, keep_vars)
344+
# case 1: .cuda was called, SCB is in self.weight
345+
param_from_weight = getattr(self.weight, scb_name)
346+
# case 2: self.init_8bit_state was called, SCB is in self.state
347+
param_from_state = getattr(self.state, scb_name)
348+
# case 3: SCB is in self.state, weight layout reordered after first forward()
349+
layout_reordered = self.state.CxB is not None
339350

340-
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
341-
weight_name = "SCB"
351+
key_name = prefix + f"{scb_name}"
352+
format_name = prefix + "weight_format"
342353

343-
# case 1: .cuda was called, SCB is in self.weight
344-
param_from_weight = getattr(self.weight, weight_name)
345-
# case 2: self.init_8bit_state was called, SCB is in self.state
346-
param_from_state = getattr(self.state, weight_name)
347-
348-
key_name = prefix + f"{weight_name}"
354+
if not self.state.has_fp16_weights:
349355
if param_from_weight is not None:
350356
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
351-
elif not self.state.has_fp16_weights and param_from_state is not None:
357+
destination[format_name] = "row"
358+
elif param_from_state is not None and not layout_reordered:
352359
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
353-
finally:
354-
if reorder_layout:
355-
self.weight.data = weight_clone
360+
destination[format_name] = "row"
361+
elif param_from_state is not None:
362+
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
363+
destination[format_name] = self.state.formatB
356364

357365
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
358366
missing_keys, unexpected_keys, error_msgs):
359367
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
360368
error_msgs)
361-
for key in unexpected_keys:
369+
unexpected_copy = list(unexpected_keys)
370+
371+
for key in unexpected_copy:
362372
input_name = key[len(prefix):]
363373
if input_name == "SCB":
364374
if self.weight.SCB is None:
365-
# buffers not yet initialized, can't call them directly without
375+
# buffers not yet initialized, can't access them directly without quantizing first
366376
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
367377
"not supported. Please call module.cuda() before module.load_state_dict()")
368378

369379
input_param = state_dict[key]
370380
self.weight.SCB.copy_(input_param)
381+
382+
if self.state.SCB is not None:
383+
self.state.SCB = self.weight.SCB
384+
371385
unexpected_keys.remove(key)
372386

373387
def init_8bit_state(self):

0 commit comments

Comments
 (0)