Skip to content

Commit 6dfa5cc

Browse files
authored
Support 4bit BNB layers meta-device materialization (#19150)
1 parent 1284713 commit 6dfa5cc

File tree

7 files changed

+330
-32
lines changed

7 files changed

+330
-32
lines changed

requirements/pytorch/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ hydra-core >=1.0.5, <1.4.0
88
jsonargparse[signatures] >=4.26.1, <4.27.0
99
rich >=12.3.0, <13.6.0
1010
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
11-
bitsandbytes <=0.41.1
11+
bitsandbytes ==0.41.1 # strict

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
- Added `lightning.fabric.utilities.AttributeDict` for convenient dict-attribute access to represent state in script ([#18943](https://github.com/Lightning-AI/lightning/pull/18943))
1616

1717

18+
- Added support for meta-device initialization and materialization of 4-bit Bitsandbytes layers ([#19150](https://github.com/Lightning-AI/lightning/pull/19150))
19+
20+
1821
- Added `TransformerEnginePrecision(fallback_compute_dtype=)` to control the dtype of operations that don't support fp8 ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))
1922

2023

src/lightning/fabric/plugins/precision/bitsandbytes.py

Lines changed: 168 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,21 @@
1313
# limitations under the License.
1414
import functools
1515
import logging
16+
import math
1617
import os
1718
import warnings
1819
from contextlib import ExitStack
1920
from functools import partial
2021
from types import ModuleType
21-
from typing import Any, Callable, ContextManager, Literal, Optional, OrderedDict, Set, Type
22+
from typing import Any, Callable, ContextManager, Literal, Optional, OrderedDict, Set, Tuple, Type, cast
2223

2324
import torch
2425
from lightning_utilities import apply_to_collection
2526
from lightning_utilities.core.imports import RequirementCache
2627
from torch import Tensor
28+
from torch.nn import init
2729
from torch.nn.modules.module import _IncompatibleKeys
28-
from typing_extensions import override
30+
from typing_extensions import Self, override
2931

3032
from lightning.fabric.plugins.precision.precision import Precision
3133
from lightning.fabric.plugins.precision.utils import (
@@ -37,7 +39,8 @@
3739

3840
log = logging.getLogger(__name__)
3941

40-
_BITSANDBYTES_AVAILABLE = RequirementCache("bitsandbytes>=0.41.0")
42+
# TODO: unpin after resolving the `quant_state` format breaking changes
43+
_BITSANDBYTES_AVAILABLE = RequirementCache("bitsandbytes==0.41.0")
4144

4245

4346
class BitsandbytesPrecision(Precision):
@@ -109,6 +112,7 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
109112
# convert modules if they haven't been converted already
110113
bnb = _import_bitsandbytes()
111114
if not any(isinstance(m, (bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)) for m in module.modules()):
115+
# this will not quantize the model but only replace the layer classes
112116
_convert_layers(module, self._linear_cls, self.ignore_modules)
113117

114118
# set the compute dtype if necessary
@@ -164,11 +168,36 @@ def _quantize_on_load_hook(quantize_fn: Callable[[torch.Tensor], None], state_di
164168

165169

166170
def _ignore_missing_weights_hook(module: torch.nn.Module, incompatible_keys: _IncompatibleKeys) -> None:
171+
# since we manually loaded the weight in the `_quantize_on_load_hook` hook, we need to avoid this missing key false
172+
# positive
167173
for key in reversed(incompatible_keys.missing_keys):
168174
if key.endswith("weight"):
169175
incompatible_keys.missing_keys.remove(key)
170176

171177

178+
def _replace_param(
179+
param: torch.nn.Parameter, data: torch.Tensor, quant_state: Optional[Tuple] = None
180+
) -> torch.nn.Parameter:
181+
bnb = _import_bitsandbytes()
182+
183+
# doing `param.data = weight` raises a RuntimeError if param.data was on meta-device, so
184+
# we need to re-create the parameters instead of overwriting the data
185+
if param.device.type == "meta":
186+
if isinstance(param, bnb.nn.Params4bit):
187+
return bnb.nn.Params4bit(
188+
data,
189+
requires_grad=data.requires_grad,
190+
quant_state=quant_state,
191+
compress_statistics=param.compress_statistics,
192+
quant_type=param.quant_type,
193+
)
194+
return torch.nn.Parameter(data, requires_grad=data.requires_grad)
195+
param.data = data
196+
if isinstance(param, bnb.nn.Params4bit):
197+
param.quant_state = quant_state
198+
return param
199+
200+
172201
@functools.lru_cache(maxsize=1)
173202
def _import_bitsandbytes() -> ModuleType:
174203
if not _BITSANDBYTES_AVAILABLE:
@@ -192,51 +221,160 @@ class _Linear8bitLt(bnb.nn.Linear8bitLt):
192221

193222
def __init__(self, *args: Any, device: Optional[_DEVICE] = None, threshold: float = 6.0, **kwargs: Any) -> None:
194223
super().__init__(*args, device=device, threshold=threshold, **kwargs)
224+
self.weight = cast(bnb.nn.Int8Params, self.weight) # type: ignore[has-type]
225+
self.bias = cast(Optional[torch.nn.Parameter], self.bias) # type: ignore[has-type]
195226
# if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
196227
# filling the device memory with float32 weights which could lead to OOM
197228
if torch.tensor(0, device=device).device.type == "cuda":
198-
self._quantize_weight(self.weight.data)
199-
self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self._quantize_weight))
229+
self.quantize_()
230+
self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_))
200231
self.register_load_state_dict_post_hook(_ignore_missing_weights_hook)
201232

202-
def _quantize_weight(self, weight: torch.Tensor) -> None:
233+
def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None:
234+
"""Inplace quantize."""
235+
if weight is None:
236+
weight = self.weight.data
237+
if weight.data.type == torch.int8:
238+
# already quantized
239+
return
240+
assert isinstance(self.weight, bnb.nn.Int8Params)
241+
self.weight = self.quantize(self.weight, weight, device)
242+
243+
@staticmethod
244+
def quantize(
245+
int8params: bnb.nn.Int8Params, weight: torch.Tensor, device: Optional[torch.device]
246+
) -> bnb.nn.Int8Params:
247+
device = device or torch.device("cuda")
248+
if device.type != "cuda":
249+
raise RuntimeError(f"Unexpected device type: {device.type}")
203250
# https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L291-L302
204-
B = weight.contiguous().to(device="cuda", dtype=torch.float16)
205-
if self.state.has_fp16_weights:
206-
self.weight.data = B
251+
B = weight.contiguous().to(device=device, dtype=torch.float16)
252+
if int8params.has_fp16_weights:
253+
int8params.data = B
207254
else:
208255
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
209256
del CBt
210257
del SCBt
211-
self.weight.data = CB
212-
setattr(self.weight, "CB", CB)
213-
setattr(self.weight, "SCB", SCB)
258+
int8params.data = CB
259+
setattr(int8params, "CB", CB)
260+
setattr(int8params, "SCB", SCB)
261+
return int8params
262+
263+
def to_empty(self, *, device: _DEVICE, recurse: bool = True) -> Self:
264+
if self.weight.device.type == "meta":
265+
# need custom logic if int8params is on meta device
266+
raise NotImplementedError
267+
if self.weight.dtype == torch.uint8: # was quantized
268+
# need the original shape here
269+
raise NotImplementedError
270+
device = torch.device(device)
271+
weight = torch.empty_like(self.weight.data, device=device)
272+
if device.type == "cuda": # re-quantize
273+
self.quantize_(weight, device)
274+
else:
275+
self.weight = _replace_param(self.weight, weight)
276+
if self.bias is not None:
277+
self.bias = _replace_param(self.bias, torch.empty_like(self.bias, device=device))
278+
return self
279+
280+
def reset_parameters(self) -> None:
281+
# from `torch.nn.Linear.reset_parameters`
282+
if self.bias is not None:
283+
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
284+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
285+
init.uniform_(self.bias, -bound, bound)
286+
287+
linear_init_finished = isinstance(self.weight, bnb.nn.Params4bit)
288+
if linear_init_finished and self.weight.dtype == torch.uint8: # was quantized
289+
# need the original shape here
290+
raise NotImplementedError
291+
weight = self.weight.data
292+
torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
293+
if linear_init_finished:
294+
if self.weight.device.type == "meta":
295+
# need custom logic if int8params is on meta device
296+
raise NotImplementedError
297+
if self.weight.device.type == "cuda": # re-quantize
298+
self.quantize_(weight)
299+
else:
300+
self.weight = _replace_param(self.weight, weight)
214301

215302
class _Linear4bit(bnb.nn.Linear4bit):
216-
"""Wraps `bnb.nn.Linear4bit` and enables instantiation directly on the device and re-quantizaton when loading
217-
the state dict."""
303+
"""Wraps `bnb.nn.Linear4bit` to enable: instantiation directly on the device, re-quantizaton when loading the
304+
state dict, meta-device initialization, and materialization."""
218305

219306
def __init__(self, *args: Any, device: Optional[_DEVICE] = None, **kwargs: Any) -> None:
220307
super().__init__(*args, device=device, **kwargs)
308+
self.weight = cast(bnb.nn.Params4bit, self.weight) # type: ignore[has-type]
309+
self.bias = cast(Optional[torch.nn.Parameter], self.bias) # type: ignore[has-type]
221310
# if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
222311
# filling the device memory with float32 weights which could lead to OOM
223312
if torch.tensor(0, device=device).device.type == "cuda":
224-
self._quantize_weight(self.weight.data)
225-
self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self._quantize_weight))
313+
self.quantize_()
314+
self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_))
226315
self.register_load_state_dict_post_hook(_ignore_missing_weights_hook)
227316

