Skip to content

Commit caf1832

Browse files
committed
Added k-bit linear quantization.
1 parent 1efb87d commit caf1832

File tree

2 files changed

+60
-4
lines changed

2 files changed

+60
-4
lines changed

bitsandbytes/functional.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,17 @@ def get_instance(cls):
130130
return cls._instance
131131

132132

133-
def create_linear_map(signed=True):
134-
if signed:
135-
return torch.linspace(-1.0, 1.0, 256)
133+
def create_linear_map(signed=True, bits=8):
134+
sign = (-1.0 if signed else 0.0)
135+
136+
values = torch.linspace(sign, 1.0, 2**bits)
137+
gap = 256 - values.numel()
138+
if gap == 0:
139+
return values
136140
else:
137-
return torch.linspace(0.0, 1.0, 256)
141+
l = values.numel()//2
142+
#return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist())
143+
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
138144

139145

140146
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):

tests/test_functional.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2091,3 +2091,53 @@ def test_fp8_quant():
20912091
print(3, sum(abserr)/len(abserr))
20922092
print(3, sum(relerr)/len(relerr))
20932093

2094+
2095+
def test_few_bit_quant():
2096+
2097+
for bits in range(2, 9):
2098+
code = F.create_linear_map(True, bits=bits).cuda()
2099+
assert code.numel() == 256
2100+
print(bits)
2101+
for i in range(100):
2102+
2103+
values = torch.randn(1, 24, device='cuda')
2104+
values /= values.abs().max()
2105+
#values[values.abs() < 1e-6] += 1e-5
2106+
2107+
q1 = []
2108+
v1 = []
2109+
for v in values[0]:
2110+
idx = torch.abs(v-code).argmin()
2111+
q1.append(idx.item())
2112+
v1.append(code[idx].item())
2113+
2114+
q1 = torch.Tensor(q1).cuda()
2115+
v1 = torch.Tensor(v1).cuda()
2116+
2117+
q2, S2 = F.quantize(values, code=code)
2118+
v2 = F.dequantize(q2, S2)
2119+
2120+
idx = torch.isclose(q1.int(), q2.int())
2121+
if idx.sum():
2122+
# some weird cases
2123+
err1 = torch.abs(v1-values).mean()
2124+
err2 = torch.abs(v2-values).mean()
2125+
assert err2 <= err1
2126+
2127+
else:
2128+
torch.testing.assert_allclose(q1, q2)
2129+
2130+
#print(e_bits, p_bits)
2131+
#abserr = []
2132+
#relerr = []
2133+
#for i in range(100):
2134+
# A1 = torch.randn(1024, 1024, device="cuda")
2135+
# C, SC = F.quantize_blockwise(A1, code=code)
2136+
# A2 = F.dequantize_blockwise(C, SC)
2137+
# diff = torch.abs(A1 - A2)
2138+
# reldiff = diff/torch.abs(A1+1e-8)
2139+
# abserr.append(diff.mean().item())
2140+
# relerr.append(reldiff.mean().item())
2141+
# #assert diff < 0.0075
2142+
#print(sum(abserr)/len(abserr))
2143+
#print(sum(relerr)/len(relerr))

0 commit comments

Comments
 (0)