Skip to content

Commit bc71d77

Browse files
authored
change: add enable_network_isolation to EstimatorBase (#1393)
This removes the need of having redundant logic in each of the subclasses of EstimatorBase.
1 parent 1385733 commit bc71d77

File tree

3 files changed

+23
-39
lines changed

3 files changed

+23
-39
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,17 @@ def __init__(
8282
:class:`~sagemaker.estimator.EstimatorBase`.
8383
"""
8484
super(AmazonAlgorithmEstimatorBase, self).__init__(
85-
role, train_instance_count, train_instance_type, **kwargs
85+
role,
86+
train_instance_count,
87+
train_instance_type,
88+
enable_network_isolation=enable_network_isolation,
89+
**kwargs
8690
)
8791

8892
data_location = data_location or "s3://{}/sagemaker-record-sets/".format(
8993
self.sagemaker_session.default_bucket()
9094
)
9195
self._data_location = data_location
92-
self._enable_network_isolation = enable_network_isolation
9396

9497
def train_image(self):
9598
"""Placeholder docstring"""
@@ -101,14 +104,6 @@ def hyperparameters(self):
101104
"""Placeholder docstring"""
102105
return hp.serialize_all(self)
103106

104-
def enable_network_isolation(self):
105-
"""If this Estimator can use network isolation when running.
106-
107-
Returns:
108-
bool: Whether this Estimator can use network isolation or not.
109-
"""
110-
return self._enable_network_isolation
111-
112107
@property
113108
def data_location(self):
114109
"""Placeholder docstring"""

src/sagemaker/estimator.py

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(
9898
debugger_hook_config=None,
9999
tensorboard_output_config=None,
100100
enable_sagemaker_metrics=None,
101+
enable_network_isolation=False,
101102
):
102103
"""Initialize an ``EstimatorBase`` instance.
103104
@@ -199,6 +200,11 @@ def __init__(
199200
Series. For more information see:
200201
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
201202
(default: ``None``).
203+
enable_network_isolation (bool): Specifies whether container will
204+
run in network isolation mode (default: ``False``). Network
205+
isolation mode restricts the container access to outside networks
206+
(such as the Internet). The container does not make any inbound or
207+
outbound network calls. Also known as Internet-free mode.
202208
"""
203209
self.role = role
204210
self.train_instance_count = train_instance_count
@@ -260,6 +266,7 @@ def __init__(
260266
self.collection_configs = None
261267

262268
self.enable_sagemaker_metrics = enable_sagemaker_metrics
269+
self._enable_network_isolation = enable_network_isolation
263270

264271
@abstractmethod
265272
def train_image(self):
@@ -290,7 +297,7 @@ def enable_network_isolation(self):
290297
Returns:
291298
bool: Whether this Estimator needs network isolation or not.
292299
"""
293-
return False
300+
return self._enable_network_isolation
294301

295302
def prepare_workflow_for_training(self, job_name=None):
296303
"""Calls _prepare_for_training. Used when setting up a workflow.
@@ -1219,21 +1226,17 @@ def __init__(
12191226
checkpoints will be provided under `/opt/ml/checkpoints/`.
12201227
(default: ``None``).
12211228
enable_network_isolation (bool): Specifies whether container will
1222-
run in network isolation mode. Network isolation mode restricts
1223-
the container access to outside networks (such as the Internet).
1224-
The container does not make any inbound or outbound network
1225-
calls. If ``True``, a channel named "code" will be created for any
1226-
user entry script for training. The user entry script, files in
1227-
source_dir (if specified), and dependencies will be uploaded in
1228-
a tar to S3. Also known as internet-free mode (default: ``False``).
1229+
run in network isolation mode (default: ``False``). Network
1230+
isolation mode restricts the container access to outside networks
1231+
(such as the Internet). The container does not make any inbound or
1232+
outbound network calls. Also known as Internet-free mode.
12291233
enable_sagemaker_metrics (bool): enable SageMaker Metrics Time
12301234
Series. For more information see:
12311235
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
12321236
(default: ``None``).
12331237
"""
12341238
self.image_name = image_name
12351239
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
1236-
self._enable_network_isolation = enable_network_isolation
12371240
super(Estimator, self).__init__(
12381241
role,
12391242
train_instance_count,
@@ -1261,16 +1264,9 @@ def __init__(
12611264
debugger_hook_config=debugger_hook_config,
12621265
tensorboard_output_config=tensorboard_output_config,
12631266
enable_sagemaker_metrics=enable_sagemaker_metrics,
1267+
enable_network_isolation=enable_network_isolation,
12641268
)
12651269

1266-
def enable_network_isolation(self):
1267-
"""If this Estimator can use network isolation when running.
1268-
1269-
Returns:
1270-
bool: Whether this Estimator can use network isolation or not.
1271-
"""
1272-
return self._enable_network_isolation
1273-
12741270
def train_image(self):
12751271
"""Returns the docker image to use for training.
12761272
@@ -1498,15 +1494,15 @@ def __init__(
14981494
>>> |------ train.py
14991495
>>> |------ common
15001496
>>> |------ virtual-env
1497+
15011498
enable_network_isolation (bool): Specifies whether container will
15021499
run in network isolation mode. Network isolation mode restricts
15031500
the container access to outside networks (such as the internet).
15041501
The container does not make any inbound or outbound network
15051502
calls. If True, a channel named "code" will be created for any
15061503
user entry script for training. The user entry script, files in
15071504
source_dir (if specified), and dependencies will be uploaded in
1508-
a tar to S3. Also known as internet-free mode (default: `False`
1509-
).
1505+
a tar to S3. Also known as internet-free mode (default: `False`).
15101506
git_config (dict[str, str]): Git configurations used for cloning
15111507
files, including ``repo``, ``branch``, ``commit``,
15121508
``2FA_enabled``, ``username``, ``password`` and ``token``. The
@@ -1579,7 +1575,7 @@ def __init__(
15791575
You can find additional parameters for initializing this class at
15801576
:class:`~sagemaker.estimator.EstimatorBase`.
15811577
"""
1582-
super(Framework, self).__init__(**kwargs)
1578+
super(Framework, self).__init__(enable_network_isolation=enable_network_isolation, **kwargs)
15831579
if entry_point.startswith("s3://"):
15841580
raise ValueError(
15851581
"Invalid entry point script: {}. Must be a path to a local file.".format(
@@ -1599,7 +1595,6 @@ def __init__(
15991595
self.container_log_level = container_log_level
16001596
self.code_location = code_location
16011597
self.image_name = image_name
1602-
self._enable_network_isolation = enable_network_isolation
16031598

16041599
self.uploaded_code = None
16051600

@@ -1608,14 +1603,6 @@ def __init__(
16081603
self.checkpoint_local_path = checkpoint_local_path
16091604
self.enable_sagemaker_metrics = enable_sagemaker_metrics
16101605

1611-
def enable_network_isolation(self):
1612-
"""Return True if this Estimator can use network isolation to run.
1613-
1614-
Returns:
1615-
bool: Whether this Estimator can use network isolation or not.
1616-
"""
1617-
return self._enable_network_isolation
1618-
16191606
def _prepare_for_training(self, job_name=None):
16201607
"""Set hyperparameters needed for training. This method will also
16211608
validate ``source_dir``.

tests/unit/test_estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def test_framework_all_init_args(sagemaker_session):
209209
checkpoint_s3_uri="s3://bucket/checkpoint",
210210
checkpoint_local_path="file://local/checkpoint",
211211
enable_sagemaker_metrics=True,
212+
enable_network_isolation=True,
212213
)
213214
_TrainingJob.start_new(f, "s3://mydata", None)
214215
sagemaker_session.train.assert_called_once()
@@ -247,6 +248,7 @@ def test_framework_all_init_args(sagemaker_session):
247248
"checkpoint_s3_uri": "s3://bucket/checkpoint",
248249
"checkpoint_local_path": "file://local/checkpoint",
249250
"enable_sagemaker_metrics": True,
251+
"enable_network_isolation": True,
250252
}
251253

252254

0 commit comments

Comments
 (0)