Skip to content

Commit acf45c6

Browse files
committed
only change some called_with into assert_called_with according to py312 unit-tests
1 parent 76efed5 commit acf45c6

File tree

2 files changed

+15
-19
lines changed

2 files changed

+15
-19
lines changed

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1048,7 +1048,7 @@ def mock_upload_data(path, bucket, key_prefix):
10481048

10491049
model_trainer.train()
10501050

1051-
assert mock_local_container.train.called_once_with(
1051+
mock_local_container.train.assert_called_once_with(
10521052
training_job_name=unique_name,
10531053
instance_type=compute.instance_type,
10541054
instance_count=compute.instance_count,

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_pipeline_create_and_update_with_config_injection(sagemaker_session_mock
9999
RoleArn=pipeline_role_arn,
100100
)
101101
pipeline.upsert()
102-
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
102+
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_with(
103103
PipelineName="MyPipeline",
104104
PipelineDefinition=pipeline.definition(),
105105
RoleArn=pipeline_role_arn,
@@ -130,7 +130,7 @@ def test_pipeline_create_with_parallelism_config(sagemaker_session_mock, role_ar
130130
role_arn=role_arn,
131131
parallelism_config=dict(MaxParallelExecutionSteps=10),
132132
)
133-
assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with(
133+
sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_with(
134134
PipelineName="MyPipeline",
135135
PipelineDefinition=pipeline.definition(),
136136
RoleArn=role_arn,
@@ -149,7 +149,7 @@ def test_pipeline_create_and_start_with_parallelism_config(sagemaker_session_moc
149149
role_arn=role_arn,
150150
parallelism_config=dict(MaxParallelExecutionSteps=10),
151151
)
152-
assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with(
152+
sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_with(
153153
PipelineName="MyPipeline",
154154
PipelineDefinition=pipeline.definition(),
155155
RoleArn=role_arn,
@@ -209,7 +209,7 @@ def test_pipeline_update(sagemaker_session_mock, role_arn):
209209
assert not pipeline.steps
210210
pipeline.update(role_arn=role_arn)
211211
assert len(json.loads(pipeline.definition())["Steps"]) == 0
212-
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
212+
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_with(
213213
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn
214214
)
215215

@@ -345,7 +345,7 @@ def test_pipeline_update_with_parallelism_config(sagemaker_session_mock, role_ar
345345
role_arn=role_arn,
346346
parallelism_config=dict(MaxParallelExecutionSteps=10),
347347
)
348-
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
348+
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_with(
349349
PipelineName="MyPipeline",
350350
PipelineDefinition=pipeline.definition(),
351351
RoleArn=role_arn,
@@ -418,7 +418,7 @@ def _raise_does_already_exists_client_error(**kwargs):
418418
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_once_with(
419419
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn
420420
)
421-
assert sagemaker_session_mock.sagemaker_client.list_tags.called_with(
421+
sagemaker_session_mock.sagemaker_client.list_tags.assert_called_with(
422422
ResourceArn="mock_pipeline_arn"
423423
)
424424

@@ -523,7 +523,7 @@ def test_pipeline_delete(sagemaker_session_mock):
523523
sagemaker_session=sagemaker_session_mock,
524524
)
525525
pipeline.delete()
526-
assert sagemaker_session_mock.sagemaker_client.delete_pipeline.called_with(
526+
sagemaker_session_mock.sagemaker_client.delete_pipeline.assert_called_with(
527527
PipelineName="MyPipeline",
528528
)
529529

@@ -536,7 +536,7 @@ def test_pipeline_describe(sagemaker_session_mock):
536536
sagemaker_session=sagemaker_session_mock,
537537
)
538538
pipeline.describe()
539-
assert sagemaker_session_mock.sagemaker_client.describe_pipeline.called_with(
539+
sagemaker_session_mock.sagemaker_client.describe_pipeline.assert_called_with(
540540
PipelineName="MyPipeline",
541541
)
542542

@@ -552,7 +552,7 @@ def test_pipeline_start(sagemaker_session_mock):
552552
sagemaker_session=sagemaker_session_mock,
553553
)
554554
pipeline.start()
555-
assert sagemaker_session_mock.start_pipeline_execution.called_with(
555+
sagemaker_session_mock.start_pipeline_execution.assert_called_with(
556556
PipelineName="MyPipeline",
557557
)
558558

@@ -821,10 +821,8 @@ def test_pipeline_build_parameters_from_execution(sagemaker_session_mock):
821821
pipeline_execution_arn=reference_execution_arn,
822822
parameter_value_overrides=parameter_value_overrides,
823823
)
824-
assert (
825-
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with(
826-
PipelineExecutionArn=reference_execution_arn
827-
)
824+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.assert_called_with(
825+
PipelineExecutionArn=reference_execution_arn
828826
)
829827
assert len(parameters) == 1
830828
assert parameters["TestParameterName"] == "NewParameterValue"
@@ -850,10 +848,8 @@ def test_pipeline_build_parameters_from_execution_with_invalid_overrides(sagemak
850848
+ f"are not present in the pipeline execution: {reference_execution_arn}"
851849
in str(error)
852850
)
853-
assert (
854-
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with(
855-
PipelineExecutionArn=reference_execution_arn
856-
)
851+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.assert_called_with(
852+
PipelineExecutionArn=reference_execution_arn
857853
)
858854

859855

@@ -912,7 +908,7 @@ def test_pipeline_execution_basics(sagemaker_session_mock):
912908
PipelineExecutionArn="my:arn"
913909
)
914910
execution.describe()
915-
assert sagemaker_session_mock.sagemaker_client.describe_pipeline_execution.called_with(
911+
sagemaker_session_mock.sagemaker_client.describe_pipeline_execution.assert_called_with(
916912
PipelineExecutionArn="my:arn"
917913
)
918914
steps = execution.list_steps()

0 commit comments

Comments
 (0)