Skip to content

Commit 5ec22c0

Browse files
authored
fix: add enable_network_isolation to generic Estimator class (#1027)
1 parent 228a81d commit 5ec22c0

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

src/sagemaker/estimator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ def __init__(
905905
train_max_wait=None,
906906
checkpoint_s3_uri=None,
907907
checkpoint_local_path=None,
908+
enable_network_isolation=False,
908909
):
909910
"""Initialize an ``Estimator`` instance.
910911
@@ -1008,9 +1009,18 @@ def __init__(
10081009
started. If the path is unset then SageMaker assumes the
10091010
checkpoints will be provided under `/opt/ml/checkpoints/`.
10101011
(default: ``None``).
1012+
enable_network_isolation (bool): Specifies whether container will
1013+
run in network isolation mode. Network isolation mode restricts
1014+
the container access to outside networks (such as the Internet).
1015+
The container does not make any inbound or outbound network
1016+
calls. If ``True``, a channel named "code" will be created for any
1017+
user entry script for training. The user entry script, files in
1018+
source_dir (if specified), and dependencies will be uploaded in
1019+
a tar to S3. Also known as internet-free mode (default: ``False``).
10111020
"""
10121021
self.image_name = image_name
10131022
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
1023+
self._enable_network_isolation = enable_network_isolation
10141024
super(Estimator, self).__init__(
10151025
role,
10161026
train_instance_count,
@@ -1036,6 +1046,14 @@ def __init__(
10361046
checkpoint_local_path=checkpoint_local_path,
10371047
)
10381048

1049+
def enable_network_isolation(self):
1050+
"""If this Estimator can use network isolation when running.
1051+
1052+
Returns:
1053+
bool: Whether this Estimator can use network isolation or not.
1054+
"""
1055+
return self._enable_network_isolation
1056+
10391057
def train_image(self):
10401058
"""Returns the docker image to use for training.
10411059

tests/unit/test_estimator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1779,6 +1779,24 @@ def test_generic_to_fit_with_encrypt_inter_container_traffic_flag(sagemaker_sess
17791779
assert args["encrypt_inter_container_traffic"] is True
17801780

17811781

1782+
def test_generic_to_fit_with_network_isolation(sagemaker_session):
1783+
e = Estimator(
1784+
IMAGE_NAME,
1785+
ROLE,
1786+
INSTANCE_COUNT,
1787+
INSTANCE_TYPE,
1788+
output_path=OUTPUT_PATH,
1789+
sagemaker_session=sagemaker_session,
1790+
enable_network_isolation=True,
1791+
)
1792+
1793+
e.fit()
1794+
1795+
sagemaker_session.train.assert_called_once()
1796+
args = sagemaker_session.train.call_args[1]
1797+
assert args["enable_network_isolation"]
1798+
1799+
17821800
def test_generic_to_deploy(sagemaker_session):
17831801
e = Estimator(
17841802
IMAGE_NAME,

0 commit comments

Comments
 (0)