Skip to content

Commit 98cbc4b

Browse files
committed
Added k-bit fp8 map.
1 parent caf1832 commit 98cbc4b

File tree

2 files changed

+52
-52
lines changed

2 files changed

+52
-52
lines changed

bitsandbytes/functional.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,15 @@ def create_linear_map(signed=True, bits=8):
143143
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
144144

145145

146-
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
146+
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
147147
e = exponent_bits
148148
p = precision_bits
149-
assert e+p == 7
149+
has_sign = 1 if signed else 0
150+
assert e+p == total_bits-has_sign
150151
# the exponent is biased to 2^(e-1) -1 == 0
151152
evalues = []
152153
pvalues = []
153-
for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)):
154+
for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)):
154155
evalues.append(2**val)
155156

156157

@@ -161,12 +162,17 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
161162
value += pval*(2**-(i+1))
162163
pvalues.append(value)
163164

164-
assert len(evalues)*len(pvalues) == 128
165+
assert len(evalues)*len(pvalues) == 2**(total_bits-has_sign)
165166
values = []
166167
for ev in evalues:
167168
for pv in pvalues:
168-
values.append(-ev*pv)
169+
if signed:
170+
values.append(-ev*pv)
169171
values.append(ev*pv)
172+
if total_bits < 8:
173+
gap = 256 - len(values)
174+
for i in range(gap):
175+
values.append(0)
170176
values.sort()
171177
code = torch.Tensor(values)
172178
code /= code.max()

tests/test_functional.py

Lines changed: 41 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from bitsandbytes import functional as F
1212

1313
torch.set_printoptions(
14-
precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
14+
precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
1515
)
1616
k = 20
1717

@@ -2095,49 +2095,43 @@ def test_fp8_quant():
20952095
def test_few_bit_quant():
20962096

20972097
for bits in range(2, 9):
2098-
code = F.create_linear_map(True, bits=bits).cuda()
2099-
assert code.numel() == 256
2100-
print(bits)
2101-
for i in range(100):
2102-
2103-
values = torch.randn(1, 24, device='cuda')
2104-
values /= values.abs().max()
2105-
#values[values.abs() < 1e-6] += 1e-5
2106-
2107-
q1 = []
2108-
v1 = []
2109-
for v in values[0]:
2110-
idx = torch.abs(v-code).argmin()
2111-
q1.append(idx.item())
2112-
v1.append(code[idx].item())
2113-
2114-
q1 = torch.Tensor(q1).cuda()
2115-
v1 = torch.Tensor(v1).cuda()
2116-
2117-
q2, S2 = F.quantize(values, code=code)
2118-
v2 = F.dequantize(q2, S2)
2119-
2120-
idx = torch.isclose(q1.int(), q2.int())
2121-
if idx.sum():
2122-
# some weird cases
2123-
err1 = torch.abs(v1-values).mean()
2124-
err2 = torch.abs(v2-values).mean()
2125-
assert err2 <= err1
2126-
2127-
else:
2128-
torch.testing.assert_allclose(q1, q2)
2129-
2130-
#print(e_bits, p_bits)
2131-
#abserr = []
2132-
#relerr = []
2133-
#for i in range(100):
2134-
# A1 = torch.randn(1024, 1024, device="cuda")
2135-
# C, SC = F.quantize_blockwise(A1, code=code)
2136-
# A2 = F.dequantize_blockwise(C, SC)
2137-
# diff = torch.abs(A1 - A2)
2138-
# reldiff = diff/torch.abs(A1+1e-8)
2139-
# abserr.append(diff.mean().item())
2140-
# relerr.append(reldiff.mean().item())
2141-
# #assert diff < 0.0075
2142-
#print(sum(abserr)/len(abserr))
2143-
#print(sum(relerr)/len(relerr))
2098+
for method in ['linear', 'fp8']:
2099+
code = None
2100+
if method == 'linear':
2101+
code = F.create_linear_map(True, bits=bits).cuda()
2102+
elif method == 'fp8':
2103+
ebits = math.ceil(bits/2)
2104+
pbits = bits-ebits-1
2105+
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
2106+
print(ebits, pbits, bits)
2107+
print(code)
2108+
assert code.numel() == 256
2109+
print(bits)
2110+
for i in range(10):
2111+
2112+
values = torch.randn(1, 32, device='cuda')
2113+
values /= values.abs().max()
2114+
#values[values.abs() < 1e-6] += 1e-5
2115+
2116+
q1 = []
2117+
v1 = []
2118+
for v in values[0]:
2119+
idx = torch.abs(v-code).argmin()
2120+
q1.append(idx.item())
2121+
v1.append(code[idx].item())
2122+
2123+
q1 = torch.Tensor(q1).cuda()
2124+
v1 = torch.Tensor(v1).cuda()
2125+
2126+
q2, S2 = F.quantize(values, code=code)
2127+
v2 = F.dequantize(q2, S2)
2128+
2129+
idx = torch.isclose(q1.int(), q2.int())
2130+
if idx.sum():
2131+
# some weird cases
2132+
err1 = torch.abs(v1-values).mean()
2133+
err2 = torch.abs(v2-values).mean()
2134+
assert err2 <= err1
2135+
2136+
else:
2137+
torch.testing.assert_allclose(q1, q2)

0 commit comments

Comments
 (0)