Skip to content

Commit ea943c0

Browse files
committed
change called_with into assert_called_once_with
1 parent d8660c0 commit ea943c0

File tree

6 files changed

+32
-36
lines changed

6 files changed

+32
-36
lines changed

tests/unit/sagemaker/feature_store/feature_processor/test_config_uploader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def remote_decorator_config_with_filter(sagemaker_session):
100100
def test_prepare_and_upload_callable(mock_stored_function, config_uploader, wrapped_func):
101101
mock_stored_function.save(wrapped_func).return_value = None
102102
config_uploader._prepare_and_upload_callable(wrapped_func, "s3_base_uri", sagemaker_session)
103-
assert mock_stored_function.assert_called_once_with(
103+
mock_stored_function.assert_called_once_with(
104104
s3_base_uri="s3_base_uri",
105105
s3_kms_key=config_uploader.remote_decorator_config.s3_kms_key,
106106
hmac_key="some_secret_key",
@@ -244,7 +244,7 @@ def test_prepare_step_input_channel(
244244
)
245245
remote_decorator_config = config_uploader.remote_decorator_config
246246

247-
assert mock_upload_callable.assert_called_once_with(wrapped_func)
247+
mock_upload_callable.assert_called_once_with(wrapped_func)
248248

249249
mock_script_upload.assert_called_once_with(
250250
spark_config=config_uploader.remote_decorator_config.spark_config,

tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def test_to_pipeline(
269269
)
270270
assert pipeline_arn == PIPELINE_ARN
271271

272-
assert mock_upload_callable.assert_called_once_with(job_function)
272+
mock_upload_callable.assert_called_once_with(job_function)
273273
local_dependencies_path = mock_runtime_manager().snapshot()
274274
mock_python_version = mock_runtime_manager()._current_python_version()
275275
container_args.extend(["--client_python_version", mock_python_version])

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1048,7 +1048,7 @@ def mock_upload_data(path, bucket, key_prefix):
10481048

10491049
model_trainer.train()
10501050

1051-
assert mock_local_container.train.assert_called_once_with(
1051+
mock_local_container.train.assert_called_once_with(
10521052
training_job_name=unique_name,
10531053
instance_type=compute.instance_type,
10541054
instance_count=compute.instance_count,

tests/unit/sagemaker/remote_function/runtime_environment/test_runtime_environment_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def test_run_pre_exec_script_cmd_error(isfile):
450450
def test_change_dir_permission(mock_subprocess_run):
451451
RuntimeEnvironmentManager().change_dir_permission(dirs=["a", "b", "c"], new_permission="777")
452452
expected_command = ["sudo", "chmod", "-R", "777", "a", "b", "c"]
453-
assert mock_subprocess_run.assert_called_once_with(
453+
mock_subprocess_run.assert_called_once_with(
454454
expected_command, check=True, stderr=subprocess.PIPE
455455
)
456456

tests/unit/sagemaker/remote_function/test_job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1190,7 +1190,7 @@ def test_start_with_spark(
11901190

11911191
assert job.job_name.startswith("job-function")
11921192

1193-
assert mock_stored_function.assert_called_once_with(
1193+
mock_stored_function.assert_called_once_with(
11941194
sagemaker_session=session(), s3_base_uri=f"{S3_URI}/{job.job_name}", s3_kms_key=None
11951195
)
11961196

tests/unit/test_session.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5599,7 +5599,7 @@ def test_feature_group_create(sagemaker_session, feature_group_dummy_definitions
55995599
feature_definitions=feature_group_dummy_definitions,
56005600
role_arn="dummy_role",
56015601
)
5602-
assert sagemaker_session.sagemaker_client.create_feature_group.called_with(
5602+
sagemaker_session.sagemaker_client.create_feature_group.assert_called_once_with(
56035603
FeatureGroupName="MyFeatureGroup",
56045604
RecordIdentifierFeatureName="feature1",
56055605
EventTimeFeatureName="feature2",
@@ -5610,14 +5610,14 @@ def test_feature_group_create(sagemaker_session, feature_group_dummy_definitions
56105610

56115611
def test_feature_group_delete(sagemaker_session):
56125612
sagemaker_session.delete_feature_group(feature_group_name="MyFeatureGroup")
5613-
assert sagemaker_session.sagemaker_client.delete_feature_group.called_with(
5613+
sagemaker_session.sagemaker_client.delete_feature_group.assert_called_once_with(
56145614
FeatureGroupName="MyFeatureGroup",
56155615
)
56165616

56175617

56185618
def test_feature_group_describe(sagemaker_session):
56195619
sagemaker_session.describe_feature_group(feature_group_name="MyFeatureGroup")
5620-
assert sagemaker_session.sagemaker_client.describe_feature_group.called_with(
5620+
sagemaker_session.sagemaker_client.describe_feature_group.assert_called_once_with(
56215621
FeatureGroupName="MyFeatureGroup",
56225622
)
56235623

@@ -5627,7 +5627,7 @@ def test_feature_group_feature_additions_update(sagemaker_session, feature_group
56275627
feature_group_name="MyFeatureGroup",
56285628
feature_additions=feature_group_dummy_definitions,
56295629
)
5630-
assert sagemaker_session.sagemaker_client.update_feature_group.called_with(
5630+
sagemaker_session.sagemaker_client.update_feature_group.assert_called_once_with(
56315631
FeatureGroupName="MyFeatureGroup",
56325632
FeatureAdditions=feature_group_dummy_definitions,
56335633
)
@@ -5639,7 +5639,7 @@ def test_feature_group_online_store_config_update(sagemaker_session):
56395639
feature_group_name="MyFeatureGroup",
56405640
online_store_config=os_conf_update,
56415641
)
5642-
assert sagemaker_session.sagemaker_client.update_feature_group.called_with(
5642+
sagemaker_session.sagemaker_client.update_feature_group.assert_called_once_with(
56435643
FeatureGroupName="MyFeatureGroup", OnlineStoreConfig=os_conf_update
56445644
)
56455645

@@ -5654,7 +5654,7 @@ def test_feature_group_throughput_config_update(sagemaker_session):
56545654
feature_group_name="MyFeatureGroup",
56555655
throughput_config=tp_update,
56565656
)
5657-
assert sagemaker_session.sagemaker_client.update_feature_group.called_with(
5657+
sagemaker_session.sagemaker_client.update_feature_group.assert_called_once_with(
56585658
FeatureGroupName="MyFeatureGroup", ThroughputConfig=tp_update
56595659
)
56605660

@@ -5675,7 +5675,7 @@ def test_feature_metadata_update(sagemaker_session):
56755675
parameter_additions=parameter_additions,
56765676
parameter_removals=parameter_removals,
56775677
)
5678-
assert sagemaker_session.sagemaker_client.update_feature_group.called_with(
5678+
sagemaker_session.sagemaker_client.update_feature_group.assert_called_once_with(
56795679
feature_group_name="TestFeatureGroup",
56805680
FeatureName="TestFeature",
56815681
Description="TestDescription",
@@ -5686,7 +5686,7 @@ def test_feature_metadata_update(sagemaker_session):
56865686
feature_group_name="TestFeatureGroup",
56875687
feature_name="TestFeature",
56885688
)
5689-
assert sagemaker_session.sagemaker_client.update_feature_group.called_with(
5689+
sagemaker_session.sagemaker_client.update_feature_group.assert_called_once_with(
56905690
feature_group_name="TestFeatureGroup",
56915691
FeatureName="TestFeature",
56925692
)
@@ -5696,7 +5696,7 @@ def test_feature_metadata_describe(sagemaker_session):
56965696
sagemaker_session.describe_feature_metadata(
56975697
feature_group_name="MyFeatureGroup", feature_name="TestFeature"
56985698
)
5699-
assert sagemaker_session.sagemaker_client.describe_feature_metadata.called_with(
5699+
sagemaker_session.sagemaker_client.describe_feature_metadata.assert_called_once_with(
57005700
FeatureGroupName="MyFeatureGroup", FeatureName="TestFeature"
57015701
)
57025702

@@ -5725,7 +5725,7 @@ def test_list_feature_groups(sagemaker_session):
57255725
next_token="token",
57265726
)
57275727
assert sagemaker_session.sagemaker_client.list_feature_groups.called_once()
5728-
assert sagemaker_session.sagemaker_client.list_feature_groups.called_with(
5728+
sagemaker_session.sagemaker_client.list_feature_groups.assert_called_once_with(
57295729
**expected_list_feature_groups_args
57305730
)
57315731

@@ -5746,7 +5746,7 @@ def test_feature_group_put_record(sagemaker_session_with_fs_runtime_client):
57465746
)
57475747
fs_client_mock = sagemaker_session_with_fs_runtime_client.sagemaker_featurestore_runtime_client
57485748

5749-
assert fs_client_mock.put_record.called_with(
5749+
fs_client_mock.put_record.assert_called_once_with(
57505750
FeatureGroupName="MyFeatureGroup",
57515751
record=[{"FeatureName": "feature1", "ValueAsString": "value1"}],
57525752
)
@@ -5762,7 +5762,7 @@ def test_feature_group_put_record_with_ttl_and_target_stores(
57625762
target_stores=["OnlineStore", "OfflineStore"],
57635763
)
57645764
fs_client_mock = sagemaker_session_with_fs_runtime_client.sagemaker_featurestore_runtime_client
5765-
assert fs_client_mock.put_record.called_with(
5765+
fs_client_mock.put_record.assert_called_once_with(
57665766
FeatureGroupName="MyFeatureGroup",
57675767
record=[{"FeatureName": "feature1", "ValueAsString": "value1"}],
57685768
target_stores=["OnlineStore", "OfflineStore"],
@@ -5781,7 +5781,7 @@ def test_start_query_execution(sagemaker_session):
57815781
query_string="query",
57825782
output_location="s3://results",
57835783
)
5784-
assert athena_mock.start_query_execution.assert_called_once_with(
5784+
athena_mock.start_query_execution.assert_called_once_with(
57855785
QueryString="query",
57865786
QueryExecutionContext={"Catalog": "catalog", "Database": "database"},
57875787
OutputLocation="s3://results",
@@ -5794,7 +5794,7 @@ def test_get_query_execution(sagemaker_session):
57945794
"athena", region_name=sagemaker_session.boto_region_name
57955795
).return_value = athena_mock
57965796
sagemaker_session.get_query_execution(query_execution_id="query_id")
5797-
assert athena_mock.get_query_execution.called_with(QueryExecutionId="query_id")
5797+
athena_mock.get_query_execution.assert_called_once_with(QueryExecutionId="query_id")
57985798

57995799

58005800
def test_download_athena_query_result(sagemaker_session):
@@ -5805,7 +5805,7 @@ def test_download_athena_query_result(sagemaker_session):
58055805
query_execution_id="query_id",
58065806
filename="filename",
58075807
)
5808-
assert sagemaker_session.s3_client.download_file.called_with(
5808+
sagemaker_session.s3_client.download_file.assert_called_once_with(
58095809
Bucket="bucket",
58105810
Key="prefix/query_id.csv",
58115811
Filename="filename",
@@ -5819,7 +5819,7 @@ def test_update_monitoring_alert(sagemaker_session):
58195819
data_points_to_alert=1,
58205820
evaluation_period=1,
58215821
)
5822-
assert sagemaker_session.sagemaker_client.update_monitoring_alert.called_with(
5822+
sagemaker_session.sagemaker_client.update_monitoring_alert.assert_called_once_with(
58235823
MonitoringScheduleName="schedule-name",
58245824
MonitoringAlertName="alert-name",
58255825
DatapointsToAlert=1,
@@ -5833,7 +5833,7 @@ def test_list_monitoring_alerts(sagemaker_session):
58335833
next_token="next_token",
58345834
max_results=100,
58355835
)
5836-
assert sagemaker_session.sagemaker_client.list_monitoring_alerts.called_with(
5836+
sagemaker_session.sagemaker_client.list_monitoring_alerts.assert_called_once_with(
58375837
MonitoringScheduleName="schedule-name",
58385838
NextToken="next_token",
58395839
MaxResults=100,
@@ -5852,7 +5852,7 @@ def test_list_monitoring_alert_history(sagemaker_session):
58525852
creation_time_before="creation_time_before",
58535853
creation_time_after="creation_time_after",
58545854
)
5855-
assert sagemaker_session.sagemaker_client.list_monitoring_alerts.called_with(
5855+
sagemaker_session.sagemaker_client.list_monitoring_alerts.assert_called_once_with(
58565856
MonitoringScheduleName="schedule-name",
58575857
MonitoringAlertName="alert-name",
58585858
SortBy="CreationTime",
@@ -5869,7 +5869,7 @@ def test_list_monitoring_alert_history(sagemaker_session):
58695869
def test_wait_for_athena_query(query_execution, sagemaker_session):
58705870
query_execution.return_value = {"QueryExecution": {"Status": {"State": "SUCCEEDED"}}}
58715871
sagemaker_session.wait_for_athena_query(query_execution_id="query_id")
5872-
assert query_execution.called_with(query_execution_id="query_id")
5872+
query_execution.assert_called_once_with(query_execution_id="query_id")
58735873

58745874

58755875
def test_search(sagemaker_session):
@@ -5908,7 +5908,7 @@ def test_search(sagemaker_session):
59085908
max_results=50,
59095909
)
59105910
assert sagemaker_session.sagemaker_client.search.called_once()
5911-
assert sagemaker_session.sagemaker_client.search.called_with(**expected_search_args)
5911+
sagemaker_session.sagemaker_client.search.assert_called_once_with(**expected_search_args)
59125912

59135913

59145914
def test_batch_get_record(sagemaker_session):
@@ -5931,7 +5931,7 @@ def test_batch_get_record(sagemaker_session):
59315931
]
59325932
)
59335933
assert sagemaker_session.sagemaker_client.batch_get_record.called_once()
5934-
assert sagemaker_session.sagemaker_client.batch_get_record.called_with(
5934+
sagemaker_session.sagemaker_client.batch_get_record.assert_called_once_with(
59355935
**expected_batch_get_record_args
59365936
)
59375937

@@ -5958,7 +5958,7 @@ def test_batch_get_record_expiration_time_response(sagemaker_session):
59585958
expiration_time_response="Disabled",
59595959
)
59605960
assert sagemaker_session.sagemaker_client.batch_get_record.called_once()
5961-
assert sagemaker_session.sagemaker_client.batch_get_record.called_with(
5961+
sagemaker_session.sagemaker_client.batch_get_record.assert_called_once_with(
59625962
**expected_batch_get_record_args
59635963
)
59645964

@@ -6291,19 +6291,15 @@ def test_create_inference_recommendations_job_propogate_other_exception(
62916291

62926292
def test_create_presigned_mlflow_tracking_server_url(sagemaker_session):
62936293
sagemaker_session.create_presigned_mlflow_tracking_server_url("ts", 1, 2)
6294-
assert (
6295-
sagemaker_session.sagemaker_client.create_presigned_mlflow_tracking_server_url.called_with(
6296-
TrackingServerName="ts", ExpiresInSeconds=1, SessionExpirationDurationInSeconds=2
6297-
)
6294+
sagemaker_session.sagemaker_client.create_presigned_mlflow_tracking_server_url.assert_called_once_with(
6295+
TrackingServerName="ts", ExpiresInSeconds=1, SessionExpirationDurationInSeconds=2
62986296
)
62996297

63006298

63016299
def test_create_presigned_mlflow_tracking_server_url_minimal(sagemaker_session):
63026300
sagemaker_session.create_presigned_mlflow_tracking_server_url("ts")
6303-
assert (
6304-
sagemaker_session.sagemaker_client.create_presigned_mlflow_tracking_server_url.called_with(
6305-
TrackingServerName="ts"
6306-
)
6301+
sagemaker_session.sagemaker_client.create_presigned_mlflow_tracking_server_url.assert_called_once_with(
6302+
TrackingServerName="ts"
63076303
)
63086304

63096305

0 commit comments

Comments
 (0)