File tree Expand file tree Collapse file tree 1 file changed +9
-1
lines changed
bitsandbytes/backends/default Expand file tree Collapse file tree 1 file changed +9
-1
lines changed Original file line number Diff line number Diff line change @@ -88,13 +88,17 @@ def _(A: torch.Tensor, threshold=0.0):
8888 rows = prod (A .shape [:- 1 ])
8989 outlier_cols = None
9090
91+ outlier_restore = None
92+
9193 if threshold > 0.0 :
9294 outliers = A .abs () >= threshold
9395
9496 if outliers .any ():
9597 # Determine which columns contain outliers, and zero out the
96- # outliers ahead of quantization.
98+ # outliers ahead of quantization. We need to keep a backup of these
99+ # outliers to restore them after quantization.
97100 outlier_cols = torch .argwhere (outliers .any (dim = 0 )).view (- 1 )
101+ outlier_restore = A [outliers ].clone ()
98102 A [outliers ] = 0
99103 else :
100104 # Needed for torch.compile support.
@@ -110,4 +114,8 @@ def _(A: torch.Tensor, threshold=0.0):
110114 if rows > 1 and outlier_cols is not None :
111115 out_row [:, outlier_cols ] = 0
112116
117+ # Restore outliers.
118+ if outlier_restore is not None :
119+ A [outliers ] = outlier_restore
120+
113121 return out_row , row_stats , outlier_cols
You can’t perform that action at this time.
0 commit comments