Skip to content

Commit df941ec

Browse files
Add int8 dequant function; misc improvements
1 parent 56abdc2 commit df941ec

File tree

5 files changed

+57
-40
lines changed

5 files changed

+57
-40
lines changed

benchmarking/int8/training_benchmark.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
k = 20
1515

16+
torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
17+
1618

1719
@pytest.mark.parametrize(
1820
("batch", "seq", "model", "hidden"),

bitsandbytes/autograd/_functions.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,18 +350,17 @@ def forward(
350350
CAt[:, state.idx] = 0
351351

352352
# Extract the input outliers in original precision
353-
subA = A[:, state.idx]
353+
subA = A[:, state.idx].contiguous()
354354

355355
# Extract the corresponding weights
356356
if state.has_fp16_weights:
357357
state.subB = B[:, state.idx].t()
358358
else:
359-
outliers = state.CB[:, state.idx]
360-
361359
# To dequantize our weights associated with the input outliers,
362360
# we want to divide by 127. It's however more performant to multiply
363361
# by the reciprocal.
364-
state.subB = (7.874016e-3 * outliers * state.SCB.view(-1, 1)).t().to(A.dtype)
362+
outliers = state.CB[:, state.idx]
363+
state.subB = (outliers.t() * state.SCB * 7.874015718698502e-3).to(A.dtype)
365364
else:
366365
subA = None
367366

@@ -378,7 +377,7 @@ def forward(
378377

379378
# 4. Mixed-precision decomposition matmul
380379
if subA is not None and state.subB is not None:
381-
output += torch.matmul(subA, state.subB)
380+
output = output.addmm(subA, state.subB)
382381

383382
# 5. Save state
384383
ctx.state = state

bitsandbytes/functional.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2722,6 +2722,20 @@ def int8_double_quant(
27222722
return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols
27232723

27242724

2725+
def int8_vectorwise_dequant(A: torch.Tensor, stats: torch.Tensor):
2726+
"""Dequantizes a tensor with dtype `torch.int8` to `torch.float32`.
2727+
2728+
Args:
2729+
A (`torch.Tensor` with dtype `torch.int8`): The quantized int8 tensor.
2730+
stats (`torch.Tensor` with dtype `torch.float32`): The row-wise quantization statistics.
2731+
2732+
Returns:
2733+
`torch.Tensor` with dtype `torch.float32`: The dequantized tensor.
2734+
"""
2735+
# To dequantize we divide by 127, or multiply by the reciprocal.
2736+
return A * stats.view(-1, 1) * 7.874015718698502e-3
2737+
2738+
27252739
def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
27262740
"""Quantizes a tensor with dtype `torch.float16` to `torch.int8` in accordance to the `LLM.int8()` algorithm.
27272741
@@ -3026,7 +3040,10 @@ def vectorwise_quant(x, dim=1, quant_type="vector"):
30263040
return None
30273041

30283042

3029-
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
3043+
@deprecated(
3044+
"This function is deprecated and will be removed in a future release. Consider using `int8_vectorwise_dequant` instead.",
3045+
category=FutureWarning,
3046+
)
30303047
def vectorwise_dequant(xq, max1, quant_type="vector"):
30313048
if quant_type == "vector":
30323049
x = (xq / C * max1).to(torch.float32)

csrc/kernels.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2159,7 +2159,7 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
21592159
// Threads will read the row values in a striped access pattern and find a local absmax.
21602160
float row_local_absmax = -FLT_MIN;
21612161
for (int i = threadIdx.x; i < cols; i += THREADS) {
2162-
const float absval = fabsf(__ldg(&(row_data[i])));
2162+
const float absval = fabsf(__ldcs(&(row_data[i])));
21632163

21642164
// For sparse decomposition, values outside of the threshold are not to be
21652165
// included when calculating the row's absmax.
@@ -2171,7 +2171,6 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
21712171
}
21722172

21732173
// Reduce thread-local absmax across the block.
2174-
// TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
21752174
const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols);
21762175
if (threadIdx.x == 0) {
21772176
// Save our block's absmax to shared memory for the quantization step.

tests/test_functional.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -695,38 +695,6 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
695695
print(sum(err3) / len(err3))
696696

697697

698-
@pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1"))
699-
@pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2"))
700-
@pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3"))
701-
@pytest.mark.parametrize("dims", [2], ids=id_formatter("dims"))
702-
@pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype)
703-
@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA"))
704-
@pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut"))
705-
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
706-
@pytest.mark.deprecated
707-
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
708-
for i in range(k):
709-
if dims == 2:
710-
A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype)
711-
elif dims == 3:
712-
A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype)
713-
714-
A.view(-1)[-1] = -1
715-
if transpose:
716-
At = A.t().contiguous()
717-
out1, S1 = F.nvidia_transform(At, to_order=orderOut)
718-
else:
719-
out1, S1 = F.nvidia_transform(A, to_order=orderOut)
720-
out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)
721-
722-
assert S1[0][0] == S2[0][0]
723-
assert S1[0][1] == S2[0][1]
724-
# print(out1)
725-
# print(out2)
726-
727-
torch.testing.assert_close(out1, out2)
728-
729-
730698
@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
731699
@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
732700
def test_coo_double_quant(dim1, dim2):
@@ -1782,6 +1750,38 @@ def test_percentile_clipping(gtype):
17821750
torch.testing.assert_close(gnorm1, gnorm2)
17831751

17841752

1753+
@pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1"))
1754+
@pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2"))
1755+
@pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3"))
1756+
@pytest.mark.parametrize("dims", [2], ids=id_formatter("dims"))
1757+
@pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype)
1758+
@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA"))
1759+
@pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut"))
1760+
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
1761+
@pytest.mark.deprecated
1762+
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
1763+
for i in range(k):
1764+
if dims == 2:
1765+
A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype)
1766+
elif dims == 3:
1767+
A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype)
1768+
1769+
A.view(-1)[-1] = -1
1770+
if transpose:
1771+
At = A.t().contiguous()
1772+
out1, S1 = F.nvidia_transform(At, to_order=orderOut)
1773+
else:
1774+
out1, S1 = F.nvidia_transform(A, to_order=orderOut)
1775+
out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)
1776+
1777+
assert S1[0][0] == S2[0][0]
1778+
assert S1[0][1] == S2[0][1]
1779+
# print(out1)
1780+
# print(out2)
1781+
1782+
torch.testing.assert_close(out1, out2)
1783+
1784+
17851785
@pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1"))
17861786
@pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2"))
17871787
@pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3"))

0 commit comments

Comments
 (0)