13
13
# limitations under the License.
14
14
import functools
15
15
import logging
16
+ import math
16
17
import os
17
18
import warnings
18
19
from contextlib import ExitStack
19
20
from functools import partial
20
21
from types import ModuleType
21
- from typing import Any , Callable , ContextManager , Literal , Optional , OrderedDict , Set , Type
22
+ from typing import Any , Callable , ContextManager , Literal , Optional , OrderedDict , Set , Tuple , Type , cast
22
23
23
24
import torch
24
25
from lightning_utilities import apply_to_collection
25
26
from lightning_utilities .core .imports import RequirementCache
26
27
from torch import Tensor
28
+ from torch .nn import init
27
29
from torch .nn .modules .module import _IncompatibleKeys
28
- from typing_extensions import override
30
+ from typing_extensions import Self , override
29
31
30
32
from lightning .fabric .plugins .precision .precision import Precision
31
33
from lightning .fabric .plugins .precision .utils import (
37
39
38
40
log = logging .getLogger (__name__ )
39
41
40
- _BITSANDBYTES_AVAILABLE = RequirementCache ("bitsandbytes>=0.41.0" )
42
+ # TODO: unpin after resolving the `quant_state` format breaking changes
43
+ _BITSANDBYTES_AVAILABLE = RequirementCache ("bitsandbytes==0.41.0" )
41
44
42
45
43
46
class BitsandbytesPrecision (Precision ):
@@ -109,6 +112,7 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
109
112
# convert modules if they haven't been converted already
110
113
bnb = _import_bitsandbytes ()
111
114
if not any (isinstance (m , (bnb .nn .Linear8bitLt , bnb .nn .Linear4bit )) for m in module .modules ()):
115
+ # this will not quantize the model but only replace the layer classes
112
116
_convert_layers (module , self ._linear_cls , self .ignore_modules )
113
117
114
118
# set the compute dtype if necessary
@@ -164,11 +168,36 @@ def _quantize_on_load_hook(quantize_fn: Callable[[torch.Tensor], None], state_di
164
168
165
169
166
170
def _ignore_missing_weights_hook (module : torch .nn .Module , incompatible_keys : _IncompatibleKeys ) -> None :
171
+ # since we manually loaded the weight in the `_quantize_on_load_hook` hook, we need to avoid this missing key false
172
+ # positive
167
173
for key in reversed (incompatible_keys .missing_keys ):
168
174
if key .endswith ("weight" ):
169
175
incompatible_keys .missing_keys .remove (key )
170
176
171
177
178
+ def _replace_param (
179
+ param : torch .nn .Parameter , data : torch .Tensor , quant_state : Optional [Tuple ] = None
180
+ ) -> torch .nn .Parameter :
181
+ bnb = _import_bitsandbytes ()
182
+
183
+ # doing `param.data = weight` raises a RuntimeError if param.data was on meta-device, so
184
+ # we need to re-create the parameters instead of overwriting the data
185
+ if param .device .type == "meta" :
186
+ if isinstance (param , bnb .nn .Params4bit ):
187
+ return bnb .nn .Params4bit (
188
+ data ,
189
+ requires_grad = data .requires_grad ,
190
+ quant_state = quant_state ,
191
+ compress_statistics = param .compress_statistics ,
192
+ quant_type = param .quant_type ,
193
+ )
194
+ return torch .nn .Parameter (data , requires_grad = data .requires_grad )
195
+ param .data = data
196
+ if isinstance (param , bnb .nn .Params4bit ):
197
+ param .quant_state = quant_state
198
+ return param
199
+
200
+
172
201
@functools .lru_cache (maxsize = 1 )
173
202
def _import_bitsandbytes () -> ModuleType :
174
203
if not _BITSANDBYTES_AVAILABLE :
@@ -192,51 +221,160 @@ class _Linear8bitLt(bnb.nn.Linear8bitLt):
192
221
193
222
def __init__ (self , * args : Any , device : Optional [_DEVICE ] = None , threshold : float = 6.0 , ** kwargs : Any ) -> None :
194
223
super ().__init__ (* args , device = device , threshold = threshold , ** kwargs )
224
+ self .weight = cast (bnb .nn .Int8Params , self .weight ) # type: ignore[has-type]
225
+ self .bias = cast (Optional [torch .nn .Parameter ], self .bias ) # type: ignore[has-type]
195
226
# if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
196
227
# filling the device memory with float32 weights which could lead to OOM
197
228
if torch .tensor (0 , device = device ).device .type == "cuda" :
198
- self ._quantize_weight ( self . weight . data )
199
- self ._register_load_state_dict_pre_hook (partial (_quantize_on_load_hook , self ._quantize_weight ))
229
+ self .quantize_ ( )
230
+ self ._register_load_state_dict_pre_hook (partial (_quantize_on_load_hook , self .quantize_ ))
200
231
self .register_load_state_dict_post_hook (_ignore_missing_weights_hook )
201
232
202
- def _quantize_weight (self , weight : torch .Tensor ) -> None :
233
+ def quantize_ (self , weight : Optional [torch .Tensor ] = None , device : Optional [torch .device ] = None ) -> None :
234
+ """Inplace quantize."""
235
+ if weight is None :
236
+ weight = self .weight .data
237
+ if weight .data .type == torch .int8 :
238
+ # already quantized
239
+ return
240
+ assert isinstance (self .weight , bnb .nn .Int8Params )
241
+ self .weight = self .quantize (self .weight , weight , device )
242
+
243
+ @staticmethod
244
+ def quantize (
245
+ int8params : bnb .nn .Int8Params , weight : torch .Tensor , device : Optional [torch .device ]
246
+ ) -> bnb .nn .Int8Params :
247
+ device = device or torch .device ("cuda" )
248
+ if device .type != "cuda" :
249
+ raise RuntimeError (f"Unexpected device type: { device .type } " )
203
250
# https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L291-L302
204
- B = weight .contiguous ().to (device = "cuda" , dtype = torch .float16 )
205
- if self . state .has_fp16_weights :
206
- self . weight .data = B
251
+ B = weight .contiguous ().to (device = device , dtype = torch .float16 )
252
+ if int8params .has_fp16_weights :
253
+ int8params .data = B
207
254
else :
208
255
CB , CBt , SCB , SCBt , coo_tensorB = bnb .functional .double_quant (B )
209
256
del CBt
210
257
del SCBt
211
- self .weight .data = CB
212
- setattr (self .weight , "CB" , CB )
213
- setattr (self .weight , "SCB" , SCB )
258
+ int8params .data = CB
259
+ setattr (int8params , "CB" , CB )
260
+ setattr (int8params , "SCB" , SCB )
261
+ return int8params
262
+
263
+ def to_empty (self , * , device : _DEVICE , recurse : bool = True ) -> Self :
264
+ if self .weight .device .type == "meta" :
265
+ # need custom logic if int8params is on meta device
266
+ raise NotImplementedError
267
+ if self .weight .dtype == torch .uint8 : # was quantized
268
+ # need the original shape here
269
+ raise NotImplementedError
270
+ device = torch .device (device )
271
+ weight = torch .empty_like (self .weight .data , device = device )
272
+ if device .type == "cuda" : # re-quantize
273
+ self .quantize_ (weight , device )
274
+ else :
275
+ self .weight = _replace_param (self .weight , weight )
276
+ if self .bias is not None :
277
+ self .bias = _replace_param (self .bias , torch .empty_like (self .bias , device = device ))
278
+ return self
279
+
280
+ def reset_parameters (self ) -> None :
281
+ # from `torch.nn.Linear.reset_parameters`
282
+ if self .bias is not None :
283
+ fan_in , _ = torch .nn .init ._calculate_fan_in_and_fan_out (self .weight )
284
+ bound = 1 / math .sqrt (fan_in ) if fan_in > 0 else 0
285
+ init .uniform_ (self .bias , - bound , bound )
286
+
287
+ linear_init_finished = isinstance (self .weight , bnb .nn .Params4bit )
288
+ if linear_init_finished and self .weight .dtype == torch .uint8 : # was quantized
289
+ # need the original shape here
290
+ raise NotImplementedError
291
+ weight = self .weight .data
292
+ torch .nn .init .kaiming_uniform_ (weight , a = math .sqrt (5 ))
293
+ if linear_init_finished :
294
+ if self .weight .device .type == "meta" :
295
+ # need custom logic if int8params is on meta device
296
+ raise NotImplementedError
297
+ if self .weight .device .type == "cuda" : # re-quantize
298
+ self .quantize_ (weight )
299
+ else :
300
+ self .weight = _replace_param (self .weight , weight )
214
301
215
302
class _Linear4bit (bnb .nn .Linear4bit ):
216
- """Wraps `bnb.nn.Linear4bit` and enables instantiation directly on the device and re-quantizaton when loading
217
- the state dict."""
303
+ """Wraps `bnb.nn.Linear4bit` to enable: instantiation directly on the device, re-quantizaton when loading the
304
+ state dict, meta-device initialization, and materialization ."""
218
305
219
306
def __init__ (self , * args : Any , device : Optional [_DEVICE ] = None , ** kwargs : Any ) -> None :
220
307
super ().__init__ (* args , device = device , ** kwargs )
308
+ self .weight = cast (bnb .nn .Params4bit , self .weight ) # type: ignore[has-type]
309
+ self .bias = cast (Optional [torch .nn .Parameter ], self .bias ) # type: ignore[has-type]
221
310
# if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
222
311
# filling the device memory with float32 weights which could lead to OOM
223
312
if torch .tensor (0 , device = device ).device .type == "cuda" :
224
- self ._quantize_weight ( self . weight . data )
225
- self ._register_load_state_dict_pre_hook (partial (_quantize_on_load_hook , self ._quantize_weight ))
313
+ self .quantize_ ( )
314
+ self ._register_load_state_dict_pre_hook (partial (_quantize_on_load_hook , self .quantize_ ))
226
315
self .register_load_state_dict_post_hook (_ignore_missing_weights_hook )
227
316
228
- def _quantize_weight (self , weight : torch .Tensor ) -> None :
317
+ def quantize_ (self , weight : Optional [torch .Tensor ] = None , device : Optional [torch .device ] = None ) -> None :
318
+ """Inplace quantize."""
319
+ if weight is None :
320
+ weight = self .weight .data
321
+ if weight .data .type == torch .uint8 :
322
+ # already quantized
323
+ return
324
+ assert isinstance (self .weight , bnb .nn .Params4bit )
325
+ self .weight = self .quantize (self .weight , weight , device )
326
+
327
+ @staticmethod
328
+ def quantize (
329
+ params4bit : bnb .nn .Params4bit , weight : torch .Tensor , device : Optional [torch .device ]
330
+ ) -> bnb .nn .Params4bit :
331
+ device = device or torch .device ("cuda" )
332
+ if device .type != "cuda" :
333
+ raise RuntimeError (f"Unexpected device type: { device .type } " )
229
334
# https://github.com/TimDettmers/bitsandbytes/blob/0.41.0/bitsandbytes/nn/modules.py#L156-L159
230
- params4bit = self .weight
231
- w = weight .contiguous ().to (device = "cuda" , dtype = torch .half )
335
+ w = weight .contiguous ().to (device = device , dtype = torch .half )
232
336
w_4bit , quant_state = bnb .functional .quantize_4bit (
233
337
w ,
234
338
blocksize = params4bit .blocksize ,
235
339
compress_statistics = params4bit .compress_statistics ,
236
340
quant_type = params4bit .quant_type ,
237
341
)
238
- params4bit .data = w_4bit
239
- params4bit .quant_state = quant_state
342
+ return _replace_param (params4bit , w_4bit , quant_state )
343
+
344
+ def to_empty (self , * , device : _DEVICE , recurse : bool = True ) -> Self :
345
+ if self .weight .dtype == torch .uint8 : # was quantized
346
+ # cannot init the quantized params directly
347
+ weight = torch .empty (self .weight .quant_state [1 ], device = device , dtype = torch .half ) # type: ignore[arg-type]
348
+ else :
349
+ weight = torch .empty_like (self .weight .data , device = device ) # type: ignore[arg-type]
350
+ device = torch .device (device )
351
+ if device .type == "cuda" : # re-quantize
352
+ self .quantize_ (weight , device )
353
+ else :
354
+ self .weight = _replace_param (self .weight , weight )
355
+ if self .bias is not None :
356
+ self .bias = _replace_param (self .bias , torch .empty_like (self .bias , device = device ))
357
+ return self
358
+
359
+ def reset_parameters (self ) -> None :
360
+ # from `torch.nn.Linear.reset_parameters`
361
+ if self .bias is not None :
362
+ fan_in , _ = torch .nn .init ._calculate_fan_in_and_fan_out (self .weight )
363
+ bound = 1 / math .sqrt (fan_in ) if fan_in > 0 else 0
364
+ init .uniform_ (self .bias , - bound , bound )
365
+
366
+ linear_init_finished = isinstance (self .weight , bnb .nn .Params4bit )
367
+ if linear_init_finished and self .weight .dtype == torch .uint8 : # was quantized
368
+ # cannot init the quantized params directly
369
+ weight = torch .empty (self .weight .quant_state [1 ], device = self .weight .device , dtype = torch .half )
370
+ else :
371
+ weight = self .weight .data
372
+ torch .nn .init .kaiming_uniform_ (weight , a = math .sqrt (5 ))
373
+ if linear_init_finished :
374
+ if self .weight .device .type == "cuda" : # re-quantize
375
+ self .quantize_ (weight )
376
+ else :
377
+ self .weight = _replace_param (self .weight , weight )
240
378
241
379
# Use a class instead `functools.partial` to respect `isinstance` checks and attribute accesses
242
380
class _Int8LinearInference (_Linear8bitLt ):
@@ -281,17 +419,21 @@ def _convert_layers(module: torch.nn.Module, linear_cls: Type, ignore_modules: S
281
419
if isinstance (child , torch .nn .Linear ) and not any (fullname .startswith (s ) for s in ignore_modules ):
282
420
log .debug (f"Replacing layer { fullname !r} with bitsandbytes equivalent" )
283
421
has_bias = child .bias is not None
422
+ # since we are going to copy over the child's data, the device doesn't matter. I chose CPU
423
+ # to avoid spiking CUDA memory even though initialization is slower
424
+ # 4bit layers support quantizing from meta-device params so this is only relevant for 8-bit
425
+ _Linear4bit = globals ()["_Linear4bit" ]
426
+ device = torch .device ("meta" if issubclass (linear_cls , _Linear4bit ) else "cpu" )
284
427
replacement = linear_cls (
285
- # since we are going to copy over the child's data, the device doesn't matter. I chose CPU
286
- # to avoid spiking CUDA memory even though initialization is slower
287
428
child .in_features ,
288
429
child .out_features ,
289
430
bias = has_bias ,
290
- device = torch . device ( "cpu" ) ,
431
+ device = device ,
291
432
)
292
433
if has_bias :
293
- replacement .bias .data = child .bias .data .clone ()
294
- replacement ._quantize_weight (child .weight .data .clone ())
434
+ replacement .bias = _replace_param (replacement .bias , child .bias .data .clone ())
435
+ state = {"quant_state" : replacement .weight .quant_state if issubclass (linear_cls , _Linear4bit ) else None }
436
+ replacement .weight = _replace_param (replacement .weight , child .weight .data .clone (), ** state )
295
437
module .__setattr__ (name , replacement )
296
438
else :
297
439
_convert_layers (child , linear_cls , ignore_modules , prefix = fullname )
0 commit comments