228-
def _quantize_weight(self, weight: torch.Tensor) -> None:
317+
def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None:
318+
"""Inplace quantize."""
319+
if weight is None:
320+
weight = self.weight.data
321+
if weight.data.type == torch.uint8:
322+
# already quantized
323+
return
324+
assert isinstance(self.weight, bnb.nn.Params4bit)
325+
self.weight = self.quantize(self.weight, weight, device)
326+
327+
@staticmethod
328+
def quantize(
329+
params4bit: bnb.nn.Params4bit, weight: torch.Tensor, device: Optional[torch.device]
330+
) -> bnb.nn.Params4bit:
331+
device = device or torch.device("cuda")
332+
if device.type != "cuda":
333+
raise RuntimeError(f"Unexpected device type: {device.type}")
229334
# https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L156-L159
230-
params4bit = self.weight
231-
w = weight.contiguous().to(device="cuda", dtype=torch.half)
335+
w = weight.contiguous().to(device=device, dtype=torch.half)
232336
w_4bit, quant_state = bnb.functional.quantize_4bit(
233337
w,
234338
blocksize=params4bit.blocksize,
235339
compress_statistics=params4bit.compress_statistics,
236340
quant_type=params4bit.quant_type,
237341
)
238-
params4bit.data = w_4bit
239-
params4bit.quant_state = quant_state
342+
return _replace_param(params4bit, w_4bit, quant_state)
343+
344+
def to_empty(self, *, device: _DEVICE, recurse: bool = True) -> Self:
345+
if self.weight.dtype == torch.uint8: # was quantized
346+
# cannot init the quantized params directly
347+
weight = torch.empty(self.weight.quant_state[1], device=device, dtype=torch.half) # type: ignore[arg-type]
348+
else:
349+
weight = torch.empty_like(self.weight.data, device=device) # type: ignore[arg-type]
350+
device = torch.device(device)
351+
if device.type == "cuda": # re-quantize
352+
self.quantize_(weight, device)
353+
else:
354+
self.weight = _replace_param(self.weight, weight)
355+
if self.bias is not None:
356+
self.bias = _replace_param(self.bias, torch.empty_like(self.bias, device=device))
357+
return self
358+
359+
def reset_parameters(self) -> None:
360+
# from `torch.nn.Linear.reset_parameters`
361+
if self.bias is not None:
362+
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
363+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
364+
init.uniform_(self.bias, -bound, bound)
365+
366+
linear_init_finished = isinstance(self.weight, bnb.nn.Params4bit)
367+
if linear_init_finished and self.weight.dtype == torch.uint8: # was quantized
368+
# cannot init the quantized params directly
369+
weight = torch.empty(self.weight.quant_state[1], device=self.weight.device, dtype=torch.half)
370+
else:
371+
weight = self.weight.data
372+
torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
373+
if linear_init_finished:
374+
if self.weight.device.type == "cuda": # re-quantize
375+
self.quantize_(weight)
376+
else:
377+
self.weight = _replace_param(self.weight, weight)
240378

