Skip to content

Commit bb17c3b

Browse files
authored
fix: pass enable_network_isolation in Estimator.create_model (#1038)
1 parent 23a33a7 commit bb17c3b

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

src/sagemaker/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,7 @@ def predict_wrapper(endpoint, session):
11421142
vpc_config=self.get_vpc_config(vpc_config_override),
11431143
sagemaker_session=self.sagemaker_session,
11441144
predictor_cls=predictor_cls,
1145+
enable_network_isolation=self.enable_network_isolation(),
11451146
**kwargs
11461147
)
11471148

tests/unit/test_estimator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1836,6 +1836,25 @@ def test_generic_to_deploy(sagemaker_session):
18361836
assert predictor.sagemaker_session == sagemaker_session
18371837

18381838

1839+
def test_generic_to_deploy_network_isolation(sagemaker_session):
1840+
e = Estimator(
1841+
IMAGE_NAME,
1842+
ROLE,
1843+
INSTANCE_COUNT,
1844+
INSTANCE_TYPE,
1845+
output_path=OUTPUT_PATH,
1846+
enable_network_isolation=True,
1847+
sagemaker_session=sagemaker_session,
1848+
)
1849+
1850+
e.fit()
1851+
e.deploy(INSTANCE_COUNT, INSTANCE_TYPE)
1852+
1853+
sagemaker_session.create_model.assert_called_once()
1854+
_, kwargs = sagemaker_session.create_model.call_args
1855+
assert kwargs["enable_network_isolation"]
1856+
1857+
18391858
def test_generic_training_job_analytics(sagemaker_session):
18401859
sagemaker_session.sagemaker_client.describe_training_job = Mock(
18411860
name="describe_training_job",

0 commit comments

Comments
 (0)