Skip to content

Commit 6e38a7a

Browse files
committed
fix: _internal_config definition
1 parent 9346011 commit 6e38a7a

File tree

2 files changed

+22
-26
lines changed

2 files changed

+22
-26
lines changed

src/sagemaker/jumpstart/estimator.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -535,18 +535,6 @@ def _validate_model_id_and_get_type_hook():
535535
if not self.model_type and not hub_arn:
536536
raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id))
537537

538-
self._internal_config = JumpStartModelInternalConfig(
539-
specs=verify_model_region_and_return_specs(
540-
region=self.region,
541-
model_id=self.model_id,
542-
version=self.model_version,
543-
hub_arn=self.hub_arn,
544-
model_type=self.model_type,
545-
scope=JumpStartScriptScope.TRAINING,
546-
sagemaker_session=self.sagemaker_session,
547-
)
548-
)
549-
550538
estimator_init_kwargs = get_init_kwargs(
551539
model_id=model_id,
552540
model_version=model_version,
@@ -620,7 +608,17 @@ def _validate_model_id_and_get_type_hook():
620608
self.role = estimator_init_kwargs.role
621609
self.sagemaker_session = estimator_init_kwargs.sagemaker_session
622610
self._enable_network_isolation = estimator_init_kwargs.enable_network_isolation
623-
611+
self._internal_config = JumpStartModelInternalConfig(
612+
specs=verify_model_region_and_return_specs(
613+
region=self.region,
614+
model_id=self.model_id,
615+
version=self.model_version,
616+
hub_arn=self.hub_arn,
617+
model_type=self.model_type,
618+
scope=JumpStartScriptScope.TRAINING,
619+
sagemaker_session=self.sagemaker_session,
620+
)
621+
)
624622
super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict())
625623

626624
def fit(

src/sagemaker/jumpstart/model.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -312,18 +312,6 @@ def _validate_model_id_and_type():
312312
if not self.model_type and not hub_arn:
313313
raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id))
314314

315-
self._internal_config = JumpStartModelInternalConfig(
316-
specs=verify_model_region_and_return_specs(
317-
region=self.region,
318-
model_id=self.model_id,
319-
version=self.model_version,
320-
hub_arn=self.hub_arn,
321-
model_type=self.model_type,
322-
scope=JumpStartScriptScope.INFERENCE,
323-
sagemaker_session=self.sagemaker_session,
324-
)
325-
)
326-
327315
self._model_data_is_set = model_data is not None
328316
model_init_kwargs = get_init_kwargs(
329317
model_id=model_id,
@@ -373,7 +361,17 @@ def _validate_model_id_and_type():
373361
self.log_subscription_warning()
374362

375363
model_init_kwargs_dict = model_init_kwargs.to_kwargs_dict()
376-
364+
self._internal_config = JumpStartModelInternalConfig(
365+
specs=verify_model_region_and_return_specs(
366+
region=self.region,
367+
model_id=self.model_id,
368+
version=self.model_version,
369+
hub_arn=self.hub_arn,
370+
model_type=self.model_type,
371+
scope=JumpStartScriptScope.INFERENCE,
372+
sagemaker_session=self.sagemaker_session,
373+
)
374+
)
377375
super(JumpStartModel, self).__init__(**model_init_kwargs_dict)
378376

379377
self.model_package_arn = model_init_kwargs.model_package_arn

0 commit comments

Comments
 (0)