Skip to content

Commit 803163c

Browse files
committed
Bugfix for bias cloning
1 parent 2e134d8 commit 803163c

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

auto_fp8/quantize.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import gc
22
import re
33
from typing import List, Tuple
4+
import copy
45

56
import torch
67
import tqdm
@@ -47,7 +48,7 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
4748
)
4849
else:
4950
min_val, max_val = tensor.aminmax()
50-
amax = min_val.abs().max(max_val.abs())
51+
amax = torch.maximum(min_val.abs(), max_val.abs())
5152
scale = finfo.max / amax.clamp(min=1e-12)
5253
# scale and clamp the tensor to bring it to
5354
# the representative range of float8 data type
@@ -202,8 +203,8 @@ def quantize_weights(
202203
or name in quantize_config.ignored_layers
203204
):
204205
continue
205-
quant_weight, quant_scale = per_tensor_quantize(linear.weight.clone())
206-
bias = linear.bias.clone() if linear.bias is not None else None
206+
quant_weight, quant_scale = per_tensor_quantize(linear.weight)
207+
bias = copy.deepcopy(linear.bias) if linear.bias is not None else None
207208
quant_linear = FP8DynamicLinear(quant_weight, quant_scale, bias)
208209
replace_module(model, name, quant_linear)
209210
del linear.weight

0 commit comments

Comments
 (0)