Skip to content

Commit 1130ae0

Browse files
committed
test assert_any_call
1 parent 76efed5 commit 1130ae0

File tree

2 files changed

+29
-34
lines changed

2 files changed

+29
-34
lines changed

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1048,7 +1048,7 @@ def mock_upload_data(path, bucket, key_prefix):
10481048

10491049
model_trainer.train()
10501050

1051-
assert mock_local_container.train.called_once_with(
1051+
mock_local_container.train.assert_any_call(
10521052
training_job_name=unique_name,
10531053
instance_type=compute.instance_type,
10541054
instance_count=compute.instance_count,

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_pipeline_create_and_update_with_config_injection(sagemaker_session_mock
9999
RoleArn=pipeline_role_arn,
100100
)
101101
pipeline.upsert()
102-
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
102+
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_any_call(
103103
PipelineName="MyPipeline",
104104
PipelineDefinition=pipeline.definition(),
105105
RoleArn=pipeline_role_arn,
@@ -130,7 +130,7 @@ def test_pipeline_create_with_parallelism_config(sagemaker_session_mock, role_ar
130130
role_arn=role_arn,
131131
parallelism_config=dict(MaxParallelExecutionSteps=10),
132132
)
133-
assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with(
133+
sagemaker_session_mock.sagemaker_client.create_pipeline.assert_any_call(
134134
PipelineName="MyPipeline",
135135
PipelineDefinition=pipeline.definition(),
136136
RoleArn=role_arn,
@@ -149,7 +149,7 @@ def test_pipeline_create_and_start_with_parallelism_config(sagemaker_session_moc
149149
role_arn=role_arn,
150150
parallelism_config=dict(MaxParallelExecutionSteps=10),
151151
)
152-
assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with(
152+
sagemaker_session_mock.sagemaker_client.create_pipeline.assert_any_call(
153153
PipelineName="MyPipeline",
154154
PipelineDefinition=pipeline.definition(),
155155
RoleArn=role_arn,
@@ -168,7 +168,7 @@ def test_pipeline_create_and_start_with_parallelism_config(sagemaker_session_moc
168168

169169
# Specify ParallelismConfiguration to another value which will be honored in backend
170170
pipeline.start(parallelism_config=dict(MaxParallelExecutionSteps=20))
171-
assert sagemaker_session_mock.sagemaker_client.start_pipeline_execution.called_with(
171+
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_any_call(
172172
PipelineName="MyPipeline",
173173
ParallelismConfiguration={"MaxParallelExecutionSteps": 20},
174174
)
@@ -187,11 +187,11 @@ def test_large_pipeline_create(sagemaker_session_mock, role_arn):
187187

188188
pipeline.create(role_arn=role_arn)
189189

190-
assert s3.S3Uploader.upload_string_as_file_body.called_with(
190+
s3.S3Uploader.upload_string_as_file_body.assert_any_call(
191191
body=pipeline.definition(), s3_uri="s3://s3_bucket/MyPipeline"
192192
)
193193

194-
assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with(
194+
sagemaker_session_mock.sagemaker_client.create_pipeline.assert_any_call(
195195
PipelineName="MyPipeline",
196196
PipelineDefinitionS3Location={"Bucket": "s3_bucket", "ObjectKey": "MyPipeline"},
197197
RoleArn=role_arn,
@@ -209,7 +209,7 @@ def test_pipeline_update(sagemaker_session_mock, role_arn):
209209
assert not pipeline.steps
210210
pipeline.update(role_arn=role_arn)
211211
assert len(json.loads(pipeline.definition())["Steps"]) == 0
212-
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
212+
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_any_call(
213213
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn
214214
)
215215

@@ -253,7 +253,7 @@ def test_pipeline_update(sagemaker_session_mock, role_arn):
253253

254254
pipeline.update(role_arn=role_arn)
255255
assert len(json.loads(pipeline.definition())["Steps"]) == 3
256-
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
256+
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_any_call(
257257
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn
258258
)
259259

@@ -345,7 +345,7 @@ def test_pipeline_update_with_parallelism_config(sagemaker_session_mock, role_ar
345345
role_arn=role_arn,
346346
parallelism_config=dict(MaxParallelExecutionSteps=10),
347347
)
348-
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
348+
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_any_call(
349349
PipelineName="MyPipeline",
350350
PipelineDefinition=pipeline.definition(),
351351
RoleArn=role_arn,
@@ -366,11 +366,11 @@ def test_large_pipeline_update(sagemaker_session_mock, role_arn):
366366

367367
pipeline.create(role_arn=role_arn)
368368

369-
assert s3.S3Uploader.upload_string_as_file_body.called_with(
369+
s3.S3Uploader.upload_string_as_file_body.assert_any_call(
370370
body=pipeline.definition(), s3_uri="s3://s3_bucket/MyPipeline"
371371
)
372372

373-
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
373+
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_any_call(
374374
PipelineName="MyPipeline",
375375
PipelineDefinitionS3Location={"Bucket": "s3_bucket", "ObjectKey": "MyPipeline"},
376376
RoleArn=role_arn,
@@ -418,12 +418,12 @@ def _raise_does_already_exists_client_error(**kwargs):
418418
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_once_with(
419419
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn
420420
)
421-
assert sagemaker_session_mock.sagemaker_client.list_tags.called_with(
421+
sagemaker_session_mock.sagemaker_client.list_tags.assert_any_call(
422422
ResourceArn="mock_pipeline_arn"
423423
)
424424

425425
tags.append({"Key": "dummy", "Value": "dummy_tag"})
426-
assert sagemaker_session_mock.sagemaker_client.add_tags.called_with(
426+
sagemaker_session_mock.sagemaker_client.add_tags.assert_any_call(
427427
ResourceArn="mock_pipeline_arn", Tags=tags
428428
)
429429

@@ -523,7 +523,7 @@ def test_pipeline_delete(sagemaker_session_mock):
523523
sagemaker_session=sagemaker_session_mock,
524524
)
525525
pipeline.delete()
526-
assert sagemaker_session_mock.sagemaker_client.delete_pipeline.called_with(
526+
sagemaker_session_mock.sagemaker_client.delete_pipeline.assert_any_call(
527527
PipelineName="MyPipeline",
528528
)
529529

@@ -536,7 +536,7 @@ def test_pipeline_describe(sagemaker_session_mock):
536536
sagemaker_session=sagemaker_session_mock,
537537
)
538538
pipeline.describe()
539-
assert sagemaker_session_mock.sagemaker_client.describe_pipeline.called_with(
539+
sagemaker_session_mock.sagemaker_client.describe_pipeline.assert_any_call(
540540
PipelineName="MyPipeline",
541541
)
542542

@@ -552,17 +552,17 @@ def test_pipeline_start(sagemaker_session_mock):
552552
sagemaker_session=sagemaker_session_mock,
553553
)
554554
pipeline.start()
555-
assert sagemaker_session_mock.start_pipeline_execution.called_with(
555+
sagemaker_session_mock.start_pipeline_execution.assert_any_call(
556556
PipelineName="MyPipeline",
557557
)
558558

559559
pipeline.start(execution_display_name="pipeline-execution")
560-
assert sagemaker_session_mock.start_pipeline_execution.called_with(
560+
sagemaker_session_mock.start_pipeline_execution.assert_any_call(
561561
PipelineName="MyPipeline", PipelineExecutionDisplayName="pipeline-execution"
562562
)
563563

564564
pipeline.start(parameters=dict(alpha="epsilon"))
565-
assert sagemaker_session_mock.start_pipeline_execution.called_with(
565+
sagemaker_session_mock.start_pipeline_execution.assert_any_call(
566566
PipelineName="MyPipeline", PipelineParameters=[{"Name": "alpha", "Value": "epsilon"}]
567567
)
568568

@@ -821,10 +821,8 @@ def test_pipeline_build_parameters_from_execution(sagemaker_session_mock):
821821
pipeline_execution_arn=reference_execution_arn,
822822
parameter_value_overrides=parameter_value_overrides,
823823
)
824-
assert (
825-
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with(
826-
PipelineExecutionArn=reference_execution_arn
827-
)
824+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.assert_any_call(
825+
PipelineExecutionArn=reference_execution_arn
828826
)
829827
assert len(parameters) == 1
830828
assert parameters["TestParameterName"] == "NewParameterValue"
@@ -850,10 +848,8 @@ def test_pipeline_build_parameters_from_execution_with_invalid_overrides(sagemak
850848
+ f"are not present in the pipeline execution: {reference_execution_arn}"
851849
in str(error)
852850
)
853-
assert (
854-
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with(
855-
PipelineExecutionArn=reference_execution_arn
856-
)
851+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.assert_any_call(
852+
PipelineExecutionArn=reference_execution_arn
857853
)
858854

859855

@@ -908,24 +904,23 @@ def test_pipeline_execution_basics(sagemaker_session_mock):
908904
)
909905
execution = pipeline.start()
910906
execution.stop()
911-
assert sagemaker_session_mock.sagemaker_client.stop_pipeline_execution.called_with(
907+
sagemaker_session_mock.sagemaker_client.stop_pipeline_execution.assert_any_call(
912908
PipelineExecutionArn="my:arn"
913909
)
914910
execution.describe()
915-
assert sagemaker_session_mock.sagemaker_client.describe_pipeline_execution.called_with(
911+
sagemaker_session_mock.sagemaker_client.describe_pipeline_execution.assert_any_call(
916912
PipelineExecutionArn="my:arn"
917913
)
918914
steps = execution.list_steps()
919-
assert sagemaker_session_mock.sagemaker_client.describe_pipeline_execution_steps.called_with(
915+
sagemaker_session_mock.sagemaker_client.describe_pipeline_execution_steps.assert_any_call(
920916
PipelineExecutionArn="my:arn"
921917
)
922918
assert len(steps) == 1
923919
list_parameters_response = execution.list_parameters()
924-
assert (
925-
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with(
926-
PipelineExecutionArn="my:arn"
927-
)
920+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.assert_any_call(
921+
PipelineExecutionArn="my:arn"
928922
)
923+
929924
parameter_list = list_parameters_response["PipelineParameters"]
930925
assert len(parameter_list) == 1
931926
assert parameter_list[0]["Name"] == "TestParameterName"

0 commit comments

Comments
 (0)