Skip to content

Commit 3bd9401

Browse files
committed
fix: Fixed perCh path in sawb_params_code
Signed-off-by: Brandon Groth <[email protected]>
1 parent dd1d8f7 commit 3bd9401

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

fms_mo/quant_refactor/sawb_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,10 @@ def sawb_params_code(
110110

111111
if perCh:
112112
# per-channel
113-
reduce_dim = list(range(1, len(input.shape)))
113+
reduce_dim = list(range(1, len(input_tensor.shape)))
114114
# conv W=[ch_o, ch_i, ki, ij], linear W=[ch_o, ch_i], reduce all dim but ch_out
115-
mu = torch.mean(input.abs(), dim=reduce_dim)
116-
std = torch.mean(input**2, dim=reduce_dim).sqrt()
115+
mu = torch.mean(input_tensor.abs(), dim=reduce_dim)
116+
std = torch.mean(input_tensor**2, dim=reduce_dim).sqrt()
117117
clip_val_vec = torch.tensor(coeff[1] * mu + coeff[0] * std)
118118
return None, clip_val_vec
119119

0 commit comments

Comments
 (0)