241379
# Use a class instead `functools.partial` to respect `isinstance` checks and attribute accesses
242380
class _Int8LinearInference(_Linear8bitLt):
@@ -281,17 +419,21 @@ def _convert_layers(module: torch.nn.Module, linear_cls: Type, ignore_modules: S
281419
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
282420
log.debug(f"Replacing layer {fullname!r} with bitsandbytes equivalent")
283421
has_bias = child.bias is not None
422+
# since we are going to copy over the child's data, the device doesn't matter. I chose CPU
423+
# to avoid spiking CUDA memory even though initialization is slower
424+
# 4bit layers support quantizing from meta-device params so this is only relevant for 8-bit
425+
_Linear4bit = globals()["_Linear4bit"]
426+
device = torch.device("meta" if issubclass(linear_cls, _Linear4bit) else "cpu")
284427
replacement = linear_cls(
285-
# since we are going to copy over the child's data, the device doesn't matter. I chose CPU
286-
# to avoid spiking CUDA memory even though initialization is slower
287428
child.in_features,
288429
child.out_features,
289430
bias=has_bias,
290-
device=torch.device("cpu"),
431+
device=device,
291432
)
292433
if has_bias:
293-
replacement.bias.data = child.bias.data.clone()
294-
replacement._quantize_weight(child.weight.data.clone())
434+
replacement.bias = _replace_param(replacement.bias, child.bias.data.clone())
435+
state = {"quant_state": replacement.weight.quant_state if issubclass(linear_cls, _Linear4bit) else None}
436+
replacement.weight = _replace_param(replacement.weight, child.weight.data.clone(), **state)
295437
module.__setattr__(name, replacement)
296438
else:
297439
_convert_layers(child, linear_cls, ignore_modules, prefix=fullname)

src/lightning/fabric/utilities/init.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import itertools
1415
from typing import Any, Callable, Dict, Optional, Sequence
1516

16-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_13
17+
import torch
18+
19+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCH_GREATER_EQUAL_2_1
20+
from lightning.fabric.utilities.types import _DEVICE
1721

1822
if _TORCH_GREATER_EQUAL_1_13:
1923
from torch.overrides import TorchFunctionMode
@@ -54,3 +58,23 @@ def __torch_function__(
5458
return kwargs["tensor"]
5559
return args[0]
5660
return func(*args, **kwargs)
61+
62+
63+
def _materialize(module: torch.nn.Module, device: _DEVICE) -> None:
64+
"""Materialize a module."""
65+
if not _TORCH_GREATER_EQUAL_2_1:
66+
raise RuntimeError("recurse=False requires torch 2.1")
67+
module.to_empty(device=device, recurse=False) # type: ignore[arg-type]
68+
if not hasattr(module, "reset_parameters"):
69+
raise TypeError(
70+
f"Materialization requires that the `{type(module).__name__}.reset_parameters` method is implemented."
71+
" This method is used to initialize any children parameters or buffers in this module."
72+
)
73+
module.reset_parameters()
74+
75+
76+
def _materialize_meta_tensors(module: torch.nn.Module, device: _DEVICE) -> None:
77+
"""Materialize all tensors in a given module."""
78+
for module in module.modules():
79+
if any(t.is_meta for t in itertools.chain(module.parameters(recurse=False), module.buffers(recurse=False))):
80+
_materialize(module, device)

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
- The Trainer now restores the training mode set through `.train()` or `.eval()` on a submodule-level when switching from validation to training ([#18951](https://github.com/Lightning-AI/lightning/pull/18951))
1616

1717

18+
- Added support for meta-device initialization and materialization of 4-bit Bitsandbytes layers ([#19150](https://github.com/Lightning-AI/lightning/pull/19150))
19+
20+
1821
- Added `TransformerEnginePrecision(fallback_compute_dtype=)` to control the dtype of operations that don't support fp8 ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))
1922

2023

0 commit comments

Comments
 (0)