|
23 | 23 | model_uris, |
24 | 24 | script_uris, |
25 | 25 | ) |
26 | | -from sagemaker.jumpstart.artifacts import _model_supports_incremental_training |
| 26 | +from sagemaker.jumpstart.artifacts import ( |
| 27 | + _model_supports_incremental_training, |
| 28 | + _retrieve_model_package_model_artifact_s3_uri, |
| 29 | +) |
27 | 30 | from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base |
28 | 31 | from sagemaker.session import Session |
29 | 32 | from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig |
|
37 | 40 | from sagemaker.jumpstart.artifacts import ( |
38 | 41 | _retrieve_estimator_init_kwargs, |
39 | 42 | _retrieve_estimator_fit_kwargs, |
| 43 | + _model_supports_training_model_uri, |
40 | 44 | ) |
41 | 45 | from sagemaker.jumpstart.constants import ( |
42 | 46 | JUMPSTART_DEFAULT_REGION_NAME, |
43 | 47 | JUMPSTART_LOGGER, |
44 | 48 | TRAINING_ENTRY_POINT_SCRIPT_NAME, |
| 49 | + SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, |
45 | 50 | ) |
46 | 51 | from sagemaker.jumpstart.enums import JumpStartScriptScope |
47 | 52 | from sagemaker.jumpstart.factory import model |
@@ -187,6 +192,7 @@ def get_init_kwargs( |
187 | 192 | estimator_init_kwargs = _add_metric_definitions_to_kwargs(estimator_init_kwargs) |
188 | 193 | estimator_init_kwargs = _add_estimator_extra_kwargs(estimator_init_kwargs) |
189 | 194 | estimator_init_kwargs = _add_role_to_kwargs(estimator_init_kwargs) |
| 195 | + estimator_init_kwargs = _add_env_to_kwargs(estimator_init_kwargs) |
190 | 196 |
|
191 | 197 | return estimator_init_kwargs |
192 | 198 |
|
@@ -446,32 +452,39 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE |
446 | 452 | def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs: |
447 | 453 | """Sets model uri in kwargs based on default or override, returns full kwargs.""" |
448 | 454 |
|
449 | | - default_model_uri = model_uris.retrieve( |
450 | | - model_scope=JumpStartScriptScope.TRAINING, |
| 455 | + if _model_supports_training_model_uri( |
451 | 456 | model_id=kwargs.model_id, |
452 | 457 | model_version=kwargs.model_version, |
| 458 | + region=kwargs.region, |
453 | 459 | tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
454 | 460 | tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
455 | | - ) |
456 | | - |
457 | | - if ( |
458 | | - kwargs.model_uri is not None |
459 | | - and kwargs.model_uri != default_model_uri |
460 | | - and not _model_supports_incremental_training( |
| 461 | + ): |
| 462 | + default_model_uri = model_uris.retrieve( |
| 463 | + model_scope=JumpStartScriptScope.TRAINING, |
461 | 464 | model_id=kwargs.model_id, |
462 | 465 | model_version=kwargs.model_version, |
463 | | - region=kwargs.region, |
464 | 466 | tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
465 | 467 | tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
466 | 468 | ) |
467 | | - ): |
468 | | - JUMPSTART_LOGGER.warning( |
469 | | - "'%s' does not support incremental training but is being trained with" |
470 | | - " non-default model artifact.", |
471 | | - kwargs.model_id, |
472 | | - ) |
473 | 469 |
|
474 | | - kwargs.model_uri = kwargs.model_uri or default_model_uri |
| 470 | + if ( |
| 471 | + kwargs.model_uri is not None |
| 472 | + and kwargs.model_uri != default_model_uri |
| 473 | + and not _model_supports_incremental_training( |
| 474 | + model_id=kwargs.model_id, |
| 475 | + model_version=kwargs.model_version, |
| 476 | + region=kwargs.region, |
| 477 | + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
| 478 | + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
| 479 | + ) |
| 480 | + ): |
| 481 | + JUMPSTART_LOGGER.warning( |
| 482 | + "'%s' does not support incremental training but is being trained with" |
| 483 | + " non-default model artifact.", |
| 484 | + kwargs.model_id, |
| 485 | + ) |
| 486 | + |
| 487 | + kwargs.model_uri = kwargs.model_uri or default_model_uri |
475 | 488 |
|
476 | 489 | return kwargs |
477 | 490 |
|
@@ -501,6 +514,31 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart |
501 | 514 | return kwargs |
502 | 515 |
|
503 | 516 |
|
| 517 | +def _add_env_to_kwargs( |
| 518 | + kwargs: JumpStartEstimatorInitKwargs, |
| 519 | +) -> JumpStartEstimatorInitKwargs: |
| 520 | + """Sets environment in kwargs based on default or override, returns full kwargs.""" |
| 521 | + |
| 522 | + model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri( |
| 523 | + model_id=kwargs.model_id, |
| 524 | + model_version=kwargs.model_version, |
| 525 | + region=kwargs.region, |
| 526 | + scope=JumpStartScriptScope.TRAINING, |
| 527 | + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
| 528 | + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
| 529 | + ) |
| 530 | + |
| 531 | + if model_package_artifact_uri: |
| 532 | + if kwargs.environment is None: |
| 533 | + kwargs.environment = {} |
| 534 | + kwargs.environment = { |
| 535 | + **{SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: model_package_artifact_uri}, |
| 536 | + **kwargs.environment, |
| 537 | + } |
| 538 | + |
| 539 | + return kwargs |
| 540 | + |
| 541 | + |
504 | 542 | def _add_entry_point_to_kwargs( |
505 | 543 | kwargs: JumpStartEstimatorInitKwargs, |
506 | 544 | ) -> JumpStartEstimatorInitKwargs: |
|
0 commit comments