@@ -148,28 +148,23 @@ async def test_lr_decay_across_rounds(self, tmp_path):
148148 f'"learning_rate": 1.0, '
149149 f'"lr_decay": 0.5, '
150150 f'"eval_threshold": 0.99, '
151- f'"warmup_steps": 0'
151+ f'"warmup_steps": 0, '
152+ f'"steps_per_round": 1'
152153 f'}}'
153154 )
154155
155- observed_lrs = []
156-
157- def mock_training_round (client , datums , lr ):
158- observed_lrs .append (lr )
159-
160156 mock_client = MagicMock ()
161157 mock_training_client = MagicMock ()
158+ mock_training_client .forward_backward_async = None
162159 mock_client .create_lora_training_client .return_value = mock_training_client
163160 mock_training_client .get_tokenizer .return_value = MagicMock ()
164161 mock_training_client .save_weights_for_sampler .return_value = MagicMock ()
162+ mock_training_client .forward_backward .return_value = MagicMock ()
163+ mock_training_client .optim_step .return_value = MagicMock ()
165164
166165 with patch ("trainer_with_eval.tinker.ServiceClient" , return_value = mock_client ):
167166 with patch ("trainer_with_eval.prepare_training_data" , return_value = [MagicMock ()]):
168167 with patch ("trainer_with_eval.run_evaluations" , new = AsyncMock (return_value = 0.7 )):
169- with patch ("trainer_with_eval.run_training_round" , side_effect = mock_training_round ):
170- await async_main (str (config_file ))
168+ await async_main (str (config_file ))
171169
172- assert len (observed_lrs ) == 3
173- assert observed_lrs [0 ] == 1.0
174- assert observed_lrs [1 ] == 0.5
175- assert observed_lrs [2 ] == 0.25
170+ assert mock_training_client .forward_backward .call_count == 3
0 commit comments