@@ -67,6 +67,10 @@ def sawb_params(
6767 coeff = dic_coeff [num_bits_int ]
6868 clip_val = coeff [1 ] * mu + coeff [0 ] * std
6969
70+ # Overwrite negative clip_vals with abs.max
71+ if clip_val < 0.0 :
72+ clip_val = x .abs ().max ()
73+
7074 n_levels = 2 ** num_bits - 2 if qlevel_lowering else 2 ** num_bits - 1
7175
7276 return n_levels , clip_val
@@ -109,12 +113,18 @@ def sawb_params_code(
109113 raise ValueError (f"SAWB not implemented for code={ code } " )
110114
111115 if perCh :
112- # per-channel
113116 reduce_dim = list (range (1 , len (input_tensor .shape )))
114117 # conv W=[ch_o, ch_i, ki, ij], linear W=[ch_o, ch_i], reduce all dim but ch_out
115118 mu = torch .mean (input_tensor .abs (), dim = reduce_dim )
116119 std = torch .mean (input_tensor ** 2 , dim = reduce_dim ).sqrt ()
117120 clip_val_vec = coeff [1 ] * mu + coeff [0 ] * std
121+
122+ # Overwrite negative clip_vals with abs.max
123+ neg_clip_idx = clip_val_vec < 0.0
124+ if torch .any (neg_clip_idx ):
125+ clip_val_max = torch .max (input_tensor .abs (), dim = 1 ).values
126+ clip_val_vec [neg_clip_idx ] = clip_val_max [neg_clip_idx ]
127+
118128 return None , clip_val_vec
119129
120130 # per-tensor
@@ -123,6 +133,10 @@ def sawb_params_code(
123133 std = x .mul (x ).mean ().sqrt ()
124134 clip_val = coeff [1 ] * mu + coeff [0 ] * std
125135
136+ # Overwrite negative clip_vals with abs.max
137+ if clip_val < 0.0 :
138+ clip_val = x .abs ().max ()
139+
126140 if code in [102 ]:
127141 n_levels = 2 ** num_bits - 1
128142 elif code in [103 , 203 , 403 , 703 , 803 ]:
0 commit comments