Skip to content

Commit 1efb87d

Browse files
committed
Added FP8 quantization map.
1 parent 8d87c0b commit 1efb87d

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

bitsandbytes/functional.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import operator
77
import random
88
import torch
9+
import itertools
910

1011
from typing import Tuple
1112
from torch import Tensor
@@ -136,6 +137,39 @@ def create_linear_map(signed=True):
136137
return torch.linspace(0.0, 1.0, 256)
137138

138139

140+
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
141+
e = exponent_bits
142+
p = precision_bits
143+
assert e+p == 7
144+
# the exponent is biased to 2^(e-1) -1 == 0
145+
evalues = []
146+
pvalues = []
147+
for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)):
148+
evalues.append(2**val)
149+
150+
151+
lst = list(itertools.product([0, 1], repeat=precision_bits))
152+
for bit_pattern in lst:
153+
value = 1
154+
for i, pval in enumerate(list(bit_pattern)):
155+
value += pval*(2**-(i+1))
156+
pvalues.append(value)
157+
158+
assert len(evalues)*len(pvalues) == 128
159+
values = []
160+
for ev in evalues:
161+
for pv in pvalues:
162+
values.append(-ev*pv)
163+
values.append(ev*pv)
164+
values.sort()
165+
code = torch.Tensor(values)
166+
code /= code.max()
167+
code[127] = 0
168+
169+
return code
170+
171+
172+
139173
def create_dynamic_map(signed=True, n=7):
140174
"""
141175
Creates the dynamic quantiztion map.

tests/test_functional.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2040,3 +2040,54 @@ def test_blockwise_cpu_large():
20402040
assert diffs[-1] < 0.011
20412041
# print(sum(diffs)/len(diffs))
20422042
# print(sum(reldiffs)/len(reldiffs))
2043+
2044+
2045+
2046+
def test_fp8_quant():
2047+
for e_bits in range(1, 7):
2048+
p_bits = 7-e_bits
2049+
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
2050+
2051+
print(e_bits, p_bits)
2052+
abserr = []
2053+
relerr = []
2054+
for i in range(100):
2055+
A1 = torch.randn(1024, 1024, device="cuda")
2056+
C, SC = F.quantize_blockwise(A1, code=code)
2057+
A2 = F.dequantize_blockwise(C, SC)
2058+
diff = torch.abs(A1 - A2)
2059+
reldiff = diff/torch.abs(A1+1e-8)
2060+
abserr.append(diff.mean().item())
2061+
relerr.append(reldiff.mean().item())
2062+
#assert diff < 0.0075
2063+
print(sum(abserr)/len(abserr))
2064+
print(sum(relerr)/len(relerr))
2065+
2066+
abserr = []
2067+
relerr = []
2068+
for i in range(100):
2069+
A1 = torch.rand(1024, 1024, device="cuda")
2070+
C, SC = F.quantize_blockwise(A1, code=code)
2071+
A2 = F.dequantize_blockwise(C, SC)
2072+
diff = torch.abs(A1 - A2)
2073+
reldiff = diff/torch.abs(A1+1e-8)
2074+
abserr.append(diff.mean().item())
2075+
relerr.append(reldiff.mean().item())
2076+
#assert diff < 0.0075
2077+
print(sum(abserr)/len(abserr))
2078+
print(sum(relerr)/len(relerr))
2079+
2080+
abserr = []
2081+
relerr = []
2082+
for i in range(100):
2083+
A1 = torch.randn(1024, 1024, device="cuda")
2084+
C, SC = F.quantize_blockwise(A1)
2085+
A2 = F.dequantize_blockwise(C, SC)
2086+
diff = torch.abs(A1 - A2)
2087+
reldiff = diff/torch.abs(A1+1e-8)
2088+
abserr.append(diff.mean().item())
2089+
relerr.append(reldiff.mean().item())
2090+
#assert diff < 0.0075
2091+
print(3, sum(abserr)/len(abserr))
2092+
print(3, sum(relerr)/len(relerr))
2093+

0 commit comments

Comments
 (0)