Skip to content

Commit 277228b

Browse files
committed
fix: Added sawb_utils neg clip_val to abs.max
Signed-off-by: Brandon Groth <[email protected]>
1 parent 49ae7d9 commit 277228b

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

fms_mo/quant_refactor/sawb_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ def sawb_params(
6767
coeff = dic_coeff[num_bits_int]
6868
clip_val = coeff[1] * mu + coeff[0] * std
6969

70+
# Overwrite negative clip_vals with abs.max
71+
if clip_val < 0.0:
72+
clip_val = x.abs().max()
73+
7074
n_levels = 2**num_bits - 2 if qlevel_lowering else 2**num_bits - 1
7175

7276
return n_levels, clip_val
@@ -109,12 +113,18 @@ def sawb_params_code(
109113
raise ValueError(f"SAWB not implemented for code={code}")
110114

111115
if perCh:
112-
# per-channel
113116
reduce_dim = list(range(1, len(input_tensor.shape)))
114117
# conv W=[ch_o, ch_i, ki, ij], linear W=[ch_o, ch_i], reduce all dim but ch_out
115118
mu = torch.mean(input_tensor.abs(), dim=reduce_dim)
116119
std = torch.mean(input_tensor**2, dim=reduce_dim).sqrt()
117120
clip_val_vec = coeff[1] * mu + coeff[0] * std
121+
122+
# Overwrite negative clip_vals with abs.max
123+
neg_clip_idx = clip_val_vec < 0.0
124+
if torch.any(neg_clip_idx):
125+
clip_val_max = torch.max(input_tensor.abs(), dim=1).values
126+
clip_val_vec[neg_clip_idx] = clip_val_max[neg_clip_idx]
127+
118128
return None, clip_val_vec
119129

120130
# per-tensor
@@ -123,6 +133,10 @@ def sawb_params_code(
123133
std = x.mul(x).mean().sqrt()
124134
clip_val = coeff[1] * mu + coeff[0] * std
125135

136+
# Overwrite negative clip_vals with abs.max
137+
if clip_val < 0.0:
138+
clip_val = x.abs().max()
139+
126140
if code in [102]:
127141
n_levels = 2**num_bits - 1
128142
elif code in [103, 203, 403, 703, 803]:

0 commit comments

Comments
 (0)