@@ -44,7 +44,13 @@ class AmazonAlgorithmEstimatorBase(EstimatorBase):
4444 repo_version = None
4545
4646 def __init__ (
47- self , role , train_instance_count , train_instance_type , data_location = None , ** kwargs
47+ self ,
48+ role ,
49+ train_instance_count ,
50+ train_instance_type ,
51+ data_location = None ,
52+ enable_network_isolation = False ,
53+ ** kwargs
4854 ):
4955 """Initialize an AmazonAlgorithmEstimatorBase.
5056
@@ -63,6 +69,10 @@ def __init__(
6369 "s3://example-bucket/some-key-prefix/". Objects will be saved in
6470 a unique sub-directory of the specified location. If None, a
6571 default data location will be used.
72+ enable_network_isolation (bool): Specifies whether container will
73+ run in network isolation mode. Network isolation mode restricts
74+ the container access to outside networks (such as the internet).
75+ Also known as internet-free mode (default: ``False``).
6676 **kwargs: Additional parameters passed to
6777 :class:`~sagemaker.estimator.EstimatorBase`.
6878
@@ -71,14 +81,6 @@ def __init__(
7181 You can find additional parameters for initializing this class at
7282 :class:`~sagemaker.estimator.EstimatorBase`.
7383 """
74-
75- if "enable_network_isolation" in kwargs :
76- logger .debug (
77- "removing unused enable_network_isolation argument: %s" ,
78- str (kwargs ["enable_network_isolation" ]),
79- )
80- del kwargs ["enable_network_isolation" ]
81-
8284 super (AmazonAlgorithmEstimatorBase , self ).__init__ (
8385 role , train_instance_count , train_instance_type , ** kwargs
8486 )
@@ -87,6 +89,7 @@ def __init__(
8789 self .sagemaker_session .default_bucket ()
8890 )
8991 self ._data_location = data_location
92+ self ._enable_network_isolation = enable_network_isolation
9093
9194 def train_image (self ):
9295 """Placeholder docstring"""
@@ -98,6 +101,14 @@ def hyperparameters(self):
98101 """Placeholder docstring"""
99102 return hp .serialize_all (self )
100103
104+ def enable_network_isolation (self ):
105+ """If this Estimator can use network isolation when running.
106+
107+ Returns:
108+ bool: Whether this Estimator can use network isolation or not.
109+ """
110+ return self ._enable_network_isolation
111+
101112 @property
102113 def data_location (self ):
103114 """Placeholder docstring"""
0 commit comments