Skip to content

Commit 412fd0e

Browse files
committed
Added better default compute_dtype handling for Linear4bit layers.
1 parent c82f51c commit 412fd0e

File tree

2 files changed

+60
-6
lines changed

2 files changed

+60
-6
lines changed

bitsandbytes/nn/modules.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55
from typing import Optional, TypeVar, Union, overload
66

7+
import warnings
78
import torch
89
import torch.nn.functional as F
910
from torch import Tensor, device, dtype, nn
@@ -205,6 +206,28 @@ def __init__(self, input_features, output_features, bias=True, compute_dtype=Non
205206
super().__init__(input_features, output_features, bias, device)
206207
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
207208
self.compute_dtype = compute_dtype
209+
self.compute_type_is_set = False
210+
211+
def set_compute_type(self, x):
212+
if x.dtype in [torch.float32, torch.bfloat16]:
213+
# the input is in a dtype that is safe to compute in, we switch
214+
# to this type for speed and stability
215+
self.compute_dtype = x.dtype
216+
elif x.dtype == torch.float16:
217+
# we take the compoute dtype passed into the layer
218+
if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
219+
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
220+
# warn the user about this
221+
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference.')
222+
warnings.filterwarnings('ignore', message='.*inference.')
223+
if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
224+
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')
225+
warnings.filterwarnings('ignore', message='.*inference or training')
226+
227+
228+
229+
230+
208231

209232
def forward(self, x: torch.Tensor):
210233
# weights are cast automatically as Int8Params, but the bias has to be cast manually
@@ -213,6 +236,10 @@ def forward(self, x: torch.Tensor):
213236

214237
if getattr(self.weight, 'quant_state', None) is None:
215238
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
239+
if not self.compute_type_is_set:
240+
self.set_compute_type(x)
241+
self.compute_type_is_set = True
242+
216243
inp_dtype = x.dtype
217244
if self.compute_dtype is not None:
218245
x = x.to(self.compute_dtype)

tests/test_modules.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,10 @@ def test_linear_kbit_fp32_bias(module):
516516
modules.append(bnb.nn.LinearNF4)
517517
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True))
518518
modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True))
519-
names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C']
519+
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32))
520+
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16))
521+
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16))
522+
names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C', 'NF4+fp32', 'NF4+fp16', 'NF4+bf16']
520523
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
521524
@pytest.mark.parametrize("module", modules, ids=names)
522525
def test_kbit_backprop(module):
@@ -563,10 +566,10 @@ def test_kbit_backprop(module):
563566
relerrs2.append(relerr2.mean().item())
564567

565568
if isinstance(module, bnb.nn.Linear8bitLt):
566-
torch.testing.assert_close(grad1, grad2, atol=0.008, rtol=0.05)
569+
assert_all_approx_close(grad1, grad2, atol=0.008, rtol=0.05, count=1)
567570
torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05)
568571
else:
569-
torch.testing.assert_close(grad1, grad2, atol=0.015, rtol=0.05)
572+
assert_all_approx_close(grad1, grad2, atol=0.015, rtol=0.05, count=1)
570573
torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05)
571574
ref.zero_grad()
572575
kbit.zero_grad()
@@ -608,9 +611,33 @@ def test_fp8linear():
608611
assert graderr < 0.00002
609612
assert bgraderr < 0.00002
610613

611-
612-
613-
614+
def test_4bit_warnings():
615+
dim1 = 64
616+
617+
with pytest.warns(UserWarning, match=r'inference or training'):
618+
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
619+
net = net.cuda()
620+
inp = torch.rand(10, dim1).cuda().half()
621+
net(inp)
622+
with pytest.warns(UserWarning, match=r'inference.'):
623+
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
624+
net = net.cuda()
625+
inp = torch.rand(1, dim1).cuda().half()
626+
net(inp)
627+
628+
with pytest.warns(UserWarning) as record:
629+
630+
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
631+
net = net.cuda()
632+
inp = torch.rand(10, dim1).cuda().half()
633+
net(inp)
634+
635+
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
636+
net = net.cuda()
637+
inp = torch.rand(1, dim1).cuda().half()
638+
net(inp)
639+
640+
assert len(record) == 2
614641

615642

616643

0 commit comments

Comments
 (0)