@@ -250,12 +250,13 @@ def test_training(self, model: keras.Model, input_data, overflow_mode: str, *arg
250250
251251 initial_weights_np = [w .numpy () for w in model .trainable_variables ]
252252
253- opt = keras .optimizers .Lion ( learning_rate = 1.0 )
253+ opt = keras .optimizers .SGD ( )
254254 loss = keras .losses .MeanAbsoluteError ()
255255 model (input_data , training = True ) # Adapt init bitwidth
256256
257257 data_len = len (input_data [0 ]) if isinstance (input_data , Sequence ) else len (input_data )
258- labels = ops .array (np .random .rand (data_len ), dtype = 'float32' )
258+ shape = (data_len , * model .output .shape [1 :]) # type: ignore
259+ labels = ops .array (np .random .rand (* shape ), dtype = 'float32' )
259260 model_wrap .compile (optimizer = opt , loss = loss ) # type: ignore
260261 model_wrap .train_on_batch (input_data , labels )
261262
@@ -265,13 +266,11 @@ def test_training(self, model: keras.Model, input_data, overflow_mode: str, *arg
265266 for w0 , w1 in zip (initial_weights_np , trained_weights ):
266267 if w1 .name in 'bif' :
267268 continue
268- if np .prod (w1 .shape ) < 10 and overflow_mode == 'SAT' :
269+ if np .prod (w1 .shape ) < 10 and 'SAT' in overflow_mode :
269270 # Overflowing weight doesn't receive grad in SAT mode
270271 # Chance of all overflow is high for small-sized weights, skip them
271272 continue
272273 if np .array_equal (w0 , w1 .numpy ()):
273- # if w1.path == 'q_multi_head_attention/key/bias':
274- # continue
275274 boom .append (f'{ w1 .path } ' )
276275 assert not boom , f'Weight { " AND " .join (boom )} did not change'
277276 assert any (np .any (w0 != w1 .numpy ()) for w0 , w1 in zip (initial_weights_np , trained_weights ) if w1 .name in 'bif' )
0 commit comments