@@ -98,6 +98,7 @@ def __init__(
9898 debugger_hook_config = None ,
9999 tensorboard_output_config = None ,
100100 enable_sagemaker_metrics = None ,
101+ enable_network_isolation = False ,
101102 ):
102103 """Initialize an ``EstimatorBase`` instance.
103104
@@ -199,6 +200,11 @@ def __init__(
199200 Series. For more information see:
200201 https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
201202 (default: ``None``).
203+ enable_network_isolation (bool): Specifies whether container will
204+ run in network isolation mode (default: ``False``). Network
205+ isolation mode restricts the container access to outside networks
206+ (such as the Internet). The container does not make any inbound or
207+ outbound network calls. Also known as Internet-free mode.
202208 """
203209 self .role = role
204210 self .train_instance_count = train_instance_count
@@ -260,6 +266,7 @@ def __init__(
260266 self .collection_configs = None
261267
262268 self .enable_sagemaker_metrics = enable_sagemaker_metrics
269+ self ._enable_network_isolation = enable_network_isolation
263270
264271 @abstractmethod
265272 def train_image (self ):
@@ -290,7 +297,7 @@ def enable_network_isolation(self):
290297 Returns:
291298 bool: Whether this Estimator needs network isolation or not.
292299 """
293- return False
300+ return self . _enable_network_isolation
294301
295302 def prepare_workflow_for_training (self , job_name = None ):
296303 """Calls _prepare_for_training. Used when setting up a workflow.
@@ -1219,21 +1226,17 @@ def __init__(
12191226 checkpoints will be provided under `/opt/ml/checkpoints/`.
12201227 (default: ``None``).
12211228 enable_network_isolation (bool): Specifies whether container will
1222- run in network isolation mode. Network isolation mode restricts
1223- the container access to outside networks (such as the Internet).
1224- The container does not make any inbound or outbound network
1225- calls. If ``True``, a channel named "code" will be created for any
1226- user entry script for training. The user entry script, files in
1227- source_dir (if specified), and dependencies will be uploaded in
1228- a tar to S3. Also known as internet-free mode (default: ``False``).
1229+ run in network isolation mode (default: ``False``). Network
1230+ isolation mode restricts the container access to outside networks
1231+ (such as the Internet). The container does not make any inbound or
1232+ outbound network calls. Also known as Internet-free mode.
12291233 enable_sagemaker_metrics (bool): enable SageMaker Metrics Time
12301234 Series. For more information see:
12311235 https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
12321236 (default: ``None``).
12331237 """
12341238 self .image_name = image_name
12351239 self .hyperparam_dict = hyperparameters .copy () if hyperparameters else {}
1236- self ._enable_network_isolation = enable_network_isolation
12371240 super (Estimator , self ).__init__ (
12381241 role ,
12391242 train_instance_count ,
@@ -1261,16 +1264,9 @@ def __init__(
12611264 debugger_hook_config = debugger_hook_config ,
12621265 tensorboard_output_config = tensorboard_output_config ,
12631266 enable_sagemaker_metrics = enable_sagemaker_metrics ,
1267+ enable_network_isolation = enable_network_isolation ,
12641268 )
12651269
1266- def enable_network_isolation (self ):
1267- """If this Estimator can use network isolation when running.
1268-
1269- Returns:
1270- bool: Whether this Estimator can use network isolation or not.
1271- """
1272- return self ._enable_network_isolation
1273-
12741270 def train_image (self ):
12751271 """Returns the docker image to use for training.
12761272
@@ -1498,15 +1494,15 @@ def __init__(
14981494 >>> |------ train.py
14991495 >>> |------ common
15001496 >>> |------ virtual-env
1497+
15011498 enable_network_isolation (bool): Specifies whether container will
15021499 run in network isolation mode. Network isolation mode restricts
15031500 the container access to outside networks (such as the internet).
15041501 The container does not make any inbound or outbound network
15051502 calls. If True, a channel named "code" will be created for any
15061503 user entry script for training. The user entry script, files in
15071504 source_dir (if specified), and dependencies will be uploaded in
1508- a tar to S3. Also known as internet-free mode (default: `False`
1509- ).
1505+ a tar to S3. Also known as internet-free mode (default: `False`).
15101506 git_config (dict[str, str]): Git configurations used for cloning
15111507 files, including ``repo``, ``branch``, ``commit``,
15121508 ``2FA_enabled``, ``username``, ``password`` and ``token``. The
@@ -1579,7 +1575,7 @@ def __init__(
15791575 You can find additional parameters for initializing this class at
15801576 :class:`~sagemaker.estimator.EstimatorBase`.
15811577 """
1582- super (Framework , self ).__init__ (** kwargs )
1578+ super (Framework , self ).__init__ (enable_network_isolation = enable_network_isolation , ** kwargs )
15831579 if entry_point .startswith ("s3://" ):
15841580 raise ValueError (
15851581 "Invalid entry point script: {}. Must be a path to a local file." .format (
@@ -1599,7 +1595,6 @@ def __init__(
15991595 self .container_log_level = container_log_level
16001596 self .code_location = code_location
16011597 self .image_name = image_name
1602- self ._enable_network_isolation = enable_network_isolation
16031598
16041599 self .uploaded_code = None
16051600
@@ -1608,14 +1603,6 @@ def __init__(
16081603 self .checkpoint_local_path = checkpoint_local_path
16091604 self .enable_sagemaker_metrics = enable_sagemaker_metrics
16101605
1611- def enable_network_isolation (self ):
1612- """Return True if this Estimator can use network isolation to run.
1613-
1614- Returns:
1615- bool: Whether this Estimator can use network isolation or not.
1616- """
1617- return self ._enable_network_isolation
1618-
16191606 def _prepare_for_training (self , job_name = None ):
16201607 """Set hyperparameters needed for training. This method will also
16211608 validate ``source_dir``.
0 commit comments