@@ -188,6 +188,7 @@ def getenv_side_effect(arg, default=None):
188188 Mock (spec = requests .Response ),
189189 Mock (spec = requests .Response ),
190190 ],
191+ "https://test.sagemaker.aws/api/2.0/mlflow/runs/update" : Mock (spec = requests .Response ),
191192 "https://test.sagemaker.aws/api/2.0/mlflow/runs/terminate" : Mock (spec = requests .Response ),
192193 }
193194
@@ -215,13 +216,20 @@ def getenv_side_effect(arg, default=None):
215216 mock_response .status_code = 200
216217 mock_response .text = json .dumps ({})
217218
219+ mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/update" ].status_code = 200
220+ mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/update" ].text = json .dumps ({
221+ "run_id" : "test_run_id" ,
222+ "status" : "FINISHED"
223+ })
224+
218225 mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/terminate" ].status_code = 200
219226 mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/terminate" ].text = json .dumps ({})
220227
221228 mock_request .side_effect = [
222229 mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/experiments/get-by-name" ],
223230 mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/create" ],
224231 * mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/log-batch" ],
232+ mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/update" ],
225233 mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/terminate" ],
226234 ]
227235
@@ -231,7 +239,7 @@ def getenv_side_effect(arg, default=None):
231239
232240 log_to_mlflow (metrics , params , tags )
233241
234- assert mock_request .call_count == 6 # Total number of API calls
242+ assert mock_request .call_count == 7 # Total number of API calls
235243
236244
237245@patch ("sagemaker.mlflow.forward_sagemaker_metrics.get_training_job_details" )
0 commit comments