Skip to content

Commit 3c9aca9

Browse files
committed
Fixed two bugs in dynamic data type creation.
1 parent a06a0f6 commit 3c9aca9

File tree

2 files changed

+24
-17
lines changed

2 files changed

+24
-17
lines changed

bitsandbytes/functional.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -322,10 +322,8 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
322322
# these are additional items that come from the case
323323
# where all the exponent bits are zero and no
324324
# indicator bit is present
325-
non_sign_bits = total_bits - (1 if signed else 0)
325+
non_sign_bits = total_bits - (1 if signed else 1)
326326
additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
327-
if not signed:
328-
additional_items = 2 * additional_items
329327
for i in range(max_exponent_bits):
330328
fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1))
331329
boundaries = torch.linspace(0.1, 1, fraction_items)
@@ -334,16 +332,18 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
334332
if signed:
335333
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
336334

337-
if additional_items > 0:
338-
boundaries = torch.linspace(0.1, 1, additional_items + 1)
339-
means = (boundaries[:-1] + boundaries[1:]) / 2.0
340-
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
341-
if signed:
342-
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
335+
if additional_items > 0:
336+
boundaries = torch.linspace(0.1, 1, additional_items + 1)
337+
means = (boundaries[:-1] + boundaries[1:]) / 2.0
338+
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
339+
if signed:
340+
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
343341

344342
data.append(0)
345343
data.append(1.0)
346344

345+
assert len(data) == 2**total_bits
346+
347347
gap = 256 - len(data)
348348
for i in range(gap):
349349
data.append(0)

tests/test_functional.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def test_quantile_quantization():
129129
assert diff < 0.001
130130

131131

132+
132133
def test_dynamic_quantization():
133134
diffs = []
134135
reldiffs = []
@@ -141,8 +142,8 @@ def test_dynamic_quantization():
141142
diffs.append(diff.mean().item())
142143
reldiffs.append(reldiff.mean().item())
143144
assert diff.mean().item() < 0.0135
144-
# print(sum(diffs)/len(diffs))
145-
# print(sum(reldiffs)/len(reldiffs))
145+
print(sum(diffs)/len(diffs))
146+
print(sum(reldiffs)/len(reldiffs))
146147

147148
for i in range(100):
148149
A1 = torch.rand(1024, 1024, device="cuda")
@@ -157,7 +158,8 @@ def test_dynamic_quantization():
157158
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
158159
@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"])
159160
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
160-
def test_dynamic_blockwise_quantization(dtype, nested, blocksize):
161+
@pytest.mark.parametrize("signed", [True, False], ids=['signed_True', 'signed_False'])
162+
def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
161163
#print('')
162164
diffs = []
163165
reldiffs = []
@@ -178,9 +180,10 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize):
178180
assert A2.dtype == dtype
179181

180182
diffs = []
183+
code = F.create_dynamic_map(signed=signed)
181184
for i in range(100):
182185
A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
183-
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
186+
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
184187
A2 = F.dequantize_blockwise(C, S)
185188
diff = torch.abs(A1 - A2).float()
186189
reldiff = diff / torch.abs(A1.float() + 1e-8)
@@ -189,11 +192,15 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize):
189192
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
190193
abserr = sum(diffs)/len(diffs)
191194
relerr = sum(reldiffs)/len(reldiffs)
192-
assert abserr < 0.0035
193-
assert relerr < 0.015
195+
if signed:
196+
assert abserr < 0.0035
197+
assert relerr < 0.015
198+
else:
199+
assert abserr < 0.00175
200+
assert relerr < 0.012
194201
assert A2.dtype == dtype
195-
#print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
196-
#print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
202+
#print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
203+
#print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
197204

198205

199206

0 commit comments

Comments
 (0)