@@ -188,6 +188,7 @@ def getenv_side_effect(arg, default=None):
188
188
Mock (spec = requests .Response ),
189
189
Mock (spec = requests .Response ),
190
190
],
191
+ "https://test.sagemaker.aws/api/2.0/mlflow/runs/update" : Mock (spec = requests .Response ),
191
192
"https://test.sagemaker.aws/api/2.0/mlflow/runs/terminate" : Mock (spec = requests .Response ),
192
193
}
193
194
@@ -215,13 +216,20 @@ def getenv_side_effect(arg, default=None):
215
216
mock_response .status_code = 200
216
217
mock_response .text = json .dumps ({})
217
218
219
+ mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/update" ].status_code = 200
220
+ mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/update" ].text = json .dumps ({
221
+ "run_id" : "test_run_id" ,
222
+ "status" : "FINISHED"
223
+ })
224
+
218
225
mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/terminate" ].status_code = 200
219
226
mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/terminate" ].text = json .dumps ({})
220
227
221
228
mock_request .side_effect = [
222
229
mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/experiments/get-by-name" ],
223
230
mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/create" ],
224
231
* mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/log-batch" ],
232
+ mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/update" ],
225
233
mock_responses ["https://test.sagemaker.aws/api/2.0/mlflow/runs/terminate" ],
226
234
]
227
235
@@ -231,7 +239,7 @@ def getenv_side_effect(arg, default=None):
231
239
232
240
log_to_mlflow (metrics , params , tags )
233
241
234
- assert mock_request .call_count == 6 # Total number of API calls
242
+ assert mock_request .call_count == 7 # Total number of API calls
235
243
236
244
237
245
@patch ("sagemaker.mlflow.forward_sagemaker_metrics.get_training_job_details" )
0 commit comments