@@ -65,6 +65,7 @@ def stop(self):
6565 @staticmethod
6666 def _load_config (inputs , estimator , expand_role = True , validate_uri = True ):
6767 """Placeholder docstring"""
68+ model_access_config , hub_access_config = _Job ._get_access_configs (estimator )
6869 input_config = _Job ._format_inputs_to_input_config (inputs , validate_uri )
6970 role = (
7071 estimator .sagemaker_session .expand_role (estimator .role )
@@ -95,19 +96,23 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
9596 validate_uri ,
9697 content_type = "application/x-sagemaker-model" ,
9798 input_mode = "File" ,
99+ model_access_config = model_access_config ,
100+ hub_access_config = hub_access_config ,
98101 )
99102 if model_channel :
100103 input_config = [] if input_config is None else input_config
101104 input_config .append (model_channel )
102105
103- if estimator .enable_network_isolation ():
104- code_channel = _Job ._prepare_channel (
105- input_config , estimator .code_uri , estimator .code_channel_name , validate_uri
106- )
106+ code_channel = _Job ._prepare_channel (
107+ input_config ,
108+ estimator .code_uri ,
109+ estimator .code_channel_name ,
110+ validate_uri ,
111+ )
107112
108- if code_channel :
109- input_config = [] if input_config is None else input_config
110- input_config .append (code_channel )
113+ if code_channel :
114+ input_config = [] if input_config is None else input_config
115+ input_config .append (code_channel )
111116
112117 return {
113118 "input_config" : input_config ,
@@ -118,6 +123,23 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
118123 "vpc_config" : vpc_config ,
119124 }
120125
126+ @staticmethod
127+ def _get_access_configs (estimator ):
128+ """Return access configs from estimator object.
129+
130+ JumpStartEstimator uses access configs which need to be added to the model channel,
131+ so they are passed down to the job level.
132+
133+ Args:
134+ estimator (EstimatorBase): estimator object with access config field if applicable
135+ """
136+ model_access_config , hub_access_config = None , None
137+ if hasattr (estimator , "model_access_config" ):
138+ model_access_config = estimator .model_access_config
139+ if hasattr (estimator , "hub_access_config" ):
140+ hub_access_config = estimator .hub_access_config
141+ return model_access_config , hub_access_config
142+
121143 @staticmethod
122144 def _format_inputs_to_input_config (inputs , validate_uri = True ):
123145 """Placeholder docstring"""
@@ -173,6 +195,8 @@ def _format_string_uri_input(
173195 input_mode = None ,
174196 compression = None ,
175197 target_attribute_name = None ,
198+ model_access_config = None ,
199+ hub_access_config = None ,
176200 ):
177201 """Placeholder docstring"""
178202 s3_input_result = TrainingInput (
@@ -181,6 +205,8 @@ def _format_string_uri_input(
181205 input_mode = input_mode ,
182206 compression = compression ,
183207 target_attribute_name = target_attribute_name ,
208+ model_access_config = model_access_config ,
209+ hub_access_config = hub_access_config ,
184210 )
185211 if isinstance (uri_input , str ) and validate_uri and uri_input .startswith ("s3://" ):
186212 return s3_input_result
@@ -193,7 +219,11 @@ def _format_string_uri_input(
193219 )
194220 if isinstance (uri_input , str ):
195221 return s3_input_result
196- if isinstance (uri_input , (TrainingInput , file_input , FileSystemInput )):
222+ if isinstance (uri_input , (file_input , FileSystemInput )):
223+ return uri_input
224+ if isinstance (uri_input , TrainingInput ):
225+ uri_input .add_hub_access_config (hub_access_config = hub_access_config )
226+ uri_input .add_model_access_config (model_access_config = model_access_config )
197227 return uri_input
198228 if is_pipeline_variable (uri_input ):
199229 return s3_input_result
@@ -211,6 +241,8 @@ def _prepare_channel(
211241 validate_uri = True ,
212242 content_type = None ,
213243 input_mode = None ,
244+ model_access_config = None ,
245+ hub_access_config = None ,
214246 ):
215247 """Placeholder docstring"""
216248 if not channel_uri :
@@ -226,7 +258,12 @@ def _prepare_channel(
226258 raise ValueError ("Duplicate channel {} not allowed." .format (channel_name ))
227259
228260 channel_input = _Job ._format_string_uri_input (
229- channel_uri , validate_uri , content_type , input_mode
261+ channel_uri ,
262+ validate_uri ,
263+ content_type ,
264+ input_mode ,
265+ model_access_config = model_access_config ,
266+ hub_access_config = hub_access_config ,
230267 )
231268 channel = _Job ._convert_input_to_channel (channel_name , channel_input )
232269
0 commit comments