Skip to content

Commit 9e445d3

Browse files
committed
chore: address pr comments, fix formatting
1 parent 557c2d3 commit 9e445d3

File tree

3 files changed

+66
-71
lines changed

3 files changed

+66
-71
lines changed

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
3232
from sagemaker.jumpstart.factory.utils import (
3333
_set_temp_sagemaker_session_if_not_set,
34-
get_model_info_kwargs,
34+
get_model_info_default_kwargs,
3535
)
3636
from sagemaker.jumpstart.hub.utils import (
3737
construct_hub_model_arn_from_inputs,
@@ -209,14 +209,11 @@ def get_init_kwargs(
209209

210210
estimator_init_kwargs = _set_temp_sagemaker_session_if_not_set(kwargs=estimator_init_kwargs)
211211
estimator_init_kwargs.specs = verify_model_region_and_return_specs(
212-
model_id=estimator_init_kwargs.model_id,
213-
version=estimator_init_kwargs.model_version,
214-
hub_arn=estimator_init_kwargs.hub_arn,
212+
**get_model_info_default_kwargs(
213+
estimator_init_kwargs, include_model_version=False, include_tolerate_flags=False
214+
),
215+
version=estimator_init_kwargs.model_version or "*",
215216
scope=JumpStartScriptScope.TRAINING,
216-
region=estimator_init_kwargs.region,
217-
sagemaker_session=estimator_init_kwargs.sagemaker_session,
218-
model_type=estimator_init_kwargs.model_type,
219-
config_name=estimator_init_kwargs.config_name,
220217
# We set these flags to True to retrieve the json specs.
221218
# Exceptions will be thrown later if these are not tolerated.
222219
tolerate_deprecated_model=True,
@@ -285,14 +282,11 @@ def get_fit_kwargs(
285282

286283
estimator_fit_kwargs = _set_temp_sagemaker_session_if_not_set(kwargs=estimator_fit_kwargs)
287284
estimator_fit_kwargs.specs = verify_model_region_and_return_specs(
288-
model_id=estimator_fit_kwargs.model_id,
289-
version=estimator_fit_kwargs.model_version,
290-
hub_arn=estimator_fit_kwargs.hub_arn,
285+
**get_model_info_default_kwargs(
286+
estimator_fit_kwargs, include_model_version=False, include_tolerate_flags=False
287+
),
288+
version=estimator_fit_kwargs.model_version or "*",
291289
scope=JumpStartScriptScope.TRAINING,
292-
region=estimator_fit_kwargs.region,
293-
sagemaker_session=estimator_fit_kwargs.sagemaker_session,
294-
model_type=estimator_fit_kwargs.model_type,
295-
config_name=estimator_fit_kwargs.config_name,
296290
# We set these flags to True to retrieve the json specs.
297291
# Exceptions will be thrown later if these are not tolerated.
298292
tolerate_deprecated_model=True,
@@ -526,7 +520,7 @@ def _add_instance_type_and_count_to_kwargs(
526520
orig_instance_type = kwargs.instance_type
527521

528522
kwargs.instance_type = kwargs.instance_type or instance_types.retrieve_default(
529-
**get_model_info_kwargs(kwargs), scope=JumpStartScriptScope.TRAINING
523+
**get_model_info_default_kwargs(kwargs), scope=JumpStartScriptScope.TRAINING
530524
)
531525

532526
kwargs.instance_count = kwargs.instance_count or 1
@@ -571,7 +565,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
571565
"""Sets image uri in kwargs based on default or override, returns full kwargs."""
572566

573567
kwargs.image_uri = kwargs.image_uri or image_uris.retrieve(
574-
**get_model_info_kwargs(kwargs),
568+
**get_model_info_default_kwargs(kwargs),
575569
instance_type=kwargs.instance_type,
576570
framework=None,
577571
image_scope=JumpStartScriptScope.TRAINING,
@@ -600,17 +594,17 @@ def _add_model_reference_arn_to_kwargs(
600594
def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs:
601595
"""Sets model uri in kwargs based on default or override, returns full kwargs."""
602596

603-
if _model_supports_training_model_uri(**get_model_info_kwargs(kwargs)):
597+
if _model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs)):
604598
default_model_uri = model_uris.retrieve(
605599
model_scope=JumpStartScriptScope.TRAINING,
606600
instance_type=kwargs.instance_type,
607-
**get_model_info_kwargs(kwargs),
601+
**get_model_info_default_kwargs(kwargs),
608602
)
609603

610604
if (
611605
kwargs.model_uri is not None
612606
and kwargs.model_uri != default_model_uri
613-
and not _model_supports_incremental_training(**get_model_info_kwargs(kwargs))
607+
and not _model_supports_incremental_training(**get_model_info_default_kwargs(kwargs))
614608
):
615609
JUMPSTART_LOGGER.warning(
616610
"'%s' does not support incremental training but is being trained with"
@@ -638,7 +632,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart
638632
"""Sets source dir in kwargs based on default or override, returns full kwargs."""
639633

640634
kwargs.source_dir = kwargs.source_dir or script_uris.retrieve(
641-
script_scope=JumpStartScriptScope.TRAINING, **get_model_info_kwargs(kwargs)
635+
script_scope=JumpStartScriptScope.TRAINING, **get_model_info_default_kwargs(kwargs)
642636
)
643637

644638
return kwargs
@@ -650,14 +644,14 @@ def _add_env_to_kwargs(
650644
"""Sets environment in kwargs based on default or override, returns full kwargs."""
651645

652646
extra_env_vars = environment_variables.retrieve_default(
653-
**get_model_info_kwargs(kwargs),
647+
**get_model_info_default_kwargs(kwargs),
654648
script=JumpStartScriptScope.TRAINING,
655649
instance_type=kwargs.instance_type,
656650
include_aws_sdk_env_vars=False,
657651
)
658652

659653
model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri(
660-
**get_model_info_kwargs(kwargs),
654+
**get_model_info_default_kwargs(kwargs),
661655
scope=JumpStartScriptScope.TRAINING,
662656
)
663657

@@ -704,7 +698,7 @@ def _add_training_job_name_to_kwargs(
704698
"""Sets resource name based on default or override, returns full kwargs."""
705699

706700
default_training_job_name = _retrieve_resource_name_base(
707-
**get_model_info_kwargs(kwargs),
701+
**get_model_info_default_kwargs(kwargs),
708702
scope=JumpStartScriptScope.TRAINING,
709703
)
710704

@@ -725,7 +719,7 @@ def _add_hyperparameters_to_kwargs(
725719
)
726720

727721
default_hyperparameters = hyperparameters_utils.retrieve_default(
728-
**get_model_info_kwargs(kwargs),
722+
**get_model_info_default_kwargs(kwargs),
729723
instance_type=kwargs.instance_type,
730724
)
731725

@@ -753,7 +747,7 @@ def _add_metric_definitions_to_kwargs(
753747

754748
default_metric_definitions = (
755749
metric_definitions_utils.retrieve_default(
756-
**get_model_info_kwargs(kwargs),
750+
**get_model_info_default_kwargs(kwargs),
757751
instance_type=kwargs.instance_type,
758752
)
759753
or []
@@ -777,7 +771,7 @@ def _add_estimator_extra_kwargs(
777771
"""Sets extra kwargs based on default or override, returns full kwargs."""
778772

779773
estimator_kwargs_to_add = _retrieve_estimator_init_kwargs(
780-
**get_model_info_kwargs(kwargs), instance_type=kwargs.instance_type
774+
**get_model_info_default_kwargs(kwargs), instance_type=kwargs.instance_type
781775
)
782776

783777
for key, value in estimator_kwargs_to_add.items():
@@ -795,7 +789,7 @@ def _add_estimator_extra_kwargs(
795789
def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstimatorFitKwargs:
796790
"""Sets extra kwargs based on default or override, returns full kwargs."""
797791

798-
fit_kwargs_to_add = _retrieve_estimator_fit_kwargs(**get_model_info_kwargs(kwargs))
792+
fit_kwargs_to_add = _retrieve_estimator_fit_kwargs(**get_model_info_default_kwargs(kwargs))
799793

800794
for key, value in fit_kwargs_to_add.items():
801795
if getattr(kwargs, key) is None:
@@ -811,7 +805,7 @@ def _add_config_name_to_kwargs(
811805

812806
kwargs.config_name = kwargs.config_name or get_top_ranked_config_name(
813807
scope=JumpStartScriptScope.TRAINING,
814-
**get_model_info_kwargs(kwargs, include_config_name=False),
808+
**get_model_info_default_kwargs(kwargs, include_config_name=False),
815809
)
816810

817811
return kwargs

src/sagemaker/jumpstart/factory/model.py

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
)
3030
from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
3131
from sagemaker.jumpstart.constants import (
32-
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3332
INFERENCE_ENTRY_POINT_SCRIPT_NAME,
3433
JUMPSTART_DEFAULT_REGION_NAME,
3534
JUMPSTART_LOGGER,
@@ -63,7 +62,7 @@
6362

6463
from sagemaker.jumpstart.factory.utils import (
6564
_set_temp_sagemaker_session_if_not_set,
66-
get_model_info_kwargs,
65+
get_model_info_default_kwargs,
6766
)
6867
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
6968
from sagemaker.base_predictor import Predictor
@@ -224,7 +223,7 @@ def _add_instance_type_to_kwargs(
224223

225224
orig_instance_type = kwargs.instance_type
226225
kwargs.instance_type = kwargs.instance_type or instance_types.retrieve_default(
227-
**get_model_info_kwargs(kwargs),
226+
**get_model_info_default_kwargs(kwargs),
228227
scope=JumpStartScriptScope.INFERENCE,
229228
training_instance_type=kwargs.training_instance_type,
230229
)
@@ -265,7 +264,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
265264
return kwargs
266265

267266
kwargs.image_uri = kwargs.image_uri or image_uris.retrieve(
268-
**get_model_info_kwargs(kwargs),
267+
**get_model_info_default_kwargs(kwargs),
269268
framework=None,
270269
image_scope=JumpStartScriptScope.INFERENCE,
271270
instance_type=kwargs.instance_type,
@@ -298,7 +297,7 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode
298297
kwargs.model_data = None
299298
return kwargs
300299

301-
model_info_kwargs = get_model_info_kwargs(kwargs)
300+
model_info_kwargs = get_model_info_default_kwargs(kwargs)
302301
model_data: Union[str, dict] = kwargs.model_data or model_uris.retrieve(
303302
**model_info_kwargs,
304303
model_scope=JumpStartScriptScope.INFERENCE,
@@ -336,9 +335,9 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode
336335

337336
source_dir = kwargs.source_dir
338337

339-
if _model_supports_inference_script_uri(**get_model_info_kwargs(kwargs)):
338+
if _model_supports_inference_script_uri(**get_model_info_default_kwargs(kwargs)):
340339
source_dir = source_dir or script_uris.retrieve(
341-
**get_model_info_kwargs(kwargs), script_scope=JumpStartScriptScope.INFERENCE
340+
**get_model_info_default_kwargs(kwargs), script_scope=JumpStartScriptScope.INFERENCE
342341
)
343342

344343
kwargs.source_dir = source_dir
@@ -355,7 +354,7 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod
355354

356355
entry_point = kwargs.entry_point
357356

358-
if _model_supports_inference_script_uri(**get_model_info_kwargs(kwargs)):
357+
if _model_supports_inference_script_uri(**get_model_info_default_kwargs(kwargs)):
359358

360359
entry_point = entry_point or INFERENCE_ENTRY_POINT_SCRIPT_NAME
361360

@@ -377,7 +376,7 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw
377376
env = {}
378377

379378
extra_env_vars = environment_variables.retrieve_default(
380-
**get_model_info_kwargs(kwargs),
379+
**get_model_info_default_kwargs(kwargs),
381380
include_aws_sdk_env_vars=False,
382381
script=JumpStartScriptScope.INFERENCE,
383382
instance_type=kwargs.instance_type,
@@ -402,7 +401,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt
402401
"""Sets model package arn based on default or override, returns full kwargs."""
403402

404403
model_package_arn = kwargs.model_package_arn or _retrieve_model_package_arn(
405-
**get_model_info_kwargs(kwargs),
404+
**get_model_info_default_kwargs(kwargs),
406405
instance_type=kwargs.instance_type,
407406
scope=JumpStartScriptScope.INFERENCE,
408407
)
@@ -414,7 +413,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt
414413
def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
415414
"""Sets extra kwargs based on default or override, returns full kwargs."""
416415

417-
model_kwargs_to_add = _retrieve_model_init_kwargs(**get_model_info_kwargs(kwargs))
416+
model_kwargs_to_add = _retrieve_model_init_kwargs(**get_model_info_default_kwargs(kwargs))
418417

419418
for key, value in model_kwargs_to_add.items():
420419
if getattr(kwargs, key) is None:
@@ -442,7 +441,7 @@ def _add_endpoint_name_to_kwargs(
442441
) -> JumpStartModelDeployKwargs:
443442
"""Sets resource name based on default or override, returns full kwargs."""
444443

445-
default_endpoint_name = _retrieve_resource_name_base(**get_model_info_kwargs(kwargs))
444+
default_endpoint_name = _retrieve_resource_name_base(**get_model_info_default_kwargs(kwargs))
446445

447446
kwargs.endpoint_name = kwargs.endpoint_name or (
448447
name_from_base(default_endpoint_name) if default_endpoint_name is not None else None
@@ -456,7 +455,7 @@ def _add_model_name_to_kwargs(
456455
) -> JumpStartModelInitKwargs:
457456
"""Sets resource name based on default or override, returns full kwargs."""
458457

459-
default_model_name = _retrieve_resource_name_base(**get_model_info_kwargs(kwargs))
458+
default_model_name = _retrieve_resource_name_base(**get_model_info_default_kwargs(kwargs))
460459

461460
kwargs.name = kwargs.name or (
462461
name_from_base(default_model_name) if default_model_name is not None else None
@@ -498,7 +497,7 @@ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any]
498497
"""Sets extra kwargs based on default or override, returns full kwargs."""
499498

500499
deploy_kwargs_to_add = _retrieve_model_deploy_kwargs(
501-
**get_model_info_kwargs(kwargs), instance_type=kwargs.instance_type
500+
**get_model_info_default_kwargs(kwargs), instance_type=kwargs.instance_type
502501
)
503502

504503
for key, value in deploy_kwargs_to_add.items():
@@ -512,7 +511,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
512511
"""Sets the resource requirements based on the default or an override. Returns full kwargs."""
513512

514513
kwargs.resources = kwargs.resources or resource_requirements.retrieve_default(
515-
**get_model_info_kwargs(kwargs),
514+
**get_model_info_default_kwargs(kwargs),
516515
scope=JumpStartScriptScope.INFERENCE,
517516
instance_type=kwargs.instance_type,
518517
)
@@ -548,7 +547,7 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
548547
"""
549548

550549
kwargs.config_name = kwargs.config_name or get_top_ranked_config_name(
551-
**get_model_info_kwargs(kwargs, include_config_name=False),
550+
**get_model_info_default_kwargs(kwargs, include_config_name=False),
552551
scope=JumpStartScriptScope.INFERENCE,
553552
)
554553

@@ -605,7 +604,7 @@ def _add_config_name_to_deploy_kwargs(
605604

606605
else:
607606
default_config_name = kwargs.config_name or get_top_ranked_config_name(
608-
**get_model_info_kwargs(kwargs, include_config_name=False),
607+
**get_model_info_default_kwargs(kwargs, include_config_name=False),
609608
scope=JumpStartScriptScope.INFERENCE,
610609
)
611610

@@ -689,14 +688,11 @@ def get_deploy_kwargs(
689688
)
690689
deploy_kwargs = _set_temp_sagemaker_session_if_not_set(kwargs=deploy_kwargs)
691690
deploy_kwargs.specs = verify_model_region_and_return_specs(
692-
model_id=model_id,
693-
version=model_version,
694-
hub_arn=hub_arn,
695-
model_type=model_type,
696-
region=region,
691+
**get_model_info_default_kwargs(
692+
deploy_kwargs, include_model_version=False, include_tolerate_flags=False
693+
),
694+
version=deploy_kwargs.model_version or "*",
697695
scope=JumpStartScriptScope.INFERENCE,
698-
sagemaker_session=deploy_kwargs.sagemaker_session,
699-
config_name=config_name,
700696
# We set these flags to True to retrieve the json specs.
701697
# Exceptions will be thrown later if these are not tolerated.
702698
tolerate_deprecated_model=True,
@@ -769,6 +765,7 @@ def get_register_kwargs(
769765
register_kwargs = JumpStartModelRegisterKwargs(
770766
model_id=model_id,
771767
model_version=model_version,
768+
config_name=config_name,
772769
hub_arn=hub_arn,
773770
model_type=model_type,
774771
region=region,
@@ -802,14 +799,11 @@ def get_register_kwargs(
802799
)
803800

804801
register_kwargs.specs = verify_model_region_and_return_specs(
805-
model_id=model_id,
806-
version=model_version,
807-
hub_arn=hub_arn,
808-
model_type=model_type,
809-
region=region,
802+
**get_model_info_default_kwargs(
803+
register_kwargs, include_model_version=False, include_tolerate_flags=False
804+
),
805+
version=register_kwargs.model_version or "*",
810806
scope=JumpStartScriptScope.INFERENCE,
811-
sagemaker_session=sagemaker_session,
812-
config_name=config_name,
813807
# We set these flags to True to retrieve the json specs.
814808
# Exceptions will be thrown later if these are not tolerated.
815809
tolerate_deprecated_model=True,
@@ -898,14 +892,11 @@ def get_init_kwargs(
898892
)
899893
model_init_kwargs = _set_temp_sagemaker_session_if_not_set(kwargs=model_init_kwargs)
900894
model_init_kwargs.specs = verify_model_region_and_return_specs(
901-
model_id=model_init_kwargs.model_id,
902-
version=model_init_kwargs.model_version,
903-
hub_arn=model_init_kwargs.hub_arn,
904-
model_type=model_init_kwargs.model_type,
905-
region=model_init_kwargs.region,
895+
**get_model_info_default_kwargs(
896+
model_init_kwargs, include_model_version=False, include_tolerate_flags=False
897+
),
898+
version=model_init_kwargs.model_version or "*",
906899
scope=JumpStartScriptScope.INFERENCE,
907-
sagemaker_session=model_init_kwargs.sagemaker_session,
908-
config_name=model_init_kwargs.config_name,
909900
# We set these flags to True to retrieve the json specs.
910901
# Exceptions will be thrown later if these are not tolerated.
911902
tolerate_deprecated_model=True,

0 commit comments

Comments
 (0)