Skip to content

Commit f40e8ae

Browse files
int8 tests passing
1 parent d02b536 commit f40e8ae

File tree

1 file changed

+9
-1
lines changed
  • bitsandbytes/backends/default

1 file changed

+9
-1
lines changed

bitsandbytes/backends/default/ops.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,17 @@ def _(A: torch.Tensor, threshold=0.0):
8888
rows = prod(A.shape[:-1])
8989
outlier_cols = None
9090

91+
outlier_restore = None
92+
9193
if threshold > 0.0:
9294
outliers = A.abs() >= threshold
9395

9496
if outliers.any():
9597
# Determine which columns contain outliers, and zero out the
96-
# outliers ahead of quantization.
98+
# outliers ahead of quantization. We need to keep a backup of these
99+
# outliers to restore them after quantization.
97100
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
101+
outlier_restore = A[outliers].clone()
98102
A[outliers] = 0
99103
else:
100104
# Needed for torch.compile support.
@@ -110,4 +114,8 @@ def _(A: torch.Tensor, threshold=0.0):
110114
if rows > 1 and outlier_cols is not None:
111115
out_row[:, outlier_cols] = 0
112116

117+
# Restore outliers.
118+
if outlier_restore is not None:
119+
A[outliers] = outlier_restore
120+
113121
return out_row, row_stats, outlier_cols

0 commit comments

Comments
 (0)