|
11 | 11 | from bitsandbytes import functional as F |
12 | 12 |
|
13 | 13 | torch.set_printoptions( |
14 | | - precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 |
| 14 | + precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 |
15 | 15 | ) |
16 | 16 | k = 20 |
17 | 17 |
|
@@ -2095,49 +2095,43 @@ def test_fp8_quant(): |
2095 | 2095 | def test_few_bit_quant(): |
2096 | 2096 |
|
2097 | 2097 | for bits in range(2, 9): |
2098 | | - code = F.create_linear_map(True, bits=bits).cuda() |
2099 | | - assert code.numel() == 256 |
2100 | | - print(bits) |
2101 | | - for i in range(100): |
2102 | | - |
2103 | | - values = torch.randn(1, 24, device='cuda') |
2104 | | - values /= values.abs().max() |
2105 | | - #values[values.abs() < 1e-6] += 1e-5 |
2106 | | - |
2107 | | - q1 = [] |
2108 | | - v1 = [] |
2109 | | - for v in values[0]: |
2110 | | - idx = torch.abs(v-code).argmin() |
2111 | | - q1.append(idx.item()) |
2112 | | - v1.append(code[idx].item()) |
2113 | | - |
2114 | | - q1 = torch.Tensor(q1).cuda() |
2115 | | - v1 = torch.Tensor(v1).cuda() |
2116 | | - |
2117 | | - q2, S2 = F.quantize(values, code=code) |
2118 | | - v2 = F.dequantize(q2, S2) |
2119 | | - |
2120 | | - idx = torch.isclose(q1.int(), q2.int()) |
2121 | | - if idx.sum(): |
2122 | | - # some weird cases |
2123 | | - err1 = torch.abs(v1-values).mean() |
2124 | | - err2 = torch.abs(v2-values).mean() |
2125 | | - assert err2 <= err1 |
2126 | | - |
2127 | | - else: |
2128 | | - torch.testing.assert_allclose(q1, q2) |
2129 | | - |
2130 | | - #print(e_bits, p_bits) |
2131 | | - #abserr = [] |
2132 | | - #relerr = [] |
2133 | | - #for i in range(100): |
2134 | | - # A1 = torch.randn(1024, 1024, device="cuda") |
2135 | | - # C, SC = F.quantize_blockwise(A1, code=code) |
2136 | | - # A2 = F.dequantize_blockwise(C, SC) |
2137 | | - # diff = torch.abs(A1 - A2) |
2138 | | - # reldiff = diff/torch.abs(A1+1e-8) |
2139 | | - # abserr.append(diff.mean().item()) |
2140 | | - # relerr.append(reldiff.mean().item()) |
2141 | | - # #assert diff < 0.0075 |
2142 | | - #print(sum(abserr)/len(abserr)) |
2143 | | - #print(sum(relerr)/len(relerr)) |
| 2098 | + for method in ['linear', 'fp8']: |
| 2099 | + code = None |
| 2100 | + if method == 'linear': |
| 2101 | + code = F.create_linear_map(True, bits=bits).cuda() |
| 2102 | + elif method == 'fp8': |
| 2103 | + ebits = math.ceil(bits/2) |
| 2104 | + pbits = bits-ebits-1 |
| 2105 | + code = F.create_fp8_map(True, ebits, pbits, bits).cuda() |
| 2106 | + print(ebits, pbits, bits) |
| 2107 | + print(code) |
| 2108 | + assert code.numel() == 256 |
| 2109 | + print(bits) |
| 2110 | + for i in range(10): |
| 2111 | + |
| 2112 | + values = torch.randn(1, 32, device='cuda') |
| 2113 | + values /= values.abs().max() |
| 2114 | + #values[values.abs() < 1e-6] += 1e-5 |
| 2115 | + |
| 2116 | + q1 = [] |
| 2117 | + v1 = [] |
| 2118 | + for v in values[0]: |
| 2119 | + idx = torch.abs(v-code).argmin() |
| 2120 | + q1.append(idx.item()) |
| 2121 | + v1.append(code[idx].item()) |
| 2122 | + |
| 2123 | + q1 = torch.Tensor(q1).cuda() |
| 2124 | + v1 = torch.Tensor(v1).cuda() |
| 2125 | + |
| 2126 | + q2, S2 = F.quantize(values, code=code) |
| 2127 | + v2 = F.dequantize(q2, S2) |
| 2128 | + |
| 2129 | + idx = torch.isclose(q1.int(), q2.int()) |
| 2130 | + if idx.sum(): |
| 2131 | + # some weird cases |
| 2132 | + err1 = torch.abs(v1-values).mean() |
| 2133 | + err2 = torch.abs(v2-values).mean() |
| 2134 | + assert err2 <= err1 |
| 2135 | + |
| 2136 | + else: |
| 2137 | + torch.testing.assert_allclose(q1, q2) |
0 commit comments