From 2c3d606a4ade22cca1585b45e5e267e5c000b91d Mon Sep 17 00:00:00 2001 From: Kalyani Nikure <110067132+knikure@users.noreply.github.com> Date: Thu, 6 Jun 2024 15:58:12 -0700 Subject: [PATCH 01/45] feat: Benchmark feature initial commit (#1463) * Sync Master benchmark feature (#1461) * feat: support config_name in all JumpStart interfaces (#4583) (#4607) * add-config-name * address comments * updates for set config * docstyle * updates * fix * format * format * remove tests * Add ReadOnly APIs (#4606) * Add ReadOnly APIs * Resolving PR review comments * Resolve PR review comments * Refactoring * Refactoring * Add Caching * Refactore * Resolving conflicts * Add Unit Tests * Fix Unit Tests * Fix unit tests * Fix UT * Refactoring * Fix Integ tests * refactoring after Notebook testing * Fix code styles --------- Co-authored-by: Jonathan Makunga * feat: tag JumpStart resource with config names (#4608) * tag config name * format * resolving comments * format * format * update * fix * format * updates inference component config name * fix: tests * ModelBuilder: Add functionalities to get and set deployment config. (#4614) * Add funtionalities to get and set deployment config * Resolve PR comments * ModelBuilder-JS * Add Unit tests * Refactoring * Testing with Notebook * Test backward compatibility * Remove Accelerated column if all not enabled * Fix docstring * Resolved PR Review comments * Docstring * increase code coverage --------- Co-authored-by: Jonathan Makunga * Benchmark feature v2 (#4618) * Add funtionalities to get and set deployment config * Resolve PR comments * ModelBuilder-JS * Add Unit tests * Refactoring * Testing with Notebook * Test backward compatibility * Remove Accelerated column if all not enabled * Fix docstring * Resolved PR Review comments * Docstring * increase code coverage * Testing fix with Notebook * Only fetch instance rate metrics if not present * Increase code coverage --------- Co-authored-by: Jonathan Makunga * fix: populate default config name to model (#4617) * fix: populate default config name to model * update condition * fix * format * flake8 * fix tests * fix coverage * temporarily skip integ test vulnerbility * fix tolerate attach method * format * fix predictor * format * Fix fetch instance rate bug (#4624) Co-authored-by: Jonathan Makunga * chore: require config name and instance type in set_deployment_config (#4625) * require config_name and instance_type in set config * docstring * add supported instance types check * add more tests * format * fix tests * Deployment Configs - Follow-ups (#4626) * Init Deployment configs outside Model init. * Testing with NB * Testing with NB-V2 * Refactoring, NB testing * NB Testing and Refactoring * Testing * Refactoring * Testing with NB * Debug * Debug display API * Debug with NB * Testing with NB * Refactoring * Refactoring * Refactoring and NB testing * Testing with NB * Refactoring * Prefix instance type with ml * Fix unit tests --------- Co-authored-by: Jonathan Makunga * fix: use different separator to flatten dict (#4629) * Use separate tags for inference and training configs (#4635) * Use separate tags for inference and training * format * format * format * format * Add supported inference and incremental training configs (#4637) * supported inference configs * add tests * format * tests * tests * address comments * format and address comments * updates * formt * format * Benchmark feature fixes (#4632) * Filter down Benchmark Metrics * Filter down Benchmark Metrics * Testing NB * Testing MB * Testing * Refactoring * Unit tests * Display instance type first, and instance rate last * Display unbalanced metrics * Testing with NB * Testing with NB * Debug * Debug * Testing with NB * Testing with NB * Testing with NB * Refactoring * Refactoring * Refactoring * Unit tests * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Refactoring * Debug * Config ranking * Debug * Debug * Debug * Debug * Debug * Ranking * Ranking-Debug * Ranking-Debug * Ranking-Debug * Ranking-Debug * Ranking-Debug * Ranking-Debug * Debug * Debug * Debug * Debug * Refactoring * Contact JumpStart team to fix flaky test. test_list_jumpstart_models_script_filter --------- Co-authored-by: Jonathan Makunga * fix: typo and merge with master branch (#4649) * Merge master into benchmark feature (#4652) * Merge master into master-benchmark-feature (#4656) * Master benchmark feature (#4658) * Master benchmark feature merge master (#4661) * Master benchmark feature (#4672) * fix: mainline alt config parsing (#4602) * fix: parsing * fix: commit tests * fix: types * updated * fix * Add Triton v24.03 URI (#4605) Co-authored-by: Nikhil Kulkarni * feature: support session tag chaining for training job (#4596) * feature: support session tag chaining for training job * fix: resolve typo * fix: resolve typo and build failure * fix: resolve typo and unit test failure --------- Co-authored-by: Jessica Zhu * prepare release v2.217.0 * update development version to v2.217.1.dev0 * fix: properly close files in lineage queries and tests (#4587) Closes #4458 * feature: set default allow_pickle param to False (#4557) * breaking: set default allow_pickle param to False * breaking: fix unit tests and linting NumpyDeserializer will not allow deserialization unless allow_pickle flag is set to True explicitly * fix: black-check --------- Co-authored-by: Ashwin Krishna * Fix:invalid component error with new metadata (#4634) * fix: invalid component name * tests * format * fix vulnerable model integ tests llama 2 * updated * fix: training dataset location * prepare release v2.218.0 * update development version to v2.218.1.dev0 * chore: update skipped flaky tests (#4644) * Update skipped flaky tests * flake8 * format * format * chore: release tgi 2.0.1 (#4642) * chore: release tgi 2.0.1 * minor fix --------- Co-authored-by: Zhaoqi <52220743+zhaoqizqwang@users.noreply.github.com> * fix: Fix UserAgent logging in Python SDK (#4647) * prepare release v2.218.1 * update development version to v2.218.2.dev0 * feature: allow choosing js payload by alias in private method * Updates for SMP v2.3.1 (#4660) Co-authored-by: Suhit Kodgule * chore(deps): bump jinja2 from 3.1.3 to 3.1.4 in /doc (#4655) Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.3 to 3.1.4. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) - [Commits](https://github.com/pallets/jinja/compare/3.1.3...3.1.4) --- updated-dependencies: - dependency-name: jinja2 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * chore(deps): bump tqdm from 4.66.2 to 4.66.3 in /tests/data/serve_resources/mlflow/pytorch (#4650) Bumps [tqdm](https://github.com/tqdm/tqdm) from 4.66.2 to 4.66.3. - [Release notes](https://github.com/tqdm/tqdm/releases) - [Commits](https://github.com/tqdm/tqdm/compare/v4.66.2...v4.66.3) --- updated-dependencies: - dependency-name: tqdm dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * chore(deps): bump jinja2 from 3.1.3 to 3.1.4 in /requirements/extras (#4654) Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.3 to 3.1.4. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) - [Commits](https://github.com/pallets/jinja/compare/3.1.3...3.1.4) --- updated-dependencies: - dependency-name: jinja2 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * prepare release v2.219.0 * update development version to v2.219.1.dev0 * fix: skip flakey tests pending investigation (#4667) * change: update image_uri_configs 05-09-2024 07:17:41 PST * Add tensorflow_serving support for mlflow models and enable lineage tracking for mlflow models (#4662) * Initial commit for tensorflow_serving support of MLflow * Add integ tests for mlflow tf_serving * fix style issues * remove unused attributes from tf builder * Add deep ping for tf_serving local mode * Initial commit for lineage impl * Initial commit for tensorflow_serving support of MLflow * Add integ tests for mlflow tf_serving * fix style issues * remove unused attributes from tf builder * Add deep ping for tf_serving local mode * Add integ tests and uts * fix local mode for tf_serving * Allow lineage tracking only in sagemaker endpoint mode * fix regex pattern * fix style issues * fix regex pattern and hard coded py version in ut * fix missing session * Resolve pr comments and fix regex for mlflow registry and ids * fix: model builder race condition on sagemaker session (#4673) Co-authored-by: Jonathan Makunga * feat: Add telemetry support for mlflow models (#4674) * Initial commit for telemetry support * Fix style issues and add more logger messages * fix value error messages in ut * feat: add new images for HF TGI release (#4677) * chore: add new images for HF TGI release * test * feature: AutoGluon 1.1.0 image_uris update (#4679) Co-authored-by: Ubuntu * change: add debug logs to workflow container dist creation (#4682) * prepare release v2.220.0 * update development version to v2.220.1.dev0 * fix: Image URI should take precedence for HF models (#4684) * Fix: Image URI should take precedence for HF models * Fix formatting * Fix formatting * Fix formatting * Increase coverage - UT pass * feat: support config_name in all JumpStart interfaces (#4583) (#4607) * add-config-name * address comments * updates for set config * docstyle * updates * fix * format * format * remove tests * Add ReadOnly APIs (#4606) * Add ReadOnly APIs * Resolving PR review comments * Resolve PR review comments * Refactoring * Refactoring * Add Caching * Refactore * Resolving conflicts * Add Unit Tests * Fix Unit Tests * Fix unit tests * Fix UT * Refactoring * Fix Integ tests * refactoring after Notebook testing * Fix code styles --------- Co-authored-by: Jonathan Makunga * feat: tag JumpStart resource with config names (#4608) * tag config name * format * resolving comments * format * format * update * fix * format * updates inference component config name * fix: tests * ModelBuilder: Add functionalities to get and set deployment config. (#4614) * Add funtionalities to get and set deployment config * Resolve PR comments * ModelBuilder-JS * Add Unit tests * Refactoring * Testing with Notebook * Test backward compatibility * Remove Accelerated column if all not enabled * Fix docstring * Resolved PR Review comments * Docstring * increase code coverage --------- Co-authored-by: Jonathan Makunga * Benchmark feature v2 (#4618) * Add funtionalities to get and set deployment config * Resolve PR comments * ModelBuilder-JS * Add Unit tests * Refactoring * Testing with Notebook * Test backward compatibility * Remove Accelerated column if all not enabled * Fix docstring * Resolved PR Review comments * Docstring * increase code coverage * Testing fix with Notebook * Only fetch instance rate metrics if not present * Increase code coverage --------- Co-authored-by: Jonathan Makunga * fix: populate default config name to model (#4617) * fix: populate default config name to model * update condition * fix * format * flake8 * fix tests * fix coverage * temporarily skip integ test vulnerbility * fix tolerate attach method * format * fix predictor * format * Fix fetch instance rate bug (#4624) Co-authored-by: Jonathan Makunga * chore: require config name and instance type in set_deployment_config (#4625) * require config_name and instance_type in set config * docstring * add supported instance types check * add more tests * format * fix tests * Deployment Configs - Follow-ups (#4626) * Init Deployment configs outside Model init. * Testing with NB * Testing with NB-V2 * Refactoring, NB testing * NB Testing and Refactoring * Testing * Refactoring * Testing with NB * Debug * Debug display API * Debug with NB * Testing with NB * Refactoring * Refactoring * Refactoring and NB testing * Testing with NB * Refactoring * Prefix instance type with ml * Fix unit tests --------- Co-authored-by: Jonathan Makunga * fix: use different separator to flatten dict (#4629) * Use separate tags for inference and training configs (#4635) * Use separate tags for inference and training * format * format * format * format * Add supported inference and incremental training configs (#4637) * supported inference configs * add tests * format * tests * tests * address comments * format and address comments * updates * formt * format * Benchmark feature fixes (#4632) * Filter down Benchmark Metrics * Filter down Benchmark Metrics * Testing NB * Testing MB * Testing * Refactoring * Unit tests * Display instance type first, and instance rate last * Display unbalanced metrics * Testing with NB * Testing with NB * Debug * Debug * Testing with NB * Testing with NB * Testing with NB * Refactoring * Refactoring * Refactoring * Unit tests * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Refactoring * Debug * Config ranking * Debug * Debug * Debug * Debug * Debug * Ranking * Ranking-Debug * Ranking-Debug * Ranking-Debug * Ranking-Debug * Ranking-Debug * Ranking-Debug * Debug * Debug * Debug * Debug * Refactoring * Contact JumpStart team to fix flaky test. test_list_jumpstart_models_script_filter --------- Co-authored-by: Jonathan Makunga * fix: typo and merge with master branch (#4649) * Merge master into benchmark feature (#4652) * Merge master into master-benchmark-feature (#4656) * Master benchmark feature (#4658) * Remove duplicate line in types.py * Remove duplicate lines * Remove duplicate lines * Remove duplicate lines * Remove duplicate lines * fix unit test --------- Signed-off-by: dependabot[bot] Co-authored-by: Haotian An <33510317+Captainia@users.noreply.github.com> Co-authored-by: Nikhil Kulkarni Co-authored-by: Nikhil Kulkarni Co-authored-by: jessicazhu3 <106775307+jessicazhu3@users.noreply.github.com> Co-authored-by: Jessica Zhu Co-authored-by: ci Co-authored-by: Justin Co-authored-by: ASHWIN KRISHNA <38850354+akrishna1995@users.noreply.github.com> Co-authored-by: Ashwin Krishna Co-authored-by: Haixin Wang <98612668+haixiw@users.noreply.github.com> Co-authored-by: Zhaoqi <52220743+zhaoqizqwang@users.noreply.github.com> Co-authored-by: Kalyani Nikure <110067132+knikure@users.noreply.github.com> Co-authored-by: Keerthan Vasist Co-authored-by: SuhitK Co-authored-by: Suhit Kodgule Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: sagemaker-bot Co-authored-by: jiapinw <95885824+jiapinw@users.noreply.github.com> Co-authored-by: Jonathan Makunga <54963715+makungaj1@users.noreply.github.com> Co-authored-by: Jonathan Makunga Co-authored-by: Prateek M Desai Co-authored-by: Ubuntu Co-authored-by: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> Co-authored-by: Samrudhi Sharma <154457034+samruds@users.noreply.github.com> Co-authored-by: evakravi <69981223+evakravi@users.noreply.github.com> * fix benchmark feature read-only apis (#4675) * Rearrange benchmark metric table * Refactoring * Refactoring * Refactoring * Refactoring * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Refactoring * Refactoring * Refactoring * Refactoring * Refactoring * Add Unit tests * Refactoring * Refactoring * hide index from DataFrame --------- Co-authored-by: Jonathan Makunga * feat: update alt config to work with model packages (#4706) * feat: update alt config to work with model packages * format * remove env vars for model package * fix tests * Update: ReadOnly APIs (#4707) * Model data arn * Refactoring * Refactoring * acceleration_configs * Refactoring * UT * Add Filter * UT * Revert "UT" * UT * UT --------- Co-authored-by: Jonathan Makunga * ModelBuilder to support display with filter. (#4712) Co-authored-by: Jonathan Makunga * Sync branch (#4718) * fix: mainline alt config parsing (#4602) * fix: parsing * fix: commit tests * fix: types * updated * fix * Add Triton v24.03 URI (#4605) Co-authored-by: Nikhil Kulkarni * feature: support session tag chaining for training job (#4596) * feature: support session tag chaining for training job * fix: resolve typo * fix: resolve typo and build failure * fix: resolve typo and unit test failure --------- Co-authored-by: Jessica Zhu * prepare release v2.217.0 * update development version to v2.217.1.dev0 * fix: properly close files in lineage queries and tests (#4587) Closes #4458 * feature: set default allow_pickle param to False (#4557) * breaking: set default allow_pickle param to False * breaking: fix unit tests and linting NumpyDeserializer will not allow deserialization unless allow_pickle flag is set to True explicitly * fix: black-check --------- Co-authored-by: Ashwin Krishna * Fix:invalid component error with new metadata (#4634) * fix: invalid component name * tests * format * fix vulnerable model integ tests llama 2 * updated * fix: training dataset location * prepare release v2.218.0 * update development version to v2.218.1.dev0 * chore: update skipped flaky tests (#4644) * Update skipped flaky tests * flake8 * format * format * chore: release tgi 2.0.1 (#4642) * chore: release tgi 2.0.1 * minor fix --------- Co-authored-by: Zhaoqi <52220743+zhaoqizqwang@users.noreply.github.com> * fix: Fix UserAgent logging in Python SDK (#4647) * prepare release v2.218.1 * update development version to v2.218.2.dev0 * feature: allow choosing js payload by alias in private method * Updates for SMP v2.3.1 (#4660) Co-authored-by: Suhit Kodgule * chore(deps): bump jinja2 from 3.1.3 to 3.1.4 in /doc (#4655) Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.3 to 3.1.4. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) - [Commits](https://github.com/pallets/jinja/compare/3.1.3...3.1.4) --- updated-dependencies: - dependency-name: jinja2 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * chore(deps): bump tqdm from 4.66.2 to 4.66.3 in /tests/data/serve_resources/mlflow/pytorch (#4650) Bumps [tqdm](https://github.com/tqdm/tqdm) from 4.66.2 to 4.66.3. - [Release notes](https://github.com/tqdm/tqdm/releases) - [Commits](https://github.com/tqdm/tqdm/compare/v4.66.2...v4.66.3) --- updated-dependencies: - dependency-name: tqdm dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * chore(deps): bump jinja2 from 3.1.3 to 3.1.4 in /requirements/extras (#4654) Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.3 to 3.1.4. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) - [Commits](https://github.com/pallets/jinja/compare/3.1.3...3.1.4) --- updated-dependencies: - dependency-name: jinja2 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * prepare release v2.219.0 * update development version to v2.219.1.dev0 * fix: skip flakey tests pending investigation (#4667) * change: update image_uri_configs 05-09-2024 07:17:41 PST * Add tensorflow_serving support for mlflow models and enable lineage tracking for mlflow models (#4662) * Initial commit for tensorflow_serving support of MLflow * Add integ tests for mlflow tf_serving * fix style issues * remove unused attributes from tf builder * Add deep ping for tf_serving local mode * Initial commit for lineage impl * Initial commit for tensorflow_serving support of MLflow * Add integ tests for mlflow tf_serving * fix style issues * remove unused attributes from tf builder * Add deep ping for tf_serving local mode * Add integ tests and uts * fix local mode for tf_serving * Allow lineage tracking only in sagemaker endpoint mode * fix regex pattern * fix style issues * fix regex pattern and hard coded py version in ut * fix missing session * Resolve pr comments and fix regex for mlflow registry and ids * fix: model builder race condition on sagemaker session (#4673) Co-authored-by: Jonathan Makunga * feat: Add telemetry support for mlflow models (#4674) * Initial commit for telemetry support * Fix style issues and add more logger messages * fix value error messages in ut * feat: add new images for HF TGI release (#4677) * chore: add new images for HF TGI release * test * feature: AutoGluon 1.1.0 image_uris update (#4679) Co-authored-by: Ubuntu * change: add debug logs to workflow container dist creation (#4682) * prepare release v2.220.0 * update development version to v2.220.1.dev0 * fix: Image URI should take precedence for HF models (#4684) * Fix: Image URI should take precedence for HF models * Fix formatting * Fix formatting * Fix formatting * Increase coverage - UT pass * feat: onboard tei image config to pysdk (#4681) * feat: onboard tei image config to pysdk * fix formatting issue * minor fix func name * fix unit tests --------- Co-authored-by: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> * fix: model builder limited container support for endpoint mode. (#4683) * Allow ModelBuilder's endpoint mode for Jumpstart models packaged with containers other than TGI and DJL * increase coverage * Add JS Support for MMS Serving * Add JS Support for MMS Serving * Unit tests * Refactoring * Refactoring * Refactoring --------- Co-authored-by: Jonathan Makunga * change: Add more debuging (#4687) * change: cover tei with image_uris.retrieve API (#4689) * fix: JS Model with non-TGI/non-DJL deployment failure (#4688) * Debug * Debug * Debug * Debug * Debug * Debug * fix docstyle * Refactoring * Add Integ tests --------- Co-authored-by: Jonathan Makunga * Feat: Pull latest tei container for sentence similiarity models on HuggingFace hub (#4686) * Update: Pull latest tei container for sentence similiarity models * Fix formatting * Address PR comments * Fix formatting * Fix check * Switch sentence similarity to be deployed on tgi * Fix formatting * Fix formatting * Fix formatting * Fix formatting * Introduce TEI builder with TGI server * Fix formmatting * Add integ test * Fix formatting * Add integ test * Add integ test * Add integ test * Add integ test * Add integ test * Fix formatting * Move to G5 for integ test * Fix formatting * Integ test updates * Integ test updates * Integ test updates * Fix formatting * Integ test updates * Move back to generate for ping * Integ test updates * Integ test updates * Fix: Add Image URI overrides for transformers models (#4693) * Fix: Add Image URI overrides for transformers models * Increase coverage * Fix formatting * prepare release v2.221.0 * update development version to v2.221.1.dev0 * Add tei cpu image (#4695) * Add tei cpu image * fix format issue * fix unit tests * fix typo * fix typo * Feat: Add TEI support for ModelBuilder (#4694) * Add TEI Serving * Add TEI Serving * Add TEI Serving * Add TEI Serving * Add TEI Serving * Add TEI Serving * Notebook testing * Notebook testing * Notebook testing * Refactoring * Refactoring * UT * UT * Refactoring * Test coverage * Refactoring * Refactoring --------- Co-authored-by: Jonathan Makunga * Convert pytorchddp distribution to smdistributed distribution (#4698) * rewrite pytorchddp to smdistributed * remove instance type check * Update estimator.py * remove validate_pytorch_distribution * fix * fix unit tests * fix formatting * check instance type not None * prepare release v2.221.1 * update development version to v2.221.2.dev0 * Update: SM Endpoint Routing Strategy Support. (#4702) * RoutingConfig * Refactoring * Docstring * UT * Refactoring * Refactoring --------- Co-authored-by: Jonathan Makunga * change: update image_uri_configs 05-29-2024 07:17:35 PST * Making project name in workflow files dynamic (#4708) * fix: Fix ci unit-tests (#4713) * chore(deps): bump requests from 2.31.0 to 2.32.2 in /tests/data/serve_resources/mlflow/pytorch (#4709) Bumps [requests](https://github.com/psf/requests) from 2.31.0 to 2.32.2. - [Release notes](https://github.com/psf/requests/releases) - [Changelog](https://github.com/psf/requests/blob/main/HISTORY.md) - [Commits](https://github.com/psf/requests/compare/v2.31.0...v2.32.2) --- updated-dependencies: - dependency-name: requests dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * chore(deps): bump apache-airflow from 2.9.0 to 2.9.1 in /requirements/extras (#4703) * chore(deps): bump apache-airflow in /requirements/extras Bumps [apache-airflow](https://github.com/apache/airflow) from 2.9.0 to 2.9.1. - [Release notes](https://github.com/apache/airflow/releases) - [Changelog](https://github.com/apache/airflow/blob/main/RELEASE_NOTES.rst) - [Commits](https://github.com/apache/airflow/compare/2.9.0...2.9.1) --- updated-dependencies: - dependency-name: apache-airflow dependency-type: direct:production ... Signed-off-by: dependabot[bot] * Update tox.ini to bump apache-airflow --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Kalyani Nikure <110067132+knikure@users.noreply.github.com> * chore(deps): bump mlflow from 2.10.2 to 2.12.1 in /tests/data/serve_resources/mlflow/pytorch (#4690) Bumps [mlflow](https://github.com/mlflow/mlflow) from 2.10.2 to 2.12.1. - [Release notes](https://github.com/mlflow/mlflow/releases) - [Changelog](https://github.com/mlflow/mlflow/blob/master/CHANGELOG.md) - [Commits](https://github.com/mlflow/mlflow/compare/v2.10.2...v2.12.1) --- updated-dependencies: - dependency-name: mlflow dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * chore(deps): bump mlflow from 2.11.1 to 2.12.1 in /tests/data/serve_resources/mlflow/xgboost (#4692) Bumps [mlflow](https://github.com/mlflow/mlflow) from 2.11.1 to 2.12.1. - [Release notes](https://github.com/mlflow/mlflow/releases) - [Changelog](https://github.com/mlflow/mlflow/blob/master/CHANGELOG.md) - [Commits](https://github.com/mlflow/mlflow/compare/v2.11.1...v2.12.1) --- updated-dependencies: - dependency-name: mlflow dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * chore(deps): bump mlflow from 2.11.1 to 2.12.1 in /tests/data/serve_resources/mlflow/tensorflow (#4691) Bumps [mlflow](https://github.com/mlflow/mlflow) from 2.11.1 to 2.12.1. - [Release notes](https://github.com/mlflow/mlflow/releases) - [Changelog](https://github.com/mlflow/mlflow/blob/master/CHANGELOG.md) - [Commits](https://github.com/mlflow/mlflow/compare/v2.11.1...v2.12.1) --- updated-dependencies: - dependency-name: mlflow dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * change: Updates for DJL 0.28.0 release (#4701) * Sync Branch --------- Signed-off-by: dependabot[bot] Co-authored-by: Haotian An <33510317+Captainia@users.noreply.github.com> Co-authored-by: Nikhil Kulkarni Co-authored-by: Nikhil Kulkarni Co-authored-by: jessicazhu3 <106775307+jessicazhu3@users.noreply.github.com> Co-authored-by: Jessica Zhu Co-authored-by: ci Co-authored-by: Justin Co-authored-by: ASHWIN KRISHNA <38850354+akrishna1995@users.noreply.github.com> Co-authored-by: Ashwin Krishna Co-authored-by: Haixin Wang <98612668+haixiw@users.noreply.github.com> Co-authored-by: Zhaoqi <52220743+zhaoqizqwang@users.noreply.github.com> Co-authored-by: Kalyani Nikure <110067132+knikure@users.noreply.github.com> Co-authored-by: Keerthan Vasist Co-authored-by: SuhitK Co-authored-by: Suhit Kodgule Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: sagemaker-bot Co-authored-by: jiapinw <95885824+jiapinw@users.noreply.github.com> Co-authored-by: Jonathan Makunga Co-authored-by: Prateek M Desai Co-authored-by: Ubuntu Co-authored-by: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> Co-authored-by: Samrudhi Sharma <154457034+samruds@users.noreply.github.com> Co-authored-by: Tom Bousso Co-authored-by: Zhaoqi Co-authored-by: Tyler Osterberg * Merge --------- Signed-off-by: dependabot[bot] Co-authored-by: Haotian An <33510317+Captainia@users.noreply.github.com> Co-authored-by: Jonathan Makunga Co-authored-by: evakravi <69981223+evakravi@users.noreply.github.com> Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> Co-authored-by: Nikhil Kulkarni Co-authored-by: Nikhil Kulkarni Co-authored-by: jessicazhu3 <106775307+jessicazhu3@users.noreply.github.com> Co-authored-by: Jessica Zhu Co-authored-by: Justin Co-authored-by: ASHWIN KRISHNA <38850354+akrishna1995@users.noreply.github.com> Co-authored-by: Ashwin Krishna Co-authored-by: Haixin Wang <98612668+haixiw@users.noreply.github.com> Co-authored-by: Zhaoqi <52220743+zhaoqizqwang@users.noreply.github.com> Co-authored-by: Kalyani Nikure <110067132+knikure@users.noreply.github.com> Co-authored-by: Keerthan Vasist Co-authored-by: SuhitK Co-authored-by: Suhit Kodgule Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: sagemaker-bot Co-authored-by: jiapinw <95885824+jiapinw@users.noreply.github.com> Co-authored-by: Prateek M Desai Co-authored-by: Ubuntu Co-authored-by: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> Co-authored-by: Samrudhi Sharma <154457034+samruds@users.noreply.github.com> Co-authored-by: Tom Bousso Co-authored-by: Zhaoqi Co-authored-by: Tyler Osterberg * Fix UT (#1465) Co-authored-by: Jonathan Makunga --------- Signed-off-by: dependabot[bot] Co-authored-by: Jonathan Makunga <54963715+makungaj1@users.noreply.github.com> Co-authored-by: Haotian An <33510317+Captainia@users.noreply.github.com> Co-authored-by: Jonathan Makunga Co-authored-by: evakravi <69981223+evakravi@users.noreply.github.com> Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> Co-authored-by: Nikhil Kulkarni Co-authored-by: Nikhil Kulkarni Co-authored-by: jessicazhu3 <106775307+jessicazhu3@users.noreply.github.com> Co-authored-by: Jessica Zhu Co-authored-by: Justin Co-authored-by: ASHWIN KRISHNA <38850354+akrishna1995@users.noreply.github.com> Co-authored-by: Ashwin Krishna Co-authored-by: Haixin Wang <98612668+haixiw@users.noreply.github.com> Co-authored-by: Zhaoqi <52220743+zhaoqizqwang@users.noreply.github.com> Co-authored-by: Keerthan Vasist Co-authored-by: SuhitK Co-authored-by: Suhit Kodgule Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: sagemaker-bot Co-authored-by: jiapinw <95885824+jiapinw@users.noreply.github.com> Co-authored-by: Prateek M Desai Co-authored-by: Ubuntu Co-authored-by: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> Co-authored-by: Samrudhi Sharma <154457034+samruds@users.noreply.github.com> Co-authored-by: Tom Bousso Co-authored-by: Zhaoqi Co-authored-by: Tyler Osterberg --- src/sagemaker/accept_types.py | 3 + src/sagemaker/content_types.py | 3 + src/sagemaker/deserializers.py | 3 + src/sagemaker/environment_variables.py | 3 + src/sagemaker/hyperparameters.py | 3 + src/sagemaker/image_uris.py | 3 + src/sagemaker/instance_types.py | 3 + .../artifacts/environment_variables.py | 7 + .../jumpstart/artifacts/hyperparameters.py | 3 + .../jumpstart/artifacts/image_uris.py | 4 + .../artifacts/incremental_training.py | 3 + .../jumpstart/artifacts/instance_types.py | 6 + src/sagemaker/jumpstart/artifacts/kwargs.py | 12 + .../jumpstart/artifacts/metric_definitions.py | 3 + .../jumpstart/artifacts/model_packages.py | 11 +- .../jumpstart/artifacts/model_uris.py | 7 + src/sagemaker/jumpstart/artifacts/payloads.py | 3 + .../jumpstart/artifacts/predictors.py | 24 + .../jumpstart/artifacts/resource_names.py | 6 +- .../artifacts/resource_requirements.py | 3 + .../jumpstart/artifacts/script_uris.py | 5 + src/sagemaker/jumpstart/enums.py | 3 + src/sagemaker/jumpstart/estimator.py | 60 +- src/sagemaker/jumpstart/factory/estimator.py | 60 +- src/sagemaker/jumpstart/factory/model.py | 142 ++++- src/sagemaker/jumpstart/model.py | 178 +++++- src/sagemaker/jumpstart/notebook_utils.py | 2 + src/sagemaker/jumpstart/session_utils.py | 86 ++- src/sagemaker/jumpstart/types.py | 236 +++++++- src/sagemaker/jumpstart/utils.py | 405 +++++++++++-- src/sagemaker/jumpstart/validators.py | 3 + src/sagemaker/metric_definitions.py | 3 + src/sagemaker/model_uris.py | 4 + src/sagemaker/predictor.py | 13 +- src/sagemaker/resource_requirements.py | 3 + src/sagemaker/script_uris.py | 3 + src/sagemaker/serializers.py | 6 + .../serve/builder/jumpstart_builder.py | 50 +- .../serve/builder/transformers_builder.py | 4 +- src/sagemaker/utils.py | 173 +++++- tests/unit/sagemaker/jumpstart/constants.py | 251 +++++++- .../jumpstart/estimator/test_estimator.py | 293 ++++++++- .../sagemaker/jumpstart/model/test_model.py | 561 +++++++++++++++++- .../jumpstart/model/test_sagemaker_config.py | 32 + .../jumpstart/test_notebook_utils.py | 13 +- .../sagemaker/jumpstart/test_predictor.py | 27 +- .../sagemaker/jumpstart/test_session_utils.py | 255 +++++--- tests/unit/sagemaker/jumpstart/test_types.py | 150 ++++- tests/unit/sagemaker/jumpstart/test_utils.py | 504 ++++++++++++---- tests/unit/sagemaker/jumpstart/utils.py | 144 ++++- .../serve/builder/test_js_builder.py | 237 ++++++++ tests/unit/sagemaker/serve/constants.py | 150 +++++ tests/unit/test_utils.py | 154 ++++- 53 files changed, 3900 insertions(+), 423 deletions(-) diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index 78aa655e04..7541425868 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -77,6 +77,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default accept type for the model matching the given arguments. @@ -98,6 +99,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default accept type to use for the model. @@ -117,4 +119,5 @@ def retrieve_default( tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index 46d0361f67..627feca0d6 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -77,6 +77,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default content type for the model matching the given arguments. @@ -98,6 +99,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default content type to use for the model. @@ -117,6 +119,7 @@ def retrieve_default( tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 1a4be43897..02e61149ec 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -97,6 +97,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> BaseDeserializer: """Retrieves the default deserializer for the model matching the given arguments. @@ -118,6 +119,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: BaseDeserializer: The default deserializer to use for the model. @@ -138,4 +140,5 @@ def retrieve_default( tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index b67066fcde..8fa52c3ec8 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -36,6 +36,7 @@ def retrieve_default( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE, + config_name: Optional[str] = None, ) -> Dict[str, str]: """Retrieves the default container environment variables for the model matching the arguments. @@ -65,6 +66,7 @@ def retrieve_default( variables specific for the instance type. script (JumpStartScriptScope): The JumpStart script for which to retrieve environment variables. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: The variables to use for the model. @@ -87,4 +89,5 @@ def retrieve_default( sagemaker_session=sagemaker_session, instance_type=instance_type, script=script, + config_name=config_name, ) diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 5873e37b9f..5c22409c50 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -36,6 +36,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> Dict[str, str]: """Retrieves the default training hyperparameters for the model matching the given arguments. @@ -66,6 +67,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: The hyperparameters to use for the model. @@ -86,6 +88,7 @@ def retrieve_default( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index be5167dcc7..743f6b1f99 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -70,6 +70,7 @@ def retrieve( inference_tool=None, serverless_inference_config=None, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name=None, ) -> str: """Retrieves the ECR URI for the Docker image matching the given arguments. @@ -123,6 +124,7 @@ def retrieve( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The ECR URI for the corresponding SageMaker Docker image. @@ -162,6 +164,7 @@ def retrieve( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]): diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index 48aaab0ac8..c4af4b2036 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -36,6 +36,7 @@ def retrieve_default( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default instance type for the model matching the given arguments. @@ -64,6 +65,7 @@ def retrieve_default( Optionally supply this to get a inference instance type conditioned on the training instance, to ensure compatability of training artifact to inference instance. (Default: None). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default instance type to use for the model. @@ -88,6 +90,7 @@ def retrieve_default( sagemaker_session=sagemaker_session, training_instance_type=training_instance_type, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index c28c27ed4e..fcb3ce3bf2 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -39,6 +39,7 @@ def _retrieve_default_environment_variables( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE, + config_name: Optional[str] = None, ) -> Dict[str, str]: """Retrieves the inference environment variables for the model matching the given arguments. @@ -68,6 +69,7 @@ def _retrieve_default_environment_variables( environment variables specific for the instance type. script (JumpStartScriptScope): The JumpStart script for which to retrieve environment variables. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the inference environment variables to use for the model. """ @@ -84,6 +86,7 @@ def _retrieve_default_environment_variables( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) default_environment_variables: Dict[str, str] = {} @@ -121,6 +124,7 @@ def _retrieve_default_environment_variables( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, instance_type=instance_type, + config_name=config_name, ) ) @@ -167,6 +171,7 @@ def _retrieve_gated_model_uri_env_var_value( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> Optional[str]: """Retrieves the gated model env var URI matching the given arguments. @@ -190,6 +195,7 @@ def _retrieve_gated_model_uri_env_var_value( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get environment variables specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: Optional[str]: the s3 URI to use for the environment variable, or None if the model does not @@ -211,6 +217,7 @@ def _retrieve_gated_model_uri_env_var_value( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) s3_key: Optional[str] = ( diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index d19530ecfb..67db7d260f 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -36,6 +36,7 @@ def _retrieve_default_hyperparameters( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ): """Retrieves the training hyperparameters for the model matching the given arguments. @@ -66,6 +67,7 @@ def _retrieve_default_hyperparameters( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get hyperparameters specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the hyperparameters to use for the model. """ @@ -82,6 +84,7 @@ def _retrieve_default_hyperparameters( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) default_hyperparameters: Dict[str, str] = {} diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 9d19d5e069..72633320f5 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -46,6 +46,7 @@ def _retrieve_image_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ): """Retrieves the container image URI for JumpStart models. @@ -95,6 +96,7 @@ def _retrieve_image_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the ECR URI for the corresponding SageMaker Docker image. @@ -116,6 +118,7 @@ def _retrieve_image_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if image_scope == JumpStartScriptScope.INFERENCE: @@ -200,4 +203,5 @@ def _retrieve_image_uri( distribution=distribution, base_framework_version=base_framework_version_override or base_framework_version, training_compiler_config=training_compiler_config, + config_name=config_name, ) diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index 1b3c6f4b29..8bbe089354 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -33,6 +33,7 @@ def _model_supports_incremental_training( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> bool: """Returns True if the model supports incremental training. @@ -54,6 +55,7 @@ def _model_supports_incremental_training( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: bool: the support status for incremental training. """ @@ -70,6 +72,7 @@ def _model_supports_incremental_training( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.supports_incremental_training() diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index e7c9c5911d..f4bf212c1c 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -40,6 +40,7 @@ def _retrieve_default_instance_type( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default instance type for the model. @@ -68,6 +69,7 @@ def _retrieve_default_instance_type( Optionally supply this to get a inference instance type conditioned on the training instance, to ensure compatability of training artifact to inference instance. (Default: None). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the default instance type to use for the model or None. @@ -89,6 +91,7 @@ def _retrieve_default_instance_type( tolerate_deprecated_model=tolerate_deprecated_model, model_type=model_type, sagemaker_session=sagemaker_session, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: @@ -128,6 +131,7 @@ def _retrieve_instance_types( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> List[str]: """Retrieves the supported instance types for the model. @@ -156,6 +160,7 @@ def _retrieve_instance_types( Optionally supply this to get a inference instance type conditioned on the training instance, to ensure compatability of training artifact to inference instance. (Default: None). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: the supported instance types to use for the model or None. @@ -176,6 +181,7 @@ def _retrieve_instance_types( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 9cd152b0bb..ceb88d9b26 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -37,6 +37,7 @@ def _retrieve_model_init_kwargs( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> dict: """Retrieves kwargs for `Model`. @@ -58,6 +59,7 @@ def _retrieve_model_init_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the kwargs to use for the use case. """ @@ -75,6 +77,7 @@ def _retrieve_model_init_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) kwargs = deepcopy(model_specs.model_kwargs) @@ -94,6 +97,7 @@ def _retrieve_model_deploy_kwargs( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> dict: """Retrieves kwargs for `Model.deploy`. @@ -117,6 +121,7 @@ def _retrieve_model_deploy_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the kwargs to use for the use case. @@ -135,6 +140,7 @@ def _retrieve_model_deploy_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) if volume_size_supported(instance_type) and model_specs.inference_volume_size is not None: @@ -151,6 +157,7 @@ def _retrieve_estimator_init_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> dict: """Retrieves kwargs for `Estimator`. @@ -174,6 +181,7 @@ def _retrieve_estimator_init_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the kwargs to use for the use case. """ @@ -190,6 +198,7 @@ def _retrieve_estimator_init_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) kwargs = deepcopy(model_specs.estimator_kwargs) @@ -210,6 +219,7 @@ def _retrieve_estimator_fit_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> dict: """Retrieves kwargs for `Estimator.fit`. @@ -231,6 +241,7 @@ def _retrieve_estimator_fit_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the kwargs to use for the use case. @@ -248,6 +259,7 @@ def _retrieve_estimator_fit_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.fit_kwargs diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index 57f66155c7..f23b66aed4 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -35,6 +35,7 @@ def _retrieve_default_training_metric_definitions( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> Optional[List[Dict[str, str]]]: """Retrieves the default training metric definitions for the model. @@ -58,6 +59,7 @@ def _retrieve_default_training_metric_definitions( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get metric definitions specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: the default training metric definitions to use for the model or None. """ @@ -74,6 +76,7 @@ def _retrieve_default_training_metric_definitions( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) default_metric_definitions = ( diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index aa22351771..67459519f3 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -37,6 +37,7 @@ def _retrieve_model_package_arn( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> Optional[str]: """Retrieves associated model pacakge arn for the model. @@ -60,6 +61,7 @@ def _retrieve_model_package_arn( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the model package arn to use for the model or None. @@ -78,6 +80,7 @@ def _retrieve_model_package_arn( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: @@ -93,7 +96,10 @@ def _retrieve_model_package_arn( if instance_specific_arn is not None: return instance_specific_arn - if model_specs.hosting_model_package_arns is None: + if ( + model_specs.hosting_model_package_arns is None + or model_specs.hosting_model_package_arns == {} + ): return None regional_arn = model_specs.hosting_model_package_arns.get(region) @@ -118,6 +124,7 @@ def _retrieve_model_package_model_artifact_s3_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> Optional[str]: """Retrieves s3 artifact uri associated with model package. @@ -141,6 +148,7 @@ def _retrieve_model_package_model_artifact_s3_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the model package artifact uri to use for the model or None. @@ -162,6 +170,7 @@ def _retrieve_model_package_model_artifact_s3_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if model_specs.training_model_package_artifact_uris is None: diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 6bb2e576fc..00c6d8b9aa 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -95,6 +95,7 @@ def _retrieve_model_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ): """Retrieves the model artifact S3 URI for the model matching the given arguments. @@ -120,6 +121,8 @@ def _retrieve_model_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + Returns: str: the model artifact S3 URI for the corresponding model. @@ -141,6 +144,7 @@ def _retrieve_model_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) model_artifact_key: str @@ -182,6 +186,7 @@ def _model_supports_training_model_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> bool: """Returns True if the model supports training with model uri field. @@ -203,6 +208,7 @@ def _model_supports_training_model_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: bool: the support status for model uri with training. """ @@ -219,6 +225,7 @@ def _model_supports_training_model_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.use_training_model_artifact() diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 3359e32732..2f4a8bb0ac 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -37,6 +37,7 @@ def _retrieve_example_payloads( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> Optional[Dict[str, JumpStartSerializablePayload]]: """Returns example payloads. @@ -58,6 +59,7 @@ def _retrieve_example_payloads( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: Optional[Dict[str, JumpStartSerializablePayload]]: dictionary mapping payload aliases to the serializable payload object. @@ -76,6 +78,7 @@ def _retrieve_example_payloads( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) default_payloads = model_specs.default_payloads diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 4f6dfe1fe3..635f063e05 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -78,6 +78,7 @@ def _retrieve_default_deserializer( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> BaseDeserializer: """Retrieves the default deserializer for the model. @@ -98,6 +99,7 @@ def _retrieve_default_deserializer( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: BaseDeserializer: the default deserializer to use for the model. @@ -111,6 +113,7 @@ def _retrieve_default_deserializer( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) return _retrieve_deserializer_from_accept_type(MIMEType.from_suffixed_type(default_accept_type)) @@ -124,6 +127,7 @@ def _retrieve_default_serializer( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> BaseSerializer: """Retrieves the default serializer for the model. @@ -144,6 +148,7 @@ def _retrieve_default_serializer( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: BaseSerializer: the default serializer to use for the model. """ @@ -156,6 +161,7 @@ def _retrieve_default_serializer( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) return _retrieve_serializer_from_content_type(MIMEType.from_suffixed_type(default_content_type)) @@ -169,6 +175,7 @@ def _retrieve_deserializer_options( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> List[BaseDeserializer]: """Retrieves the supported deserializers for the model. @@ -189,6 +196,7 @@ def _retrieve_deserializer_options( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: List[BaseDeserializer]: the supported deserializers to use for the model. """ @@ -201,6 +209,7 @@ def _retrieve_deserializer_options( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) seen_classes: Set[Type] = set() @@ -227,6 +236,7 @@ def _retrieve_serializer_options( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> List[BaseSerializer]: """Retrieves the supported serializers for the model. @@ -247,6 +257,7 @@ def _retrieve_serializer_options( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: List[BaseSerializer]: the supported serializers to use for the model. """ @@ -258,6 +269,7 @@ def _retrieve_serializer_options( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) seen_classes: Set[Type] = set() @@ -285,6 +297,7 @@ def _retrieve_default_content_type( tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> str: """Retrieves the default content type for the model. @@ -305,6 +318,7 @@ def _retrieve_default_content_type( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the default content type to use for the model. """ @@ -322,6 +336,7 @@ def _retrieve_default_content_type( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) default_content_type = model_specs.predictor_specs.default_content_type @@ -336,6 +351,7 @@ def _retrieve_default_accept_type( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default accept type for the model. @@ -356,6 +372,7 @@ def _retrieve_default_accept_type( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the default accept type to use for the model. """ @@ -373,6 +390,7 @@ def _retrieve_default_accept_type( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) default_accept_type = model_specs.predictor_specs.default_accept_type @@ -388,6 +406,7 @@ def _retrieve_supported_accept_types( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> List[str]: """Retrieves the supported accept types for the model. @@ -408,6 +427,7 @@ def _retrieve_supported_accept_types( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: the supported accept types to use for the model. """ @@ -425,6 +445,7 @@ def _retrieve_supported_accept_types( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) supported_accept_types = model_specs.predictor_specs.supported_accept_types @@ -440,6 +461,7 @@ def _retrieve_supported_content_types( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> List[str]: """Retrieves the supported content types for the model. @@ -460,6 +482,7 @@ def _retrieve_supported_content_types( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: the supported content types to use for the model. """ @@ -477,6 +500,7 @@ def _retrieve_supported_content_types( tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) supported_content_types = model_specs.predictor_specs.supported_content_types diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index cffd46d043..b4fdac770b 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -35,6 +35,8 @@ def _retrieve_resource_name_base( tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + scope: JumpStartScriptScope = JumpStartScriptScope.INFERENCE, + config_name: Optional[str] = None, ) -> bool: """Returns default resource name. @@ -56,6 +58,7 @@ def _retrieve_resource_name_base( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config. (Default: None). Returns: str: the default resource name. """ @@ -67,12 +70,13 @@ def _retrieve_resource_name_base( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, - scope=JumpStartScriptScope.INFERENCE, + scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, model_type=model_type, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.resource_name_base diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 369acac85f..49126da336 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -54,6 +54,7 @@ def _retrieve_default_resources( model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> ResourceRequirements: """Retrieves the default resource requirements for the model. @@ -79,6 +80,7 @@ def _retrieve_default_resources( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get host requirements specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default resource requirements to use for the model or None. @@ -102,6 +104,7 @@ def _retrieve_default_resources( tolerate_deprecated_model=tolerate_deprecated_model, model_type=model_type, sagemaker_session=sagemaker_session, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index f69732d2e0..97313ec626 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -37,6 +37,7 @@ def _retrieve_script_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ): """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -62,6 +63,7 @@ def _retrieve_script_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the model script URI for the corresponding model. @@ -83,6 +85,7 @@ def _retrieve_script_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if script_scope == JumpStartScriptScope.INFERENCE: @@ -108,6 +111,7 @@ def _model_supports_inference_script_uri( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> bool: """Returns True if the model supports inference with script uri field. @@ -145,6 +149,7 @@ def _model_supports_inference_script_uri( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.use_inference_script_uri() diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index ca49fd41a3..9666ce828f 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -93,6 +93,9 @@ class JumpStartTag(str, Enum): MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version" MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type" + INFERENCE_CONFIG_NAME = "sagemaker-sdk:jumpstart-inference-config-name" + TRAINING_CONFIG_NAME = "sagemaker-sdk:jumpstart-training-config-name" + class SerializerType(str, Enum): """Enum class for serializers associated with JumpStart models.""" diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index f53d109dc8..5f7e0ed82c 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -33,8 +33,10 @@ from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs from sagemaker.jumpstart.factory.model import get_default_predictor -from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job +from sagemaker.jumpstart.session_utils import get_model_info_from_training_job +from sagemaker.jumpstart.types import JumpStartMetadataConfig from sagemaker.jumpstart.utils import ( + get_jumpstart_configs, validate_model_id_and_get_type, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, @@ -109,6 +111,7 @@ def __init__( container_arguments: Optional[List[str]] = None, disable_output_compression: Optional[bool] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + config_name: Optional[str] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, ): """Initializes a ``JumpStartEstimator``. @@ -501,6 +504,8 @@ def __init__( to Amazon S3 without compression after training finishes. enable_remote_debug (bool or PipelineVariable): Optional. Specifies whether RemoteDebug is enabled for the training job + config_name (Optional[str]): + Name of the training configuration to apply to the Estimator. (Default: None). enable_session_tag_chaining (bool or PipelineVariable): Optional. Specifies whether SessionTagChaining is enabled for the training job @@ -581,6 +586,7 @@ def _validate_model_id_and_get_type_hook(): disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, + config_name=config_name, enable_session_tag_chaining=enable_session_tag_chaining, ) @@ -595,6 +601,8 @@ def _validate_model_id_and_get_type_hook(): self.role = estimator_init_kwargs.role self.sagemaker_session = estimator_init_kwargs.sagemaker_session self._enable_network_isolation = estimator_init_kwargs.enable_network_isolation + self.config_name = estimator_init_kwargs.config_name + self.init_kwargs = estimator_init_kwargs.to_kwargs_dict(False) super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict()) @@ -669,6 +677,7 @@ def fit( tolerate_vulnerable_model=self.tolerate_vulnerable_model, tolerate_deprecated_model=self.tolerate_deprecated_model, sagemaker_session=self.sagemaker_session, + config_name=self.config_name, ) return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict()) @@ -681,6 +690,7 @@ def attach( model_version: Optional[str] = None, sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_channel_name: str = "model", + config_name: Optional[str] = None, ) -> "JumpStartEstimator": """Attach to an existing training job. @@ -716,6 +726,8 @@ def attach( model data will be downloaded (default: 'model'). If no channel with the same name exists in the training job, this option will be ignored. + config_name (str): Optional. Name of the training configuration to use + when attaching to the training job. (Default: None). Returns: Instance of the calling ``JumpStartEstimator`` Class with the attached @@ -725,10 +737,9 @@ def attach( ValueError: if the model ID or version cannot be inferred from the training job. """ - + config_name = None if model_id is None: - - model_id, model_version = get_model_id_version_from_training_job( + model_id, model_version, _, config_name = get_model_info_from_training_job( training_job_name=training_job_name, sagemaker_session=sagemaker_session ) @@ -741,6 +752,9 @@ def attach( "tolerate_deprecated_model": True, # model is already trained } + if config_name: + additional_kwargs.update({"config_name": config_name}) + model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, @@ -749,6 +763,7 @@ def attach( tolerate_deprecated_model=True, # model is already trained, so tolerate if deprecated tolerate_vulnerable_model=True, # model is already trained, so tolerate if vulnerable sagemaker_session=sagemaker_session, + config_name=config_name, ) # eula was already accepted if the model was successfully trained @@ -798,6 +813,7 @@ def deploy( dependencies: Optional[List[str]] = None, git_config: Optional[Dict[str, str]] = None, use_compiled_model: bool = False, + inference_config_name: Optional[str] = None, ) -> PredictorBase: """Creates endpoint from training job. @@ -1033,6 +1049,8 @@ def deploy( (Default: None). use_compiled_model (bool): Flag to select whether to use compiled (optimized) model. (Default: False). + inference_config_name (Optional[str]): Name of the inference configuration to + be used in the model. (Default: None). """ self.orig_predictor_cls = predictor_cls @@ -1085,6 +1103,8 @@ def deploy( git_config=git_config, use_compiled_model=use_compiled_model, training_instance_type=self.instance_type, + training_config_name=self.config_name, + inference_config_name=inference_config_name, ) predictor = super(JumpStartEstimator, self).deploy( @@ -1101,11 +1121,43 @@ def deploy( tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, + config_name=estimator_deploy_kwargs.config_name, ) # If a predictor class was passed, do not mutate predictor return predictor + def list_training_configs(self) -> List[JumpStartMetadataConfig]: + """Returns a list of configs associated with the estimator. + + Raises: + ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. + """ + configs_dict = get_jumpstart_configs( + model_id=self.model_id, + model_version=self.model_version, + model_type=self.model_type, + region=self.region, + scope=JumpStartScriptScope.TRAINING, + sagemaker_session=self.sagemaker_session, + ) + return list(configs_dict.values()) + + def set_training_config(self, config_name: str) -> None: + """Sets the config to apply to the model. + + Args: + config_name (str): The name of the config. + """ + self.__init__( + model_id=self.model_id, + model_version=self.model_version, + config_name=config_name, + ) + def __str__(self) -> str: """Overriding str(*) method to make more human-readable.""" return stringify_object(self) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 70b205bc74..c936b2f5eb 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -60,7 +60,7 @@ JumpStartModelInitKwargs, ) from sagemaker.jumpstart.utils import ( - add_jumpstart_model_id_version_tags, + add_jumpstart_model_info_tags, get_eula_message, get_default_jumpstart_session_with_user_agent_suffix, update_dict_if_key_not_present, @@ -130,6 +130,7 @@ def get_init_kwargs( disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + config_name: Optional[str] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, ) -> JumpStartEstimatorInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object.""" @@ -189,6 +190,7 @@ def get_init_kwargs( disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, + config_name=config_name, enable_session_tag_chaining=enable_session_tag_chaining, ) @@ -207,6 +209,7 @@ def get_init_kwargs( estimator_init_kwargs = _add_role_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_env_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_tags_to_kwargs(estimator_init_kwargs) + estimator_init_kwargs = _add_config_name_to_kwargs(estimator_init_kwargs) return estimator_init_kwargs @@ -223,6 +226,7 @@ def get_fit_kwargs( tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, + config_name: Optional[str] = None, ) -> JumpStartEstimatorFitKwargs: """Returns kwargs required call `fit` on `sagemaker.estimator.Estimator` object.""" @@ -238,6 +242,7 @@ def get_fit_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) estimator_fit_kwargs = _add_model_version_to_kwargs(estimator_fit_kwargs) @@ -289,6 +294,8 @@ def get_deploy_kwargs( use_compiled_model: Optional[bool] = None, model_name: Optional[str] = None, training_instance_type: Optional[str] = None, + training_config_name: Optional[str] = None, + inference_config_name: Optional[str] = None, ) -> JumpStartEstimatorDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Estimator` object.""" @@ -316,6 +323,8 @@ def get_deploy_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + training_config_name=training_config_name, + config_name=inference_config_name, ) model_init_kwargs: JumpStartModelInitKwargs = model.get_init_kwargs( @@ -344,6 +353,7 @@ def get_deploy_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, training_instance_type=training_instance_type, disable_instance_type_logging=True, + config_name=model_deploy_kwargs.config_name, ) estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs( @@ -388,6 +398,7 @@ def get_deploy_kwargs( tolerate_vulnerable_model=model_init_kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=model_init_kwargs.tolerate_deprecated_model, use_compiled_model=use_compiled_model, + config_name=model_deploy_kwargs.config_name, ) return estimator_deploy_kwargs @@ -448,6 +459,7 @@ def _add_instance_type_and_count_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) kwargs.instance_count = kwargs.instance_count or 1 @@ -471,11 +483,16 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ).version if kwargs.sagemaker_session.settings.include_jumpstart_tags: - kwargs.tags = add_jumpstart_model_id_version_tags( - kwargs.tags, kwargs.model_id, full_model_version + kwargs.tags = add_jumpstart_model_info_tags( + kwargs.tags, + kwargs.model_id, + full_model_version, + config_name=kwargs.config_name, + scope=JumpStartScriptScope.TRAINING, ) return kwargs @@ -493,6 +510,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) return kwargs @@ -518,6 +536,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE sagemaker_session=kwargs.sagemaker_session, region=kwargs.region, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) if ( @@ -530,6 +549,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) ): JUMPSTART_LOGGER.warning( @@ -565,6 +585,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, region=kwargs.region, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) return kwargs @@ -585,6 +606,7 @@ def _add_env_to_kwargs( sagemaker_session=kwargs.sagemaker_session, script=JumpStartScriptScope.TRAINING, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri( @@ -595,6 +617,7 @@ def _add_env_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) if model_package_artifact_uri: @@ -622,6 +645,7 @@ def _add_env_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) if model_specs.is_gated_model(): raise ValueError( @@ -651,9 +675,11 @@ def _add_training_job_name_to_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, region=kwargs.region, + scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) kwargs.job_name = kwargs.job_name or ( @@ -680,6 +706,7 @@ def _add_hyperparameters_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) for key, value in default_hyperparameters.items(): @@ -713,6 +740,7 @@ def _add_metric_definitions_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) or [] ) @@ -742,6 +770,7 @@ def _add_estimator_extra_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) for key, value in estimator_kwargs_to_add.items(): @@ -766,6 +795,7 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) for key, value in fit_kwargs_to_add.items(): @@ -773,3 +803,27 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim setattr(kwargs, key, value) return kwargs + + +def _add_config_name_to_kwargs( + kwargs: JumpStartEstimatorInitKwargs, +) -> JumpStartEstimatorInitKwargs: + """Sets tags in kwargs based on default or override, returns full kwargs.""" + + specs = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + scope=JumpStartScriptScope.TRAINING, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, + ) + + if specs.training_configs and specs.training_configs.get_top_config_from_ranking(): + kwargs.config_name = ( + kwargs.config_name or specs.training_configs.get_top_config_from_ranking().config_name + ) + + return kwargs diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 89b0578342..6cdb3d8382 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -41,9 +41,10 @@ JumpStartModelDeployKwargs, JumpStartModelInitKwargs, JumpStartModelRegisterKwargs, + JumpStartModelSpecs, ) from sagemaker.jumpstart.utils import ( - add_jumpstart_model_id_version_tags, + add_jumpstart_model_info_tags, get_default_jumpstart_session_with_user_agent_suffix, update_dict_if_key_not_present, resolve_model_sagemaker_config_field, @@ -72,6 +73,7 @@ def get_default_predictor( tolerate_deprecated_model: bool, sagemaker_session: Session, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> Predictor: """Converts predictor returned from ``Model.deploy()`` into a JumpStart-specific one. @@ -94,6 +96,7 @@ def get_default_predictor( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) predictor.deserializer = deserializers.retrieve_default( model_id=model_id, @@ -103,6 +106,7 @@ def get_default_predictor( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) predictor.accept = accept_types.retrieve_default( model_id=model_id, @@ -112,6 +116,7 @@ def get_default_predictor( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) predictor.content_type = content_types.retrieve_default( model_id=model_id, @@ -121,6 +126,7 @@ def get_default_predictor( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) return predictor @@ -189,7 +195,6 @@ def _add_instance_type_to_kwargs( """Sets instance type based on default or override, returns full kwargs.""" orig_instance_type = kwargs.instance_type - kwargs.instance_type = kwargs.instance_type or instance_types.retrieve_default( region=kwargs.region, model_id=kwargs.model_id, @@ -200,6 +205,7 @@ def _add_instance_type_to_kwargs( sagemaker_session=kwargs.sagemaker_session, training_instance_type=kwargs.training_instance_type, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) if not disable_instance_type_logging and orig_instance_type is None: @@ -231,6 +237,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) return kwargs @@ -252,6 +259,7 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"): @@ -292,6 +300,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ): source_dir = source_dir or script_uris.retrieve( script_scope=JumpStartScriptScope.INFERENCE, @@ -301,6 +310,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ) kwargs.source_dir = source_dir @@ -324,6 +334,7 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + config_name=kwargs.config_name, ): entry_point = entry_point or INFERENCE_ENTRY_POINT_SCRIPT_NAME @@ -355,6 +366,7 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw sagemaker_session=kwargs.sagemaker_session, script=JumpStartScriptScope.INFERENCE, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, ) for key, value in extra_env_vars.items(): @@ -385,6 +397,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) kwargs.model_package_arn = model_package_arn @@ -402,6 +415,7 @@ def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelI tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) for key, value in model_kwargs_to_add.items(): @@ -438,6 +452,7 @@ def _add_endpoint_name_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) kwargs.endpoint_name = kwargs.endpoint_name or ( @@ -460,6 +475,7 @@ def _add_model_name_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) kwargs.name = kwargs.name or ( @@ -481,11 +497,17 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: tolerate_deprecated_model=kwargs.tolerate_deprecated_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ).version if kwargs.sagemaker_session.settings.include_jumpstart_tags: - kwargs.tags = add_jumpstart_model_id_version_tags( - kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type + kwargs.tags = add_jumpstart_model_info_tags( + kwargs.tags, + kwargs.model_id, + full_model_version, + kwargs.model_type, + config_name=kwargs.config_name, + scope=JumpStartScriptScope.INFERENCE, ) return kwargs @@ -503,6 +525,7 @@ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any] tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, + config_name=kwargs.config_name, ) for key, value in deploy_kwargs_to_add.items(): @@ -525,8 +548,106 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel sagemaker_session=kwargs.sagemaker_session, model_type=kwargs.model_type, instance_type=kwargs.instance_type, + config_name=kwargs.config_name, + ) + + return kwargs + + +def _select_inference_config_from_training_config( + specs: JumpStartModelSpecs, training_config_name: str +) -> Optional[str]: + """Selects the inference config from the training config. + + Args: + specs (JumpStartModelSpecs): The specs for the model. + training_config_name (str): The name of the training config. + + Returns: + str: The name of the inference config. + """ + if specs.training_configs: + resolved_training_config = specs.training_configs.configs.get(training_config_name) + if resolved_training_config: + return resolved_training_config.default_inference_config + + return None + + +def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: + """Sets default config name to the kwargs. Returns full kwargs. + + Raises: + ValueError: If the instance_type is not supported with the current config. + """ + + specs = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + scope=JumpStartScriptScope.INFERENCE, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + config_name=kwargs.config_name, + ) + if specs.inference_configs: + default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name + kwargs.config_name = kwargs.config_name or default_config_name + + if not kwargs.config_name: + return kwargs + + if kwargs.config_name not in set(specs.inference_configs.configs.keys()): + raise ValueError( + f"Config {kwargs.config_name} is not supported for model {kwargs.model_id}." + ) + + resolved_config = specs.inference_configs.configs[kwargs.config_name].resolved_config + supported_instance_types = resolved_config.get("supported_inference_instance_types", []) + if kwargs.instance_type not in supported_instance_types: + raise ValueError( + f"Instance type {kwargs.instance_type} " + f"is not supported for config {kwargs.config_name}." + ) + + return kwargs + + +def _add_config_name_to_deploy_kwargs( + kwargs: JumpStartModelDeployKwargs, training_config_name: Optional[str] = None +) -> JumpStartModelInitKwargs: + """Sets default config name to the kwargs. Returns full kwargs. + + If a training_config_name is passed, then choose the inference config + based on the supported inference configs in that training config. + + Raises: + ValueError: If the instance_type is not supported with the current config. + """ + + specs = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + scope=JumpStartScriptScope.INFERENCE, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + config_name=kwargs.config_name, ) + if training_config_name: + kwargs.config_name = _select_inference_config_from_training_config( + specs=specs, training_config_name=training_config_name + ) + + if specs.inference_configs: + default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name + kwargs.config_name = kwargs.config_name or default_config_name + return kwargs @@ -560,6 +681,8 @@ def get_deploy_kwargs( resources: Optional[ResourceRequirements] = None, managed_instance_scaling: Optional[str] = None, endpoint_type: Optional[EndpointType] = None, + training_config_name: Optional[str] = None, + config_name: Optional[str] = None, routing_config: Optional[Dict[str, Any]] = None, ) -> JumpStartModelDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object.""" @@ -592,6 +715,7 @@ def get_deploy_kwargs( accept_eula=accept_eula, endpoint_logging=endpoint_logging, resources=resources, + config_name=config_name, routing_config=routing_config, ) @@ -601,6 +725,10 @@ def get_deploy_kwargs( deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs) + deploy_kwargs = _add_config_name_to_deploy_kwargs( + kwargs=deploy_kwargs, training_config_name=training_config_name + ) + deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs) deploy_kwargs.initial_instance_count = initial_instance_count or 1 @@ -646,6 +774,7 @@ def get_register_kwargs( data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, source_uri: Optional[str] = None, + config_name: Optional[str] = None, ) -> JumpStartModelRegisterKwargs: """Returns kwargs required to call `register` on `sagemaker.estimator.Model` object.""" @@ -688,6 +817,7 @@ def get_register_kwargs( sagemaker_session=sagemaker_session, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + config_name=config_name, ) register_kwargs.content_types = ( @@ -730,6 +860,7 @@ def get_init_kwargs( training_instance_type: Optional[str] = None, disable_instance_type_logging: bool = False, resources: Optional[ResourceRequirements] = None, + config_name: Optional[str] = None, ) -> JumpStartModelInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Model` object.""" @@ -761,6 +892,7 @@ def get_init_kwargs( model_package_arn=model_package_arn, training_instance_type=training_instance_type, resources=resources, + config_name=config_name, ) model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) @@ -791,4 +923,6 @@ def get_init_kwargs( model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) + return model_init_kwargs diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 994193de3e..f72a3140dc 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -14,7 +14,8 @@ from __future__ import absolute_import -from typing import Dict, List, Optional, Union, Any +from typing import Dict, List, Optional, Any, Union +import pandas as pd from botocore.exceptions import ClientError from sagemaker import payloads @@ -36,10 +37,18 @@ get_init_kwargs, get_register_kwargs, ) -from sagemaker.jumpstart.types import JumpStartSerializablePayload +from sagemaker.jumpstart.types import ( + JumpStartSerializablePayload, + DeploymentConfigMetadata, +) from sagemaker.jumpstart.utils import ( validate_model_id_and_get_type, verify_model_region_and_return_specs, + get_jumpstart_configs, + get_metrics_from_deployment_configs, + add_instance_rate_stats_to_benchmark_metrics, + deployment_config_response_data, + _deployment_config_lru_cache, ) from sagemaker.jumpstart.constants import JUMPSTART_LOGGER from sagemaker.jumpstart.enums import JumpStartModelType @@ -92,6 +101,7 @@ def __init__( git_config: Optional[Dict[str, str]] = None, model_package_arn: Optional[str] = None, resources: Optional[ResourceRequirements] = None, + config_name: Optional[str] = None, ): """Initializes a ``JumpStartModel``. @@ -277,6 +287,8 @@ def __init__( for a model to be deployed to an endpoint. Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature. (Default: None). + config_name (Optional[str]): The name of the JumpStartConfig that can be + optionally applied to the model and override corresponding fields. Raises: ValueError: If the model ID is not recognized by JumpStart. """ @@ -326,6 +338,7 @@ def _validate_model_id_and_type(): git_config=git_config, model_package_arn=model_package_arn, resources=resources, + config_name=config_name, ) self.orig_predictor_cls = predictor_cls @@ -338,6 +351,7 @@ def _validate_model_id_and_type(): self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model self.region = model_init_kwargs.region self.sagemaker_session = model_init_kwargs.sagemaker_session + self.config_name = model_init_kwargs.config_name if self.model_type == JumpStartModelType.PROPRIETARY: self.log_subscription_warning() @@ -345,6 +359,15 @@ def _validate_model_id_and_type(): super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) self.model_package_arn = model_init_kwargs.model_package_arn + self.init_kwargs = model_init_kwargs.to_kwargs_dict(False) + + self._metadata_configs = get_jumpstart_configs( + region=self.region, + model_id=self.model_id, + model_version=self.model_version, + sagemaker_session=self.sagemaker_session, + model_type=self.model_type, + ) def log_subscription_warning(self) -> None: """Log message prompting the customer to subscribe to the proprietary model.""" @@ -402,6 +425,70 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload: sagemaker_session=self.sagemaker_session, ) + def set_deployment_config(self, config_name: str, instance_type: str) -> None: + """Sets the deployment config to apply to the model. + + Args: + config_name (str): + The name of the deployment config to apply to the model. + Call list_deployment_configs to see the list of config names. + instance_type (str): + The instance_type that the model will use after setting + the config. + """ + self.__init__( + model_id=self.model_id, + model_version=self.model_version, + instance_type=instance_type, + config_name=config_name, + ) + + @property + def deployment_config(self) -> Optional[Dict[str, Any]]: + """The deployment config that will be applied to ``This`` model. + + Returns: + Optional[Dict[str, Any]]: Deployment config. + """ + if self.config_name is None: + return None + for config in self.list_deployment_configs(): + if config.get("DeploymentConfigName") == self.config_name: + return config + return None + + @property + def benchmark_metrics(self) -> pd.DataFrame: + """Benchmark Metrics for deployment configs. + + Returns: + Benchmark Metrics: Pandas DataFrame object. + """ + df = pd.DataFrame(self._get_deployment_configs_benchmarks_data()) + blank_index = [""] * len(df) + df.index = blank_index + return df + + def display_benchmark_metrics(self, **kwargs) -> None: + """Display deployment configs benchmark metrics.""" + df = self.benchmark_metrics + + instance_type = kwargs.get("instance_type") + if instance_type: + df = df[df["Instance Type"].str.contains(instance_type)] + + print(df.to_markdown(index=False, floatfmt=".2f")) + + def list_deployment_configs(self) -> List[Dict[str, Any]]: + """List deployment configs for ``This`` model. + + Returns: + List[Dict[str, Any]]: A list of deployment configs. + """ + return deployment_config_response_data( + self._get_deployment_configs(self.config_name, self.instance_type) + ) + def _create_sagemaker_model( self, instance_type=None, @@ -628,6 +715,7 @@ def deploy( managed_instance_scaling=managed_instance_scaling, endpoint_type=endpoint_type, model_type=self.model_type, + config_name=self.config_name, routing_config=routing_config, ) if ( @@ -648,6 +736,7 @@ def deploy( model_type=self.model_type, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=self.sagemaker_session, + config_name=self.config_name, ).model_subscription_link get_proprietary_model_subscription_error(e, subscription_link) raise @@ -663,6 +752,7 @@ def deploy( tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, model_type=self.model_type, + config_name=self.config_name, ) # If a predictor class was passed, do not mutate predictor @@ -773,6 +863,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + config_name=self.config_name, ) model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict()) @@ -790,6 +881,89 @@ def register_deploy_wrapper(*args, **kwargs): return model_package + @_deployment_config_lru_cache + def _get_deployment_configs_benchmarks_data(self) -> Dict[str, Any]: + """Deployment configs benchmark metrics. + + Returns: + Dict[str, List[str]]: Deployment config benchmark data. + """ + return get_metrics_from_deployment_configs( + self._get_deployment_configs(None, None), + ) + + @_deployment_config_lru_cache + def _get_deployment_configs( + self, selected_config_name: Optional[str], selected_instance_type: Optional[str] + ) -> List[DeploymentConfigMetadata]: + """Retrieve deployment configs metadata. + + Args: + selected_config_name (Optional[str]): The name of the selected deployment config. + selected_instance_type (Optional[str]): The selected instance type. + """ + deployment_configs = [] + if not self._metadata_configs: + return deployment_configs + + err = None + for config_name, metadata_config in self._metadata_configs.items(): + if selected_config_name == config_name: + instance_type_to_use = selected_instance_type + else: + instance_type_to_use = metadata_config.resolved_config.get( + "default_inference_instance_type" + ) + + if metadata_config.benchmark_metrics: + err, metadata_config.benchmark_metrics = ( + add_instance_rate_stats_to_benchmark_metrics( + self.region, metadata_config.benchmark_metrics + ) + ) + + config_components = metadata_config.config_components.get(config_name) + image_uri = ( + ( + config_components.hosting_instance_type_variants.get("regional_aliases", {}) + .get(self.region, {}) + .get("alias_ecr_uri_1") + ) + if config_components + else self.image_uri + ) + + init_kwargs = get_init_kwargs( + config_name=config_name, + model_id=self.model_id, + instance_type=instance_type_to_use, + sagemaker_session=self.sagemaker_session, + image_uri=image_uri, + region=self.region, + model_version=self.model_version, + ) + deploy_kwargs = get_deploy_kwargs( + model_id=self.model_id, + instance_type=instance_type_to_use, + sagemaker_session=self.sagemaker_session, + region=self.region, + model_version=self.model_version, + ) + + deployment_config_metadata = DeploymentConfigMetadata( + config_name, + metadata_config, + init_kwargs, + deploy_kwargs, + ) + deployment_configs.append(deployment_config_metadata) + + if err and err["Code"] == "AccessDeniedException": + error_message = "Instance rate metrics will be omitted. Reason: %s" + JUMPSTART_LOGGER.warning(error_message, err["Message"]) + + return deployment_configs + def __str__(self) -> str: """Overriding str(*) method to make more human-readable.""" return stringify_object(self) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 83613cd71b..781548b42a 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -535,6 +535,7 @@ def get_model_url( model_version: str, region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> str: """Retrieve web url describing pretrained model. @@ -563,5 +564,6 @@ def get_model_url( sagemaker_session=sagemaker_session, scope=JumpStartScriptScope.INFERENCE, model_type=model_type, + config_name=config_name, ) return model_specs.url diff --git a/src/sagemaker/jumpstart/session_utils.py b/src/sagemaker/jumpstart/session_utils.py index e511a052d1..0955ae9480 100644 --- a/src/sagemaker/jumpstart/session_utils.py +++ b/src/sagemaker/jumpstart/session_utils.py @@ -17,17 +17,17 @@ from typing import Optional, Tuple from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn +from sagemaker.jumpstart.utils import get_jumpstart_model_info_from_resource_arn from sagemaker.session import Session from sagemaker.utils import aws_partition -def get_model_id_version_from_endpoint( +def get_model_info_from_endpoint( endpoint_name: str, inference_component_name: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[str, str, Optional[str]]: - """Given an endpoint and optionally inference component names, return the model ID and version. +) -> Tuple[str, str, Optional[str], Optional[str], Optional[str]]: + """Optionally inference component names, return the model ID, version and config name. Infers the model ID and version based on the resource tags. Returns a tuple of the model ID and version. A third string element is included in the tuple for any inferred inference @@ -46,7 +46,9 @@ def get_model_id_version_from_endpoint( ( model_id, model_version, - ) = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301 + inference_config_name, + training_config_name, + ) = _get_model_info_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301 inference_component_name, sagemaker_session ) @@ -54,22 +56,35 @@ def get_model_id_version_from_endpoint( ( model_id, model_version, + inference_config_name, + training_config_name, inference_component_name, - ) = _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301 + ) = _get_model_info_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301 endpoint_name, sagemaker_session ) else: - model_id, model_version = _get_model_id_version_from_model_based_endpoint( + ( + model_id, + model_version, + inference_config_name, + training_config_name, + ) = _get_model_info_from_model_based_endpoint( endpoint_name, inference_component_name, sagemaker_session ) - return model_id, model_version, inference_component_name + return ( + model_id, + model_version, + inference_component_name, + inference_config_name, + training_config_name, + ) -def _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( +def _get_model_info_from_inference_component_endpoint_without_inference_component_name( endpoint_name: str, sagemaker_session: Session -) -> Tuple[str, str, str]: - """Given an endpoint name, derives the model ID, version, and inferred inference component name. +) -> Tuple[str, str, str, str]: + """Derives the model ID, version, config name and inferred inference component name. This function assumes the endpoint corresponds to an inference-component-based endpoint. An endpoint is inference-component-based if and only if the associated endpoint config @@ -98,14 +113,14 @@ def _get_model_id_version_from_inference_component_endpoint_without_inference_co ) inference_component_name = inference_component_names[0] return ( - *_get_model_id_version_from_inference_component_endpoint_with_inference_component_name( + *_get_model_info_from_inference_component_endpoint_with_inference_component_name( inference_component_name, sagemaker_session ), inference_component_name, ) -def _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( +def _get_model_info_from_inference_component_endpoint_with_inference_component_name( inference_component_name: str, sagemaker_session: Session ): """Returns the model ID and version inferred from a SageMaker inference component. @@ -123,9 +138,12 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo f"inference-component/{inference_component_name}" ) - model_id, model_version = get_jumpstart_model_id_version_from_resource_arn( - inference_component_arn, sagemaker_session - ) + ( + model_id, + model_version, + inference_config_name, + training_config_name, + ) = get_jumpstart_model_info_from_resource_arn(inference_component_arn, sagemaker_session) if not model_id: raise ValueError( @@ -134,15 +152,15 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo "when retrieving default predictor for this inference component." ) - return model_id, model_version + return model_id, model_version, inference_config_name, training_config_name -def _get_model_id_version_from_model_based_endpoint( +def _get_model_info_from_model_based_endpoint( endpoint_name: str, inference_component_name: Optional[str], sagemaker_session: Session, -) -> Tuple[str, str]: - """Returns the model ID and version inferred from a model-based endpoint. +) -> Tuple[str, str, Optional[str], Optional[str]]: + """Returns the model ID, version and config name inferred from a model-based endpoint. Raises: ValueError: If an inference component name is supplied, or if the endpoint does @@ -161,9 +179,12 @@ def _get_model_id_version_from_model_based_endpoint( endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}" - model_id, model_version = get_jumpstart_model_id_version_from_resource_arn( - endpoint_arn, sagemaker_session - ) + ( + model_id, + model_version, + inference_config_name, + training_config_name, + ) = get_jumpstart_model_info_from_resource_arn(endpoint_arn, sagemaker_session) if not model_id: raise ValueError( @@ -172,14 +193,14 @@ def _get_model_id_version_from_model_based_endpoint( "predictor for this endpoint." ) - return model_id, model_version + return model_id, model_version, inference_config_name, training_config_name -def get_model_id_version_from_training_job( +def get_model_info_from_training_job( training_job_name: str, sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[str, str]: - """Returns the model ID and version inferred from a training job. +) -> Tuple[str, str, Optional[str], Optional[str]]: + """Returns the model ID and version and config name inferred from a training job. Raises: ValueError: If the training job does not have tags from which the model ID @@ -194,9 +215,12 @@ def get_model_id_version_from_training_job( f"arn:{partition}:sagemaker:{region}:{account_id}:training-job/{training_job_name}" ) - model_id, inferred_model_version = get_jumpstart_model_id_version_from_resource_arn( - training_job_arn, sagemaker_session - ) + ( + model_id, + inferred_model_version, + inference_config_name, + training_config_name, + ) = get_jumpstart_model_info_from_resource_arn(training_job_arn, sagemaker_session) model_version = inferred_model_version or None @@ -207,4 +231,4 @@ def get_model_id_version_from_training_job( "for this training job." ) - return model_id, model_version + return model_id, model_version, inference_config_name, training_config_name diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index bea125d423..f197421d65 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -746,7 +746,7 @@ def _get_regional_property( class JumpStartBenchmarkStat(JumpStartDataHolderType): """Data class JumpStart benchmark stat.""" - __slots__ = ["name", "value", "unit"] + __slots__ = ["name", "value", "unit", "concurrency"] def __init__(self, spec: Dict[str, Any]): """Initializes a JumpStartBenchmarkStat object. @@ -765,6 +765,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.name: str = json_obj["name"] self.value: str = json_obj["value"] self.unit: Union[int, str] = json_obj["unit"] + self.concurrency: Union[int, str] = json_obj["concurrency"] def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartBenchmarkStat object.""" @@ -950,7 +951,10 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hosting_eula_key: Optional[str] = json_obj.get("hosting_eula_key") - self.hosting_model_package_arns: Optional[Dict] = json_obj.get("hosting_model_package_arns") + model_package_arns = json_obj.get("hosting_model_package_arns") + self.hosting_model_package_arns: Optional[Dict] = ( + model_package_arns if model_package_arns is not None else {} + ) self.hosting_use_script_uri: bool = json_obj.get("hosting_use_script_uri", True) self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( @@ -1074,30 +1078,57 @@ class JumpStartMetadataConfig(JumpStartDataHolderType): __slots__ = [ "base_fields", "benchmark_metrics", + "acceleration_configs", "config_components", "resolved_metadata_config", + "config_name", + "default_inference_config", + "default_incremental_training_config", + "supported_inference_configs", + "supported_incremental_training_configs", ] def __init__( self, + config_name: str, + config: Dict[str, Any], base_fields: Dict[str, Any], config_components: Dict[str, JumpStartConfigComponent], - benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]], ): """Initializes a JumpStartMetadataConfig object from its json representation. Args: + config_name (str): Name of the config, + config (Dict[str, Any]): + Dictionary representation of the config. base_fields (Dict[str, Any]): - The default base fields that are used to construct the final resolved config. + The default base fields that are used to construct the resolved config. config_components (Dict[str, JumpStartConfigComponent]): The list of components that are used to construct the resolved config. - benchmark_metrics (Dict[str, List[JumpStartBenchmarkStat]]): - The dictionary of benchmark metrics with name being the key. """ self.base_fields = base_fields self.config_components: Dict[str, JumpStartConfigComponent] = config_components - self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = benchmark_metrics + self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = ( + { + stat_name: [JumpStartBenchmarkStat(stat) for stat in stats] + for stat_name, stats in config.get("benchmark_metrics").items() + } + if config and config.get("benchmark_metrics") + else None + ) + self.acceleration_configs = config.get("acceleration_configs") self.resolved_metadata_config: Optional[Dict[str, Any]] = None + self.config_name: Optional[str] = config_name + self.default_inference_config: Optional[str] = config.get("default_inference_config") + self.default_incremental_training_config: Optional[str] = config.get( + "default_incremental_training_config" + ) + self.supported_inference_configs: Optional[List[str]] = config.get( + "supported_inference_configs" + ) + self.supported_incremental_training_configs: Optional[List[str]] = config.get( + "supported_incremental_training_configs" + ) def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartMetadataConfig object.""" @@ -1121,6 +1152,12 @@ def resolved_config(self) -> Dict[str, Any]: deepcopy(component.to_json()), component.OVERRIDING_DENY_LIST, ) + + # Remove environment variables from resolved config if using model packages + hosting_model_pacakge_arns = resolved_config.get("hosting_model_package_arns") + if hosting_model_pacakge_arns is not None and hosting_model_pacakge_arns != {}: + resolved_config["inference_environment_variables"] = [] + self.resolved_metadata_config = resolved_config return resolved_config @@ -1163,6 +1200,8 @@ def get_top_config_from_ranking( ) -> Optional[JumpStartMetadataConfig]: """Gets the best the config based on config ranking. + Fallback to use the ordering in config names if + ranking is not available. Args: ranking_name (str): The ranking name that config priority is based on. @@ -1170,13 +1209,8 @@ def get_top_config_from_ranking( The instance type which the config selection is based on. Raises: - ValueError: If the config exists but missing config ranking. NotImplementedError: If the scope is unrecognized. """ - if self.configs and ( - not self.config_rankings or not self.config_rankings.get(ranking_name) - ): - raise ValueError(f"Config exists but missing config ranking {ranking_name}.") if self.scope == JumpStartScriptScope.INFERENCE: instance_type_attribute = "supported_inference_instance_types" @@ -1185,8 +1219,14 @@ def get_top_config_from_ranking( else: raise NotImplementedError(f"Unknown script scope {self.scope}") - rankings = self.config_rankings.get(ranking_name) - for config_name in rankings.rankings: + if self.configs and ( + not self.config_rankings or not self.config_rankings.get(ranking_name) + ): + ranked_config_names = sorted(list(self.configs.keys())) + else: + rankings = self.config_rankings.get(ranking_name) + ranked_config_names = rankings.rankings + for config_name in ranked_config_names: resolved_config = self.configs[config_name].resolved_config if instance_type and instance_type not in getattr( resolved_config, instance_type_attribute @@ -1248,6 +1288,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = ( { alias: JumpStartMetadataConfig( + alias, + config, json_obj, ( { @@ -1257,14 +1299,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if config and config.get("component_names") else None ), - ( - { - stat_name: [JumpStartBenchmarkStat(stat) for stat in stats] - for stat_name, stats in config.get("benchmark_metrics").items() - } - if config and config.get("benchmark_metrics") - else None - ), ) for alias, config in json_obj["inference_configs"].items() } @@ -1300,6 +1334,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = ( { alias: JumpStartMetadataConfig( + alias, + config, json_obj, ( { @@ -1309,14 +1345,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if config and config.get("component_names") else None ), - ( - { - stat_name: [JumpStartBenchmarkStat(stat) for stat in stats] - for stat_name, stats in config.get("benchmark_metrics").items() - } - if config and config.get("benchmark_metrics") - else None - ), ) for alias, config in json_obj["training_configs"].items() } @@ -1464,11 +1492,11 @@ class JumpStartKwargs(JumpStartDataHolderType): SERIALIZATION_EXCLUSION_SET: Set[str] = set() - def to_kwargs_dict(self): + def to_kwargs_dict(self, exclude_keys: bool = True): """Serializes object to dictionary to be used for kwargs for method arguments.""" kwargs_dict = {} for field in self.__slots__: - if field not in self.SERIALIZATION_EXCLUSION_SET: + if exclude_keys and field not in self.SERIALIZATION_EXCLUSION_SET or not exclude_keys: att_value = getattr(self, field) if att_value is not None: kwargs_dict[field] = getattr(self, field) @@ -1506,6 +1534,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "model_package_arn", "training_instance_type", "resources", + "config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -1518,6 +1547,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "region", "model_package_arn", "training_instance_type", + "config_name", } def __init__( @@ -1549,6 +1579,7 @@ def __init__( model_package_arn: Optional[str] = None, training_instance_type: Optional[str] = None, resources: Optional[ResourceRequirements] = None, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartModelInitKwargs object.""" @@ -1579,6 +1610,7 @@ def __init__( self.model_package_arn = model_package_arn self.training_instance_type = training_instance_type self.resources = resources + self.config_name = config_name class JumpStartModelDeployKwargs(JumpStartKwargs): @@ -1614,6 +1646,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "endpoint_logging", "resources", "endpoint_type", + "config_name", "routing_config", ] @@ -1626,6 +1659,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "sagemaker_session", "training_instance_type", + "config_name", } def __init__( @@ -1659,6 +1693,7 @@ def __init__( endpoint_logging: Optional[bool] = None, resources: Optional[ResourceRequirements] = None, endpoint_type: Optional[EndpointType] = None, + config_name: Optional[str] = None, routing_config: Optional[Dict[str, Any]] = None, ) -> None: """Instantiates JumpStartModelDeployKwargs object.""" @@ -1692,6 +1727,7 @@ def __init__( self.endpoint_logging = endpoint_logging self.resources = resources self.endpoint_type = endpoint_type + self.config_name = config_name self.routing_config = routing_config @@ -1753,6 +1789,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "disable_output_compression", "enable_infra_check", "enable_remote_debug", + "config_name", "enable_session_tag_chaining", ] @@ -1763,6 +1800,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_id", "model_version", "model_type", + "config_name", } def __init__( @@ -1821,6 +1859,7 @@ def __init__( disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + config_name: Optional[str] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -1881,6 +1920,7 @@ def __init__( self.disable_output_compression = disable_output_compression self.enable_infra_check = enable_infra_check self.enable_remote_debug = enable_remote_debug + self.config_name = config_name self.enable_session_tag_chaining = enable_session_tag_chaining @@ -1900,6 +1940,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "tolerate_deprecated_model", "tolerate_vulnerable_model", "sagemaker_session", + "config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -1910,6 +1951,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "tolerate_deprecated_model", "tolerate_vulnerable_model", "sagemaker_session", + "config_name", } def __init__( @@ -1926,6 +1968,7 @@ def __init__( tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -1941,6 +1984,7 @@ def __init__( self.tolerate_deprecated_model = tolerate_deprecated_model self.tolerate_vulnerable_model = tolerate_vulnerable_model self.sagemaker_session = sagemaker_session + self.config_name = config_name class JumpStartEstimatorDeployKwargs(JumpStartKwargs): @@ -1986,6 +2030,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "model_name", "use_compiled_model", + "config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -1995,6 +2040,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): "model_id", "model_version", "sagemaker_session", + "config_name", } def __init__( @@ -2038,6 +2084,7 @@ def __init__( tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, use_compiled_model: bool = False, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -2080,6 +2127,7 @@ def __init__( self.tolerate_deprecated_model = tolerate_deprecated_model self.tolerate_vulnerable_model = tolerate_vulnerable_model self.use_compiled_model = use_compiled_model + self.config_name = config_name class JumpStartModelRegisterKwargs(JumpStartKwargs): @@ -2114,6 +2162,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "data_input_configuration", "skip_model_validation", "source_uri", + "config_name", ] SERIALIZATION_EXCLUSION_SET = { @@ -2123,6 +2172,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "model_id", "model_version", "sagemaker_session", + "config_name", } def __init__( @@ -2155,6 +2205,7 @@ def __init__( data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, source_uri: Optional[str] = None, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartModelRegisterKwargs object.""" @@ -2187,3 +2238,124 @@ def __init__( self.data_input_configuration = data_input_configuration self.skip_model_validation = skip_model_validation self.source_uri = source_uri + self.config_name = config_name + + +class BaseDeploymentConfigDataHolder(JumpStartDataHolderType): + """Base class for Deployment Config Data.""" + + def _convert_to_pascal_case(self, attr_name: str) -> str: + """Converts a snake_case attribute name into a camelCased string. + + Args: + attr_name (str): The snake_case attribute name. + Returns: + str: The PascalCased attribute name. + """ + return attr_name.replace("_", " ").title().replace(" ", "") + + def to_json(self) -> Dict[str, Any]: + """Represents ``This`` object as JSON.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + att = self._convert_to_pascal_case(att) + json_obj[att] = self._val_to_json(cur_val) + return json_obj + + def _val_to_json(self, val: Any) -> Any: + """Converts the given value to JSON. + + Args: + val (Any): The value to convert. + Returns: + Any: The converted json value. + """ + if issubclass(type(val), JumpStartDataHolderType): + if isinstance(val, JumpStartBenchmarkStat): + val.name = val.name.replace("_", " ").title() + return val.to_json() + if isinstance(val, list): + list_obj = [] + for obj in val: + list_obj.append(self._val_to_json(obj)) + return list_obj + if isinstance(val, dict): + dict_obj = {} + for k, v in val.items(): + if isinstance(v, JumpStartDataHolderType): + dict_obj[self._convert_to_pascal_case(k)] = self._val_to_json(v) + else: + dict_obj[k] = self._val_to_json(v) + return dict_obj + return val + + +class DeploymentArgs(BaseDeploymentConfigDataHolder): + """Dataclass representing a Deployment Args.""" + + __slots__ = [ + "image_uri", + "model_data", + "model_package_arn", + "environment", + "instance_type", + "compute_resource_requirements", + "model_data_download_timeout", + "container_startup_health_check_timeout", + ] + + def __init__( + self, + init_kwargs: Optional[JumpStartModelInitKwargs] = None, + deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None, + resolved_config: Optional[Dict[str, Any]] = None, + ): + """Instantiates DeploymentArgs object.""" + if init_kwargs is not None: + self.image_uri = init_kwargs.image_uri + self.model_data = init_kwargs.model_data + self.model_package_arn = init_kwargs.model_package_arn + self.instance_type = init_kwargs.instance_type + self.environment = init_kwargs.env + if init_kwargs.resources is not None: + self.compute_resource_requirements = ( + init_kwargs.resources.get_compute_resource_requirements() + ) + if deploy_kwargs is not None: + self.model_data_download_timeout = deploy_kwargs.model_data_download_timeout + self.container_startup_health_check_timeout = ( + deploy_kwargs.container_startup_health_check_timeout + ) + if resolved_config is not None: + self.default_instance_type = resolved_config.get("default_inference_instance_type") + self.supported_instance_types = resolved_config.get( + "supported_inference_instance_types" + ) + + +class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder): + """Dataclass representing a Deployment Config Metadata""" + + __slots__ = [ + "deployment_config_name", + "deployment_args", + "acceleration_configs", + "benchmark_metrics", + ] + + def __init__( + self, + config_name: Optional[str] = None, + metadata_config: Optional[JumpStartMetadataConfig] = None, + init_kwargs: Optional[JumpStartModelInitKwargs] = None, + deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None, + ): + """Instantiates DeploymentConfigMetadata object.""" + self.deployment_config_name = config_name + self.deployment_args = DeploymentArgs( + init_kwargs, deploy_kwargs, metadata_config.resolved_config + ) + self.benchmark_metrics = metadata_config.benchmark_metrics + self.acceleration_configs = metadata_config.acceleration_configs diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 657ab11535..22974a3838 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -15,9 +15,11 @@ from copy import copy import logging import os +from functools import lru_cache, wraps from typing import Any, Dict, List, Set, Optional, Tuple, Union from urllib.parse import urlparse import boto3 +from botocore.exceptions import ClientError from packaging.version import Version import botocore import sagemaker @@ -43,10 +45,11 @@ JumpStartModelHeader, JumpStartModelSpecs, JumpStartVersionedModelId, + DeploymentConfigMetadata, ) from sagemaker.session import Session from sagemaker.config import load_sagemaker_config -from sagemaker.utils import resolve_value_from_config, TagsDict +from sagemaker.utils import resolve_value_from_config, TagsDict, get_instance_rate_per_hour from sagemaker.workflow import is_pipeline_variable from sagemaker.user_agent import get_user_agent_extra_suffix @@ -321,6 +324,8 @@ def add_single_jumpstart_tag( tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags) + or tag_key_in_array(enums.JumpStartTag.INFERENCE_CONFIG_NAME, curr_tags) + or tag_key_in_array(enums.JumpStartTag.TRAINING_CONFIG_NAME, curr_tags) ) if is_uri else False @@ -351,11 +356,13 @@ def get_jumpstart_base_name_if_jumpstart_model( return None -def add_jumpstart_model_id_version_tags( +def add_jumpstart_model_info_tags( tags: Optional[List[TagsDict]], model_id: str, model_version: str, model_type: Optional[enums.JumpStartModelType] = None, + config_name: Optional[str] = None, + scope: enums.JumpStartScriptScope = None, ) -> List[TagsDict]: """Add custom model ID and version tags to JumpStart related resources.""" if model_id is None or model_version is None: @@ -379,6 +386,20 @@ def add_jumpstart_model_id_version_tags( tags, is_uri=False, ) + if config_name and scope == enums.JumpStartScriptScope.INFERENCE: + tags = add_single_jumpstart_tag( + config_name, + enums.JumpStartTag.INFERENCE_CONFIG_NAME, + tags, + is_uri=False, + ) + if config_name and scope == enums.JumpStartScriptScope.TRAINING: + tags = add_single_jumpstart_tag( + config_name, + enums.JumpStartTag.TRAINING_CONFIG_NAME, + tags, + is_uri=False, + ) return tags @@ -550,6 +571,7 @@ def verify_model_region_and_return_specs( tolerate_deprecated_model: bool = False, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> JumpStartModelSpecs: """Verifies that an acceptable model_id, version, scope, and region combination is provided. @@ -572,6 +594,7 @@ def verify_model_region_and_return_specs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Raises: NotImplementedError: If the scope is not supported. @@ -637,6 +660,9 @@ def verify_model_region_and_return_specs( scope=constants.JumpStartScriptScope.TRAINING, ) + if model_specs and config_name: + model_specs.set_config(config_name, scope) + return model_specs @@ -798,52 +824,80 @@ def validate_model_id_and_get_type( return None -def get_jumpstart_model_id_version_from_resource_arn( +def _extract_value_from_list_of_tags( + tag_keys: List[str], + list_tags_result: List[str], + resource_name: str, + resource_arn: str, +): + """Extracts value from list of tags with check of duplicate tags. + + Returns None if no value is found. + """ + resolved_value = None + for tag_key in tag_keys: + try: + value_from_tag = get_tag_value(tag_key, list_tags_result) + except KeyError: + continue + if value_from_tag is not None: + if resolved_value is not None and value_from_tag != resolved_value: + constants.JUMPSTART_LOGGER.warning( + "Found multiple %s tags on the following resource: %s", + resource_name, + resource_arn, + ) + resolved_value = None + break + resolved_value = value_from_tag + return resolved_value + + +def get_jumpstart_model_info_from_resource_arn( resource_arn: str, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[Optional[str], Optional[str]]: - """Returns the JumpStart model ID and version if in resource tags. +) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: + """Returns the JumpStart model ID, version and config name if in resource tags. - Returns 'None' if model ID or version cannot be inferred from tags. + Returns 'None' if model ID or version or config name cannot be inferred from tags. """ list_tags_result = sagemaker_session.list_tags(resource_arn) - model_id: Optional[str] = None - model_version: Optional[str] = None - model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS] model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS] + inference_config_name_keys = [enums.JumpStartTag.INFERENCE_CONFIG_NAME] + training_config_name_keys = [enums.JumpStartTag.TRAINING_CONFIG_NAME] + + model_id: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=model_id_keys, + list_tags_result=list_tags_result, + resource_name="model ID", + resource_arn=resource_arn, + ) - for model_id_key in model_id_keys: - try: - model_id_from_tag = get_tag_value(model_id_key, list_tags_result) - except KeyError: - continue - if model_id_from_tag is not None: - if model_id is not None and model_id_from_tag != model_id: - constants.JUMPSTART_LOGGER.warning( - "Found multiple model ID tags on the following resource: %s", resource_arn - ) - model_id = None - break - model_id = model_id_from_tag + model_version: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=model_version_keys, + list_tags_result=list_tags_result, + resource_name="model version", + resource_arn=resource_arn, + ) - for model_version_key in model_version_keys: - try: - model_version_from_tag = get_tag_value(model_version_key, list_tags_result) - except KeyError: - continue - if model_version_from_tag is not None: - if model_version is not None and model_version_from_tag != model_version: - constants.JUMPSTART_LOGGER.warning( - "Found multiple model version tags on the following resource: %s", resource_arn - ) - model_version = None - break - model_version = model_version_from_tag + inference_config_name: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=inference_config_name_keys, + list_tags_result=list_tags_result, + resource_name="inference config name", + resource_arn=resource_arn, + ) - return model_id, model_version + training_config_name: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=training_config_name_keys, + list_tags_result=list_tags_result, + resource_name="training config name", + resource_arn=resource_arn, + ) + + return model_id, model_version, inference_config_name, training_config_name def get_region_fallback( @@ -893,7 +947,11 @@ def get_config_names( scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, ) -> List[str]: - """Returns a list of config names for the given model ID and region.""" + """Returns a list of config names for the given model ID and region. + + Raises: + ValueError: If the script scope is not supported by JumpStart. + """ model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, @@ -908,7 +966,7 @@ def get_config_names( elif scope == enums.JumpStartScriptScope.TRAINING: metadata_configs = model_specs.training_configs else: - raise ValueError(f"Unknown script scope {scope}.") + raise ValueError(f"Unknown script scope: {scope}.") return list(metadata_configs.configs.keys()) if metadata_configs else [] @@ -922,7 +980,11 @@ def get_benchmark_stats( scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, ) -> Dict[str, List[JumpStartBenchmarkStat]]: - """Returns benchmark stats for the given model ID and region.""" + """Returns benchmark stats for the given model ID and region. + + Raises: + ValueError: If the script scope is not supported by JumpStart. + """ model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, @@ -937,7 +999,7 @@ def get_benchmark_stats( elif scope == enums.JumpStartScriptScope.TRAINING: metadata_configs = model_specs.training_configs else: - raise ValueError(f"Unknown script scope {scope}.") + raise ValueError(f"Unknown script scope: {scope}.") if not config_names: config_names = metadata_configs.configs.keys() if metadata_configs else [] @@ -945,7 +1007,7 @@ def get_benchmark_stats( benchmark_stats = {} for config_name in config_names: if config_name not in metadata_configs.configs: - raise ValueError(f"Unknown config name: '{config_name}'") + raise ValueError(f"Unknown config name: {config_name}") benchmark_stats[config_name] = metadata_configs.configs.get(config_name).benchmark_metrics return benchmark_stats @@ -959,8 +1021,12 @@ def get_jumpstart_configs( sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, -) -> Dict[str, List[JumpStartMetadataConfig]]: - """Returns metadata configs for the given model ID and region.""" +) -> Dict[str, JumpStartMetadataConfig]: + """Returns metadata configs for the given model ID and region. + + Raises: + ValueError: If the script scope is not supported by JumpStart. + """ model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, @@ -975,10 +1041,12 @@ def get_jumpstart_configs( elif scope == enums.JumpStartScriptScope.TRAINING: metadata_configs = model_specs.training_configs else: - raise ValueError(f"Unknown script scope {scope}.") + raise ValueError(f"Unknown script scope: {scope}.") if not config_names: - config_names = metadata_configs.configs.keys() if metadata_configs else [] + config_names = ( + metadata_configs.config_rankings.get("overall").rankings if metadata_configs else [] + ) return ( {config_name: metadata_configs.configs[config_name] for config_name in config_names} @@ -1021,3 +1089,250 @@ def get_default_jumpstart_session_with_user_agent_suffix( config=botocore_config, ) return session + + +def add_instance_rate_stats_to_benchmark_metrics( + region: str, + benchmark_metrics: Optional[Dict[str, List[JumpStartBenchmarkStat]]], +) -> Optional[Tuple[Dict[str, str], Dict[str, List[JumpStartBenchmarkStat]]]]: + """Adds instance types metric stats to the given benchmark_metrics dict. + + Args: + region (str): AWS region. + benchmark_metrics (Optional[Dict[str, List[JumpStartBenchmarkStat]]]): + Returns: + Optional[Tuple[Dict[str, str], Dict[str, List[JumpStartBenchmarkStat]]]]: + Contains Error and metrics. + """ + if not benchmark_metrics: + return None + + err_message = None + final_benchmark_metrics = {} + for instance_type, benchmark_metric_stats in benchmark_metrics.items(): + instance_type = instance_type if instance_type.startswith("ml.") else f"ml.{instance_type}" + + if not has_instance_rate_stat(benchmark_metric_stats) and not err_message: + try: + instance_type_rate = get_instance_rate_per_hour( + instance_type=instance_type, region=region + ) + + if not benchmark_metric_stats: + benchmark_metric_stats = [] + benchmark_metric_stats.append( + JumpStartBenchmarkStat({"concurrency": None, **instance_type_rate}) + ) + + final_benchmark_metrics[instance_type] = benchmark_metric_stats + except ClientError as e: + final_benchmark_metrics[instance_type] = benchmark_metric_stats + err_message = e.response["Error"] + except Exception: # pylint: disable=W0703 + final_benchmark_metrics[instance_type] = benchmark_metric_stats + else: + final_benchmark_metrics[instance_type] = benchmark_metric_stats + + return err_message, final_benchmark_metrics + + +def has_instance_rate_stat(benchmark_metric_stats: Optional[List[JumpStartBenchmarkStat]]) -> bool: + """Determines whether a benchmark metric stats contains instance rate metric stat. + + Args: + benchmark_metric_stats (Optional[List[JumpStartBenchmarkStat]]): + List of benchmark metric stats. + Returns: + bool: Whether the benchmark metric stats contains instance rate metric stat. + """ + if benchmark_metric_stats is None: + return True + for benchmark_metric_stat in benchmark_metric_stats: + if benchmark_metric_stat.name.lower() == "instance rate": + return True + return False + + +def get_metrics_from_deployment_configs( + deployment_configs: Optional[List[DeploymentConfigMetadata]], +) -> Dict[str, List[str]]: + """Extracts benchmark metrics from deployment configs metadata. + + Args: + deployment_configs (Optional[List[DeploymentConfigMetadata]]): + List of deployment configs metadata. + Returns: + Dict[str, List[str]]: Deployment configs bench metrics dict. + """ + if not deployment_configs: + return {} + + data = {"Instance Type": [], "Config Name": [], "Concurrent Users": []} + instance_rate_data = {} + for index, deployment_config in enumerate(deployment_configs): + benchmark_metrics = deployment_config.benchmark_metrics + if not deployment_config.deployment_args or not benchmark_metrics: + continue + + for current_instance_type, current_instance_type_metrics in benchmark_metrics.items(): + instance_type_rate, concurrent_users = _normalize_benchmark_metrics( + current_instance_type_metrics + ) + + for concurrent_user, metrics in concurrent_users.items(): + instance_type_to_display = ( + f"{current_instance_type} (Default)" + if index == 0 + and int(concurrent_user) == 1 + and current_instance_type + == deployment_config.deployment_args.default_instance_type + else current_instance_type + ) + + data["Config Name"].append(deployment_config.deployment_config_name) + data["Instance Type"].append(instance_type_to_display) + data["Concurrent Users"].append(concurrent_user) + + if instance_type_rate: + instance_rate_column_name = ( + f"{instance_type_rate.name} ({instance_type_rate.unit})" + ) + instance_rate_data[instance_rate_column_name] = instance_rate_data.get( + instance_rate_column_name, [] + ) + instance_rate_data[instance_rate_column_name].append(instance_type_rate.value) + + for metric in metrics: + column_name = _normalize_benchmark_metric_column_name(metric.name) + data[column_name] = data.get(column_name, []) + data[column_name].append(metric.value) + + data = {**data, **instance_rate_data} + return data + + +def _normalize_benchmark_metric_column_name(name: str) -> str: + """Normalizes benchmark metric column name. + + Args: + name (str): Name of the metric. + Returns: + str: Normalized metric column name. + """ + if "latency" in name.lower(): + name = "Latency for each user (TTFT in ms)" + elif "throughput" in name.lower(): + name = "Throughput per user (token/seconds)" + return name + + +def _normalize_benchmark_metrics( + benchmark_metric_stats: List[JumpStartBenchmarkStat], +) -> Tuple[JumpStartBenchmarkStat, Dict[str, List[JumpStartBenchmarkStat]]]: + """Normalizes benchmark metrics dict. + + Args: + benchmark_metric_stats (List[JumpStartBenchmarkStat]): + List of benchmark metrics stats. + Returns: + Tuple[JumpStartBenchmarkStat, Dict[str, List[JumpStartBenchmarkStat]]]: + Normalized benchmark metrics dict. + """ + instance_type_rate = None + concurrent_users = {} + for current_instance_type_metric in benchmark_metric_stats: + if current_instance_type_metric.name.lower() == "instance rate": + instance_type_rate = current_instance_type_metric + elif current_instance_type_metric.concurrency not in concurrent_users: + concurrent_users[current_instance_type_metric.concurrency] = [ + current_instance_type_metric + ] + else: + concurrent_users[current_instance_type_metric.concurrency].append( + current_instance_type_metric + ) + + return instance_type_rate, concurrent_users + + +def deployment_config_response_data( + deployment_configs: Optional[List[DeploymentConfigMetadata]], +) -> List[Dict[str, Any]]: + """Deployment config api response data. + + Args: + deployment_configs (Optional[List[DeploymentConfigMetadata]]): + List of deployment configs metadata. + Returns: + List[Dict[str, Any]]: List of deployment config api response data. + """ + configs = [] + if not deployment_configs: + return configs + + for deployment_config in deployment_configs: + deployment_config_json = deployment_config.to_json() + benchmark_metrics = deployment_config_json.get("BenchmarkMetrics") + if benchmark_metrics and deployment_config.deployment_args: + deployment_config_json["BenchmarkMetrics"] = { + deployment_config.deployment_args.instance_type: benchmark_metrics.get( + deployment_config.deployment_args.instance_type + ) + } + + configs.append(deployment_config_json) + return configs + + +def _deployment_config_lru_cache(_func=None, *, maxsize: int = 128, typed: bool = False): + """LRU cache for deployment configs.""" + + def has_instance_rate_metric(config: DeploymentConfigMetadata) -> bool: + """Determines whether metadata config contains instance rate metric stat. + + Args: + config (DeploymentConfigMetadata): Metadata config metadata. + Returns: + bool: Whether the metadata config contains instance rate metric stat. + """ + if config.benchmark_metrics is None: + return True + for benchmark_metric_stats in config.benchmark_metrics.values(): + if not has_instance_rate_stat(benchmark_metric_stats): + return False + return True + + def wrapper_cache(f): + f = lru_cache(maxsize=maxsize, typed=typed)(f) + + @wraps(f) + def wrapped_f(*args, **kwargs): + res = f(*args, **kwargs) + + # Clear cache on first call if + # - The output does not contain Instant rate metrics + # as this is caused by missing policy. + if f.cache_info().hits == 0 and f.cache_info().misses == 1: + if isinstance(res, list): + for item in res: + if isinstance( + item, DeploymentConfigMetadata + ) and not has_instance_rate_metric(item): + f.cache_clear() + break + elif isinstance(res, dict): + keys = list(res.keys()) + if len(keys) == 0 or "Instance Rate" not in keys[-1]: + f.cache_clear() + elif len(res[keys[1]]) > len(res[keys[-1]]): + del res[keys[-1]] + f.cache_clear() + return res + + wrapped_f.cache_info = f.cache_info + wrapped_f.cache_clear = f.cache_clear + return wrapped_f + + if _func is None: + return wrapper_cache + return wrapper_cache(_func) diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index c7098a1185..bcb0365f7b 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -171,6 +171,7 @@ def validate_hyperparameters( sagemaker_session: Optional[session.Session] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + config_name: Optional[str] = None, ) -> None: """Validate hyperparameters for JumpStart models. @@ -193,6 +194,7 @@ def validate_hyperparameters( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Raises: JumpStartHyperparametersError: If the hyperparameters are not formatted correctly, @@ -218,6 +220,7 @@ def validate_hyperparameters( sagemaker_session=sagemaker_session, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + config_name=config_name, ) hyperparameters_specs = model_specs.hyperparameters diff --git a/src/sagemaker/metric_definitions.py b/src/sagemaker/metric_definitions.py index 71dd26db45..0c066ff801 100644 --- a/src/sagemaker/metric_definitions.py +++ b/src/sagemaker/metric_definitions.py @@ -33,6 +33,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> Optional[List[Dict[str, str]]]: """Retrieves the default training metric definitions for the model matching the given arguments. @@ -56,6 +57,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: The default metric definitions to use for the model or None. @@ -76,4 +78,5 @@ def retrieve_default( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 937180bd44..122647e536 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -34,6 +34,7 @@ def retrieve( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> str: """Retrieves the model artifact Amazon S3 URI for the model matching the given arguments. @@ -57,6 +58,8 @@ def retrieve( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + Returns: str: The model artifact S3 URI for the corresponding model. @@ -81,4 +84,5 @@ def retrieve( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 6f846bba65..780a1a56c8 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -18,7 +18,7 @@ from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.factory.model import get_default_predictor -from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint +from sagemaker.jumpstart.session_utils import get_model_info_from_endpoint from sagemaker.session import Session @@ -43,6 +43,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> Predictor: """Retrieves the default predictor for the model matching the given arguments. @@ -65,6 +66,8 @@ def retrieve_default( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + config_name (Optional[str]): The name of the configuration to use for the + predictor. (Default: None) Returns: Predictor: The default predictor to use for the model. @@ -78,9 +81,9 @@ def retrieve_default( inferred_model_id, inferred_model_version, inferred_inference_component_name, - ) = get_model_id_version_from_endpoint( - endpoint_name, inference_component_name, sagemaker_session - ) + inferred_config_name, + _, + ) = get_model_info_from_endpoint(endpoint_name, inference_component_name, sagemaker_session) if not inferred_model_id: raise ValueError( @@ -92,6 +95,7 @@ def retrieve_default( model_id = inferred_model_id model_version = model_version or inferred_model_version or "*" inference_component_name = inference_component_name or inferred_inference_component_name + config_name = config_name or inferred_config_name or None else: model_version = model_version or "*" @@ -110,4 +114,5 @@ def retrieve_default( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py index df14ac558f..7808d0172a 100644 --- a/src/sagemaker/resource_requirements.py +++ b/src/sagemaker/resource_requirements.py @@ -37,6 +37,7 @@ def retrieve_default( model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> ResourceRequirements: """Retrieves the default resource requirements for the model matching the given arguments. @@ -62,6 +63,7 @@ def retrieve_default( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get host requirements specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default resource requirements to use for the model. @@ -87,4 +89,5 @@ def retrieve_default( model_type=model_type, sagemaker_session=sagemaker_session, instance_type=instance_type, + config_name=config_name, ) diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index 9a1c4933d2..6e10785498 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -33,6 +33,7 @@ def retrieve( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> str: """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -55,6 +56,7 @@ def retrieve( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The model script URI for the corresponding model. @@ -78,4 +80,5 @@ def retrieve( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index aefb52bd97..d197df731c 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -45,6 +45,7 @@ def retrieve_options( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> List[BaseSerializer]: """Retrieves the supported serializers for the model matching the given arguments. @@ -66,6 +67,7 @@ def retrieve_options( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: List[SimpleBaseSerializer]: The supported serializers to use for the model. @@ -85,6 +87,7 @@ def retrieve_options( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) @@ -96,6 +99,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> BaseSerializer: """Retrieves the default serializer for the model matching the given arguments. @@ -117,6 +121,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: SimpleBaseSerializer: The default serializer to use for the model. @@ -137,4 +142,5 @@ def retrieve_default( tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index bc31e8d323..e8ef546f7a 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -16,7 +16,7 @@ import copy from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import Type +from typing import Type, Any, List, Dict, Optional import logging from sagemaker.model import Model @@ -467,8 +467,56 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800): sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration ) + def set_deployment_config(self, config_name: str, instance_type: str) -> None: + """Sets the deployment config to apply to the model. + + Args: + config_name (str): + The name of the deployment config to apply to the model. + Call list_deployment_configs to see the list of config names. + instance_type (str): + The instance_type that the model will use after setting + the config. + """ + if not hasattr(self, "pysdk_model") or self.pysdk_model is None: + raise Exception("Cannot set deployment config to an uninitialized model.") + + self.pysdk_model.set_deployment_config(config_name, instance_type) + + def get_deployment_config(self) -> Optional[Dict[str, Any]]: + """Gets the deployment config to apply to the model. + + Returns: + Optional[Dict[str, Any]]: Deployment config to apply to this model. + """ + if not hasattr(self, "pysdk_model") or self.pysdk_model is None: + self._build_for_jumpstart() + + return self.pysdk_model.deployment_config + + def display_benchmark_metrics(self, **kwargs): + """Display Markdown Benchmark Metrics for deployment configs.""" + if not hasattr(self, "pysdk_model") or self.pysdk_model is None: + self._build_for_jumpstart() + + self.pysdk_model.display_benchmark_metrics(**kwargs) + + def list_deployment_configs(self) -> List[Dict[str, Any]]: + """List deployment configs for ``This`` model in the current region. + + Returns: + List[Dict[str, Any]]: A list of deployment configs. + """ + if not hasattr(self, "pysdk_model") or self.pysdk_model is None: + self._build_for_jumpstart() + + return self.pysdk_model.list_deployment_configs() + def _build_for_jumpstart(self): """Placeholder docstring""" + if hasattr(self, "pysdk_model") and self.pysdk_model is not None: + return self.pysdk_model + # we do not pickle for jumpstart. set to none self.secret_key = None self.jumpstart = True diff --git a/src/sagemaker/serve/builder/transformers_builder.py b/src/sagemaker/serve/builder/transformers_builder.py index 614290b132..f84d8f868d 100644 --- a/src/sagemaker/serve/builder/transformers_builder.py +++ b/src/sagemaker/serve/builder/transformers_builder.py @@ -151,11 +151,11 @@ def _get_hf_metadata_create_model(self) -> Type[Model]: vpc_config=self.vpc_config, ) - if self.mode == Mode.LOCAL_CONTAINER: + if not self.image_uri and self.mode == Mode.LOCAL_CONTAINER: self.image_uri = pysdk_model.serving_image_uri( self.sagemaker_session.boto_region_name, "local" ) - else: + elif not self.image_uri: self.image_uri = pysdk_model.serving_image_uri( self.sagemaker_session.boto_region_name, self.instance_type ) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 430effefa3..a70ba9eb98 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -25,6 +25,7 @@ import tarfile import tempfile import time +from functools import lru_cache from typing import Union, Any, List, Optional, Dict import json import abc @@ -33,10 +34,12 @@ from os.path import abspath, realpath, dirname, normpath, join as joinpath from importlib import import_module + +import boto3 import botocore from botocore.utils import merge_dicts from six.moves.urllib import parse -import pandas as pd +from six import viewitems from sagemaker import deprecations from sagemaker.config import validate_sagemaker_config @@ -1603,44 +1606,80 @@ def can_model_package_source_uri_autopopulate(source_uri: str): ) -def flatten_dict(source_dict: Dict[str, Any], sep: str = ".") -> Dict[str, Any]: - """Flatten a nested dictionary. +def flatten_dict( + d: Dict[str, Any], + max_flatten_depth=None, +) -> Dict[str, Any]: + """Flatten a dictionary object. - Args: - source_dict (dict): The dictionary to be flattened. - sep (str): The separator to be used in the flattened dictionary. - Returns: - transformed_dict: The flattened dictionary. + d (Dict[str, Any]): + The dict that will be flattened. + max_flatten_depth (Optional[int]): + Maximum depth to merge. """ - flat_dict_list = pd.json_normalize(source_dict, sep=sep).to_dict(orient="records") - if flat_dict_list: - return flat_dict_list[0] - return {} + def tuple_reducer(k1, k2): + if k1 is None: + return (k2,) + return k1 + (k2,) -def unflatten_dict(source_dict: Dict[str, Any], sep: str = ".") -> Dict[str, Any]: - """Unflatten a flattened dictionary back into a nested dictionary. + # check max_flatten_depth + if max_flatten_depth is not None and max_flatten_depth < 1: + raise ValueError("max_flatten_depth should not be less than 1.") - Args: - source_dict (dict): The input flattened dictionary. - sep (str): The separator used in the flattened keys. + reducer = tuple_reducer - Returns: - transformed_dict: The reconstructed nested dictionary. + flat_dict = {} + + def _flatten(_d, depth, parent=None): + key_value_iterable = viewitems(_d) + has_item = False + for key, value in key_value_iterable: + has_item = True + flat_key = reducer(parent, key) + if isinstance(value, dict) and (max_flatten_depth is None or depth < max_flatten_depth): + has_child = _flatten(value, depth=depth + 1, parent=flat_key) + if has_child: + continue + + if flat_key in flat_dict: + raise ValueError("duplicated key '{}'".format(flat_key)) + flat_dict[flat_key] = value + + return has_item + + _flatten(d, depth=1) + return flat_dict + + +def nested_set_dict(d: Dict[str, Any], keys: List[str], value: Any) -> None: + """Set a value to a sequence of nested keys.""" + + key = keys[0] + + if len(keys) == 1: + d[key] = value + return + if not d: + return + + d = d.setdefault(key, {}) + nested_set_dict(d, keys[1:], value) + + +def unflatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: + """Unflatten dict-like object. + + d (Dict[str, Any]) : + The dict that will be unflattened. """ - if not source_dict: - return {} - result = {} - for key, value in source_dict.items(): - keys = key.split(sep) - current = result - for k in keys[:-1]: - if k not in current: - current[k] = {} - current = current[k] if current[k] is not None else current - current[keys[-1]] = value - return result + unflattened_dict = {} + for flat_key, value in viewitems(d): + key_tuple = flat_key + nested_set_dict(unflattened_dict, key_tuple, value) + + return unflattened_dict def deep_override_dict( @@ -1686,3 +1725,75 @@ def _resolve_routing_config(routing_config: Optional[Dict[str, Any]]) -> Optiona "or RoutingStrategy.LEAST_OUTSTANDING_REQUESTS" ) return None + + +@lru_cache +def get_instance_rate_per_hour( + instance_type: str, + region: str, +) -> Optional[Dict[str, str]]: + """Gets instance rate per hour for the given instance type. + + Args: + instance_type (str): The instance type. + region (str): The region. + Returns: + Optional[Dict[str, str]]: Instance rate per hour. + Example: {'name': 'Instance Rate', 'unit': 'USD/Hrs', 'value': '1.125'}. + + Raises: + Exception: An exception is raised if + the IAM role is not authorized to perform pricing:GetProducts. + or unexpected event happened. + """ + region_name = "us-east-1" + if region.startswith("eu") or region.startswith("af"): + region_name = "eu-central-1" + elif region.startswith("ap") or region.startswith("cn"): + region_name = "ap-south-1" + + pricing_client: boto3.client = boto3.client("pricing", region_name=region_name) + res = pricing_client.get_products( + ServiceCode="AmazonSageMaker", + Filters=[ + {"Type": "TERM_MATCH", "Field": "instanceName", "Value": instance_type}, + {"Type": "TERM_MATCH", "Field": "locationType", "Value": "AWS Region"}, + {"Type": "TERM_MATCH", "Field": "regionCode", "Value": region}, + ], + ) + + price_list = res.get("PriceList", []) + if len(price_list) > 0: + price_data = price_list[0] + if isinstance(price_data, str): + price_data = json.loads(price_data) + + instance_rate_per_hour = extract_instance_rate_per_hour(price_data) + if instance_rate_per_hour is not None: + return instance_rate_per_hour + raise Exception(f"Unable to get instance rate per hour for instance type: {instance_type}.") + + +def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Optional[Dict[str, str]]: + """Extract instance rate per hour for the given Price JSON data. + + Args: + price_data (Dict[str, Any]): The Price JSON data. + Returns: + Optional[Dict[str, str], None]: Instance rate per hour. + """ + + if price_data is not None: + price_dimensions = price_data.get("terms", {}).get("OnDemand", {}).values() + for dimension in price_dimensions: + for price in dimension.get("priceDimensions", {}).values(): + for currency in price.get("pricePerUnit", {}).keys(): + value = price.get("pricePerUnit", {}).get(currency) + if value is not None: + value = str(round(float(value), 3)) + return { + "unit": f"{currency}/{price.get('unit', 'Hrs')}", + "value": value, + "name": "Instance Rate", + } + return None diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index f165a513a9..fb7ca38bad 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -7357,7 +7357,7 @@ "training_model_package_artifact_uris": None, "deprecate_warn_message": None, "deprecated_message": None, - "hosting_model_package_arns": None, + "hosting_model_package_arns": {}, "hosting_eula_key": None, "model_subscription_link": None, "hyperparameters": [ @@ -7662,28 +7662,44 @@ "inference_configs": { "neuron-inference": { "benchmark_metrics": { - "ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + "ml.inf2.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] }, "component_names": ["neuron-inference"], }, "neuron-inference-budget": { "benchmark_metrics": { - "ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + "ml.inf2.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] }, "component_names": ["neuron-base"], }, "gpu-inference-budget": { "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] }, "component_names": ["gpu-inference-budget"], }, "gpu-inference": { "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] }, "component_names": ["gpu-inference"], }, + "gpu-inference-model-package": { + "benchmark_metrics": { + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] + }, + "component_names": ["gpu-inference-model-package"], + }, }, "inference_config_components": { "neuron-base": { @@ -7725,6 +7741,14 @@ }, }, }, + "gpu-inference-model-package": { + "default_inference_instance_type": "ml.p2.xlarge", + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "hosting_model_package_arns": { + "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/ll" + "ama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" + }, + }, "gpu-inference-budget": { "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference-budget/model/", @@ -7748,35 +7772,70 @@ "training_configs": { "neuron-training": { "benchmark_metrics": { - "ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}], - "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], + "ml.tr1n1.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ], + "ml.tr1n1.4xlarge": [ + {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1} + ], }, "component_names": ["neuron-training"], + "default_inference_config": "neuron-inference", + "default_incremental_training_config": "neuron-training", + "supported_inference_configs": ["neuron-inference", "neuron-inference-budget"], + "supported_incremental_training_configs": ["neuron-training", "neuron-training-budget"], }, "neuron-training-budget": { "benchmark_metrics": { - "ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}], - "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], + "ml.tr1n1.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ], + "ml.tr1n1.4xlarge": [ + {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1} + ], }, "component_names": ["neuron-training-budget"], + "default_inference_config": "neuron-inference-budget", + "default_incremental_training_config": "neuron-training-budget", + "supported_inference_configs": ["neuron-inference", "neuron-inference-budget"], + "supported_incremental_training_configs": ["neuron-training", "neuron-training-budget"], }, "gpu-training": { "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "200", "unit": "Tokens/S"}], + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "200", "unit": "Tokens/S", "concurrency": "1"} + ], }, "component_names": ["gpu-training"], + "default_inference_config": "gpu-inference", + "default_incremental_training_config": "gpu-training", + "supported_inference_configs": ["gpu-inference", "gpu-inference-budget"], + "supported_incremental_training_configs": ["gpu-training", "gpu-training-budget"], }, "gpu-training-budget": { "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": "1"} + ] }, "component_names": ["gpu-training-budget"], + "default_inference_config": "gpu-inference-budget", + "default_incremental_training_config": "gpu-training-budget", + "supported_inference_configs": ["gpu-inference", "gpu-inference-budget"], + "supported_incremental_training_configs": ["gpu-training", "gpu-training-budget"], }, }, "training_config_components": { "neuron-training": { + "default_training_instance_type": "ml.trn1.2xlarge", "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training/model/", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, "training_instance_type_variants": { "regional_aliases": { "us-west-2": { @@ -7788,6 +7847,7 @@ }, }, "gpu-training": { + "default_training_instance_type": "ml.p2.xlarge", "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training/model/", "training_instance_type_variants": { @@ -7804,6 +7864,7 @@ }, }, "neuron-training-budget": { + "default_training_instance_type": "ml.trn1.2xlarge", "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training-budget/model/", "training_instance_type_variants": { @@ -7817,6 +7878,7 @@ }, }, "gpu-training-budget": { + "default_training_instance_type": "ml.p2.xlarge", "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training-budget/model/", "training_instance_type_variants": { @@ -7907,3 +7969,170 @@ }, } } + + +DEPLOYMENT_CONFIGS = [ + { + "DeploymentConfigName": "neuron-inference", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "InstanceType": "ml.p2.xlarge", + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + }, + "AccelerationConfigs": None, + "BenchmarkMetrics": [ + {"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs", "concurrency": 1} + ], + }, + { + "DeploymentConfigName": "neuron-inference-budget", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "InstanceType": "ml.p2.xlarge", + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + }, + "AccelerationConfigs": None, + "BenchmarkMetrics": [ + {"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs", "concurrency": 1} + ], + }, + { + "DeploymentConfigName": "gpu-inference-budget", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "InstanceType": "ml.p2.xlarge", + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + }, + "AccelerationConfigs": None, + "BenchmarkMetrics": [ + {"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs", "concurrency": 1} + ], + }, + { + "DeploymentConfigName": "gpu-inference", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "InstanceType": "ml.p2.xlarge", + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + }, + "AccelerationConfigs": None, + "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], + }, +] + + +INIT_KWARGS = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu" + "-py310-cu121-ubuntu20.04", + "model_data": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface-textgeneration" + "-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "instance_type": "ml.p2.xlarge", + "env": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "role": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20230707T131628", + "name": "hf-textgeneration-bloom-1b1-2024-04-22-20-23-48-799", + "enable_network_isolation": True, +} diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 2e8dc1e9a2..17d0861bff 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -46,6 +46,8 @@ from sagemaker.model import Model from sagemaker.predictor import Predictor from tests.unit.sagemaker.jumpstart.utils import ( + get_prototype_manifest, + get_prototype_spec_with_configs, get_special_model_spec, overwrite_dictionary, ) @@ -700,7 +702,6 @@ def test_estimator_use_kwargs(self): "input_mode": "File", "output_path": "Optional[Union[str, PipelineVariable]] = None", "output_kms_key": "Optional[Union[str, PipelineVariable]] = None", - "base_job_name": "Optional[str] = None", "sagemaker_session": DEFAULT_JUMPSTART_SAGEMAKER_SESSION, "hyperparameters": {"hyp1": "val1"}, "tags": [], @@ -1033,6 +1034,8 @@ def test_jumpstart_estimator_attach_eula_model( additional_kwargs={ "model_id": "gemma-model", "model_version": "*", + "tolerate_vulnerable_model": True, + "tolerate_deprecated_model": True, "environment": {"accept_eula": "true"}, "tolerate_vulnerable_model": True, "tolerate_deprecated_model": True, @@ -1040,7 +1043,7 @@ def test_jumpstart_estimator_attach_eula_model( ) @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") - @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") + @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1048,15 +1051,17 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, - get_model_id_version_from_training_job: mock.Mock, + get_model_info_from_training_job: mock.Mock, mock_attach: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS - get_model_id_version_from_training_job.return_value = ( + get_model_info_from_training_job.return_value = ( "js-trainable-model-prepacked", "1.0.0", + None, + None, ) mock_get_model_specs.side_effect = get_special_model_spec @@ -1067,7 +1072,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( training_job_name="some-training-job-name", sagemaker_session=mock_session ) - get_model_id_version_from_training_job.assert_called_once_with( + get_model_info_from_training_job.assert_called_once_with( training_job_name="some-training-job-name", sagemaker_session=mock_session, ) @@ -1085,7 +1090,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( ) @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") - @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") + @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1093,13 +1098,13 @@ def test_jumpstart_estimator_attach_no_model_id_sad_case( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, - get_model_id_version_from_training_job: mock.Mock, + get_model_info_from_training_job: mock.Mock, mock_attach: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS - get_model_id_version_from_training_job.side_effect = ValueError() + get_model_info_from_training_job.side_effect = ValueError() mock_get_model_specs.side_effect = get_special_model_spec @@ -1110,7 +1115,7 @@ def test_jumpstart_estimator_attach_no_model_id_sad_case( training_job_name="some-training-job-name", sagemaker_session=mock_session ) - get_model_id_version_from_training_job.assert_called_once_with( + get_model_info_from_training_job.assert_called_once_with( training_job_name="some-training-job-name", sagemaker_session=mock_session, ) @@ -1137,6 +1142,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): "region", "tolerate_vulnerable_model", "tolerate_deprecated_model", + "config_name", } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -1158,7 +1164,9 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): js_class_deploy = JumpStartEstimator.deploy js_class_deploy_args = set(signature(js_class_deploy).parameters.keys()) - assert js_class_deploy_args - parent_class_deploy_args == model_class_init_args - { + assert js_class_deploy_args - parent_class_deploy_args - { + "inference_config_name" + } == model_class_init_args - { "model_data", "self", "name", @@ -1241,6 +1249,7 @@ def test_no_predictor_returns_default_predictor( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=estimator.sagemaker_session, + config_name=None, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @@ -1410,6 +1419,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, + config_name=None, ) @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @@ -1465,6 +1475,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, + config_name=None, ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @@ -1909,6 +1920,268 @@ def test_jumpstart_estimator_session( assert len(s3_clients) == 1 assert list(s3_clients)[0] == session.s3_client + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_initialization_with_config_name( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator( + model_id=model_id, + config_name="gpu-training", + ) + + mock_estimator_init.assert_called_once_with( + instance_type="ml.p2.xlarge", + instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3", + model_uri="s3://jumpstart-cache-prod-us-west-2/artifacts/meta-textgeneration-llama-2-7b/" + "gpu-training/model/", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/transfer_learning/" + "eqa/v1.0.0/sourcedir.tar.gz", + entry_point="transfer_learning.py", + hyperparameters={"epochs": "3", "adam-learning-rate": "2e-05", "batch-size": "4"}, + role="fake role! do not use!", + sagemaker_session=estimator.sagemaker_session, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "gpu-training"}, + ], + enable_network_isolation=False, + ) + + estimator.fit() + + mock_estimator_fit.assert_called_once_with(wait=True) + + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_set_config_name( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training") + + estimator.set_training_config(config_name="gpu-training-budget") + + mock_estimator_init.assert_called_with( + instance_type="ml.p2.xlarge", + instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3", + model_uri="s3://jumpstart-cache-prod-us-west-2/artifacts/meta-textgeneration-llama-2-7b/" + "gpu-training-budget/model/", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" + "transfer_learning/eqa/v1.0.0/sourcedir.tar.gz", + entry_point="transfer_learning.py", + hyperparameters={"epochs": "3", "adam-learning-rate": "2e-05", "batch-size": "4"}, + role="fake role! do not use!", + sagemaker_session=estimator.sagemaker_session, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "gpu-training-budget"}, + ], + enable_network_isolation=False, + ) + + estimator.fit() + + mock_estimator_fit.assert_called_once_with(wait=True) + + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_default_inference_config( + self, + mock_estimator_fit: mock.Mock, + mock_estimator_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_estimator_deploy.return_value = default_predictor + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training") + + assert estimator.config_name == "gpu-training" + + estimator.deploy() + + mock_estimator_deploy.assert_called_once_with( + instance_type="ml.p2.xlarge", + initial_instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.5.0-gpu-py3", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz", + entry_point="inference.py", + predictor_cls=Predictor, + wait=True, + role="fake role! do not use!", + use_compiled_model=False, + enable_network_isolation=False, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-inference"}, + ], + ) + + @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") + @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_incremental_training_config( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_model_info_from_training_job: mock.Mock, + mock_attach: mock.Mock, + ): + mock_get_model_info_from_training_job.return_value = ( + "pytorch-eqa-bert-base-cased", + "1.0.0", + None, + "gpu-training-budget", + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training") + + assert estimator.config_name == "gpu-training" + + JumpStartEstimator.attach( + training_job_name="some-training-job-name", sagemaker_session=mock_session + ) + + mock_attach.assert_called_once_with( + training_job_name="some-training-job-name", + sagemaker_session=mock_session, + model_channel_name="model", + additional_kwargs={ + "model_id": "pytorch-eqa-bert-base-cased", + "model_version": "1.0.0", + "tolerate_vulnerable_model": True, + "tolerate_deprecated_model": True, + "config_name": "gpu-training-budget", + }, + ) + + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_deploy_with_config( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_estimator_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_estimator_deploy.return_value = default_predictor + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training-budget") + + assert estimator.config_name == "gpu-training-budget" + + estimator.deploy() + + mock_estimator_deploy.assert_called_once_with( + instance_type="ml.p2.xlarge", + initial_instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.5.0-gpu-py3", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz", + entry_point="inference.py", + predictor_cls=Predictor, + wait=True, + role="fake role! do not use!", + use_compiled_model=False, + enable_network_isolation=False, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-inference-budget"}, + ], + ) + def test_jumpstart_estimator_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 140b839937..25e01d5d10 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -15,6 +15,8 @@ from typing import Optional, Set from unittest import mock import unittest + +import pandas as pd from mock import MagicMock, Mock import pytest from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig @@ -40,12 +42,18 @@ from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from tests.unit.sagemaker.jumpstart.utils import ( + get_prototype_spec_with_configs, get_spec_from_base_spec, get_special_model_spec, overwrite_dictionary, get_special_model_spec_for_inference_component_based_endpoint, get_prototype_manifest, get_prototype_model_spec, + get_base_spec_with_prototype_configs, + get_mock_init_kwargs, + get_base_deployment_configs, + get_base_spec_with_prototype_configs_with_missing_benchmarks, + append_instance_stat_metrics, ) import boto3 @@ -60,9 +68,11 @@ class ModelTest(unittest.TestCase): - mock_session_empty_config = MagicMock(sagemaker_config={}) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @@ -82,6 +92,7 @@ def test_non_prepacked( mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, mock_jumpstart_model_factory_logger: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -141,6 +152,9 @@ def test_non_prepacked( endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( @@ -158,6 +172,7 @@ def test_non_prepacked_inference_component_based_endpoint( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -223,6 +238,9 @@ def test_non_prepacked_inference_component_based_endpoint( endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( @@ -240,6 +258,7 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -300,6 +319,9 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" @@ -315,6 +337,7 @@ def test_prepacked( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -361,6 +384,9 @@ def test_prepacked( endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.model.LOGGER.warning") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") @@ -380,6 +406,7 @@ def test_no_compiled_model_warning_log_js_models( mock_endpoint_from_production_variants: mock.Mock, mock_timestamp: mock.Mock, mock_warning: mock.Mock(), + mock_get_jumpstart_configs: mock.Mock, ): mock_timestamp.return_value = "1234" @@ -400,6 +427,9 @@ def test_no_compiled_model_warning_log_js_models( mock_warning.assert_not_called() + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") @mock.patch("sagemaker.session.Session.create_model") @@ -417,6 +447,7 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( mock_create_model: mock.Mock, mock_endpoint_from_production_variants: mock.Mock, mock_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_timestamp.return_value = "1234" @@ -464,6 +495,7 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( ], ) + @mock.patch("sagemaker.jumpstart.model.get_jumpstart_configs") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @@ -483,7 +515,9 @@ def test_proprietary_model_endpoint( mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): + mock_get_jumpstart_configs.side_effect = lambda *args, **kwargs: {} mock_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) @@ -523,6 +557,7 @@ def test_proprietary_model_endpoint( container_startup_health_check_timeout=600, ) + @mock.patch("sagemaker.jumpstart.model.get_jumpstart_configs") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -534,7 +569,9 @@ def test_deprecated( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): + mock_get_jumpstart_configs.side_effect = lambda *args, **kwargs: {} mock_model_deploy.return_value = default_predictor mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -550,6 +587,9 @@ def test_deprecated( JumpStartModel(model_id=model_id, tolerate_deprecated_model=True).deploy() + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -561,6 +601,7 @@ def test_vulnerable( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -624,6 +665,9 @@ def test_model_use_kwargs(self): deploy_kwargs=all_deploy_kwargs_used, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.factory.model.environment_variables.retrieve_default") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( @@ -641,6 +685,7 @@ def evaluate_model_workflow_with_kwargs( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_retrieve_environment_variables: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, init_kwargs: Optional[dict] = None, deploy_kwargs: Optional[dict] = None, ): @@ -731,6 +776,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): "tolerate_deprecated_model", "instance_type", "model_package_arn", + "config_name", } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -743,6 +789,9 @@ def test_jumpstart_model_kwargs_match_parent_class(self): assert js_class_deploy_args - parent_class_deploy_args == set() assert parent_class_deploy_args - js_class_deploy_args == deploy_args_to_skip + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @@ -751,6 +800,7 @@ def test_validate_model_id_and_get_type( mock_validate_model_id_and_get_type: mock.Mock, mock_init: mock.Mock, mock_get_init_kwargs: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS JumpStartModel(model_id="valid_model_id") @@ -759,6 +809,9 @@ def test_validate_model_id_and_get_type( with pytest.raises(ValueError): JumpStartModel(model_id="invalid_model_id") + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( @@ -776,6 +829,7 @@ def test_no_predictor_returns_default_predictor( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets @@ -804,10 +858,14 @@ def test_no_predictor_returns_default_predictor( tolerate_vulnerable_model=False, sagemaker_session=model.sagemaker_session, model_type=JumpStartModelType.OPEN_WEIGHTS, + config_name=None, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( @@ -825,6 +883,7 @@ def test_no_predictor_yes_async_inference_config( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets @@ -846,6 +905,9 @@ def test_no_predictor_yes_async_inference_config( mock_get_default_predictor.assert_not_called() + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( @@ -863,6 +925,7 @@ def test_yes_predictor_returns_default_predictor( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets @@ -884,6 +947,9 @@ def test_yes_predictor_returns_default_predictor( self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -901,6 +967,7 @@ def test_model_id_not_found_refeshes_cache_inference( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.side_effect = [False, False] @@ -969,6 +1036,9 @@ def test_model_id_not_found_refeshes_cache_inference( ] ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -976,6 +1046,7 @@ def test_jumpstart_model_tags( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1005,6 +1076,9 @@ def test_jumpstart_model_tags( [{"Key": "blah", "Value": "blahagain"}] + js_tags, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1012,6 +1086,7 @@ def test_jumpstart_model_tags_disabled( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1039,6 +1114,9 @@ def test_jumpstart_model_tags_disabled( [{"Key": "blah", "Value": "blahagain"}], ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1046,6 +1124,7 @@ def test_jumpstart_model_package_arn( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1073,6 +1152,9 @@ def test_jumpstart_model_package_arn( self.assertIn(tag, mock_session.create_model.call_args[1]["tags"]) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1080,6 +1162,7 @@ def test_jumpstart_model_package_arn_override( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1115,6 +1198,9 @@ def test_jumpstart_model_package_arn_override( }, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" @@ -1126,6 +1212,7 @@ def test_jumpstart_model_package_arn_unsupported_region( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1143,6 +1230,9 @@ def test_jumpstart_model_package_arn_unsupported_region( "us-east-2. Please try one of the following regions: us-west-2, us-east-1." ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( @@ -1162,6 +1252,7 @@ def test_model_data_s3_prefix_override( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -1211,6 +1302,9 @@ def test_model_data_s3_prefix_override( '"S3DataType": "S3Prefix", "CompressionType": "None"}}', ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" @@ -1228,6 +1322,7 @@ def test_model_data_s3_prefix_model( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -1257,6 +1352,9 @@ def test_model_data_s3_prefix_model( mock_js_info_logger.assert_not_called() + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" @@ -1274,6 +1372,7 @@ def test_model_artifact_variant_model( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -1324,6 +1423,9 @@ def test_model_artifact_variant_model( enable_network_isolation=True, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch( "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" @@ -1339,6 +1441,7 @@ def test_model_registry_accept_and_response_types( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -1358,6 +1461,9 @@ def test_model_registry_accept_and_response_types( response_types=["application/json;verbose", "application/json"], ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -1371,6 +1477,7 @@ def test_jumpstart_model_session( mock_deploy, mock_init, get_default_predictor, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = True @@ -1404,6 +1511,9 @@ def test_jumpstart_model_session( assert len(s3_clients) == 1 assert list(s3_clients)[0] == session.s3_client + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch.dict( "sagemaker.jumpstart.cache.os.environ", { @@ -1424,6 +1534,7 @@ def test_model_local_mode( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_model_specs.side_effect = get_prototype_model_spec mock_get_manifest.side_effect = ( @@ -1450,6 +1561,454 @@ def test_model_local_mode( endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_initialization_with_config_name( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id, config_name="neuron-inference") + + assert model.config_name == "neuron-inference" + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.inf2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"}, + ], + wait=True, + endpoint_logging=False, + ) + + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_set_deployment_config( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_model_spec + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + assert model.config_name is None + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.p2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + ], + wait=True, + endpoint_logging=False, + ) + + mock_get_model_specs.reset_mock() + mock_model_deploy.reset_mock() + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + model.set_deployment_config("neuron-inference", "ml.inf2.2xlarge") + + assert model.config_name == "neuron-inference" + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.inf2.2xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"}, + ], + wait=True, + endpoint_logging=False, + ) + mock_model_deploy.reset_mock() + model.set_deployment_config("neuron-inference", "ml.inf2.xlarge") + + assert model.config_name == "neuron-inference" + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.inf2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"}, + ], + wait=True, + endpoint_logging=False, + ) + + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_set_deployment_config_model_package( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + assert model.config_name == "neuron-inference" + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.inf2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"}, + ], + wait=True, + endpoint_logging=False, + ) + + mock_model_deploy.reset_mock() + + model.set_deployment_config( + config_name="gpu-inference-model-package", instance_type="ml.p2.xlarge" + ) + + assert ( + model.model_package_arn + == "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" + ) + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.p2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-inference-model-package"}, + ], + wait=True, + endpoint_logging=False, + ) + + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_set_deployment_config_incompatible_instance_type_or_name( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_model_spec + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + assert model.config_name is None + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.p2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + ], + wait=True, + endpoint_logging=False, + ) + + mock_get_model_specs.reset_mock() + mock_model_deploy.reset_mock() + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + with pytest.raises(ValueError) as error: + model.set_deployment_config("neuron-inference", "ml.inf2.32xlarge") + assert ( + "Instance type ml.inf2.32xlarge is not supported for config neuron-inference." + in str(error) + ) + + with pytest.raises(ValueError) as error: + model.set_deployment_config("neuron-inference-unknown-name", "ml.inf2.32xlarge") + assert "Cannot find Jumpstart config name neuron-inference-unknown-name. " in str(error) + + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_list_deployment_configs( + self, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + mock_get_init_kwargs: mock.Mock, + ): + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks() + ) + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + configs = model.list_deployment_configs() + + self.assertEqual(configs, get_base_deployment_configs(True)) + + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_list_deployment_configs_empty( + self, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + ): + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_special_model_spec(model_id="gemma-model") + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + configs = model.list_deployment_configs() + + self.assertTrue(len(configs) == 0) + + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_retrieve_deployment_config( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + mock_get_init_kwargs: mock.Mock, + ): + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_base_spec_with_prototype_configs() + ) + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + expected = get_base_deployment_configs()[0] + config_name = expected.get("DeploymentConfigName") + instance_type = expected.get("InstanceType") + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs( + model_id, config_name + ) + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + model.set_deployment_config(config_name, instance_type) + + self.assertEqual(model.deployment_config, expected) + + mock_get_init_kwargs.reset_mock() + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_display_benchmark_metrics( + self, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + mock_get_init_kwargs: mock.Mock, + ): + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks() + ) + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + model.display_benchmark_metrics() + model.display_benchmark_metrics(instance_type="g5.12xlarge") + + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_benchmark_metrics( + self, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + mock_get_init_kwargs: mock.Mock, + ): + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_base_spec_with_prototype_configs() + ) + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + df = model.benchmark_metrics + + self.assertTrue(isinstance(df, pd.DataFrame)) + def test_jumpstart_model_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py index 70409704e6..2be4bde7e4 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py @@ -60,6 +60,9 @@ class IntelligentDefaultsModelTest(unittest.TestCase): region = "us-west-2" sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -77,6 +80,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -101,6 +105,9 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( assert "enable_network_isolation" not in mock_model_init.call_args[1] + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -118,6 +125,7 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -147,6 +155,9 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( override_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -164,6 +175,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -193,6 +205,9 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( config_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -210,6 +225,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -241,6 +257,9 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( override_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -258,6 +277,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -287,6 +307,9 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( metadata_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -304,6 +327,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -334,6 +358,9 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( override_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -351,6 +378,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -375,6 +403,9 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( self.assertEquals(mock_model_init.call_args[1].get("role"), execution_role) assert "enable_network_isolation" not in mock_model_init.call_args[1] + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -392,6 +423,7 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index c0a37c5b38..301afe4d53 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -1,6 +1,7 @@ from __future__ import absolute_import -import json + import datetime +import json from unittest import TestCase from unittest.mock import Mock, patch @@ -235,7 +236,7 @@ def test_list_jumpstart_models_script_filter( get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json() ) patched_get_manifest.side_effect = ( - lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region) ) manifest_length = len(get_prototype_manifest()) @@ -243,7 +244,7 @@ def test_list_jumpstart_models_script_filter( for val in vals: kwargs = {"filter": And(f"training_supported == {val}", "model_type is open_weights")} list_jumpstart_models(**kwargs) - assert patched_read_s3_file.call_count == manifest_length + assert patched_read_s3_file.call_count == 2 * manifest_length assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -251,7 +252,7 @@ def test_list_jumpstart_models_script_filter( kwargs = {"filter": And(f"training_supported != {val}", "model_type is open_weights")} list_jumpstart_models(**kwargs) - assert patched_read_s3_file.call_count == manifest_length + assert patched_read_s3_file.call_count == 2 * manifest_length assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -270,7 +271,7 @@ def test_list_jumpstart_models_script_filter( ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0"), ("xgboost-classification-model", "1.0.0"), ] - assert patched_read_s3_file.call_count == manifest_length + assert patched_read_s3_file.call_count == 2 * manifest_length assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -279,7 +280,7 @@ def test_list_jumpstart_models_script_filter( kwargs = {"filter": And(f"training_supported not in {vals}", "model_type is open_weights")} models = list_jumpstart_models(**kwargs) assert [] == models - assert patched_read_s3_file.call_count == manifest_length + assert patched_read_s3_file.call_count == 2 * manifest_length assert patched_get_manifest.call_count == 2 @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 52f28f2da1..a3425a7b90 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -18,7 +18,7 @@ from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec, get_spec_from_base_spec -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support( @@ -52,7 +52,7 @@ def test_jumpstart_predictor_support( assert js_predictor.accept == MIMEType.JSON -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_proprietary_predictor_support( @@ -91,13 +91,13 @@ def test_proprietary_predictor_support( @patch("sagemaker.predictor.Predictor") @patch("sagemaker.predictor.get_default_predictor") -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( patched_get_model_specs, patched_verify_model_region_and_return_specs, - patched_get_jumpstart_model_id_version_from_endpoint, + patched_get_model_info_from_endpoint, patched_get_default_predictor, patched_predictor, ): @@ -105,19 +105,19 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_get_jumpstart_model_id_version_from_endpoint.return_value = ( + patched_get_model_info_from_endpoint.return_value = ( "predictor-specs-model", "1.2.3", None, + None, + None, ) mock_session = Mock() predictor.retrieve_default(endpoint_name="blah", sagemaker_session=mock_session) - patched_get_jumpstart_model_id_version_from_endpoint.assert_called_once_with( - "blah", None, mock_session - ) + patched_get_model_info_from_endpoint.assert_called_once_with("blah", None, mock_session) patched_get_default_predictor.assert_called_once_with( predictor=patched_predictor.return_value, @@ -128,11 +128,12 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( tolerate_vulnerable_model=False, sagemaker_session=mock_session, model_type=JumpStartModelType.OPEN_WEIGHTS, + config_name=None, ) @patch("sagemaker.predictor.get_default_predictor") -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( @@ -159,7 +160,8 @@ def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( patched_get_default_predictor.assert_not_called() -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}) +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") @patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @@ -169,7 +171,8 @@ def test_jumpstart_serializable_payload_with_predictor( patched_verify_model_region_and_return_specs, patched_validate_model_id_and_get_type, patched_get_object_cached, - patched_get_model_id_version_from_endpoint, + patched_get_model_info_from_endpoint, + patched_get_jumpstart_configs, ): patched_get_object_cached.return_value = base64.b64decode("encodedimage") @@ -179,7 +182,7 @@ def test_jumpstart_serializable_payload_with_predictor( patched_get_model_specs.side_effect = get_special_model_spec model_id, model_version = "default_payloads", "*" - patched_get_model_id_version_from_endpoint.return_value = model_id, model_version, None + patched_get_model_info_from_endpoint.return_value = model_id, model_version, None js_predictor = predictor.retrieve_default( endpoint_name="blah", model_id=model_id, model_version=model_version diff --git a/tests/unit/sagemaker/jumpstart/test_session_utils.py b/tests/unit/sagemaker/jumpstart/test_session_utils.py index 76ad50f31c..ce06a189bd 100644 --- a/tests/unit/sagemaker/jumpstart/test_session_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_session_utils.py @@ -4,167 +4,202 @@ import pytest from sagemaker.jumpstart.session_utils import ( - _get_model_id_version_from_inference_component_endpoint_with_inference_component_name, - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name, - _get_model_id_version_from_model_based_endpoint, - get_model_id_version_from_endpoint, - get_model_id_version_from_training_job, + _get_model_info_from_inference_component_endpoint_with_inference_component_name, + _get_model_info_from_inference_component_endpoint_without_inference_component_name, + _get_model_info_from_model_based_endpoint, + get_model_info_from_endpoint, + get_model_info_from_training_job, ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_training_job_happy_case( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_training_job_happy_case( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", + None, + None, ) - retval = get_model_id_version_from_training_job("bLaH", sagemaker_session=mock_sm_session) + retval = get_model_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version") + assert retval == ("model_id", "model_version", None, None) - mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:training-job/bLaH", mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_training_job_no_model_id_inferred( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_training_job_config_name( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( + "model_id", + "model_version", + None, + "training_config_name", + ) + + retval = get_model_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) + + assert retval == ("model_id", "model_version", None, "training_config_name") + + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( + "arn:aws:sagemaker:us-west-2:123456789012:training-job/bLaH", mock_sm_session + ) + + +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_training_job_no_model_id_inferred( + mock_get_jumpstart_model_info_from_resource_arn, +): + mock_sm_session = Mock() + mock_sm_session.boto_region_name = "us-west-2" + mock_sm_session.account_id = Mock(return_value="123456789012") + + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( None, None, ) with pytest.raises(ValueError): - get_model_id_version_from_training_job("blah", sagemaker_session=mock_sm_session) + get_model_info_from_training_job("blah", sagemaker_session=mock_sm_session) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_model_based_endpoint_happy_case( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_model_based_endpoint_happy_case( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", + None, + None, ) - retval = _get_model_id_version_from_model_based_endpoint( + retval = _get_model_info_from_model_based_endpoint( "bLaH", inference_component_name=None, sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version") + assert retval == ("model_id", "model_version", None, None) - mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:endpoint/blah", mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_model_based_endpoint_inference_component_supplied( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_model_based_endpoint_inference_component_supplied( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", + None, + None, ) with pytest.raises(ValueError): - _get_model_id_version_from_model_based_endpoint( + _get_model_info_from_model_based_endpoint( "blah", inference_component_name="some-name", sagemaker_session=mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_model_based_endpoint_no_model_id_inferred( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_model_based_endpoint_no_model_id_inferred( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( + None, None, None, ) with pytest.raises(ValueError): - _get_model_id_version_from_model_based_endpoint( + _get_model_info_from_model_based_endpoint( "blah", inference_component_name="some-name", sagemaker_session=mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_inference_component_endpoint_with_inference_component_name_happy_case( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_inference_component_endpoint_with_inference_component_name_happy_case( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", + None, + None, ) - retval = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( + retval = _get_model_info_from_inference_component_endpoint_with_inference_component_name( "bLaH", sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version") + assert retval == ("model_id", "model_version", None, None) - mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:inference-component/bLaH", mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_inference_component_endpoint_with_inference_component_name_no_model_id_inferred( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_inference_component_endpoint_with_inference_component_name_no_model_id_inferred( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( + None, + None, None, None, ) with pytest.raises(ValueError): - _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( + _get_model_info_from_inference_component_endpoint_with_inference_component_name( "blah", sagemaker_session=mock_sm_session ) @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_id_version_from_inference_component_endpoint_without_inference_component_name_happy_case( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_info_from_inference_component_endpoint_without_inference_component_name_happy_case( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -172,10 +207,8 @@ def test_get_model_id_version_from_inference_component_endpoint_without_inferenc return_value=["icname"] ) - retval = ( - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( - "blahblah", mock_sm_session - ) + retval = _get_model_info_from_inference_component_endpoint_without_inference_component_name( + "blahblah", mock_sm_session ) assert retval == ("model_id", "model_version", "icname") @@ -185,14 +218,14 @@ def test_get_model_id_version_from_inference_component_endpoint_without_inferenc @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_no_ic_for_endpoint( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_info_from_inference_component_endpoint_without_ic_name_no_ic_for_endpoint( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -200,7 +233,7 @@ def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_ return_value=[] ) with pytest.raises(ValueError): - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( + _get_model_info_from_inference_component_endpoint_without_inference_component_name( "blahblah", mock_sm_session ) @@ -210,14 +243,14 @@ def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_ @patch( - "sagemaker.jumpstart.session_utils._get_model_id" - "_version_from_inference_component_endpoint_with_inference_component_name" + "sagemaker.jumpstart.session_utils._get_model" + "_info_from_inference_component_endpoint_with_inference_component_name" ) def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_multiple_ics_for_endpoint( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -227,7 +260,7 @@ def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_ ) with pytest.raises(ValueError): - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( + _get_model_info_from_inference_component_endpoint_without_inference_component_name( "blahblah", mock_sm_session ) @@ -236,67 +269,119 @@ def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_ ) -@patch("sagemaker.jumpstart.session_utils._get_model_id_version_from_model_based_endpoint") -def test_get_model_id_version_from_endpoint_non_inference_component_endpoint( - mock_get_model_id_version_from_model_based_endpoint, +@patch("sagemaker.jumpstart.session_utils._get_model_info_from_model_based_endpoint") +def test_get_model_info_from_endpoint_non_inference_component_endpoint( + mock_get_model_info_from_model_based_endpoint, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = False - mock_get_model_id_version_from_model_based_endpoint.return_value = ( + mock_get_model_info_from_model_based_endpoint.return_value = ( "model_id", "model_version", + None, + None, ) - retval = get_model_id_version_from_endpoint("blah", sagemaker_session=mock_sm_session) + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", None) - mock_get_model_id_version_from_model_based_endpoint.assert_called_once_with( + assert retval == ("model_id", "model_version", None, None, None) + mock_get_model_info_from_model_based_endpoint.assert_called_once_with( "blah", None, mock_sm_session ) mock_sm_session.is_inference_component_based_endpoint.assert_called_once_with("blah") @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_id_version_from_endpoint_inference_component_endpoint_with_inference_component_name( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_info_from_endpoint_inference_component_endpoint_with_inference_component_name( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = True - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", + None, + None, ) - retval = get_model_id_version_from_endpoint( + retval = get_model_info_from_endpoint( "blah", inference_component_name="icname", sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version", "icname") - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.assert_called_once_with( + assert retval == ("model_id", "model_version", "icname", None, None) + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.assert_called_once_with( "icname", mock_sm_session ) mock_sm_session.is_inference_component_based_endpoint.assert_not_called() @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_component_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_" + "endpoint_without_inference_component_name" +) +def test_get_model_info_from_endpoint_inference_component_endpoint_without_inference_component_name( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name, +): + mock_sm_session = Mock() + mock_sm_session.is_inference_component_based_endpoint.return_value = True + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( + "model_id", + "model_version", + None, + None, + "inferred-icname", + ) + + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) + + assert retval == ("model_id", "model_version", "inferred-icname", None, None) + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + + +@patch( + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_" "endpoint_without_inference_component_name" ) -def test_get_model_id_version_from_endpoint_inference_component_endpoint_without_inference_component_name( - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name, +def test_get_model_info_from_endpoint_inference_component_endpoint_with_inference_config_name( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = True - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( "model_id", "model_version", + "inference_config_name", + None, + "inferred-icname", + ) + + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) + + assert retval == ("model_id", "model_version", "inferred-icname", "inference_config_name", None) + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + + +@patch( + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_" + "endpoint_without_inference_component_name" +) +def test_get_model_info_from_endpoint_inference_component_endpoint_with_training_config_name( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name, +): + mock_sm_session = Mock() + mock_sm_session.is_inference_component_based_endpoint.return_value = True + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( + "model_id", + "model_version", + None, + "training_config_name", "inferred-icname", ) - retval = get_model_id_version_from_endpoint("blah", sagemaker_session=mock_sm_session) + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", "inferred-icname") - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + assert retval == ("model_id", "model_version", "inferred-icname", None, "training_config_name") + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index b2758c73ef..23fa42c09a 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -17,11 +17,14 @@ from sagemaker.jumpstart.types import ( JumpStartBenchmarkStat, JumpStartECRSpecs, + JumpStartEnvironmentVariable, JumpStartHyperparameter, JumpStartInstanceTypeVariants, JumpStartModelSpecs, JumpStartModelHeader, JumpStartConfigComponent, + DeploymentConfigMetadata, + JumpStartModelInitKwargs, ) from tests.unit.sagemaker.jumpstart.constants import ( BASE_SPEC, @@ -29,6 +32,7 @@ INFERENCE_CONFIGS, TRAINING_CONFIG_RANKINGS, TRAINING_CONFIGS, + INIT_KWARGS, ) INSTANCE_TYPE_VARIANT = JumpStartInstanceTypeVariants( @@ -924,6 +928,7 @@ def test_inference_configs_parsing(): "neuron-inference", "neuron-budget", "gpu-inference", + "gpu-inference-model-package", "gpu-inference-budget", ] @@ -1016,6 +1021,80 @@ def test_inference_configs_parsing(): } ), ] + assert specs1.inference_environment_variables == [ + JumpStartEnvironmentVariable( + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + } + ), + JumpStartEnvironmentVariable( + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + } + ), + JumpStartEnvironmentVariable( + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + } + ), + JumpStartEnvironmentVariable( + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + } + ), + JumpStartEnvironmentVariable( + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + } + ), + JumpStartEnvironmentVariable( + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + } + ), + JumpStartEnvironmentVariable( + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + } + ), + JumpStartEnvironmentVariable( + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + } + ), + ] # Overrided fields in top config assert specs1.supported_inference_instance_types == ["ml.inf2.xlarge", "ml.inf2.2xlarge"] @@ -1024,7 +1103,9 @@ def test_inference_configs_parsing(): assert config.benchmark_metrics == { "ml.inf2.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), ] } assert len(config.config_components) == 1 @@ -1052,6 +1133,20 @@ def test_inference_configs_parsing(): ) assert list(config.config_components.keys()) == ["neuron-inference"] + config = specs1.inference_configs.configs["gpu-inference-model-package"] + assert config.config_components["gpu-inference-model-package"] == JumpStartConfigComponent( + "gpu-inference-model-package", + { + "default_inference_instance_type": "ml.p2.xlarge", + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "hosting_model_package_arns": { + "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/" + "llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" + }, + }, + ) + assert config.resolved_config.get("inference_environment_variables") == [] + spec = { **BASE_SPEC, **INFERENCE_CONFIGS, @@ -1070,6 +1165,7 @@ def test_set_inference_configs(): "neuron-inference", "neuron-budget", "gpu-inference", + "gpu-inference-model-package", "gpu-inference-budget", ] @@ -1078,7 +1174,7 @@ def test_set_inference_configs(): assert "Cannot find Jumpstart config name invalid_name." "List of config names that is supported by the model: " "['neuron-inference', 'neuron-inference-budget', " - "'gpu-inference-budget', 'gpu-inference']" in str(error.value) + "'gpu-inference-budget', 'gpu-inference', 'gpu-inference-model-package']" in str(error.value) assert specs1.supported_inference_instance_types == ["ml.inf2.xlarge", "ml.inf2.2xlarge"] specs1.set_config("gpu-inference") @@ -1188,18 +1284,29 @@ def test_training_configs_parsing(): assert config.benchmark_metrics == { "ml.tr1n1.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), ], "ml.tr1n1.4xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "50", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1} + ), ], } assert len(config.config_components) == 1 assert config.config_components["neuron-training"] == JumpStartConfigComponent( "neuron-training", { + "default_training_instance_type": "ml.trn1.2xlarge", "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training/model/", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, "training_instance_type_variants": { "regional_aliases": { "us-west-2": { @@ -1256,3 +1363,38 @@ def test_set_training_config(): with pytest.raises(ValueError) as error: specs1.set_config("invalid_name", scope="unknown scope") + + +def test_deployment_config_metadata(): + spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} + specs = JumpStartModelSpecs(spec) + jumpstart_config = specs.inference_configs.get_top_config_from_ranking() + + deployment_config_metadata = DeploymentConfigMetadata( + jumpstart_config.config_name, + jumpstart_config, + JumpStartModelInitKwargs( + model_id=specs.model_id, + model_data=INIT_KWARGS.get("model_data"), + image_uri=INIT_KWARGS.get("image_uri"), + instance_type=INIT_KWARGS.get("instance_type"), + env=INIT_KWARGS.get("env"), + config_name=jumpstart_config.config_name, + ), + ) + + json_obj = deployment_config_metadata.to_json() + + assert isinstance(json_obj, dict) + assert json_obj["DeploymentConfigName"] == jumpstart_config.config_name + for key in json_obj["BenchmarkMetrics"]: + assert len(json_obj["BenchmarkMetrics"][key]) == len( + jumpstart_config.benchmark_metrics.get(key) + ) + assert json_obj["AccelerationConfigs"] == jumpstart_config.resolved_config.get( + "acceleration_configs" + ) + assert json_obj["DeploymentArgs"]["ImageUri"] == INIT_KWARGS.get("image_uri") + assert json_obj["DeploymentArgs"]["ModelData"] == INIT_KWARGS.get("model_data") + assert json_obj["DeploymentArgs"]["Environment"] == INIT_KWARGS.get("env") + assert json_obj["DeploymentArgs"]["InstanceType"] == INIT_KWARGS.get("instance_type") diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 941e2797ea..a5a063c696 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -13,6 +13,8 @@ from __future__ import absolute_import import os from unittest import TestCase + +from botocore.exceptions import ClientError from mock.mock import Mock, patch import pytest import boto3 @@ -49,8 +51,10 @@ get_spec_from_base_spec, get_special_model_spec, get_prototype_manifest, + get_base_deployment_configs_metadata, + get_base_deployment_configs, ) -from mock import MagicMock, call +from mock import MagicMock MOCK_CLIENT = MagicMock() @@ -207,16 +211,16 @@ def test_is_jumpstart_model_uri(): assert utils.is_jumpstart_model_uri(random_jumpstart_s3_uri("random_key")) -def test_add_jumpstart_model_id_version_tags(): +def test_add_jumpstart_model_info_tags(): tags = None model_id = "model_id" version = "version" + inference_config_name = "inference_config_name" + training_config_name = "training_config_name" assert [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, @@ -228,9 +232,7 @@ def test_add_jumpstart_model_id_version_tags(): assert [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version_2"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "random key", "Value": "random_value"}, @@ -241,9 +243,7 @@ def test_add_jumpstart_model_id_version_tags(): {"Key": "random key", "Value": "random_value"}, {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, @@ -254,9 +254,7 @@ def test_add_jumpstart_model_id_version_tags(): assert [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "random key", "Value": "random_value"}, @@ -265,8 +263,58 @@ def test_add_jumpstart_model_id_version_tags(): version = None assert [ {"Key": "random key", "Value": "random_value"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) + + tags = [ + {"Key": "random key", "Value": "random_value"}, + ] + model_id = "model_id" + version = "version" + assert [ + {"Key": "random key", "Value": "random_value"}, + {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, + {"Key": "sagemaker-sdk:jumpstart-inference-config-name", "Value": "inference_config_name"}, + ] == utils.add_jumpstart_model_info_tags( + tags=tags, + model_id=model_id, + model_version=version, + config_name=inference_config_name, + scope=JumpStartScriptScope.INFERENCE, + ) + + tags = [ + {"Key": "random key", "Value": "random_value"}, + ] + model_id = "model_id" + version = "version" + assert [ + {"Key": "random key", "Value": "random_value"}, + {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, + {"Key": "sagemaker-sdk:jumpstart-training-config-name", "Value": "training_config_name"}, + ] == utils.add_jumpstart_model_info_tags( + tags=tags, + model_id=model_id, + model_version=version, + config_name=training_config_name, + scope=JumpStartScriptScope.TRAINING, + ) + + tags = [ + {"Key": "random key", "Value": "random_value"}, + ] + model_id = "model_id" + version = "version" + assert [ + {"Key": "random key", "Value": "random_value"}, + {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, + ] == utils.add_jumpstart_model_info_tags( + tags=tags, + model_id=model_id, + model_version=version, + config_name=training_config_name, ) @@ -1319,10 +1367,8 @@ def test_no_model_id_no_version_found(self): mock_list_tags.return_value = [{"Key": "blah", "Value": "blah1"}] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1336,10 +1382,8 @@ def test_model_id_no_version_found(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - ("model_id", None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id", None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1353,10 +1397,66 @@ def test_no_model_id_version_found(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, "model_version"), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, "model_version", None, None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_no_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [{"Key": "blah", "Value": "blah1"}] + + self.assertEquals( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_inference_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name"}, + ] + + self.assertEquals( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, "config_name", None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_training_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "config_name"}, + ] + + self.assertEquals( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, "config_name"), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_both_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "inference_config_name"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "training_config_name"}, + ] + + self.assertEquals( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, "inference_config_name", "training_config_name"), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1371,10 +1471,8 @@ def test_model_id_version_found(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - ("model_id", "model_version"), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id", "model_version", None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1391,10 +1489,8 @@ def test_multiple_model_id_versions_found(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1411,10 +1507,8 @@ def test_multiple_model_id_versions_found_aliases_consistent(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - ("model_id_1", "model_version_1"), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id_1", "model_version_1", None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1431,10 +1525,26 @@ def test_multiple_model_id_versions_found_aliases_inconsistent(self): ] self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, None), + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_multiple_config_names_found_aliases_inconsistent(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.MODEL_ID, "Value": "model_id_1"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "model_version_1"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name_1"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name_2"}, + ] + + self.assertEquals( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id_1", "model_version_1", None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1529,6 +1639,7 @@ def test_get_jumpstart_config_names_success( "neuron-inference-budget", "gpu-inference-budget", "gpu-inference", + "gpu-inference-model-package", ] @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1599,22 +1710,37 @@ def test_get_jumpstart_benchmark_stats_full_list( ) == { "neuron-inference": { "ml.inf2.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] }, "neuron-inference-budget": { "ml.inf2.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] }, "gpu-inference-budget": { "ml.p3.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] }, "gpu-inference": { "ml.p3.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) + ] + }, + "gpu-inference-model-package": { + "ml.p3.2xlarge": [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] }, } @@ -1634,12 +1760,16 @@ def test_get_jumpstart_benchmark_stats_partial_list( ) == { "neuron-inference-budget": { "ml.inf2.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] }, "gpu-inference-budget": { "ml.p3.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] }, } @@ -1659,7 +1789,9 @@ def test_get_jumpstart_benchmark_stats_single_stat( ) == { "neuron-inference-budget": { "ml.inf2.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] } } @@ -1687,6 +1819,16 @@ def test_get_jumpstart_benchmark_stats_training( ): patched_get_model_specs.side_effect = get_base_spec_with_prototype_configs + print( + utils.get_benchmark_stats( + "mock-region", + "mock-model", + "mock-model-version", + scope=JumpStartScriptScope.TRAINING, + config_names=["neuron-training", "gpu-training-budget"], + ) + ) + assert utils.get_benchmark_stats( "mock-region", "mock-model", @@ -1696,97 +1838,201 @@ def test_get_jumpstart_benchmark_stats_training( ) == { "neuron-training": { "ml.tr1n1.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ], "ml.tr1n1.4xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "50", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1} + ) ], }, "gpu-training-budget": { "ml.p3.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": "1"} + ) ] }, } -class TestUserAgent: - @patch("sagemaker.jumpstart.utils.os.getenv") - def test_get_jumpstart_user_agent_extra_suffix(self, mock_getenv): - mock_getenv.return_value = False - assert utils.get_jumpstart_user_agent_extra_suffix("some-id", "some-version").endswith( - "md/js_model_id#some-id md/js_model_ver#some-version" - ) - mock_getenv.return_value = None - assert utils.get_jumpstart_user_agent_extra_suffix("some-id", "some-version").endswith( - "md/js_model_id#some-id md/js_model_ver#some-version" - ) - mock_getenv.return_value = "True" - assert not utils.get_jumpstart_user_agent_extra_suffix("some-id", "some-version").endswith( - "md/js_model_id#some-id md/js_model_ver#some-version" - ) - mock_getenv.return_value = True - assert not utils.get_jumpstart_user_agent_extra_suffix("some-id", "some-version").endswith( - "md/js_model_id#some-id md/js_model_ver#some-version" - ) +def test_extract_metrics_from_deployment_configs(): + configs = get_base_deployment_configs_metadata() + configs[0].benchmark_metrics = None + configs[2].deployment_args = None - @patch("sagemaker.jumpstart.utils.botocore.session") - @patch("sagemaker.jumpstart.utils.botocore.config.Config") - @patch("sagemaker.jumpstart.utils.get_jumpstart_user_agent_extra_suffix") - @patch("sagemaker.jumpstart.utils.boto3.Session") - @patch("sagemaker.jumpstart.utils.boto3.client") - @patch("sagemaker.jumpstart.utils.Session") - def test_get_default_jumpstart_session_with_user_agent_suffix( - self, - mock_sm_session, - mock_boto3_client, - mock_botocore_session, - mock_get_jumpstart_user_agent_extra_suffix, - mock_botocore_config, - mock_boto3_session, - ): - utils.get_default_jumpstart_session_with_user_agent_suffix("model_id", "model_version") - mock_boto3_session.get_session.assert_called_once_with() - mock_get_jumpstart_user_agent_extra_suffix.assert_called_once_with( - "model_id", "model_version" - ) - mock_botocore_config.assert_called_once_with( - user_agent_extra=mock_get_jumpstart_user_agent_extra_suffix.return_value - ) - mock_botocore_session.assert_called_once_with( - region_name=JUMPSTART_DEFAULT_REGION_NAME, - botocore_session=mock_boto3_session.get_session.return_value, - ) - mock_boto3_client.assert_has_calls( + data = utils.get_metrics_from_deployment_configs(configs) + + for key in data: + assert len(data[key]) == (len(configs) - 2) + + +@patch("sagemaker.jumpstart.utils.get_instance_rate_per_hour") +def test_add_instance_rate_stats_to_benchmark_metrics( + mock_get_instance_rate_per_hour, +): + mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "3.76", + } + + err, out = utils.add_instance_rate_stats_to_benchmark_metrics( + "us-west-2", + { + "ml.p2.xlarge": [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) + ], + "ml.gd4.xlarge": [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) + ], + }, + ) + + assert err is None + for key in out: + assert len(out[key]) == 2 + for metric in out[key]: + if metric.name == "Instance Rate": + assert metric.to_json() == { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "3.76", + "concurrency": None, + } + + +def test__normalize_benchmark_metrics(): + rate, metrics = utils._normalize_benchmark_metrics( + [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), + JumpStartBenchmarkStat( + {"name": "Throughput", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 2} + ), + JumpStartBenchmarkStat( + {"name": "Throughput", "value": "100", "unit": "Tokens/S", "concurrency": 2} + ), + JumpStartBenchmarkStat( + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "3.76", "concurrency": None} + ), + ] + ) + + assert rate == JumpStartBenchmarkStat( + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "3.76", "concurrency": None} + ) + assert metrics == { + 1: [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), + JumpStartBenchmarkStat( + {"name": "Throughput", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), + ], + 2: [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 2} + ), + JumpStartBenchmarkStat( + {"name": "Throughput", "value": "100", "unit": "Tokens/S", "concurrency": 2} + ), + ], + } + + +@pytest.mark.parametrize( + "name, expected", + [ + ("latency", "Latency for each user (TTFT in ms)"), + ("throughput", "Throughput per user (token/seconds)"), + ], +) +def test__normalize_benchmark_metric_column_name(name, expected): + out = utils._normalize_benchmark_metric_column_name(name) + + assert out == expected + + +@patch("sagemaker.jumpstart.utils.get_instance_rate_per_hour") +def test_add_instance_rate_stats_to_benchmark_metrics_client_ex( + mock_get_instance_rate_per_hour, +): + mock_get_instance_rate_per_hour.side_effect = ClientError( + { + "Error": { + "Message": "is not authorized to perform: pricing:GetProducts", + "Code": "AccessDenied", + }, + }, + "GetProducts", + ) + + err, out = utils.add_instance_rate_stats_to_benchmark_metrics( + "us-west-2", + { + "ml.p2.xlarge": [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) + ], + }, + ) + + assert err["Message"] == "is not authorized to perform: pricing:GetProducts" + assert err["Code"] == "AccessDenied" + for key in out: + assert len(out[key]) == 1 + + +@pytest.mark.parametrize( + "stats, expected", + [ + (None, True), + ( + [ + JumpStartBenchmarkStat( + { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "3.76", + "concurrency": None, + } + ) + ], + True, + ), + ( [ - call( - "sagemaker", - region_name=JUMPSTART_DEFAULT_REGION_NAME, - config=mock_botocore_config.return_value, - ), - call( - "sagemaker-runtime", - region_name=JUMPSTART_DEFAULT_REGION_NAME, - config=mock_botocore_config.return_value, - ), + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": None} + ) ], - any_order=True, - ) + False, + ), + ], +) +def test_has_instance_rate_stat(stats, expected): + assert utils.has_instance_rate_stat(stats) is expected - @patch("botocore.client.BaseClient._make_request") - def test_get_default_jumpstart_session_with_user_agent_suffix_http_header( - self, - mock_make_request, - ): - session = utils.get_default_jumpstart_session_with_user_agent_suffix( - "model_id", "model_version" - ) - try: - session.sagemaker_client.list_endpoints() - except Exception: - pass - assert ( - "md/js_model_id#model_id md/js_model_ver#model_version" - in mock_make_request.call_args[0][1]["headers"]["User-Agent"] - ) +@pytest.mark.parametrize( + "data, expected", + [(None, []), ([], []), (get_base_deployment_configs_metadata(), get_base_deployment_configs())], +) +def test_deployment_config_response_data(data, expected): + out = utils.deployment_config_response_data(data) + + print(out) + assert out == expected diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index e102251060..cc4ef71cee 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -12,9 +12,10 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import copy -from typing import List +from typing import List, Dict, Any, Optional import boto3 +from sagemaker.compute_resource_requirements import ResourceRequirements from sagemaker.jumpstart.cache import JumpStartModelsCache from sagemaker.jumpstart.constants import ( JUMPSTART_DEFAULT_REGION_NAME, @@ -27,6 +28,10 @@ JumpStartModelSpecs, JumpStartS3FileType, JumpStartModelHeader, + JumpStartModelInitKwargs, + DeploymentConfigMetadata, + JumpStartModelDeployKwargs, + JumpStartBenchmarkStat, ) from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.utils import get_formatted_manifest @@ -43,6 +48,8 @@ SPECIAL_MODEL_SPECS_DICT, TRAINING_CONFIG_RANKINGS, TRAINING_CONFIGS, + DEPLOYMENT_CONFIGS, + INIT_KWARGS, ) @@ -222,6 +229,43 @@ def get_base_spec_with_prototype_configs( return JumpStartModelSpecs(spec) +def get_base_spec_with_prototype_configs_with_missing_benchmarks( + region: str = None, + model_id: str = None, + version: str = None, + s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, +) -> JumpStartModelSpecs: + spec = copy.deepcopy(BASE_SPEC) + copy_inference_configs = copy.deepcopy(INFERENCE_CONFIGS) + copy_inference_configs["inference_configs"]["neuron-inference"]["benchmark_metrics"] = None + + inference_configs = {**copy_inference_configs, **INFERENCE_CONFIG_RANKINGS} + training_configs = {**TRAINING_CONFIGS, **TRAINING_CONFIG_RANKINGS} + + spec.update(inference_configs) + spec.update(training_configs) + + return JumpStartModelSpecs(spec) + + +def get_prototype_spec_with_configs( + region: str = None, + model_id: str = None, + version: str = None, + s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, +) -> JumpStartModelSpecs: + spec = copy.deepcopy(PROTOTYPICAL_MODEL_SPECS_DICT[model_id]) + inference_configs = {**INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} + training_configs = {**TRAINING_CONFIGS, **TRAINING_CONFIG_RANKINGS} + + spec.update(inference_configs) + spec.update(training_configs) + + return JumpStartModelSpecs(spec) + + def patched_retrieval_function( _modelCacheObj: JumpStartModelsCache, key: JumpStartCachedS3ContentKey, @@ -280,3 +324,101 @@ def overwrite_dictionary( base_dictionary[key] = value return base_dictionary + + +def get_base_deployment_configs_with_acceleration_configs() -> List[Dict[str, Any]]: + configs = copy.deepcopy(DEPLOYMENT_CONFIGS) + configs[0]["AccelerationConfigs"] = [ + {"Type": "Speculative-Decoding", "Enabled": True, "Spec": {"Version": "0.1"}} + ] + return configs + + +def get_mock_init_kwargs( + model_id: str, config_name: Optional[str] = None +) -> JumpStartModelInitKwargs: + return JumpStartModelInitKwargs( + model_id=model_id, + model_type=JumpStartModelType.OPEN_WEIGHTS, + model_data=INIT_KWARGS.get("model_data"), + image_uri=INIT_KWARGS.get("image_uri"), + instance_type=INIT_KWARGS.get("instance_type"), + env=INIT_KWARGS.get("env"), + resources=ResourceRequirements(), + config_name=config_name, + ) + + +def get_base_deployment_configs_metadata( + omit_benchmark_metrics: bool = False, +) -> List[DeploymentConfigMetadata]: + specs = ( + get_base_spec_with_prototype_configs_with_missing_benchmarks() + if omit_benchmark_metrics + else get_base_spec_with_prototype_configs() + ) + configs = [] + for config_name in specs.inference_configs.config_rankings.get("overall").rankings: + jumpstart_config = specs.inference_configs.configs.get(config_name) + benchmark_metrics = jumpstart_config.benchmark_metrics + + if benchmark_metrics: + for instance_type in benchmark_metrics: + benchmark_metrics[instance_type].append( + JumpStartBenchmarkStat( + { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "3.76", + "concurrency": None, + } + ) + ) + + configs.append( + DeploymentConfigMetadata( + config_name=config_name, + metadata_config=jumpstart_config, + init_kwargs=get_mock_init_kwargs( + get_base_spec_with_prototype_configs().model_id, config_name + ), + deploy_kwargs=JumpStartModelDeployKwargs( + model_id=get_base_spec_with_prototype_configs().model_id, + ), + ) + ) + return configs + + +def get_base_deployment_configs( + omit_benchmark_metrics: bool = False, +) -> List[Dict[str, Any]]: + configs = [] + for config in get_base_deployment_configs_metadata(omit_benchmark_metrics): + config_json = config.to_json() + if config_json["BenchmarkMetrics"]: + config_json["BenchmarkMetrics"] = { + config.deployment_args.instance_type: config_json["BenchmarkMetrics"].get( + config.deployment_args.instance_type + ) + } + configs.append(config_json) + return configs + + +def append_instance_stat_metrics( + metrics: Dict[str, List[JumpStartBenchmarkStat]] +) -> Dict[str, List[JumpStartBenchmarkStat]]: + if metrics is not None: + for key in metrics: + metrics[key].append( + JumpStartBenchmarkStat( + { + "name": "Instance Rate", + "value": "3.76", + "unit": "USD/Hrs", + "concurrency": None, + } + ) + ) + return metrics diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index 2065e86818..e38317067c 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -23,6 +23,7 @@ LocalModelOutOfMemoryException, LocalModelInvocationException, ) +from tests.unit.sagemaker.serve.constants import DEPLOYMENT_CONFIGS mock_model_id = "huggingface-llm-amazon-falconlite" mock_t5_model_id = "google/flan-t5-xxl" @@ -724,3 +725,239 @@ def test_js_gated_model_ex( ValueError, lambda: builder.build(), ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_list_deployment_configs( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.side_effect = ( + lambda: DEPLOYMENT_CONFIGS + ) + + configs = builder.list_deployment_configs() + + self.assertEqual(configs, DEPLOYMENT_CONFIGS) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_get_deployment_config( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + + expected = DEPLOYMENT_CONFIGS[0] + mock_pre_trained_model.return_value.deployment_config = expected + + self.assertEqual(builder.get_deployment_config(), expected) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_set_deployment_config( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + + builder.build() + builder.set_deployment_config("config-1", "ml.g5.24xlarge") + + mock_pre_trained_model.return_value.set_deployment_config.assert_called_with( + "config-1", "ml.g5.24xlarge" + ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_set_deployment_config_ex( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + + self.assertRaisesRegex( + Exception, + "Cannot set deployment config to an uninitialized model.", + lambda: ModelBuilder( + model="facebook/galactica-mock-model-id", schema_builder=mock_schema_builder + ).set_deployment_config("config-2", "ml.g5.24xlarge"), + ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_display_benchmark_metrics( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.side_effect = ( + lambda: DEPLOYMENT_CONFIGS + ) + + builder.list_deployment_configs() + + builder.display_benchmark_metrics() + + mock_pre_trained_model.return_value.display_benchmark_metrics.assert_called_once() + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_display_benchmark_metrics_initial( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.side_effect = ( + lambda: DEPLOYMENT_CONFIGS + ) + + builder.display_benchmark_metrics() + + mock_pre_trained_model.return_value.display_benchmark_metrics.assert_called_once() diff --git a/tests/unit/sagemaker/serve/constants.py b/tests/unit/sagemaker/serve/constants.py index db9dd623d8..5c40c1bf64 100644 --- a/tests/unit/sagemaker/serve/constants.py +++ b/tests/unit/sagemaker/serve/constants.py @@ -15,3 +15,153 @@ MOCK_IMAGE_CONFIG = {"RepositoryAccessMode": "Vpc"} MOCK_VPC_CONFIG = {"Subnets": ["subnet-1234"], "SecurityGroupIds": ["sg123"]} +DEPLOYMENT_CONFIGS = [ + { + "ConfigName": "neuron-inference", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentConfig": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, + { + "ConfigName": "neuron-inference-budget", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentConfig": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, + { + "ConfigName": "gpu-inference-budget", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentConfig": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, + { + "ConfigName": "gpu-inference", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentConfig": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, +] diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index d81984e81f..d5214d01c3 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -51,6 +51,8 @@ _is_bad_link, custom_extractall_tarfile, can_model_package_source_uri_autopopulate, + get_instance_rate_per_hour, + extract_instance_rate_per_hour, _resolve_routing_config, ) from tests.unit.sagemaker.workflow.helpers import CustomStep @@ -1819,7 +1821,13 @@ def test_can_model_package_source_uri_autopopulate(): class TestDeepMergeDict(TestCase): def test_flatten_dict_basic(self): nested_dict = {"a": 1, "b": {"x": 2, "y": {"p": 3, "q": 4}}, "c": 5} - flattened_dict = {"a": 1, "b.x": 2, "b.y.p": 3, "b.y.q": 4, "c": 5} + flattened_dict = { + ("a",): 1, + ("b", "x"): 2, + ("b", "y", "p"): 3, + ("b", "y", "q"): 4, + ("c",): 5, + } self.assertDictEqual(flatten_dict(nested_dict), flattened_dict) self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict) @@ -1831,13 +1839,19 @@ def test_flatten_dict_empty(self): def test_flatten_dict_no_nested(self): nested_dict = {"a": 1, "b": 2, "c": 3} - flattened_dict = {"a": 1, "b": 2, "c": 3} + flattened_dict = {("a",): 1, ("b",): 2, ("c",): 3} self.assertDictEqual(flatten_dict(nested_dict), flattened_dict) self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict) def test_flatten_dict_with_various_types(self): nested_dict = {"a": [1, 2, 3], "b": {"x": None, "y": {"p": [], "q": ""}}, "c": 9} - flattened_dict = {"a": [1, 2, 3], "b.x": None, "b.y.p": [], "b.y.q": "", "c": 9} + flattened_dict = { + ("a",): [1, 2, 3], + ("b", "x"): None, + ("b", "y", "p"): [], + ("b", "y", "q"): "", + ("c",): 9, + } self.assertDictEqual(flatten_dict(nested_dict), flattened_dict) self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict) @@ -1870,6 +1884,140 @@ def test_deep_override_skip_keys(self): self.assertEqual(deep_override_dict(dict1, dict2, skip_keys=["c", "d"]), expected_result) +@pytest.mark.parametrize( + "instance, region, amazon_sagemaker_price_result, expected", + [ + ( + "ml.t4g.nano", + "us-west-2", + { + "PriceList": [ + { + "terms": { + "OnDemand": { + "3WK7G7WSYVS3K492.JRTCKXETXF": { + "priceDimensions": { + "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7": { + "unit": "Hrs", + "endRange": "Inf", + "description": "$0.9 per Unused Reservation Linux p2.xlarge Instance Hour", + "appliesTo": [], + "rateCode": "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7", + "beginRange": "0", + "pricePerUnit": {"USD": "0.9000000000"}, + } + } + } + } + }, + } + ] + }, + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.9"}, + ), + ( + "ml.t4g.nano", + "eu-central-1", + { + "PriceList": [ + '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": ' + '"$0.0083 per' + "On" + 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": ' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": ' + '"0.0083000000"}}},' + '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF",' + '"termAttributes": {}}}}}' + ] + }, + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.008"}, + ), + ( + "ml.t4g.nano", + "af-south-1", + { + "PriceList": [ + '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": ' + '"$0.0083 per' + "On" + 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": ' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": ' + '"0.0083000000"}}},' + '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF",' + '"termAttributes": {}}}}}' + ] + }, + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.008"}, + ), + ( + "ml.t4g.nano", + "ap-northeast-2", + { + "PriceList": [ + '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": ' + '"$0.0083 per' + "On" + 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": ' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": ' + '"0.0083000000"}}},' + '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF",' + '"termAttributes": {}}}}}' + ] + }, + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.008"}, + ), + ], +) +@patch("boto3.client") +def test_get_instance_rate_per_hour( + mock_client, instance, region, amazon_sagemaker_price_result, expected +): + + mock_client.return_value.get_products.side_effect = ( + lambda *args, **kwargs: amazon_sagemaker_price_result + ) + instance_rate = get_instance_rate_per_hour(instance_type=instance, region=region) + + assert instance_rate == expected + + +@pytest.mark.parametrize( + "price_data, expected_result", + [ + (None, None), + ( + { + "terms": { + "OnDemand": { + "3WK7G7WSYVS3K492.JRTCKXETXF": { + "priceDimensions": { + "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7": { + "unit": "Hrs", + "endRange": "Inf", + "description": "$0.9 per Unused Reservation Linux p2.xlarge Instance Hour", + "appliesTo": [], + "rateCode": "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7", + "beginRange": "0", + "pricePerUnit": {"USD": "0.9000000000"}, + } + } + } + } + } + }, + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.9"}, + ), + ], +) +def test_extract_instance_rate_per_hour(price_data, expected_result): + out = extract_instance_rate_per_hour(price_data) + + assert out == expected_result + + @pytest.mark.parametrize( "routing_config, expected", [ From 73bf4397f03a1228b7c002c9745bb515dfd43fa0 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga <54963715+makungaj1@users.noreply.github.com> Date: Mon, 10 Jun 2024 08:30:24 -0700 Subject: [PATCH 02/45] feat: Model class to support AdditionalModelDataSources (#1469) * Add support for AdditionalModelDataSources * Resolve PR comments * Resolve PR comments * Resolve PR comments * fix unit tests * Resolve PR comments --------- Co-authored-by: Jonathan Makunga --- src/sagemaker/jumpstart/model.py | 8 ++++++++ src/sagemaker/model.py | 5 +++++ src/sagemaker/session.py | 6 ++++++ .../unit/sagemaker/jumpstart/estimator/test_estimator.py | 1 + tests/unit/sagemaker/jumpstart/model/test_model.py | 2 +- 5 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index f72a3140dc..e7846f396f 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -369,6 +369,14 @@ def _validate_model_id_and_type(): model_type=self.model_type, ) + self.additional_model_data_sources = ( + self._metadata_configs.get(self.config_name).resolved_config.get( + "hosting_additional_data_sources" + ) + if self._metadata_configs.get(self.config_name) + else None + ) + def log_subscription_warning(self) -> None: """Log message prompting the customer to subscribe to the proprietary model.""" subscription_link = verify_model_region_and_return_specs( diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 1bb6cb2e5c..9c54738ff4 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -160,6 +160,7 @@ def __init__( dependencies: Optional[List[str]] = None, git_config: Optional[Dict[str, str]] = None, resources: Optional[ResourceRequirements] = None, + additional_model_data_sources: Optional[Dict[str, Any]] = None, ): """Initialize an SageMaker ``Model``. @@ -323,9 +324,12 @@ def __init__( for a model to be deployed to an endpoint. Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature. (Default: None). + additional_model_data_sources (Optional[Dict[str, Any]]): Additional location + of SageMaker model data (default: None). """ self.model_data = model_data + self.additional_model_data_sources = additional_model_data_sources self.image_uri = image_uri self.predictor_cls = predictor_cls self.name = name @@ -671,6 +675,7 @@ def prepare_container_def( accept_eula=( accept_eula if accept_eula is not None else getattr(self, "accept_eula", None) ), + additional_model_data_sources=self.additional_model_data_sources, ) def is_repack(self) -> bool: diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index bf2a736871..c33104fa95 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -7137,6 +7137,7 @@ def container_def( container_mode=None, image_config=None, accept_eula=None, + additional_model_data_sources=None, ): """Create a definition for executing a container as part of a SageMaker model. @@ -7159,6 +7160,8 @@ def container_def( The `accept_eula` value must be explicitly defined as `True` in order to accept the end-user license agreement (EULA) that some models require. (Default: None). + additional_model_data_sources (PipelineVariable or dict): Additional location + of SageMaker model data (default: None). Returns: dict[str, str]: A complete container definition object usable with the CreateModel API if @@ -7168,6 +7171,9 @@ def container_def( env = {} c_def = {"Image": image_uri, "Environment": env} + if additional_model_data_sources: + c_def["AdditionalModelDataSources"] = additional_model_data_sources + if isinstance(model_data_url, str) and ( not (model_data_url.startswith("s3://") and model_data_url.endswith("tar.gz")) or accept_eula is None diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 17d0861bff..4347d23a35 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1168,6 +1168,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): "inference_config_name" } == model_class_init_args - { "model_data", + "additional_model_data_sources", "self", "name", "resources", diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 25e01d5d10..fa858d0ac1 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -759,7 +759,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): Please add the new argument to the skip set below, and reach out to JumpStart team.""" - init_args_to_skip: Set[str] = set([]) + init_args_to_skip: Set[str] = set(["additional_model_data_sources"]) deploy_args_to_skip: Set[str] = set(["kwargs"]) parent_class_init = Model.__init__ From c4529e3b72bcef3bb074a51f41d418c43901c623 Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Mon, 10 Jun 2024 13:53:52 -0400 Subject: [PATCH 03/45] feat: additional hosting model data source parsing (#1467) * feat: Additional Model Data source parsing * address comments * address comments * format --- src/sagemaker/jumpstart/types.py | 235 +++++++++++++++++++ src/sagemaker/utils.py | 1 + tests/unit/sagemaker/jumpstart/constants.py | 38 +++ tests/unit/sagemaker/jumpstart/test_types.py | 78 ++++++ tests/unit/sagemaker/jumpstart/test_utils.py | 8 + tests/unit/test_utils.py | 12 + 6 files changed, 372 insertions(+) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index f197421d65..9751015f4b 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -743,6 +743,235 @@ def _get_regional_property( return alias_value +class ModelAccessConfig(JumpStartDataHolderType): + """Data class of model access config that mirrors CreateModel API.""" + + __slots__ = ["accept_eula"] + + def __init__(self, spec: Dict[str, Any]): + """Initializes a ModelAccessConfig object. + + Args: + spec (Dict[str, Any]): Dictionary representation of data source. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of data source. + """ + self.accept_eula: bool = json_obj["accept_eula"] + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of ModelAccessConfig object.""" + json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + return json_obj + + +class HubAccessConfig(JumpStartDataHolderType): + """Data class of model access config that mirrors CreateModel API.""" + + __slots__ = ["hub_content_arn"] + + def __init__(self, spec: Dict[str, Any]): + """Initializes a HubAccessConfig object. + + Args: + spec (Dict[str, Any]): Dictionary representation of data source. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of data source. + """ + self.hub_content_arn: bool = json_obj["accept_eula"] + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of ModelAccessConfig object.""" + json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + return json_obj + + +class S3DataSource(JumpStartDataHolderType): + """Data class of S3 data source that mirrors CreateModel API.""" + + __slots__ = [ + "compression_type", + "s3_data_type", + "s3_uri", + "model_access_config", + "hub_access_config", + ] + + def __init__(self, spec: Dict[str, Any]): + """Initializes a S3DataSource object. + + Args: + spec (Dict[str, Any]): Dictionary representation of data source. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of data source. + """ + self.compression_type: str = json_obj["compression_type"] + self.s3_data_type: str = json_obj["s3_data_type"] + self.s3_uri: str = json_obj["s3_uri"] + self.model_access_config: ModelAccessConfig = ( + ModelAccessConfig(json_obj["model_access_config"]) + if json_obj.get("model_access_config") + else None + ) + self.hub_access_config: HubAccessConfig = ( + HubAccessConfig(json_obj["hub_access_config"]) + if json_obj.get("hub_access_config") + else None + ) + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of S3DataSource object.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + else: + json_obj[att] = cur_val + return json_obj + + +class AdditionalModelDataSource(JumpStartDataHolderType): + """Data class of additional model data source mirrors Hosting API.""" + + __slots__ = ["channel_name", "s3_data_source"] + + def __init__(self, spec: Dict[str, Any]): + """Initializes a AdditionalModelDataSource object. + + Args: + spec (Dict[str, Any]): Dictionary representation of data source. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of data source. + """ + self.channel_name: str = json_obj["channel_name"] + self.s3_data_source: S3DataSource = S3DataSource(json_obj["s3_data_source"]) + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of AdditionalModelDataSource object.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + else: + json_obj[att] = cur_val + return json_obj + + +class JumpStartModelDataSource(JumpStartDataHolderType): + """Data class JumpStart additional model data source.""" + + __slots__ = ["version", "additional_model_data_source"] + + def __init__(self, spec: Dict[str, Any]): + """Initializes a JumpStartModelDataSource object. + + Args: + spec (Dict[str, Any]): Dictionary representation of data source. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of data source. + """ + self.version: str = json_obj["artifact_version"] + self.additional_model_data_source: AdditionalModelDataSource = AdditionalModelDataSource( + json_obj + ) + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of JumpStartModelDataSource object.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + else: + json_obj[att] = cur_val + return json_obj + + +class JumpStartAdditionalDataSources(JumpStartDataHolderType): + """Data class of additional data sources.""" + + __slots__ = ["speculative_decoding", "scripts"] + + def __init__(self, spec: Dict[str, Any]): + """Initializes a AdditionalDataSources object. + + Args: + spec (Dict[str, Any]): Dictionary representation of data source. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of data source. + """ + self.speculative_decoding: Optional[List[JumpStartModelDataSource]] = ( + [ + JumpStartModelDataSource(data_source) + for data_source in json_obj["speculative_decoding"] + ] + if json_obj.get("speculative_decoding") + else None + ) + self.scripts: Optional[List[JumpStartModelDataSource]] = ( + [JumpStartModelDataSource(data_source) for data_source in json_obj["scripts"]] + if json_obj.get("scripts") + else None + ) + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of AdditionalDataSources object.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + if isinstance(cur_val, list): + json_obj[att] = [] + for obj in cur_val: + if issubclass(type(obj), JumpStartDataHolderType): + json_obj[att].append(obj.to_json()) + else: + json_obj[att].append(obj) + else: + json_obj[att] = cur_val + return json_obj + + class JumpStartBenchmarkStat(JumpStartDataHolderType): """Data class JumpStart benchmark stat.""" @@ -857,6 +1086,7 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "default_payloads", "gated_bucket", "model_subscription_link", + "hosting_additional_data_sources", ] def __init__(self, fields: Dict[str, Any]): @@ -962,6 +1192,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if json_obj.get("hosting_instance_type_variants") else None ) + self.hosting_additional_data_sources: Optional[JumpStartAdditionalDataSources] = ( + JumpStartAdditionalDataSources(json_obj["hosting_additional_data_sources"]) + if json_obj.get("hosting_additional_data_sources") + else None + ) if self.training_supported: self.training_ecr_specs: Optional[JumpStartECRSpecs] = ( diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index a70ba9eb98..5e0f47a406 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1690,6 +1690,7 @@ def deep_override_dict( skip_keys = [] flattened_dict1 = flatten_dict(dict1) + flattened_dict1 = {key: value for key, value in flattened_dict1.items() if value is not None} flattened_dict2 = flatten_dict( {key: value for key, value in dict2.items() if key not in skip_keys} ) diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index fb7ca38bad..9e34a862f8 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -7515,6 +7515,7 @@ "training_config_components": None, "inference_config_rankings": None, "training_config_rankings": None, + "hosting_additional_data_sources": None, } BASE_HEADER = { @@ -7700,6 +7701,14 @@ }, "component_names": ["gpu-inference-model-package"], }, + "gpu-accelerated": { + "benchmark_metrics": { + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] + }, + "component_names": ["gpu-accelerated"], + }, }, "inference_config_components": { "neuron-base": { @@ -7765,6 +7774,34 @@ }, }, }, + "gpu-accelerated": { + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, + "hosting_additional_data_sources": { + "speculative_decoding": [ + { + "channel_name": "draft_model_name", + "artifact_version": "1.2.1", + "s3_data_source": { + "compression_type": "None", + "model_access_config": {"accept_eula": False}, + "s3_data_type": "S3Prefix", + "s3_uri": "key/to/draft/model/artifact/", + }, + } + ], + }, + }, }, } @@ -7907,6 +7944,7 @@ "neuron-inference-budget", "gpu-inference", "gpu-inference-budget", + "gpu-accelerated", ], }, "performance": { diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 23fa42c09a..37161ed4f6 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -930,6 +930,7 @@ def test_inference_configs_parsing(): "gpu-inference", "gpu-inference-model-package", "gpu-inference-budget", + "gpu-accelerated", ] # Non-overrided fields in top config @@ -1167,6 +1168,7 @@ def test_set_inference_configs(): "gpu-inference", "gpu-inference-model-package", "gpu-inference-budget", + "gpu-accelerated", ] with pytest.raises(ValueError) as error: @@ -1321,6 +1323,82 @@ def test_training_configs_parsing(): assert list(config.config_components.keys()) == ["neuron-training"] +def test_additional_model_data_source_parsing(): + accelerated_first_rankings = { + "inference_config_rankings": { + "overall": { + "description": "Overall rankings of configs", + "rankings": [ + "gpu-accelerated", + "neuron-inference", + "neuron-inference-budget", + "gpu-inference", + "gpu-inference-budget", + ], + } + } + } + spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **accelerated_first_rankings} + specs1 = JumpStartModelSpecs(spec) + + config = specs1.inference_configs.get_top_config_from_ranking() + + assert config.benchmark_metrics == { + "ml.p3.2xlarge": [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), + ] + } + assert len(config.config_components) == 1 + assert config.config_components["gpu-accelerated"] == JumpStartConfigComponent( + "gpu-accelerated", + { + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, + "hosting_additional_data_sources": { + "speculative_decoding": [ + { + "channel_name": "draft_model_name", + "artifact_version": "1.2.1", + "s3_data_source": { + "compression_type": "None", + "model_access_config": {"accept_eula": False}, + "s3_data_type": "S3Prefix", + "s3_uri": "key/to/draft/model/artifact/", + }, + } + ], + }, + }, + ) + assert list(config.config_components.keys()) == ["gpu-accelerated"] + assert config.resolved_config["hosting_additional_data_sources"] == { + "speculative_decoding": [ + { + "channel_name": "draft_model_name", + "artifact_version": "1.2.1", + "s3_data_source": { + "compression_type": "None", + "model_access_config": {"accept_eula": False}, + "s3_data_type": "S3Prefix", + "s3_uri": "key/to/draft/model/artifact/", + }, + } + ], + } + + def test_set_inference_config(): spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} specs1 = JumpStartModelSpecs(spec) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index a5a063c696..f0bb5a6219 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1640,6 +1640,7 @@ def test_get_jumpstart_config_names_success( "gpu-inference-budget", "gpu-inference", "gpu-inference-model-package", + "gpu-accelerated", ] @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1743,6 +1744,13 @@ def test_get_jumpstart_benchmark_stats_full_list( ) ] }, + "gpu-accelerated": { + "ml.p3.2xlarge": [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) + ] + }, } @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index d5214d01c3..b04593bfae 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1873,6 +1873,18 @@ def test_deep_override_nested_lists(self): expected_merged = {"a": [5], "b": {"c": [6, 7], "d": [8]}} self.assertDictEqual(deep_override_dict(dict1, dict2), expected_merged) + def test_deep_override_nested_lists_overriding_none(self): + dict1 = {"a": [{"c": "d"}, {"e": "f"}], "t": None} + dict2 = { + "a": [{"1": "2"}, {"3": "4"}, {"5": "6"}, "7"], + "t": {"g": [{"1": "2"}, {"3": "4"}, {"5": "6"}, "7"]}, + } + expected_merged = { + "a": [{"1": "2"}, {"3": "4"}, {"5": "6"}, "7"], + "t": {"g": [{"1": "2"}, {"3": "4"}, {"5": "6"}, "7"]}, + } + self.assertDictEqual(deep_override_dict(dict1, dict2), expected_merged) + def test_deep_override_skip_keys(self): dict1 = {"a": 1, "b": {"x": 2, "y": 3}, "c": [4, 5]} dict2 = { From 015120998d794bce04a19a754c9ee8119578c08f Mon Sep 17 00:00:00 2001 From: Jacky Lee Date: Tue, 11 Jun 2024 09:54:53 -0700 Subject: [PATCH 04/45] Add optimize to ModelBuilder (#1468) * Add optimize to ModelBuilder * Add polling for job completion * fix UTs --------- Co-authored-by: Jacky Lee --- src/sagemaker/serve/builder/model_builder.py | 142 +++++++++++++++++- src/sagemaker/serve/utils/optimize_utils.py | 58 +++++++ src/sagemaker/serve/utils/telemetry_logger.py | 25 +-- .../serve/builder/test_model_builder.py | 83 +++++++++- 4 files changed, 296 insertions(+), 12 deletions(-) create mode 100644 src/sagemaker/serve/utils/optimize_utils.py diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 44bc46b00b..297d276033 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -62,6 +62,10 @@ from sagemaker.serve.utils import task from sagemaker.serve.utils.exceptions import TaskNotFoundException from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model +from sagemaker.serve.utils.optimize_utils import ( + _is_compatible_with_compilation, + _poll_optimization_job, +) from sagemaker.serve.utils.predictors import _get_local_mode_predictor from sagemaker.serve.utils.hardware_detector import ( _get_gpu_info, @@ -83,6 +87,7 @@ from sagemaker.serve.validations.check_image_and_hardware_type import ( validate_image_uri_and_hardware, ) +from sagemaker.utils import Tags from sagemaker.workflow.entities import PipelineVariable from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata @@ -804,8 +809,15 @@ def save( This function is available for models served by DJL serving. Args: - save_path (Optional[str]): The path where you want to save resources. - s3_path (Optional[str]): The path where you want to upload resources. + save_path (Optional[str]): The path where you want to save resources. Defaults to + ``None``. + s3_path (Optional[str]): The path where you want to upload resources. Defaults to + ``None``. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. Defaults to + ``None``. + role_arn (Optional[str]): The IAM role arn. Defaults to ``None``. """ self.sagemaker_session = sagemaker_session or Session() @@ -915,3 +927,129 @@ def _try_fetch_gpu_info(self): raise ValueError( f"Unable to determine single GPU size for instance: [{self.instance_type}]" ) + + def optimize(self, *args, **kwargs) -> Type[Model]: + """Runs a model optimization job. + + Args: + instance_type (str): Target deployment instance type that the model is optimized for. + output_path (str): Specifies where to store the compiled/quantized model. + role (Optional[str]): Execution role. Defaults to ``None``. + tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. + job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. + quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``. + compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. + env_vars (Optional[Dict]): Additional environment variables to run the optimization + container. Defaults to ``None``. + vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. + kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading + to S3. Defaults to ``None``. + max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to + ``None``. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Type[Model]: A deployable ``Model`` object. + """ + # need to get telemetry_opt_out info before telemetry decorator is called + self.serve_settings = self._get_serve_setting() + + return self._model_builder_optimize_wrapper(*args, **kwargs) + + @_capture_telemetry("optimize") + def _model_builder_optimize_wrapper( + self, + instance_type: str, + output_path: str, + role: Optional[str] = None, + tags: Optional[Tags] = None, + job_name: Optional[str] = None, + quantization_config: Optional[Dict] = None, + compilation_config: Optional[Dict] = None, + env_vars: Optional[Dict] = None, + vpc_config: Optional[Dict] = None, + kms_key: Optional[str] = None, + max_runtime_in_sec: Optional[int] = None, + sagemaker_session: Optional[Session] = None, + ) -> Type[Model]: + """Runs a model optimization job. + + Args: + instance_type (str): Target deployment instance type that the model is optimized for. + output_path (str): Specifies where to store the compiled/quantized model. + role (Optional[str]): Execution role. Defaults to ``None``. + tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. + job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. + quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``. + compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. + env_vars (Optional[Dict]): Additional environment variables to run the optimization + container. Defaults to ``None``. + vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. + kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading + to S3. Defaults to ``None``. + max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to + ``None``. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Type[Model]: A deployable ``Model`` object. + """ + self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() + + # TODO: inject actual model source location based on different scenarios + model_source = {"S3": {"S3Uri": self.model_path, "ModelAccessConfig": {"AcceptEula": True}}} + + optimization_configs = [] + if quantization_config: + optimization_configs.append({"ModelQuantizationConfig": quantization_config}) + if compilation_config: + if _is_compatible_with_compilation(instance_type): + optimization_configs.append({"ModelCompilationConfig": compilation_config}) + else: + logger.warning( + "Model compilation is currently only supported for Inferentia and Trainium" + "instances, ignoring `compilation_config'." + ) + + output_config = {"S3OutputLocation": output_path} + if kms_key: + output_config["KmsKeyId"] = kms_key + + job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}" + create_optimization_job_args = { + "OptimizationJobName": job_name, + "ModelSource": model_source, + "DeploymentInstanceType": instance_type, + "OptimizationConfigs": optimization_configs, + "OutputConfig": output_config, + "RoleArn": role or self.role_arn, + } + + if env_vars: + create_optimization_job_args["OptimizationEnvironment"] = env_vars + + if max_runtime_in_sec: + create_optimization_job_args["StoppingCondition"] = { + "MaxRuntimeInSeconds": max_runtime_in_sec + } + + # TODO: tag injection if it is a JumpStart model + if tags: + create_optimization_job_args["Tags"] = tags + + if vpc_config: + create_optimization_job_args["VpcConfig"] = vpc_config + + response = self.sagemaker_session.sagemaker_client.create_optimization_job( + **create_optimization_job_args + ) + + if not _poll_optimization_job(job_name, self.sagemaker_session): + raise Exception("Optimization job timed out.") + + # TODO: return model created by optimization job + return response diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py new file mode 100644 index 0000000000..32395c1478 --- /dev/null +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -0,0 +1,58 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Holds the util functions used for the optimize function""" +from __future__ import absolute_import + +import time +import logging + +from sagemaker import Session + +# TODO: determine how long optimization jobs take +OPTIMIZE_POLLER_MAX_TIMEOUT_SECS = 300 +OPTIMIZE_POLLER_INTERVAL_SECS = 30 + +logger = logging.getLogger(__name__) + + +def _is_compatible_with_compilation(instance_type: str) -> bool: + """Checks whether an instance is compatible with compilation. + + Args: + instance_type (str): The instance type used for the compilation job. + + Returns: + bool: Whether the given instance type is compatible with compilation. + """ + return instance_type.startswith("ml.inf") or instance_type.startswith("ml.trn") + + +def _poll_optimization_job(job_name: str, sagemaker_session: Session) -> bool: + """Polls optimization job status until success. + + Args: + job_name (str): The name of the optimization job. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. + + Returns: + bool: Whether the optimization job was successful. + """ + logger.info("Polling status of optimization job %s", job_name) + start_time = time.time() + while time.time() - start_time < OPTIMIZE_POLLER_MAX_TIMEOUT_SECS: + result = sagemaker_session.sagemaker_client.describe_optimization_job(job_name) + # TODO: use correct condition to determine whether optimization job is complete + if result is not None: + return result + time.sleep(OPTIMIZE_POLLER_INTERVAL_SECS) diff --git a/src/sagemaker/serve/utils/telemetry_logger.py b/src/sagemaker/serve/utils/telemetry_logger.py index 99aeb4ff26..fe99e787a0 100644 --- a/src/sagemaker/serve/utils/telemetry_logger.py +++ b/src/sagemaker/serve/utils/telemetry_logger.py @@ -80,15 +80,22 @@ def wrapper(self, *args, **kwargs): response = None caught_ex = None - image_uri_tail = self.image_uri.split("/")[1] - image_uri_option = _get_image_uri_option(self.image_uri, self._is_custom_image_uri) - extra = ( - f"{func_name}" - f"&x-modelServer={MODEL_SERVER_TO_CODE[str(self.model_server)]}" - f"&x-imageTag={image_uri_tail}" - f"&x-sdkVersion={SDK_VERSION}" - f"&x-defaultImageUsage={image_uri_option}" - ) + extra = f"{func_name}" + + if self.model_server: + extra += f"&x-modelServer={MODEL_SERVER_TO_CODE[str(self.model_server)]}" + + if self.image_uri: + image_uri_tail = self.image_uri.split("/")[1] + image_uri_option = _get_image_uri_option(self.image_uri, self._is_custom_image_uri) + + if self.image_uri: + extra += f"&x-imageTag={image_uri_tail}" + + extra += f"&x-sdkVersion={SDK_VERSION}" + + if self.image_uri: + extra += f"&x-defaultImageUsage={image_uri_option}" if self.model_server == ModelServer.DJL_SERVING or self.model_server == ModelServer.TGI: extra += f"&x-modelName={self.model}" diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 0c06b5ae8e..ed7c736633 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -44,7 +44,7 @@ mock_image_uri = "abcd/efghijk" mock_1p_dlc_image_uri = "763104351884.dkr.ecr.us-east-1.amazonaws.com" -mock_role_arn = "sample role arn" +mock_role_arn = "arn:aws:iam::123456789012:role/SageMakerRole" mock_s3_model_data_url = "sample s3 data url" mock_secret_key = "mock_secret_key" mock_instance_type = "mock instance type" @@ -2257,3 +2257,84 @@ def test_build_tensorflow_serving_non_mlflow_case( mock_role_arn, mock_session, ) + + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + def test_optimize(self, mock_send_telemetry, mock_get_serve_setting): + mock_sagemaker_session = Mock() + + mock_settings = Mock() + mock_settings.telemetry_opt_out = False + mock_get_serve_setting.return_value = mock_settings + + builder = ModelBuilder( + model_path=MODEL_PATH, + schema_builder=schema_builder, + model=mock_fw_model, + sagemaker_session=mock_sagemaker_session, + ) + + job_name = "my-optimization-job" + instance_type = "ml.inf1.xlarge" + output_path = "s3://my-bucket/output" + quantization_config = { + "Image": "quantization-image-uri", + "OverrideEnvironment": {"ENV_VAR": "value"}, + } + compilation_config = { + "Image": "compilation-image-uri", + "OverrideEnvironment": {"ENV_VAR": "value"}, + } + env_vars = {"Var1": "value", "Var2": "value"} + kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id" + max_runtime_in_sec = 3600 + tags = [ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ] + vpc_config = { + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + } + + expected_create_optimization_job_args = { + "ModelSource": {"S3": {"S3Uri": MODEL_PATH, "ModelAccessConfig": {"AcceptEula": True}}}, + "DeploymentInstanceType": instance_type, + "OptimizationEnvironment": env_vars, + "OptimizationConfigs": [ + {"ModelQuantizationConfig": quantization_config}, + {"ModelCompilationConfig": compilation_config}, + ], + "OutputConfig": {"S3OutputLocation": output_path, "KmsKeyId": kms_key}, + "RoleArn": mock_role_arn, + "OptimizationJobName": job_name, + "StoppingCondition": {"MaxRuntimeInSeconds": max_runtime_in_sec}, + "Tags": [ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ], + "VpcConfig": vpc_config, + } + + mock_sagemaker_session.sagemaker_client.create_optimization_job.return_value = { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job" + } + + builder.optimize( + instance_type=instance_type, + output_path=output_path, + role=mock_role_arn, + job_name=job_name, + quantization_config=quantization_config, + compilation_config=compilation_config, + env_vars=env_vars, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + tags=tags, + vpc_config=vpc_config, + ) + + mock_send_telemetry.assert_called_once() + mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( + **expected_create_optimization_job_args + ) From 9a410e5f1edc5e8f03b8ef45566695c150ca81d9 Mon Sep 17 00:00:00 2001 From: Adam Kozdrowicz Date: Tue, 11 Jun 2024 14:18:40 -0400 Subject: [PATCH 05/45] feat: Added utils for extracting JS data sources (#1471) * added utils for accessing hosting data sources * added utils for accessing hosting data sources * removed other changes * fixed formatting issues * remove .keys() * updated JumpStartModelDataSource * fix slots * remove print * fix tests * update tests --- src/sagemaker/jumpstart/types.py | 43 ++++++++------------ tests/unit/sagemaker/jumpstart/constants.py | 31 ++++++++++++++ tests/unit/sagemaker/jumpstart/test_types.py | 18 ++++++++ 3 files changed, 66 insertions(+), 26 deletions(-) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 9751015f4b..ef203e4014 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -884,18 +884,10 @@ def to_json(self) -> Dict[str, Any]: return json_obj -class JumpStartModelDataSource(JumpStartDataHolderType): +class JumpStartModelDataSource(AdditionalModelDataSource): """Data class JumpStart additional model data source.""" - __slots__ = ["version", "additional_model_data_source"] - - def __init__(self, spec: Dict[str, Any]): - """Initializes a JumpStartModelDataSource object. - - Args: - spec (Dict[str, Any]): Dictionary representation of data source. - """ - self.from_json(spec) + __slots__ = ["artifact_version"] + AdditionalModelDataSource.__slots__ def from_json(self, json_obj: Dict[str, Any]) -> None: """Sets fields in object based on json. @@ -903,22 +895,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Args: json_obj (Dict[str, Any]): Dictionary representation of data source. """ - self.version: str = json_obj["artifact_version"] - self.additional_model_data_source: AdditionalModelDataSource = AdditionalModelDataSource( - json_obj - ) - - def to_json(self) -> Dict[str, Any]: - """Returns json representation of JumpStartModelDataSource object.""" - json_obj = {} - for att in self.__slots__: - if hasattr(self, att): - cur_val = getattr(self, att) - if issubclass(type(cur_val), JumpStartDataHolderType): - json_obj[att] = cur_val.to_json() - else: - json_obj[att] = cur_val - return json_obj + super().from_json(json_obj) + self.artifact_version: str = json_obj["artifact_version"] class JumpStartAdditionalDataSources(JumpStartDataHolderType): @@ -1655,6 +1633,19 @@ def supports_incremental_training(self) -> bool: """Returns True if the model supports incremental training.""" return self.incremental_training_supported + def get_speculative_decoding_s3_data_sources(self) -> List[JumpStartAdditionalDataSources]: + """Returns data sources for speculative decoding.""" + return self.hosting_additional_data_sources.speculative_decoding or [] + + def get_additional_s3_data_sources(self) -> List[JumpStartAdditionalDataSources]: + """Returns a list of the additional S3 data sources for use by the model.""" + additional_data_sources = [] + if self.hosting_additional_data_sources: + for data_source in self.hosting_additional_data_sources.to_json(): + data_sources = getattr(self.hosting_additional_data_sources, data_source) or [] + additional_data_sources.extend(data_sources) + return additional_data_sources + class JumpStartVersionedModelId(JumpStartDataHolderType): """Data class for versioned model IDs.""" diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 9e34a862f8..ee6e320659 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -7518,6 +7518,37 @@ "hosting_additional_data_sources": None, } +BASE_HOSTING_ADDITIONAL_DATA_SOURCES = { + "hosting_additional_data_sources": { + "speculative_decoding": [ + { + "channel_name": "speculative_decoding_channel", + "artifact_version": "version", + "s3_data_source": { + "compression_type": "None", + "s3_data_type": "S3Prefix", + "s3_uri": "s3://bucket/path1", + "hub_access_config": None, + "model_access_config": None, + }, + } + ], + "scripts": [ + { + "channel_name": "scripts_channel", + "artifact_version": "version", + "s3_data_source": { + "compression_type": "None", + "s3_data_type": "S3Prefix", + "s3_uri": "s3://bucket/path1", + "hub_access_config": None, + "model_access_config": None, + }, + } + ], + }, +} + BASE_HEADER = { "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", "version": "1.0.0", diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 37161ed4f6..d325cc4781 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -28,6 +28,7 @@ ) from tests.unit.sagemaker.jumpstart.constants import ( BASE_SPEC, + BASE_HOSTING_ADDITIONAL_DATA_SOURCES, INFERENCE_CONFIG_RANKINGS, INFERENCE_CONFIGS, TRAINING_CONFIG_RANKINGS, @@ -436,6 +437,23 @@ def test_jumpstart_model_specs(): assert specs3 == specs1 +def test_get_speculative_decoding_s3_data_sources(): + specs = JumpStartModelSpecs({**BASE_SPEC, **BASE_HOSTING_ADDITIONAL_DATA_SOURCES}) + assert ( + specs.get_speculative_decoding_s3_data_sources() + == specs.hosting_additional_data_sources.speculative_decoding + ) + + +def test_get_additional_s3_data_sources(): + specs = JumpStartModelSpecs({**BASE_SPEC, **BASE_HOSTING_ADDITIONAL_DATA_SOURCES}) + data_sources = [ + *specs.hosting_additional_data_sources.speculative_decoding, + *specs.hosting_additional_data_sources.scripts, + ] + assert specs.get_additional_s3_data_sources() == data_sources + + def test_jumpstart_image_uri_instance_variants(): assert ( From 2331dec993b077c7e77e9462e7945bdf4999fe3d Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Wed, 12 Jun 2024 14:54:48 -0400 Subject: [PATCH 06/45] fix: update passing additional model data sources to API (#1472) * feat: Added utils for extracting JS data sources (#1471) * added utils for accessing hosting data sources * added utils for accessing hosting data sources * removed other changes * fixed formatting issues * remove .keys() * updated JumpStartModelDataSource * fix slots * remove print * fix tests * update tests * fix: update passing additional model data sources to API * format * format * format * format and address comments * format * format * format --------- Co-authored-by: Adam Kozdrowicz --- src/sagemaker/jumpstart/factory/model.py | 40 ++++++++++++++++++- src/sagemaker/jumpstart/model.py | 17 ++++---- src/sagemaker/jumpstart/types.py | 26 +++++++----- src/sagemaker/utils.py | 30 ++++++++++++++ .../sagemaker/jumpstart/model/test_model.py | 2 +- tests/unit/test_utils.py | 40 +++++++++++++++++++ 6 files changed, 134 insertions(+), 21 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 6cdb3d8382..14287d5fcf 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -57,7 +57,7 @@ from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig from sagemaker.session import Session -from sagemaker.utils import name_from_base, format_tags, Tags +from sagemaker.utils import camel_case_to_pascal_case, name_from_base, format_tags, Tags from sagemaker.workflow.entities import PipelineVariable from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker import resource_requirements @@ -615,6 +615,40 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta return kwargs +def _add_additional_model_data_sources_to_kwargs( + kwargs: JumpStartModelInitKwargs, +) -> JumpStartModelInitKwargs: + """Sets default additional model data sources to init kwargs""" + + specs = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + scope=JumpStartScriptScope.INFERENCE, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + config_name=kwargs.config_name, + ) + + additional_data_sources = specs.get_additional_s3_data_sources() + api_shape_additional_model_data_sources = ( + [ + camel_case_to_pascal_case(data_source.to_json()) + for data_source in additional_data_sources + ] + if specs.get_additional_s3_data_sources() + else None + ) + + kwargs.additional_model_data_sources = ( + kwargs.additional_model_data_sources or api_shape_additional_model_data_sources + ) + + return kwargs + + def _add_config_name_to_deploy_kwargs( kwargs: JumpStartModelDeployKwargs, training_config_name: Optional[str] = None ) -> JumpStartModelInitKwargs: @@ -861,6 +895,7 @@ def get_init_kwargs( disable_instance_type_logging: bool = False, resources: Optional[ResourceRequirements] = None, config_name: Optional[str] = None, + additional_model_data_sources: Optional[Dict[str, Any]] = None, ) -> JumpStartModelInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Model` object.""" @@ -893,6 +928,7 @@ def get_init_kwargs( training_instance_type=training_instance_type, resources=resources, config_name=config_name, + additional_model_data_sources=additional_model_data_sources, ) model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) @@ -925,4 +961,6 @@ def get_init_kwargs( model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_additional_model_data_sources_to_kwargs(kwargs=model_init_kwargs) + return model_init_kwargs diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index e7846f396f..ed7dbff2f1 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -102,6 +102,7 @@ def __init__( model_package_arn: Optional[str] = None, resources: Optional[ResourceRequirements] = None, config_name: Optional[str] = None, + additional_model_data_sources: Optional[Dict[str, Any]] = None, ): """Initializes a ``JumpStartModel``. @@ -287,8 +288,10 @@ def __init__( for a model to be deployed to an endpoint. Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature. (Default: None). - config_name (Optional[str]): The name of the JumpStartConfig that can be - optionally applied to the model and override corresponding fields. + config_name (Optional[str]): The name of the JumpStart config that can be + optionally applied to the model. + additional_model_data_sources (Optional[Dict[str, Any]]): Additional location + of SageMaker model data (default: None). Raises: ValueError: If the model ID is not recognized by JumpStart. """ @@ -339,6 +342,7 @@ def _validate_model_id_and_type(): model_package_arn=model_package_arn, resources=resources, config_name=config_name, + additional_model_data_sources=additional_model_data_sources, ) self.orig_predictor_cls = predictor_cls @@ -352,6 +356,7 @@ def _validate_model_id_and_type(): self.region = model_init_kwargs.region self.sagemaker_session = model_init_kwargs.sagemaker_session self.config_name = model_init_kwargs.config_name + self.additional_model_data_sources = model_init_kwargs.additional_model_data_sources if self.model_type == JumpStartModelType.PROPRIETARY: self.log_subscription_warning() @@ -369,14 +374,6 @@ def _validate_model_id_and_type(): model_type=self.model_type, ) - self.additional_model_data_sources = ( - self._metadata_configs.get(self.config_name).resolved_config.get( - "hosting_additional_data_sources" - ) - if self._metadata_configs.get(self.config_name) - else None - ) - def log_subscription_warning(self) -> None: """Log message prompting the customer to subscribe to the proprietary model.""" subscription_link = verify_model_region_and_return_specs( diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index ef203e4014..859471004b 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -844,13 +844,15 @@ def to_json(self) -> Dict[str, Any]: cur_val = getattr(self, att) if issubclass(type(cur_val), JumpStartDataHolderType): json_obj[att] = cur_val.to_json() - else: + elif cur_val: json_obj[att] = cur_val return json_obj class AdditionalModelDataSource(JumpStartDataHolderType): - """Data class of additional model data source mirrors Hosting API.""" + """Data class of additional model data source mirrors CreateModel API.""" + + SERIALIZATION_EXCLUSION_SET: Set[str] = set() __slots__ = ["channel_name", "s3_data_source"] @@ -871,23 +873,26 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.channel_name: str = json_obj["channel_name"] self.s3_data_source: S3DataSource = S3DataSource(json_obj["s3_data_source"]) - def to_json(self) -> Dict[str, Any]: + def to_json(self, exclude_keys=True) -> Dict[str, Any]: """Returns json representation of AdditionalModelDataSource object.""" json_obj = {} for att in self.__slots__: if hasattr(self, att): - cur_val = getattr(self, att) - if issubclass(type(cur_val), JumpStartDataHolderType): - json_obj[att] = cur_val.to_json() - else: - json_obj[att] = cur_val + if exclude_keys and att not in self.SERIALIZATION_EXCLUSION_SET or not exclude_keys: + cur_val = getattr(self, att) + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + else: + json_obj[att] = cur_val return json_obj class JumpStartModelDataSource(AdditionalModelDataSource): """Data class JumpStart additional model data source.""" - __slots__ = ["artifact_version"] + AdditionalModelDataSource.__slots__ + SERIALIZATION_EXCLUSION_SET = {"artifact_version"} + + __slots__ = list(SERIALIZATION_EXCLUSION_SET) + AdditionalModelDataSource.__slots__ def from_json(self, json_obj: Dict[str, Any]) -> None: """Sets fields in object based on json. @@ -1761,6 +1766,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "training_instance_type", "resources", "config_name", + "additional_model_data_sources", ] SERIALIZATION_EXCLUSION_SET = { @@ -1806,6 +1812,7 @@ def __init__( training_instance_type: Optional[str] = None, resources: Optional[ResourceRequirements] = None, config_name: Optional[str] = None, + additional_model_data_sources: Optional[Dict[str, Any]] = None, ) -> None: """Instantiates JumpStartModelInitKwargs object.""" @@ -1837,6 +1844,7 @@ def __init__( self.training_instance_type = training_instance_type self.resources = resources self.config_name = config_name + self.additional_model_data_sources = additional_model_data_sources class JumpStartModelDeployKwargs(JumpStartKwargs): diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 5e0f47a406..adb286f660 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1798,3 +1798,33 @@ def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Optional[Dict[ "name": "Instance Rate", } return None + + +def camel_case_to_pascal_case(data: Dict[str, Any]) -> Dict[str, Any]: + """Iteratively updates a dictionary to convert all keys from snake_case to PascalCase. + + Args: + data (dict): The dictionary to be updated. + + Returns: + dict: The updated dictionary with keys in PascalCase. + """ + result = {} + + def convert_key(key): + """Converts a snake_case key to PascalCase.""" + return "".join(part.capitalize() for part in key.split("_")) + + def convert_value(value): + """Recursively processes the value of a key-value pair.""" + if isinstance(value, dict): + return camel_case_to_pascal_case(value) + if isinstance(value, list): + return [convert_value(item) for item in value] + + return value + + for key, value in data.items(): + result[convert_key(key)] = convert_value(value) + + return result diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index fa858d0ac1..25e01d5d10 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -759,7 +759,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): Please add the new argument to the skip set below, and reach out to JumpStart team.""" - init_args_to_skip: Set[str] = set(["additional_model_data_sources"]) + init_args_to_skip: Set[str] = set([]) deploy_args_to_skip: Set[str] = set(["kwargs"]) parent_class_init = Model.__init__ diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index b04593bfae..731333d8ba 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -34,6 +34,7 @@ from sagemaker.experiments._run_context import _RunContext from sagemaker.session_settings import SessionSettings from sagemaker.utils import ( + camel_case_to_pascal_case, deep_override_dict, flatten_dict, get_instance_type_family, @@ -2055,3 +2056,42 @@ def test_resolve_routing_config(routing_config, expected): def test_resolve_routing_config_ex(): pytest.raises(ValueError, lambda: _resolve_routing_config({"RoutingStrategy": "Invalid"})) + + +class TestConvertToPascalCase(TestCase): + def test_simple_dict(self): + input_dict = {"first_name": "John", "last_name": "Doe"} + expected_output = {"FirstName": "John", "LastName": "Doe"} + self.assertEqual(camel_case_to_pascal_case(input_dict), expected_output) + + def camel_case_to_pascal_case_nested(self): + input_dict = { + "model_name": "my-model", + "primary_container": { + "image": "my-docker-image:latest", + "model_data_url": "s3://my-bucket/model.tar.gz", + "environment": {"env_var_1": "value1", "env_var_2": "value2"}, + }, + "execution_role_arn": "arn:aws:iam::123456789012:role/my-sagemaker-role", + "tags": [ + {"key": "project", "value": "my-project"}, + {"key": "environment", "value": "development"}, + ], + } + expected_output = { + "ModelName": "my-model", + "PrimaryContainer": { + "Image": "my-docker-image:latest", + "ModelDataUrl": "s3://my-bucket/model.tar.gz", + "Environment": {"EnvVar1": "value1", "EnvVar2": "value2"}, + }, + "ExecutionRoleArn": "arn:aws:iam::123456789012:role/my-sagemaker-role", + "Tags": [ + {"Key": "project", "Value": "my-project"}, + {"Key": "environment", "Value": "development"}, + ], + } + self.assertEqual(camel_case_to_pascal_case(input_dict), expected_output) + + def test_empty_input(self): + self.assertEqual(camel_case_to_pascal_case({}), {}) From 3c7b9665dd3a46d00e537b23b5338f14555d57b0 Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Thu, 13 Jun 2024 15:57:26 -0400 Subject: [PATCH 07/45] fix: overriding instance specific fields in config components (#1478) * fix: instance specific variables override * format --- src/sagemaker/jumpstart/factory/model.py | 1 + src/sagemaker/utils.py | 2 -- tests/unit/sagemaker/jumpstart/constants.py | 10 +++++----- .../sagemaker/jumpstart/estimator/test_estimator.py | 12 ++++++++---- tests/unit/sagemaker/jumpstart/test_types.py | 2 +- 5 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 14287d5fcf..a7b37420df 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -677,6 +677,7 @@ def _add_config_name_to_deploy_kwargs( kwargs.config_name = _select_inference_config_from_training_config( specs=specs, training_config_name=training_config_name ) + return kwargs if specs.inference_configs: default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index adb286f660..045a214759 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1660,8 +1660,6 @@ def nested_set_dict(d: Dict[str, Any], keys: List[str], value: Any) -> None: if len(keys) == 1: d[key] = value return - if not d: - return d = d.setdefault(key, {}) nested_set_dict(d, keys[1:], value) diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index ee6e320659..847710e0e4 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -7758,7 +7758,7 @@ "regional_aliases": { "us-west-2": { "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-hosting:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" } }, "variants": {"inf2": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, @@ -7772,7 +7772,7 @@ "regional_aliases": { "us-west-2": { "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + "huggingface-pytorch-hosting:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" } }, "variants": { @@ -7796,7 +7796,7 @@ "regional_aliases": { "us-west-2": { "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + "pytorch-hosting:1.13.1-py310-sdk2.14.1-ubuntu20.04" } }, "variants": { @@ -7922,7 +7922,7 @@ "regional_aliases": { "us-west-2": { "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + "huggingface-pytorch-training:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" } }, "variants": { @@ -7953,7 +7953,7 @@ "regional_aliases": { "us-west-2": { "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + "pytorch-training:1.13.1-py310-sdk2.14.1-ubuntu20.04" } }, "variants": { diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 4347d23a35..6f5f3dba05 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1953,7 +1953,8 @@ def test_estimator_initialization_with_config_name( mock_estimator_init.assert_called_once_with( instance_type="ml.p2.xlarge", instance_count=1, - image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3", + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-training:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04", model_uri="s3://jumpstart-cache-prod-us-west-2/artifacts/meta-textgeneration-llama-2-7b/" "gpu-training/model/", source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/transfer_learning/" @@ -2005,7 +2006,8 @@ def test_estimator_set_config_name( mock_estimator_init.assert_called_with( instance_type="ml.p2.xlarge", instance_count=1, - image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3", + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-training:1.13.1-py310-sdk2.14.1-ubuntu20.04", model_uri="s3://jumpstart-cache-prod-us-west-2/artifacts/meta-textgeneration-llama-2-7b/" "gpu-training-budget/model/", source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" @@ -2060,7 +2062,8 @@ def test_estimator_default_inference_config( mock_estimator_deploy.assert_called_once_with( instance_type="ml.p2.xlarge", initial_instance_count=1, - image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.5.0-gpu-py3", + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-hosting" + ":2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" "pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz", entry_point="inference.py", @@ -2167,7 +2170,8 @@ def test_estimator_deploy_with_config( mock_estimator_deploy.assert_called_once_with( instance_type="ml.p2.xlarge", initial_instance_count=1, - image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.5.0-gpu-py3", + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting:1.13.1-py310-sdk2.14.1-ubuntu20.04", source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" "pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz", entry_point="inference.py", diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index d325cc4781..da7af3310f 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -1143,7 +1143,7 @@ def test_inference_configs_parsing(): "regional_aliases": { "us-west-2": { "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-hosting:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" } }, "variants": {"inf2": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, From f55e3c93c564e78b75f42d03ae84e3ce2898511b Mon Sep 17 00:00:00 2001 From: Jonathan Makunga <54963715+makungaj1@users.noreply.github.com> Date: Thu, 13 Jun 2024 17:14:10 -0700 Subject: [PATCH 08/45] Feat: Add optimize to ModelBuilder JS (#1474) * QS JS vanilla model * Use Alt config for Optimization * JS Optimize * Resolve config * inject additional tags * Inject tags * Refactoring * Refactoring * Filter Deployment config * Refactoring * Refactoring * Refactoring * Refactoring --------- Co-authored-by: Jonathan Makunga --- src/sagemaker/enums.py | 6 + src/sagemaker/jumpstart/utils.py | 17 ++ src/sagemaker/model.py | 15 +- .../serve/builder/jumpstart_builder.py | 150 +++++++++++++ src/sagemaker/serve/builder/model_builder.py | 96 ++++----- src/sagemaker/serve/utils/optimize_utils.py | 133 ++++++++++-- tests/unit/sagemaker/jumpstart/test_utils.py | 15 ++ .../serve/builder/test_model_builder.py | 3 + .../serve/utils/test_optimize_utils.py | 201 ++++++++++++++++++ 9 files changed, 568 insertions(+), 68 deletions(-) create mode 100644 tests/unit/sagemaker/serve/utils/test_optimize_utils.py diff --git a/src/sagemaker/enums.py b/src/sagemaker/enums.py index f02b275cbe..b5b3931464 100644 --- a/src/sagemaker/enums.py +++ b/src/sagemaker/enums.py @@ -40,3 +40,9 @@ class RoutingStrategy(Enum): """The endpoint routes requests to the specific instances that have more capacity to process them. """ + + +class Tag(str, Enum): + """Enum class for tag keys to apply to models.""" + + OPTIMIZATION_JOB_NAME = "sagemaker-sdk:optimization-job-name" diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 22974a3838..a22f95fd28 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1336,3 +1336,20 @@ def wrapped_f(*args, **kwargs): if _func is None: return wrapper_cache return wrapper_cache(_func) + + +def _extract_image_tag_and_version(image_uri: str) -> Tuple[Optional[str], Optional[str]]: + """Extract Image tag and version from image URI. + + Args: + image_uri (str): Image URI. + + Returns: + Tuple[Optional[str], Optional[str]]: The tag and version of the image. + """ + if image_uri is None: + return None, None + + tag = image_uri.split(":")[-1] + + return tag, tag.split("-")[0] diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 9c54738ff4..7e23df0c41 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -404,6 +404,18 @@ def __init__( self.content_types = None self.response_types = None self.accept_eula = None + self._tags: Optional[Tags] = None + + def add_tags(self, tags: Tags) -> None: + """Add tags to this ``Model`` + + Args: + tags (Tags): Tags to add. + """ + if self._tags and tags: + self._tags.update(tags) + else: + self._tags = tags @runnable_by_pipeline def register( @@ -1457,7 +1469,8 @@ def deploy( sagemaker_session=self.sagemaker_session, ) - tags = format_tags(tags) + self.add_tags(tags) + tags = format_tags(self._tags) if ( getattr(self.sagemaker_session, "settings", None) is not None diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index e8ef546f7a..20a3738b02 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -19,6 +19,8 @@ from typing import Type, Any, List, Dict, Optional import logging +from sagemaker.jumpstart import enums +from sagemaker.jumpstart.utils import verify_model_region_and_return_specs, get_eula_message from sagemaker.model import Model from sagemaker import model_uris from sagemaker.serve.model_server.djl_serving.prepare import prepare_djl_js_resources @@ -33,6 +35,11 @@ LocalModelLoadException, SkipTuningComboException, ) +from sagemaker.serve.utils.optimize_utils import ( + _extract_supported_deployment_config, + _is_speculation_enabled, + _is_compatible_with_optimization_job, +) from sagemaker.serve.utils.predictors import ( DjlLocalModePredictor, TgiLocalModePredictor, @@ -53,6 +60,7 @@ from sagemaker.serve.utils.types import ModelServer from sagemaker.base_predictor import PredictorBase from sagemaker.jumpstart.model import JumpStartModel +from sagemaker.utils import Tags _DJL_MODEL_BUILDER_ENTRY_POINT = "inference.py" _NO_JS_MODEL_EX = "HuggingFace JumpStart Model ID not detected. Building for HuggingFace Model ID." @@ -564,6 +572,148 @@ def _build_for_jumpstart(self): return self.pysdk_model + def _optimize_for_jumpstart( + self, + output_path: str, + instance_type: Optional[str] = None, + role: Optional[str] = None, + tags: Optional[Tags] = None, + job_name: Optional[str] = None, + accept_eula: Optional[bool] = None, + quantization_config: Optional[Dict] = None, + compilation_config: Optional[Dict] = None, + speculative_decoding_config: Optional[Dict] = None, + env_vars: Optional[Dict] = None, + vpc_config: Optional[Dict] = None, + kms_key: Optional[str] = None, + max_runtime_in_sec: Optional[int] = None, + ) -> None: + """Runs a model optimization job. + + Args: + output_path (str): Specifies where to store the compiled/quantized model. + instance_type (Optional[str]): Target deployment instance type that + the model is optimized for. + role (Optional[str]): Execution role. Defaults to ``None``. + tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. + job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). + quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``. + compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. + speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. + Defaults to ``None`` + env_vars (Optional[Dict]): Additional environment variables to run the optimization + container. Defaults to ``None``. + vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. + kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading + to S3. Defaults to ``None``. + max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to + ``None``. + """ + model_specs = verify_model_region_and_return_specs( + region=self.sagemaker_session.boto_region_name, + model_id=self.pysdk_model.model_id, + version=self.pysdk_model.model_version, + sagemaker_session=self.sagemaker_session, + scope=enums.JumpStartScriptScope.INFERENCE, + model_type=self.pysdk_model.model_type, + ) + + if model_specs.is_gated_model() and accept_eula is not True: + raise ValueError(get_eula_message(model_specs, self.sagemaker_session.boto_region_name)) + + if not (self.pysdk_model.model_data and self.pysdk_model.model_data.get("S3DataSource")): + raise ValueError("Model Optimization Job only supports model backed by S3.") + + has_alternative_config = self.pysdk_model.deployment_config is not None + merged_env_vars = None + # TODO: Match Optimization Input Schema + model_source = { + "S3": {"S3Uri": self.pysdk_model.model_data.get("S3DataSource").get("S3Uri")}, + "SageMakerModel": {"ModelName": self.model}, + } + + if has_alternative_config: + image_uri = self.pysdk_model.deployment_config.get("DeploymentArgs").get("ImageUri") + instance_type = self.pysdk_model.deployment_config.get("InstanceType") + else: + image_uri = self.pysdk_model.image_uri + + if not _is_compatible_with_optimization_job(instance_type, image_uri) or ( + speculative_decoding_config + and not _is_speculation_enabled(self.pysdk_model.deployment_config) + ): + deployment_config = _extract_supported_deployment_config( + self.pysdk_model.list_deployment_configs(), speculative_decoding_config is None + ) + + if deployment_config: + self.pysdk_model.set_deployment_config( + config_name=deployment_config.get("DeploymentConfigName"), + instance_type=deployment_config.get("InstanceType"), + ) + merged_env_vars = self.pysdk_model.deployment_config.get("Environment") + + if speculative_decoding_config: + # TODO: Match Optimization Input Schema + s3 = { + "S3Uri": self.pysdk_model.additional_model_data_sources[ + "SpeculativeDecoding" + ][0]["S3DataSource"]["S3Uri"] + } + model_source["S3"].update(s3) + elif speculative_decoding_config: + raise ValueError("Can't find deployment config for model optimization job.") + + optimization_config = {} + if env_vars: + if merged_env_vars: + merged_env_vars.update(env_vars) + else: + merged_env_vars = env_vars + if quantization_config: + optimization_config["ModelQuantizationConfig"] = quantization_config + if compilation_config: + optimization_config["ModelCompilationConfig"] = compilation_config + + if accept_eula: + self.pysdk_model.accept_eula = accept_eula + self.pysdk_model.model_data["S3DataSource"].update( + {"ModelAccessConfig": {"AcceptEula": accept_eula}} + ) + model_source["S3"].update({"ModelAccessConfig": {"AcceptEula": accept_eula}}) + + output_config = {"S3OutputLocation": output_path} + if kms_key: + output_config["KmsKeyId"] = kms_key + + create_optimization_job_args = { + "OptimizationJobName": job_name, + "ModelSource": model_source, + "DeploymentInstanceType": instance_type, + "Environment": merged_env_vars, + "OptimizationConfigs": [optimization_config], + "OutputConfig": output_config, + "RoleArn": role, + } + + if max_runtime_in_sec: + create_optimization_job_args["StoppingCondition"] = { + "MaxRuntimeInSeconds": max_runtime_in_sec + } + if tags: + create_optimization_job_args["Tags"] = tags + if vpc_config: + create_optimization_job_args["VpcConfig"] = vpc_config + + self.sagemaker_session.sagemaker_client.create_optimization_job( + **create_optimization_job_args + ) + def _is_gated_model(self, model) -> bool: """Determine if ``this`` Model is Gated diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 297d276033..5062e55d8c 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -62,10 +62,7 @@ from sagemaker.serve.utils import task from sagemaker.serve.utils.exceptions import TaskNotFoundException from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model -from sagemaker.serve.utils.optimize_utils import ( - _is_compatible_with_compilation, - _poll_optimization_job, -) +from sagemaker.serve.utils.optimize_utils import _poll_optimization_job, _generate_optimized_model from sagemaker.serve.utils.predictors import _get_local_mode_predictor from sagemaker.serve.utils.hardware_detector import ( _get_gpu_info, @@ -961,13 +958,15 @@ def optimize(self, *args, **kwargs) -> Type[Model]: @_capture_telemetry("optimize") def _model_builder_optimize_wrapper( self, - instance_type: str, output_path: str, + instance_type: Optional[str] = None, role: Optional[str] = None, tags: Optional[Tags] = None, job_name: Optional[str] = None, + accept_eula: Optional[bool] = None, quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None, + speculative_decoding_config: Optional[Dict] = None, env_vars: Optional[Dict] = None, vpc_config: Optional[Dict] = None, kms_key: Optional[str] = None, @@ -977,13 +976,20 @@ def _model_builder_optimize_wrapper( """Runs a model optimization job. Args: - instance_type (str): Target deployment instance type that the model is optimized for. output_path (str): Specifies where to store the compiled/quantized model. + instance_type (str): Target deployment instance type that the model is optimized for. role (Optional[str]): Execution role. Defaults to ``None``. tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``. compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. + speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. + Defaults to ``None`` env_vars (Optional[Dict]): Additional environment variables to run the optimization container. Defaults to ``None``. vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. @@ -999,57 +1005,39 @@ def _model_builder_optimize_wrapper( Type[Model]: A deployable ``Model`` object. """ self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() + self.build(mode=self.mode, sagemaker_session=self.sagemaker_session) + job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}" - # TODO: inject actual model source location based on different scenarios - model_source = {"S3": {"S3Uri": self.model_path, "ModelAccessConfig": {"AcceptEula": True}}} - - optimization_configs = [] - if quantization_config: - optimization_configs.append({"ModelQuantizationConfig": quantization_config}) - if compilation_config: - if _is_compatible_with_compilation(instance_type): - optimization_configs.append({"ModelCompilationConfig": compilation_config}) - else: - logger.warning( - "Model compilation is currently only supported for Inferentia and Trainium" - "instances, ignoring `compilation_config'." - ) + if self._is_jumpstart_model_id(): + self._optimize_for_jumpstart( + output_path=output_path, + instance_type=instance_type, + role=role if role else self.role_arn, + tags=tags, + job_name=job_name, + accept_eula=accept_eula, + quantization_config=quantization_config, + compilation_config=compilation_config, + speculative_decoding_config=speculative_decoding_config, + env_vars=env_vars, + vpc_config=vpc_config, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + ) - output_config = {"S3OutputLocation": output_path} - if kms_key: - output_config["KmsKeyId"] = kms_key + # TODO: use the wait for job pattern similar to + # https://quip-amazon.com/TKaPAhJck5sD/PySDK-Model-Optimization#temp:C:YcX3f2b103dabb4431090568bca2 + if not _poll_optimization_job(job_name, self.sagemaker_session): + raise Exception("Optimization job timed out.") - job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}" - create_optimization_job_args = { - "OptimizationJobName": job_name, - "ModelSource": model_source, - "DeploymentInstanceType": instance_type, - "OptimizationConfigs": optimization_configs, - "OutputConfig": output_config, - "RoleArn": role or self.role_arn, - } - - if env_vars: - create_optimization_job_args["OptimizationEnvironment"] = env_vars - - if max_runtime_in_sec: - create_optimization_job_args["StoppingCondition"] = { - "MaxRuntimeInSeconds": max_runtime_in_sec - } - - # TODO: tag injection if it is a JumpStart model - if tags: - create_optimization_job_args["Tags"] = tags - - if vpc_config: - create_optimization_job_args["VpcConfig"] = vpc_config - - response = self.sagemaker_session.sagemaker_client.create_optimization_job( - **create_optimization_job_args + describe_optimization_job_res = ( + self.sagemaker_session.sagemaker_client.describe_optimization_job( + OptimizationJobName=job_name + ) ) - if not _poll_optimization_job(job_name, self.sagemaker_session): - raise Exception("Optimization job timed out.") + self.pysdk_model = _generate_optimized_model( + self.pysdk_model, describe_optimization_job_res + ) - # TODO: return model created by optimization job - return response + return self.pysdk_model diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index 32395c1478..305758e502 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -13,10 +13,15 @@ """Holds the util functions used for the optimize function""" from __future__ import absolute_import +import re import time import logging +from typing import List, Dict, Any, Optional -from sagemaker import Session +from sagemaker import Session, Model +from sagemaker.enums import Tag +from sagemaker.fw_utils import _is_gpu_instance +from sagemaker.jumpstart.utils import _extract_image_tag_and_version # TODO: determine how long optimization jobs take OPTIMIZE_POLLER_MAX_TIMEOUT_SECS = 300 @@ -25,18 +30,6 @@ logger = logging.getLogger(__name__) -def _is_compatible_with_compilation(instance_type: str) -> bool: - """Checks whether an instance is compatible with compilation. - - Args: - instance_type (str): The instance type used for the compilation job. - - Returns: - bool: Whether the given instance type is compatible with compilation. - """ - return instance_type.startswith("ml.inf") or instance_type.startswith("ml.trn") - - def _poll_optimization_job(job_name: str, sagemaker_session: Session) -> bool: """Polls optimization job status until success. @@ -56,3 +49,117 @@ def _poll_optimization_job(job_name: str, sagemaker_session: Session) -> bool: if result is not None: return result time.sleep(OPTIMIZE_POLLER_INTERVAL_SECS) + + +def _is_inferentia_or_trainium(instance_type: Optional[str]) -> bool: + """Checks whether an instance is compatible with Inferentia. + + Args: + instance_type (str): The instance type used for the compilation job. + + Returns: + bool: Whether the given instance type is Inferentia or Trainium. + """ + if isinstance(instance_type, str): + match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) + if match: + if match[1].startswith("inf") or match[1].startswith("trn"): + return True + return False + + +def _is_compatible_with_optimization_job( + instance_type: Optional[str], image_uri: Optional[str] +) -> bool: + """Checks whether an instance is compatible with an optimization job. + + Args: + instance_type (str): The instance type used for the compilation job. + image_uri (str): The image URI of the optimization job. + + Returns: + bool: Whether the given instance type is compatible with an optimization job. + """ + image_tag, image_version = _extract_image_tag_and_version(image_uri) + if not image_tag or not image_version: + return False + + return ( + _is_gpu_instance(instance_type) and "djl-inference:" in image_uri and "-lmi" in image_tag + ) or ( + _is_inferentia_or_trainium(instance_type) + and "djl-inference:" in image_uri + and "-neuronx-s" in image_tag + ) + + +def _generate_optimized_model(pysdk_model: Model, optimization_response: dict) -> Model: + """Generates a new optimization model. + + Args: + pysdk_model (Model): A PySDK model. + optimization_response (dict): The optimization response. + + Returns: + Model: A deployable optimized model. + """ + pysdk_model.image_uri = optimization_response["RecommendedInferenceImage"] + pysdk_model.env = optimization_response["OptimizationEnvironment"] + pysdk_model.model_data["S3DataSource"]["S3Uri"] = optimization_response["ModelSource"]["S3"] + pysdk_model.instance_type = optimization_response["DeploymentInstanceType"] + pysdk_model.add_tags( + {"key": Tag.OPTIMIZATION_JOB_NAME, "value": optimization_response["OptimizationJobName"]} + ) + + return pysdk_model + + +def _is_speculation_enabled(deployment_config: Optional[Dict[str, Any]]) -> bool: + """Checks whether speculation is enabled for this deployment config. + + Args: + deployment_config (Dict[str, Any]): A deployment config. + + Returns: + bool: Whether the speculation is enabled for this deployment config. + """ + if deployment_config is None: + return False + + acceleration_configs = deployment_config.get("AccelerationConfigs") + if acceleration_configs: + for acceleration_config in acceleration_configs: + if acceleration_config.get("type").lower() == "speculation" and acceleration_config.get( + "enabled" + ): + return True + return False + + +def _extract_supported_deployment_config( + deployment_configs: Optional[List[Dict[str, Any]]], + speculation_enabled: Optional[bool] = False, +) -> Optional[Dict[str, Any]]: + """Extracts supported deployment configurations. + + Args: + deployment_configs (Optional[List[Dict[str, Any]]]): A list of deployment configurations. + speculation_enabled (Optional[bool]): Whether speculation is enabled. + + Returns: + Optional[Dict[str, Any]]: Supported deployment configuration. + """ + if deployment_configs is None: + return None + + for deployment_config in deployment_configs: + image_uri: str = deployment_config.get("DeploymentArgs").get("ImageUri") + instance_type = deployment_config.get("InstanceType") + + if _is_compatible_with_optimization_job(instance_type, image_uri): + if speculation_enabled: + if _is_speculation_enabled(deployment_config): + return deployment_config + else: + return deployment_config + return None diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index f0bb5a6219..b3cdb69137 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -46,6 +46,7 @@ JumpStartModelHeader, JumpStartVersionedModelId, ) +from sagemaker.jumpstart.utils import _extract_image_tag_and_version from tests.unit.sagemaker.jumpstart.utils import ( get_base_spec_with_prototype_configs, get_spec_from_base_spec, @@ -2044,3 +2045,17 @@ def test_deployment_config_response_data(data, expected): print(out) assert out == expected + + +@pytest.mark.parametrize( + "image_uri, expected", + [ + ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124", + ("0.28.0-lmi10.0.0-cu124", "0.28.0"), + ), + (None, (None, None)), + ], +) +def test_extract_image_tag_and_version(image_uri, expected): + assert _extract_image_tag_and_version(image_uri) == expected diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index ed7c736633..b70f855486 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -13,6 +13,8 @@ from __future__ import absolute_import from unittest.mock import MagicMock, patch, Mock, mock_open +import pytest + import unittest from pathlib import Path from copy import deepcopy @@ -2258,6 +2260,7 @@ def test_build_tensorflow_serving_non_mlflow_case( mock_session, ) + @pytest.mark.skip(reason="Implementation not completed") @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") def test_optimize(self, mock_send_telemetry, mock_get_serve_setting): diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py new file mode 100644 index 0000000000..03e70b0ad8 --- /dev/null +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -0,0 +1,201 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import Mock + +import pytest + +from sagemaker.enums import Tag +from sagemaker.serve.utils.optimize_utils import ( + _generate_optimized_model, + _is_speculation_enabled, + _extract_supported_deployment_config, + _is_inferentia_or_trainium, + _is_compatible_with_optimization_job, +) + +mock_optimization_job_output = { + "OptimizationJobName": "optimization_job_name", + "RecommendedInferenceImage": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi2.0.0-gpu-py310-cu121-ubuntu22.04", + "OptimizationEnvironment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "MAX_BATCH_PREFILL_TOKENS": "8192", + "MAX_CONCURRENT_REQUESTS": "512", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ModelSource": { + "S3": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/" + "meta-textgeneration-llama-3-8b/artifacts/inference-prepack/v2.0.0/" + }, + "DeploymentInstanceType": "ml.m5.xlarge", +} + + +@pytest.mark.parametrize( + "instance, expected", + [ + ("ml.trn1.2xlarge", True), + ("ml.inf2.xlarge", True), + ("ml.c7gd.4xlarge", False), + ], +) +def test_is_inferentia_or_trainium(instance, expected): + assert _is_inferentia_or_trainium(instance) == expected + + +@pytest.mark.parametrize( + "instance, image_uri, expected", + [ + ( + "ml.g5.12xlarge", + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124", + True, + ), + ( + "ml.trn1.2xlarge", + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-neuronx-sdk2.18.2", + True, + ), + ( + "ml.inf2.xlarge", + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-neuronx-sdk2.18.2", + True, + ), + ( + "ml.c7gd.4xlarge", + "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:" + "2.1.1-tgi2.0.0-gpu-py310-cu121-ubuntu22.04", + False, + ), + ], +) +def test_is_compatible_with_optimization_job(instance, image_uri, expected): + assert _is_compatible_with_optimization_job(instance, image_uri) == expected + + +@pytest.mark.parametrize( + "deployment_configs, expected", + [ + ( + [ + { + "InstanceType": "ml.c7gd.4xlarge", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124" + }, + "AccelerationConfigs": [ + { + "type": "acceleration", + "enabled": True, + "spec": {"compiler": "a", "version": "1"}, + } + ], + } + ], + None, + ), + ( + [ + { + "InstanceType": "ml.g5.12xlarge", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124" + }, + "AccelerationConfigs": [ + { + "type": "speculation", + "enabled": True, + } + ], + } + ], + { + "InstanceType": "ml.g5.12xlarge", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124" + }, + "AccelerationConfigs": [ + { + "type": "speculation", + "enabled": True, + } + ], + }, + ), + (None, None), + ], +) +def test_extract_supported_deployment_config(deployment_configs, expected): + assert _extract_supported_deployment_config(deployment_configs, True) == expected + + +def test_generate_optimized_model(): + pysdk_model = Mock() + pysdk_model.model_data = {"S3DataSource": {"S3Uri": "s3://foo/bar"}} + + optimized_model = _generate_optimized_model(pysdk_model, mock_optimization_job_output) + + assert optimized_model.image_uri == mock_optimization_job_output["RecommendedInferenceImage"] + assert optimized_model.env == mock_optimization_job_output["OptimizationEnvironment"] + assert ( + optimized_model.model_data["S3DataSource"]["S3Uri"] + == mock_optimization_job_output["ModelSource"]["S3"] + ) + assert optimized_model.instance_type == mock_optimization_job_output["DeploymentInstanceType"] + pysdk_model.add_tags.assert_called_once_with( + { + "key": Tag.OPTIMIZATION_JOB_NAME, + "value": mock_optimization_job_output["OptimizationJobName"], + } + ) + + +@pytest.mark.parametrize( + "deployment_config, expected", + [ + ( + { + "AccelerationConfigs": [ + { + "type": "acceleration", + "enabled": True, + "spec": {"compiler": "a", "version": "1"}, + } + ], + }, + False, + ), + ( + { + "AccelerationConfigs": [ + { + "type": "speculation", + "enabled": True, + } + ], + }, + True, + ), + (None, False), + ], +) +def test_is_speculation_enabled(deployment_config, expected): + assert _is_speculation_enabled(deployment_config) is expected From c6581ff8efc630ebc8e43312a7b20565fe8d3e47 Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Fri, 14 Jun 2024 15:25:03 -0400 Subject: [PATCH 09/45] feat: use Neo bucket in speculative decoding data source (#1479) * Use Neo bucket in speculative decoding data source * address comments * format * address comments * add buckets to regional config * remove opt-in regions for neo buckets --- src/sagemaker/jumpstart/constants.py | 20 +++ src/sagemaker/jumpstart/factory/model.py | 11 +- src/sagemaker/jumpstart/types.py | 38 ++++- src/sagemaker/jumpstart/utils.py | 30 +++- tests/unit/sagemaker/jumpstart/constants.py | 1 + .../sagemaker/jumpstart/model/test_model.py | 66 +++++++++ tests/unit/sagemaker/jumpstart/test_types.py | 35 +++++ tests/unit/sagemaker/jumpstart/test_utils.py | 132 ++++++++++-------- 8 files changed, 267 insertions(+), 66 deletions(-) diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 5b0f749c64..b94fb2982c 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -44,31 +44,37 @@ region_name="us-west-2", content_bucket="jumpstart-cache-prod-us-west-2", gated_content_bucket="jumpstart-private-cache-prod-us-west-2", + neo_content_bucket="sagemaker-sd-models-prod-us-west-2", ), JumpStartLaunchedRegionInfo( region_name="us-east-1", content_bucket="jumpstart-cache-prod-us-east-1", gated_content_bucket="jumpstart-private-cache-prod-us-east-1", + neo_content_bucket="sagemaker-sd-models-prod-us-east-1", ), JumpStartLaunchedRegionInfo( region_name="us-east-2", content_bucket="jumpstart-cache-prod-us-east-2", gated_content_bucket="jumpstart-private-cache-prod-us-east-2", + neo_content_bucket="sagemaker-sd-models-prod-us-east-2", ), JumpStartLaunchedRegionInfo( region_name="eu-west-1", content_bucket="jumpstart-cache-prod-eu-west-1", gated_content_bucket="jumpstart-private-cache-prod-eu-west-1", + neo_content_bucket="sagemaker-sd-models-prod-eu-west-1", ), JumpStartLaunchedRegionInfo( region_name="eu-central-1", content_bucket="jumpstart-cache-prod-eu-central-1", gated_content_bucket="jumpstart-private-cache-prod-eu-central-1", + neo_content_bucket="sagemaker-sd-models-prod-eu-central-1", ), JumpStartLaunchedRegionInfo( region_name="eu-north-1", content_bucket="jumpstart-cache-prod-eu-north-1", gated_content_bucket="jumpstart-private-cache-prod-eu-north-1", + neo_content_bucket="sagemaker-sd-models-prod-eu-north-1", ), JumpStartLaunchedRegionInfo( region_name="me-south-1", @@ -84,11 +90,13 @@ region_name="ap-south-1", content_bucket="jumpstart-cache-prod-ap-south-1", gated_content_bucket="jumpstart-private-cache-prod-ap-south-1", + neo_content_bucket="sagemaker-sd-models-prod-ap-south-1", ), JumpStartLaunchedRegionInfo( region_name="eu-west-3", content_bucket="jumpstart-cache-prod-eu-west-3", gated_content_bucket="jumpstart-private-cache-prod-eu-west-3", + neo_content_bucket="sagemaker-sd-models-prod-eu-west-3", ), JumpStartLaunchedRegionInfo( region_name="af-south-1", @@ -99,6 +107,7 @@ region_name="sa-east-1", content_bucket="jumpstart-cache-prod-sa-east-1", gated_content_bucket="jumpstart-private-cache-prod-sa-east-1", + neo_content_bucket="sagemaker-sd-models-prod-sa-east-1", ), JumpStartLaunchedRegionInfo( region_name="ap-east-1", @@ -109,21 +118,25 @@ region_name="ap-northeast-2", content_bucket="jumpstart-cache-prod-ap-northeast-2", gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-2", + neo_content_bucket="sagemaker-sd-models-prod-ap-northeast-2", ), JumpStartLaunchedRegionInfo( region_name="ap-northeast-3", content_bucket="jumpstart-cache-prod-ap-northeast-3", gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-3", + neo_content_bucket="sagemaker-sd-models-prod-ap-northeast-3", ), JumpStartLaunchedRegionInfo( region_name="ap-southeast-3", content_bucket="jumpstart-cache-prod-ap-southeast-3", gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-3", + neo_content_bucket="sagemaker-sd-models-prod-ap-southeast-3", ), JumpStartLaunchedRegionInfo( region_name="eu-west-2", content_bucket="jumpstart-cache-prod-eu-west-2", gated_content_bucket="jumpstart-private-cache-prod-eu-west-2", + neo_content_bucket="sagemaker-sd-models-prod-eu-west-2", ), JumpStartLaunchedRegionInfo( region_name="eu-south-1", @@ -134,26 +147,31 @@ region_name="ap-northeast-1", content_bucket="jumpstart-cache-prod-ap-northeast-1", gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-1", + neo_content_bucket="sagemaker-sd-models-prod-ap-northeast-1", ), JumpStartLaunchedRegionInfo( region_name="us-west-1", content_bucket="jumpstart-cache-prod-us-west-1", gated_content_bucket="jumpstart-private-cache-prod-us-west-1", + neo_content_bucket="sagemaker-sd-models-prod-us-west-1", ), JumpStartLaunchedRegionInfo( region_name="ap-southeast-1", content_bucket="jumpstart-cache-prod-ap-southeast-1", gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-1", + neo_content_bucket="sagemaker-sd-models-prod-ap-southeast-1", ), JumpStartLaunchedRegionInfo( region_name="ap-southeast-2", content_bucket="jumpstart-cache-prod-ap-southeast-2", gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-2", + neo_content_bucket="sagemaker-sd-models-prod-ap-southeast-2", ), JumpStartLaunchedRegionInfo( region_name="ca-central-1", content_bucket="jumpstart-cache-prod-ca-central-1", gated_content_bucket="jumpstart-private-cache-prod-ca-central-1", + neo_content_bucket="sagemaker-sd-models-prod-ca-central-1", ), JumpStartLaunchedRegionInfo( region_name="cn-north-1", @@ -184,6 +202,7 @@ ) JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2" +NEO_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2" JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json" @@ -201,6 +220,7 @@ "AWS_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE" ) ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE = "AWS_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE" +ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE = "AWS_NEO_CONTENT_BUCKET_OVERRIDE" JUMPSTART_RESOURCE_BASE_NAME = "sagemaker-jumpstart" diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index a7b37420df..7de6407e47 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -46,6 +46,7 @@ from sagemaker.jumpstart.utils import ( add_jumpstart_model_info_tags, get_default_jumpstart_session_with_user_agent_suffix, + get_neo_content_bucket, update_dict_if_key_not_present, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, @@ -631,14 +632,16 @@ def _add_additional_model_data_sources_to_kwargs( model_type=kwargs.model_type, config_name=kwargs.config_name, ) - - additional_data_sources = specs.get_additional_s3_data_sources() + # Append speculative decoding data source from metadata + speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources() + for data_source in speculative_decoding_data_sources: + data_source.s3_data_source.set_bucket(get_neo_content_bucket()) api_shape_additional_model_data_sources = ( [ camel_case_to_pascal_case(data_source.to_json()) - for data_source in additional_data_sources + for data_source in speculative_decoding_data_sources ] - if specs.get_additional_s3_data_sources() + if specs.get_speculative_decoding_s3_data_sources() else None ) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 859471004b..13eb9e80bb 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -15,7 +15,13 @@ from copy import deepcopy from enum import Enum from typing import Any, Dict, List, Optional, Set, Union -from sagemaker.utils import get_instance_type_family, format_tags, Tags, deep_override_dict +from sagemaker.utils import ( + S3_PREFIX, + get_instance_type_family, + format_tags, + Tags, + deep_override_dict, +) from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines @@ -116,10 +122,14 @@ class JumpStartS3FileType(str, Enum): class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): """Data class for launched region info.""" - __slots__ = ["content_bucket", "region_name", "gated_content_bucket"] + __slots__ = ["content_bucket", "region_name", "gated_content_bucket", "neo_content_bucket"] def __init__( - self, content_bucket: str, region_name: str, gated_content_bucket: Optional[str] = None + self, + content_bucket: str, + region_name: str, + gated_content_bucket: Optional[str] = None, + neo_content_bucket: Optional[str] = None, ): """Instantiates JumpStartLaunchedRegionInfo object. @@ -128,10 +138,13 @@ def __init__( region_name (str): Name of JumpStart launched region. gated_content_bucket (Optional[str[]): Name of JumpStart gated s3 content bucket optionally associated with region. + neo_content_bucket (Optional[str]): Name of Neo service s3 content bucket + optionally associated with region. """ self.content_bucket = content_bucket self.gated_content_bucket = gated_content_bucket self.region_name = region_name + self.neo_content_bucket = neo_content_bucket class JumpStartModelHeader(JumpStartDataHolderType): @@ -848,6 +861,21 @@ def to_json(self) -> Dict[str, Any]: json_obj[att] = cur_val return json_obj + def set_bucket(self, bucket: str) -> None: + """Sets bucket name from S3 URI.""" + + if self.s3_uri.startswith(S3_PREFIX): + s3_path = self.s3_uri[len(S3_PREFIX) :] + old_bucket = s3_path.split("/")[0] + key = s3_path[len(old_bucket) :] + self.s3_uri = f"{S3_PREFIX}{bucket}{key}" # pylint: disable=W0201 + return + + if not bucket.endswith("/"): + bucket += "/" + + self.s3_uri = f"{S3_PREFIX}{bucket}{self.s3_uri}" # pylint: disable=W0201 + class AdditionalModelDataSource(JumpStartDataHolderType): """Data class of additional model data source mirrors CreateModel API.""" @@ -1638,8 +1666,10 @@ def supports_incremental_training(self) -> bool: """Returns True if the model supports incremental training.""" return self.incremental_training_supported - def get_speculative_decoding_s3_data_sources(self) -> List[JumpStartAdditionalDataSources]: + def get_speculative_decoding_s3_data_sources(self) -> List[JumpStartModelDataSource]: """Returns data sources for speculative decoding.""" + if not self.hosting_additional_data_sources: + return [] return self.hosting_additional_data_sources.speculative_decoding or [] def get_additional_s3_data_sources(self) -> List[JumpStartAdditionalDataSources]: diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index a22f95fd28..48b95bf887 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -156,7 +156,7 @@ def get_jumpstart_content_bucket( except KeyError: formatted_launched_regions_str = get_jumpstart_launched_regions_message() raise ValueError( - f"Unable to get content bucket for JumpStart in {region} region. " + f"Unable to get content bucket for Neo in {region} region. " f"{formatted_launched_regions_str}" ) @@ -170,6 +170,34 @@ def get_jumpstart_content_bucket( return bucket_to_return +def get_neo_content_bucket( + region: str = constants.NEO_DEFAULT_REGION_NAME, +) -> str: + """Returns the regionalized S3 bucket name for Neo service. + + Raises: + ValueError: If Neo is not launched in ``region``. + """ + + bucket_to_return: Optional[str] = None + if ( + constants.ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE in os.environ + and len(os.environ[constants.ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE]) > 0 + ): + bucket_to_return = os.environ[constants.ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE] + info_log = f"Using Neo bucket override: '{bucket_to_return}'" + constants.JUMPSTART_LOGGER.info(info_log) + else: + try: + bucket_to_return = constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[ + region + ].neo_content_bucket + except KeyError: + raise ValueError(f"Unable to get content bucket for Neo in {region} region.") + + return bucket_to_return + + def get_formatted_manifest( manifest: List[Dict], ) -> Dict[JumpStartVersionedModelId, JumpStartModelHeader]: diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 847710e0e4..734857945a 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -7806,6 +7806,7 @@ }, }, "gpu-accelerated": { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], "hosting_instance_type_variants": { "regional_aliases": { "us-west-2": { diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 25e01d5d10..a3d70933eb 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -1687,6 +1687,72 @@ def test_model_set_deployment_config( endpoint_logging=False, ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.model.Model.__init__") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_deployment_config_additional_model_data_source( + self, + mock_model_init: mock.Mock, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_session.return_value = sagemaker_session + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + model = JumpStartModel(model_id=model_id, config_name="gpu-accelerated") + + mock_model_init.assert_called_once_with( + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04", + model_data="s3://jumpstart-cache-prod-us-west-2/pytorch-infer/" + "infer-pytorch-eqa-bert-base-cased.tar.gz", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz", + entry_point="inference.py", + predictor_cls=Predictor, + role=execution_role, + sagemaker_session=sagemaker_session, + enable_network_isolation=False, + additional_model_data_sources=[ + { + "ChannelName": "draft_model_name", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://sagemaker-sd-models-prod-us-west-2/key/to/draft/model/artifact/", + "ModelAccessConfig": {"AcceptEula": False}, + }, + } + ], + ) + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.p2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-accelerated"}, + ], + wait=True, + endpoint_logging=False, + ) + @mock.patch( "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} ) diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index da7af3310f..06099ee066 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import copy +from unittest import TestCase import pytest from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.types import ( @@ -25,7 +26,9 @@ JumpStartConfigComponent, DeploymentConfigMetadata, JumpStartModelInitKwargs, + S3DataSource, ) +from sagemaker.utils import S3_PREFIX from tests.unit.sagemaker.jumpstart.constants import ( BASE_SPEC, BASE_HOSTING_ADDITIONAL_DATA_SOURCES, @@ -437,6 +440,37 @@ def test_jumpstart_model_specs(): assert specs3 == specs1 +class TestS3DataSource(TestCase): + def setUp(self): + self.s3_data_source = S3DataSource( + { + "compression_type": "None", + "s3_data_type": "S3Prefix", + "s3_uri": "key/to/model/artifact/", + "model_access_config": {"accept_eula": False}, + } + ) + + def test_set_bucket_with_valid_s3_uri(self): + self.s3_data_source.set_bucket("my-bucket") + self.assertEqual(self.s3_data_source.s3_uri, f"{S3_PREFIX}my-bucket/key/to/model/artifact/") + + def test_set_bucket_with_existing_s3_uri(self): + self.s3_data_source.s3_uri = "s3://my-bucket/key/to/model/artifact/" + self.s3_data_source.set_bucket("random-new-bucket") + assert self.s3_data_source.s3_uri == "s3://random-new-bucket/key/to/model/artifact/" + + def test_set_bucket_with_existing_s3_uri_empty_bucket(self): + self.s3_data_source.s3_uri = "s3://my-bucket" + self.s3_data_source.set_bucket("random-new-bucket") + assert self.s3_data_source.s3_uri == "s3://random-new-bucket" + + def test_set_bucket_with_existing_s3_uri_empty(self): + self.s3_data_source.s3_uri = "s3://" + self.s3_data_source.set_bucket("random-new-bucket") + assert self.s3_data_source.s3_uri == "s3://random-new-bucket" + + def test_get_speculative_decoding_s3_data_sources(): specs = JumpStartModelSpecs({**BASE_SPEC, **BASE_HOSTING_ADDITIONAL_DATA_SOURCES}) assert ( @@ -1372,6 +1406,7 @@ def test_additional_model_data_source_parsing(): assert config.config_components["gpu-accelerated"] == JumpStartConfigComponent( "gpu-accelerated", { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], "hosting_instance_type_variants": { "regional_aliases": { "us-west-2": { diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index b3cdb69137..bb5aa93d24 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -26,6 +26,7 @@ ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING, ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE, ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE, + ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE, EXTRA_MODEL_ID_TAGS, EXTRA_MODEL_VERSION_TAGS, JUMPSTART_DEFAULT_REGION_NAME, @@ -33,6 +34,7 @@ JUMPSTART_LOGGER, JUMPSTART_REGION_NAME_SET, JUMPSTART_RESOURCE_BASE_NAME, + NEO_DEFAULT_REGION_NAME, JumpStartScriptScope, ) from functools import partial @@ -65,79 +67,95 @@ def random_jumpstart_s3_uri(key): return f"s3://{random.choice(list(JUMPSTART_GATED_AND_PUBLIC_BUCKET_NAME_SET))}/{key}" -def test_get_jumpstart_content_bucket(): - bad_region = "bad_region" - assert bad_region not in JUMPSTART_REGION_NAME_SET - with pytest.raises(ValueError): - utils.get_jumpstart_content_bucket(bad_region) - - -def test_get_jumpstart_content_bucket_no_args(): - assert ( - utils.get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME) - == utils.get_jumpstart_content_bucket() - ) - +class TestBucketUtils(TestCase): + def test_get_jumpstart_content_bucket(self): + bad_region = "bad_region" + assert bad_region not in JUMPSTART_REGION_NAME_SET + with pytest.raises(ValueError): + utils.get_jumpstart_content_bucket(bad_region) -def test_get_jumpstart_content_bucket_override(): - with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE: "some-val"}): - with patch("logging.Logger.info") as mocked_info_log: - random_region = "random_region" - assert "some-val" == utils.get_jumpstart_content_bucket(random_region) - mocked_info_log.assert_called_with("Using JumpStart bucket override: 'some-val'") + def test_get_jumpstart_content_bucket_no_args(self): + assert ( + utils.get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME) + == utils.get_jumpstart_content_bucket() + ) + def test_get_jumpstart_content_bucket_override(self): + with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE: "some-val"}): + with patch("logging.Logger.info") as mocked_info_log: + random_region = "random_region" + assert "some-val" == utils.get_jumpstart_content_bucket(random_region) + mocked_info_log.assert_called_with("Using JumpStart bucket override: 'some-val'") -def test_get_jumpstart_gated_content_bucket(): - bad_region = "bad_region" - assert bad_region not in JUMPSTART_REGION_NAME_SET - with pytest.raises(ValueError): - utils.get_jumpstart_gated_content_bucket(bad_region) + def test_get_jumpstart_gated_content_bucket(self): + bad_region = "bad_region" + assert bad_region not in JUMPSTART_REGION_NAME_SET + with pytest.raises(ValueError): + utils.get_jumpstart_gated_content_bucket(bad_region) + def test_get_jumpstart_gated_content_bucket_no_args(self): + assert ( + utils.get_jumpstart_gated_content_bucket(JUMPSTART_DEFAULT_REGION_NAME) + == utils.get_jumpstart_gated_content_bucket() + ) -def test_get_jumpstart_gated_content_bucket_no_args(): - assert ( - utils.get_jumpstart_gated_content_bucket(JUMPSTART_DEFAULT_REGION_NAME) - == utils.get_jumpstart_gated_content_bucket() - ) + def test_get_jumpstart_gated_content_bucket_override(self): + with patch.dict( + os.environ, {ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE: "some-val"} + ): + with patch("logging.Logger.info") as mocked_info_log: + random_region = "random_region" + assert "some-val" == utils.get_jumpstart_gated_content_bucket(random_region) + mocked_info_log.assert_called_once_with( + "Using JumpStart gated bucket override: 'some-val'" + ) + def test_get_jumpstart_launched_regions_message(self): -def test_get_jumpstart_gated_content_bucket_override(): - with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE: "some-val"}): - with patch("logging.Logger.info") as mocked_info_log: - random_region = "random_region" - assert "some-val" == utils.get_jumpstart_gated_content_bucket(random_region) - mocked_info_log.assert_called_once_with( - "Using JumpStart gated bucket override: 'some-val'" + with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {}): + assert ( + utils.get_jumpstart_launched_regions_message() + == "JumpStart is not available in any region." ) + with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"some_region"}): + assert ( + utils.get_jumpstart_launched_regions_message() + == "JumpStart is available in some_region region." + ) -def test_get_jumpstart_launched_regions_message(): + with patch( + "sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", + {"some_region1", "some_region2"}, + ): + assert ( + utils.get_jumpstart_launched_regions_message() + == "JumpStart is available in some_region1 and some_region2 regions." + ) - with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {}): - assert ( - utils.get_jumpstart_launched_regions_message() - == "JumpStart is not available in any region." - ) + with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"a", "b", "c"}): + assert ( + utils.get_jumpstart_launched_regions_message() + == "JumpStart is available in a, b, and c regions." + ) - with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"some_region"}): - assert ( - utils.get_jumpstart_launched_regions_message() - == "JumpStart is available in some_region region." - ) + def test_get_neo_content_bucket(self): + bad_region = "bad_region" + assert bad_region not in JUMPSTART_REGION_NAME_SET + with pytest.raises(ValueError): + utils.get_neo_content_bucket(bad_region) - with patch( - "sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"some_region1", "some_region2"} - ): + def test_get_neo_content_bucket_no_args(self): assert ( - utils.get_jumpstart_launched_regions_message() - == "JumpStart is available in some_region1 and some_region2 regions." + utils.get_neo_content_bucket(NEO_DEFAULT_REGION_NAME) == utils.get_neo_content_bucket() ) - with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"a", "b", "c"}): - assert ( - utils.get_jumpstart_launched_regions_message() - == "JumpStart is available in a, b, and c regions." - ) + def test_get_neo_content_bucket_override(self): + with patch.dict(os.environ, {ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE: "some-val"}): + with patch("logging.Logger.info") as mocked_info_log: + random_region = "random_region" + assert "some-val" == utils.get_neo_content_bucket(random_region) + mocked_info_log.assert_called_with("Using Neo bucket override: 'some-val'") def test_get_formatted_manifest(): From 997e2cef3ea7eea3ad3abb3669f42523d3663724 Mon Sep 17 00:00:00 2001 From: Jacky Lee Date: Fri, 14 Jun 2024 18:10:51 -0700 Subject: [PATCH 10/45] feat: add build/deploy support for fine-tuned JS models (#1473) * feat: add support for fine-tuned JS models * Refactor * Refactor * Refactor * Refactor * pylint * pylint --------- Co-authored-by: Jacky Lee --- src/sagemaker/enums.py | 2 + .../serve/builder/jumpstart_builder.py | 58 ++++++++++++++++++- src/sagemaker/serve/builder/model_builder.py | 16 ++--- 3 files changed, 68 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/enums.py b/src/sagemaker/enums.py index b5b3931464..f648fc9896 100644 --- a/src/sagemaker/enums.py +++ b/src/sagemaker/enums.py @@ -46,3 +46,5 @@ class Tag(str, Enum): """Enum class for tag keys to apply to models.""" OPTIMIZATION_JOB_NAME = "sagemaker-sdk:optimization-job-name" + FINE_TUNING_MODEL_PATH = "sagemaker-sdk:fine-tuning-model-path" + FINE_TUNING_JOB_NAME = "sagemaker-sdk:fine-tuning-job-name" diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 20a3738b02..98dec6a171 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -14,11 +14,15 @@ from __future__ import absolute_import import copy +import re from abc import ABC, abstractmethod from datetime import datetime, timedelta from typing import Type, Any, List, Dict, Optional import logging +from botocore.exceptions import ClientError + +from sagemaker.enums import Tag from sagemaker.jumpstart import enums from sagemaker.jumpstart.utils import verify_model_region_and_return_specs, get_eula_message from sagemaker.model import Model @@ -105,6 +109,7 @@ def __init__(self): self.nb_instance_type = None self.ram_usage_model_load = None self.jumpstart = None + self.model_metadata = None @abstractmethod def _prepare_for_mode(self): @@ -520,6 +525,54 @@ def list_deployment_configs(self) -> List[Dict[str, Any]]: return self.pysdk_model.list_deployment_configs() + def _is_fine_tuned_model(self) -> bool: + """Checks whether a fine-tuned model exists.""" + return self.model_metadata and ( + self.model_metadata.get("FINE_TUNING_MODEL_PATH") + or self.model_metadata.get("FINE_TUNING_JOB_NAME") + ) + + def _update_model_data_for_fine_tuned_model(self, pysdk_model: Type[Model]) -> Type[Model]: + """Set the model path and data and add fine-tuning tags for the model.""" + # TODO: determine precedence of FINE_TUNING_MODEL_PATH and FINE_TUNING_JOB_NAME + if fine_tuning_model_path := self.model_metadata.get("FINE_TUNING_MODEL_PATH"): + if not re.match("^(https|s3)://([^/]+)/?(.*)$", fine_tuning_model_path): + raise ValueError( + f"Invalid path for FINE_TUNING_MODEL_PATH: {fine_tuning_model_path}." + ) + pysdk_model.model_data["S3DataSource"]["S3Uri"] = fine_tuning_model_path + pysdk_model.add_tags( + {"key": Tag.FINE_TUNING_MODEL_PATH, "value": fine_tuning_model_path} + ) + return pysdk_model + + if fine_tuning_job_name := self.model_metadata.get("FINE_TUNING_JOB_NAME"): + try: + response = self.sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=fine_tuning_job_name + ) + fine_tuning_model_path = response["OutputDataConfig"]["S3OutputPath"] + pysdk_model.model_data["S3DataSource"]["S3Uri"] = fine_tuning_model_path + pysdk_model.model_data["S3DataSource"]["CompressionType"] = response[ + "OutputDataConfig" + ]["CompressionType"] + pysdk_model.add_tags( + [ + {"key": Tag.FINE_TUNING_JOB_NAME, "value": fine_tuning_job_name}, + {"key": Tag.FINE_TUNING_MODEL_PATH, "value": fine_tuning_model_path}, + ] + ) + return pysdk_model + except ClientError: + raise ValueError( + f"Invalid job name for FINE_TUNING_JOB_NAME: {fine_tuning_job_name}." + ) + + raise ValueError( + "Input model not found. Please provide either `model_path`, or " + "`FINE_TUNING_MODEL_PATH` or `FINE_TUNING_JOB_NAME` under `model_metadata`." + ) + def _build_for_jumpstart(self): """Placeholder docstring""" if hasattr(self, "pysdk_model") and self.pysdk_model is not None: @@ -534,6 +587,9 @@ def _build_for_jumpstart(self): logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri) + if self._is_fine_tuned_model(): + pysdk_model = self._update_model_data_for_fine_tuned_model(pysdk_model) + if self._is_gated_model(pysdk_model) and self.mode != Mode.SAGEMAKER_ENDPOINT: raise ValueError( "JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode." @@ -714,7 +770,7 @@ def _optimize_for_jumpstart( **create_optimization_job_args ) - def _is_gated_model(self, model) -> bool: + def _is_gated_model(self, model: Model) -> bool: """Determine if ``this`` Model is Gated Args: diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 5062e55d8c..892477d0b0 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. """Holds the ModelBuilder class and the ModelServer enum.""" from __future__ import absolute_import + import uuid from typing import Any, Type, List, Dict, Optional, Union from dataclasses import dataclass, field @@ -278,8 +279,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, default=None, metadata={ "help": "Define the model metadata to override, currently supports `HF_TASK`, " - "`MLFLOW_MODEL_PATH`. HF_TASK should be set for new models without task metadata in " - "the Hub, Adding unsupported task types will throw an exception" + "`MLFLOW_MODEL_PATH`, `FINE_TUNING_MODEL_PATH`, and `FINE_TUNING_JOB_NAME`. HF_TASK " + "should be set for new models without task metadata in the Hub, Adding unsupported " + "task types will throw an exception." }, ) @@ -739,8 +741,8 @@ def build( # pylint: disable=R0911 ) self.serve_settings = self._get_serve_setting() - self._is_custom_image_uri = self.image_uri is not None + self._is_mlflow_model = self._check_if_input_is_mlflow_model() if self._is_mlflow_model: logger.warning( @@ -925,7 +927,7 @@ def _try_fetch_gpu_info(self): f"Unable to determine single GPU size for instance: [{self.instance_type}]" ) - def optimize(self, *args, **kwargs) -> Type[Model]: + def optimize(self, *args, **kwargs) -> Model: """Runs a model optimization job. Args: @@ -948,7 +950,7 @@ def optimize(self, *args, **kwargs) -> Type[Model]: function creates one using the default AWS configuration chain. Returns: - Type[Model]: A deployable ``Model`` object. + Model: A deployable ``Model`` object. """ # need to get telemetry_opt_out info before telemetry decorator is called self.serve_settings = self._get_serve_setting() @@ -972,7 +974,7 @@ def _model_builder_optimize_wrapper( kms_key: Optional[str] = None, max_runtime_in_sec: Optional[int] = None, sagemaker_session: Optional[Session] = None, - ) -> Type[Model]: + ) -> Model: """Runs a model optimization job. Args: @@ -1002,7 +1004,7 @@ def _model_builder_optimize_wrapper( function creates one using the default AWS configuration chain. Returns: - Type[Model]: A deployable ``Model`` object. + Model: A deployable ``Model`` object. """ self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() self.build(mode=self.mode, sagemaker_session=self.sagemaker_session) From 701b788c35dfaf7201ecf8daa97044685abfcb1e Mon Sep 17 00:00:00 2001 From: Jonathan Makunga <54963715+makungaj1@users.noreply.github.com> Date: Mon, 17 Jun 2024 07:46:32 -0700 Subject: [PATCH 11/45] update: Add optimize to ModelBuilder JS (#1480) * Testing with Notebook * Refactoring * _poll_optimization_job refactoring * Resolve PR Comments * Refactoring * Refactoring * refactoring * Fix conflicts * Notebook testing * Refactoring * Refactoring --------- Co-authored-by: Jonathan Makunga --- src/sagemaker/enums.py | 1 + .../serve/builder/jumpstart_builder.py | 206 ++++++++++++------ src/sagemaker/serve/builder/model_builder.py | 26 +-- src/sagemaker/serve/utils/optimize_utils.py | 105 ++++----- src/sagemaker/session.py | 43 ++++ .../serve/utils/test_optimize_utils.py | 75 ++----- 6 files changed, 252 insertions(+), 204 deletions(-) diff --git a/src/sagemaker/enums.py b/src/sagemaker/enums.py index f648fc9896..caa0d77175 100644 --- a/src/sagemaker/enums.py +++ b/src/sagemaker/enums.py @@ -46,5 +46,6 @@ class Tag(str, Enum): """Enum class for tag keys to apply to models.""" OPTIMIZATION_JOB_NAME = "sagemaker-sdk:optimization-job-name" + SPECULATIVE_DRAFT_MODL_PROVIDER = "sagemaker-sdk:speculative-draft-model-provider" FINE_TUNING_MODEL_PATH = "sagemaker-sdk:fine-tuning-model-path" FINE_TUNING_JOB_NAME = "sagemaker-sdk:fine-tuning-job-name" diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 98dec6a171..3011fe6a33 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -23,8 +23,6 @@ from botocore.exceptions import ClientError from sagemaker.enums import Tag -from sagemaker.jumpstart import enums -from sagemaker.jumpstart.utils import verify_model_region_and_return_specs, get_eula_message from sagemaker.model import Model from sagemaker import model_uris from sagemaker.serve.model_server.djl_serving.prepare import prepare_djl_js_resources @@ -40,9 +38,9 @@ SkipTuningComboException, ) from sagemaker.serve.utils.optimize_utils import ( - _extract_supported_deployment_config, - _is_speculation_enabled, _is_compatible_with_optimization_job, + _extract_model_source, + _update_environment_variables, ) from sagemaker.serve.utils.predictors import ( DjlLocalModePredictor, @@ -643,7 +641,7 @@ def _optimize_for_jumpstart( vpc_config: Optional[Dict] = None, kms_key: Optional[str] = None, max_runtime_in_sec: Optional[int] = None, - ) -> None: + ) -> Dict[str, Any]: """Runs a model optimization job. Args: @@ -669,79 +667,60 @@ def _optimize_for_jumpstart( to S3. Defaults to ``None``. max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to ``None``. - """ - model_specs = verify_model_region_and_return_specs( - region=self.sagemaker_session.boto_region_name, - model_id=self.pysdk_model.model_id, - version=self.pysdk_model.model_version, - sagemaker_session=self.sagemaker_session, - scope=enums.JumpStartScriptScope.INFERENCE, - model_type=self.pysdk_model.model_type, - ) - if model_specs.is_gated_model() and accept_eula is not True: - raise ValueError(get_eula_message(model_specs, self.sagemaker_session.boto_region_name)) - - if not (self.pysdk_model.model_data and self.pysdk_model.model_data.get("S3DataSource")): - raise ValueError("Model Optimization Job only supports model backed by S3.") + Returns: + Dict[str, Any]: Model optimization job input arguments. + """ + if self._is_gated_model() and accept_eula is not True: + raise ValueError( + f"ValueError: Model '{self.model}' " + f"requires accepting end-user license agreement (EULA)." + ) - has_alternative_config = self.pysdk_model.deployment_config is not None - merged_env_vars = None - # TODO: Match Optimization Input Schema - model_source = { - "S3": {"S3Uri": self.pysdk_model.model_data.get("S3DataSource").get("S3Uri")}, - "SageMakerModel": {"ModelName": self.model}, - } + optimization_env_vars = None + pysdk_model_env_vars = None + model_source = _extract_model_source(self.pysdk_model.model_data, accept_eula) - if has_alternative_config: - image_uri = self.pysdk_model.deployment_config.get("DeploymentArgs").get("ImageUri") - instance_type = self.pysdk_model.deployment_config.get("InstanceType") + if speculative_decoding_config: + self._set_additional_model_source(speculative_decoding_config) + optimization_env_vars = self.pysdk_model.deployment_config.get("DeploymentArgs").get( + "Environment" + ) else: - image_uri = self.pysdk_model.image_uri - - if not _is_compatible_with_optimization_job(instance_type, image_uri) or ( - speculative_decoding_config - and not _is_speculation_enabled(self.pysdk_model.deployment_config) - ): - deployment_config = _extract_supported_deployment_config( - self.pysdk_model.list_deployment_configs(), speculative_decoding_config is None + image_uri = None + if quantization_config and quantization_config.get("Image"): + image_uri = quantization_config.get("Image") + elif compilation_config and compilation_config.get("Image"): + image_uri = compilation_config.get("Image") + instance_type = ( + instance_type + or self.pysdk_model.deployment_config.get("DeploymentArgs").get("InstanceType") + or _get_nb_instance() ) + if not _is_compatible_with_optimization_job(instance_type, image_uri): + deployment_config = self._find_compatible_deployment_config(None) + if deployment_config: + optimization_env_vars = deployment_config.get("DeploymentArgs").get( + "Environment" + ) + self.pysdk_model.set_deployment_config( + config_name=deployment_config.get("DeploymentConfigName"), + instance_type=deployment_config.get("InstanceType"), + ) - if deployment_config: - self.pysdk_model.set_deployment_config( - config_name=deployment_config.get("DeploymentConfigName"), - instance_type=deployment_config.get("InstanceType"), - ) - merged_env_vars = self.pysdk_model.deployment_config.get("Environment") - - if speculative_decoding_config: - # TODO: Match Optimization Input Schema - s3 = { - "S3Uri": self.pysdk_model.additional_model_data_sources[ - "SpeculativeDecoding" - ][0]["S3DataSource"]["S3Uri"] - } - model_source["S3"].update(s3) - elif speculative_decoding_config: - raise ValueError("Can't find deployment config for model optimization job.") + optimization_env_vars = _update_environment_variables(optimization_env_vars, env_vars) optimization_config = {} - if env_vars: - if merged_env_vars: - merged_env_vars.update(env_vars) - else: - merged_env_vars = env_vars if quantization_config: optimization_config["ModelQuantizationConfig"] = quantization_config + pysdk_model_env_vars = _update_environment_variables( + pysdk_model_env_vars, quantization_config["OverrideEnvironment"] + ) if compilation_config: optimization_config["ModelCompilationConfig"] = compilation_config - - if accept_eula: - self.pysdk_model.accept_eula = accept_eula - self.pysdk_model.model_data["S3DataSource"].update( - {"ModelAccessConfig": {"AcceptEula": accept_eula}} + pysdk_model_env_vars = _update_environment_variables( + pysdk_model_env_vars, compilation_config["OverrideEnvironment"] ) - model_source["S3"].update({"ModelAccessConfig": {"AcceptEula": accept_eula}}) output_config = {"S3OutputLocation": output_path} if kms_key: @@ -751,12 +730,13 @@ def _optimize_for_jumpstart( "OptimizationJobName": job_name, "ModelSource": model_source, "DeploymentInstanceType": instance_type, - "Environment": merged_env_vars, "OptimizationConfigs": [optimization_config], "OutputConfig": output_config, "RoleArn": role, } + if optimization_env_vars: + create_optimization_job_args["Environment"] = optimization_env_vars if max_runtime_in_sec: create_optimization_job_args["StoppingCondition"] = { "MaxRuntimeInSeconds": max_runtime_in_sec @@ -766,11 +746,10 @@ def _optimize_for_jumpstart( if vpc_config: create_optimization_job_args["VpcConfig"] = vpc_config - self.sagemaker_session.sagemaker_client.create_optimization_job( - **create_optimization_job_args - ) + self.pysdk_model.env.update(pysdk_model_env_vars) + return create_optimization_job_args - def _is_gated_model(self, model: Model) -> bool: + def _is_gated_model(self, model=None) -> bool: """Determine if ``this`` Model is Gated Args: @@ -778,10 +757,95 @@ def _is_gated_model(self, model: Model) -> bool: Returns: bool: ``True`` if ``this`` Model is Gated """ - s3_uri = model.model_data + s3_uri = model.model_data if model else self.pysdk_model.model_data if isinstance(s3_uri, dict): s3_uri = s3_uri.get("S3DataSource").get("S3Uri") if s3_uri is None: return False return "private" in s3_uri + + def _set_additional_model_source( + self, speculative_decoding_config: Optional[Dict[str, Any]] = None + ) -> None: + """Set Additional Model Source to ``this`` model. + + Args: + speculative_decoding_config (Optional[Dict[str, Any]]): Speculative decoding config. + """ + if speculative_decoding_config: + model_provider: str = speculative_decoding_config["ModelProvider"] + + if model_provider.lower() == "sagemaker": + if not self._is_speculation_enabled(self.pysdk_model.deployment_config): + deployment_config = self._find_compatible_deployment_config( + speculative_decoding_config + ) + if deployment_config: + self.pysdk_model.set_deployment_config( + config_name=deployment_config.get("DeploymentConfigName"), + instance_type=deployment_config.get("InstanceType"), + ) + self.pysdk_model.add_tags( + {"key": Tag.SPECULATIVE_DRAFT_MODL_PROVIDER, "value": "sagemaker"}, + ) + else: + raise ValueError( + "Cannot find deployment config compatible for optimization job." + ) + else: + s3_uri = speculative_decoding_config.get("ModelSource") + if not s3_uri: + raise ValueError("Custom S3 Uri cannot be none.") + + self.pysdk_model.additional_model_data_sources["speculative_decoding"][0][ + "s3_data_source" + ]["s3_uri"] = s3_uri + self.pysdk_model.add_tags( + {"key": Tag.SPECULATIVE_DRAFT_MODL_PROVIDER, "value": "customer"}, + ) + + def _find_compatible_deployment_config( + self, speculative_decoding_config: Optional[Dict] = None + ) -> Optional[Dict[str, Any]]: + """Finds compatible model deployment config for optimization job. + + Args: + speculative_decoding_config (Optional[Dict]): Speculative decoding config. + + Returns: + Optional[Dict[str, Any]]: A compatible model deployment config for optimization job. + """ + for deployment_config in self.pysdk_model.list_deployment_configs(): + instance_type = deployment_config.get("deployment_config").get("InstanceType") + image_uri = deployment_config.get("deployment_config").get("ImageUri") + + if _is_compatible_with_optimization_job(instance_type, image_uri): + if not speculative_decoding_config: + return deployment_config + + if self._is_speculation_enabled(deployment_config): + return deployment_config + + return None + + def _is_speculation_enabled(self, deployment_config: Optional[Dict[str, Any]]) -> bool: + """Checks whether speculative is enabled for the given deployment config. + + Args: + deployment_config (Dict[str, Any]): A deployment config. + + Returns: + bool: Whether speculative is enabled for this deployment config. + """ + if deployment_config is None: + return False + + acceleration_configs = deployment_config.get("AccelerationConfigs") + if acceleration_configs: + for acceleration_config in acceleration_configs: + if acceleration_config.get( + "type", "default" + ).lower() == "speculative" and acceleration_config.get("enabled"): + return True + return False diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 892477d0b0..8f03f3aa87 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -63,7 +63,7 @@ from sagemaker.serve.utils import task from sagemaker.serve.utils.exceptions import TaskNotFoundException from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model -from sagemaker.serve.utils.optimize_utils import _poll_optimization_job, _generate_optimized_model +from sagemaker.serve.utils.optimize_utils import _generate_optimized_model from sagemaker.serve.utils.predictors import _get_local_mode_predictor from sagemaker.serve.utils.hardware_detector import ( _get_gpu_info, @@ -972,7 +972,7 @@ def _model_builder_optimize_wrapper( env_vars: Optional[Dict] = None, vpc_config: Optional[Dict] = None, kms_key: Optional[str] = None, - max_runtime_in_sec: Optional[int] = None, + max_runtime_in_sec: Optional[int] = 36000, sagemaker_session: Optional[Session] = None, ) -> Model: """Runs a model optimization job. @@ -998,7 +998,7 @@ def _model_builder_optimize_wrapper( kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading to S3. Defaults to ``None``. max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to - ``None``. + 36000 seconds. sagemaker_session (Optional[Session]): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the function creates one using the default AWS configuration chain. @@ -1010,8 +1010,9 @@ def _model_builder_optimize_wrapper( self.build(mode=self.mode, sagemaker_session=self.sagemaker_session) job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}" + input_args = {} if self._is_jumpstart_model_id(): - self._optimize_for_jumpstart( + input_args = self._optimize_for_jumpstart( output_path=output_path, instance_type=instance_type, role=role if role else self.role_arn, @@ -1027,19 +1028,8 @@ def _model_builder_optimize_wrapper( max_runtime_in_sec=max_runtime_in_sec, ) - # TODO: use the wait for job pattern similar to - # https://quip-amazon.com/TKaPAhJck5sD/PySDK-Model-Optimization#temp:C:YcX3f2b103dabb4431090568bca2 - if not _poll_optimization_job(job_name, self.sagemaker_session): - raise Exception("Optimization job timed out.") - - describe_optimization_job_res = ( - self.sagemaker_session.sagemaker_client.describe_optimization_job( - OptimizationJobName=job_name - ) - ) - - self.pysdk_model = _generate_optimized_model( - self.pysdk_model, describe_optimization_job_res - ) + self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args) + job_status = self.sagemaker_session.wait_for_optimization_job(job_name) + self.pysdk_model = _generate_optimized_model(self.pysdk_model, job_status) return self.pysdk_model diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index 305758e502..4a9babdbb8 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -14,43 +14,17 @@ from __future__ import absolute_import import re -import time import logging -from typing import List, Dict, Any, Optional +from typing import Dict, Any, Optional, Union -from sagemaker import Session, Model +from sagemaker import Model from sagemaker.enums import Tag from sagemaker.fw_utils import _is_gpu_instance -from sagemaker.jumpstart.utils import _extract_image_tag_and_version -# TODO: determine how long optimization jobs take -OPTIMIZE_POLLER_MAX_TIMEOUT_SECS = 300 -OPTIMIZE_POLLER_INTERVAL_SECS = 30 logger = logging.getLogger(__name__) -def _poll_optimization_job(job_name: str, sagemaker_session: Session) -> bool: - """Polls optimization job status until success. - - Args: - job_name (str): The name of the optimization job. - sagemaker_session (Session): Session object which manages interactions - with Amazon SageMaker APIs and any other AWS services needed. - - Returns: - bool: Whether the optimization job was successful. - """ - logger.info("Polling status of optimization job %s", job_name) - start_time = time.time() - while time.time() - start_time < OPTIMIZE_POLLER_MAX_TIMEOUT_SECS: - result = sagemaker_session.sagemaker_client.describe_optimization_job(job_name) - # TODO: use correct condition to determine whether optimization job is complete - if result is not None: - return result - time.sleep(OPTIMIZE_POLLER_INTERVAL_SECS) - - def _is_inferentia_or_trainium(instance_type: Optional[str]) -> bool: """Checks whether an instance is compatible with Inferentia. @@ -80,17 +54,18 @@ def _is_compatible_with_optimization_job( Returns: bool: Whether the given instance type is compatible with an optimization job. """ - image_tag, image_version = _extract_image_tag_and_version(image_uri) - if not image_tag or not image_version: + if not instance_type: return False + compatible_image = True + if image_uri: + compatible_image = "djl-inference:" in image_uri and ( + "-lmi" in image_uri or "-neuronx-" in image_uri + ) + return ( - _is_gpu_instance(instance_type) and "djl-inference:" in image_uri and "-lmi" in image_tag - ) or ( - _is_inferentia_or_trainium(instance_type) - and "djl-inference:" in image_uri - and "-neuronx-s" in image_tag - ) + _is_gpu_instance(instance_type) or _is_inferentia_or_trainium(instance_type) + ) and compatible_image def _generate_optimized_model(pysdk_model: Model, optimization_response: dict) -> Model: @@ -136,30 +111,46 @@ def _is_speculation_enabled(deployment_config: Optional[Dict[str, Any]]) -> bool return False -def _extract_supported_deployment_config( - deployment_configs: Optional[List[Dict[str, Any]]], - speculation_enabled: Optional[bool] = False, +def _extract_model_source( + model_data: Optional[Union[Dict[str, Any], str]], accept_eula: Optional[bool] ) -> Optional[Dict[str, Any]]: - """Extracts supported deployment configurations. + """Extracts model source from model data. + + Args: + model_data (Optional[Union[Dict[str, Any], str]]): A model data. + + Returns: + Optional[Dict[str, Any]]: Model source data. + """ + if model_data is None: + raise ValueError("Model Optimization Job only supports model with S3 data source.") + + s3_uri = model_data + if isinstance(s3_uri, dict): + s3_uri = s3_uri.get("S3DataSource").get("S3Uri") + + # Todo: Inject fine-tune data source + model_source = {"S3": {"S3Uri": s3_uri}} + if accept_eula: + model_source["S3"]["ModelAccessConfig"] = {"AcceptEula": True} + return model_source + + +def _update_environment_variables( + env: Optional[Dict[str, str]], new_env: Optional[Dict[str, str]] +) -> Optional[Dict[str, str]]: + """Updates environment variables based on environment variables. Args: - deployment_configs (Optional[List[Dict[str, Any]]]): A list of deployment configurations. - speculation_enabled (Optional[bool]): Whether speculation is enabled. + env (Optional[Dict[str, str]]): The environment variables. + new_env (Optional[Dict[str, str]]): The new environment variables. Returns: - Optional[Dict[str, Any]]: Supported deployment configuration. + Optional[Dict[str, str]]: The updated environment variables. """ - if deployment_configs is None: - return None - - for deployment_config in deployment_configs: - image_uri: str = deployment_config.get("DeploymentArgs").get("ImageUri") - instance_type = deployment_config.get("InstanceType") - - if _is_compatible_with_optimization_job(instance_type, image_uri): - if speculation_enabled: - if _is_speculation_enabled(deployment_config): - return deployment_config - else: - return deployment_config - return None + if new_env: + if env: + env.update(new_env) + else: + env = new_env + return env diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index c33104fa95..596608be8a 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2591,6 +2591,24 @@ def wait_for_auto_ml_job(self, job, poll=5): _check_job_status(job, desc, "AutoMLJobStatus") return desc + def wait_for_optimization_job(self, job, poll=5): + """Wait for an Amazon SageMaker Optimization job to complete. + + Args: + job (str): Name of optimization job to wait for. + poll (int): Polling interval in seconds (default: 5). + + Returns: + (dict): Return value from the ``DescribeOptimizationJob`` API. + + Raises: + exceptions.ResourceNotFound: If optimization job fails with CapacityError. + exceptions.UnexpectedStatusException: If optimization job fails. + """ + desc = _wait_until(lambda: _optimization_job_status(self.sagemaker_client, job), poll) + _check_job_status(job, desc, "OptimizationJobStatus") + return desc + def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this method self, job_name, wait=False, poll=10 ): @@ -7609,6 +7627,31 @@ def _auto_ml_job_status(sagemaker_client, job_name): return desc +def _optimization_job_status(sagemaker_client, job_name): + """Placeholder docstring""" + optimization_job_status_codes = { + "INPROGRESS": "!", + "COMPLETED": ".", + "FAILED": "*", + "STARTING": "s", + "STOPPING": "_", + "STOPPED": ",", + } + in_progress_statuses = ["INPROGRESS", "STARTING", "STOPPING"] + + desc = sagemaker_client.describe_optimization_job(OptimizationJobName=job_name) + status = desc["OptimizationJobStatus"] + + print(optimization_job_status_codes.get(status, "?"), end="") + sys.stdout.flush() + + if status in in_progress_statuses: + return None + + print("") + return desc + + def _create_model_package_status(sagemaker_client, model_package_name): """Placeholder docstring""" in_progress_statuses = ["InProgress", "Pending"] diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py index 03e70b0ad8..1740176d61 100644 --- a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -20,15 +20,16 @@ from sagemaker.serve.utils.optimize_utils import ( _generate_optimized_model, _is_speculation_enabled, - _extract_supported_deployment_config, _is_inferentia_or_trainium, _is_compatible_with_optimization_job, + _update_environment_variables, ) mock_optimization_job_output = { "OptimizationJobName": "optimization_job_name", "RecommendedInferenceImage": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" "huggingface-pytorch-tgi-inference:2.1.1-tgi2.0.0-gpu-py310-cu121-ubuntu22.04", + "OptimizationJobStatus": "COMPLETED", "OptimizationEnvironment": { "SAGEMAKER_PROGRAM": "inference.py", "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -76,7 +77,7 @@ def test_is_inferentia_or_trainium(instance, expected): ), ( "ml.inf2.xlarge", - "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-neuronx-sdk2.18.2", + None, True, ), ( @@ -85,68 +86,13 @@ def test_is_inferentia_or_trainium(instance, expected): "2.1.1-tgi2.0.0-gpu-py310-cu121-ubuntu22.04", False, ), + (None, None, False), ], ) def test_is_compatible_with_optimization_job(instance, image_uri, expected): assert _is_compatible_with_optimization_job(instance, image_uri) == expected -@pytest.mark.parametrize( - "deployment_configs, expected", - [ - ( - [ - { - "InstanceType": "ml.c7gd.4xlarge", - "DeploymentArgs": { - "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124" - }, - "AccelerationConfigs": [ - { - "type": "acceleration", - "enabled": True, - "spec": {"compiler": "a", "version": "1"}, - } - ], - } - ], - None, - ), - ( - [ - { - "InstanceType": "ml.g5.12xlarge", - "DeploymentArgs": { - "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124" - }, - "AccelerationConfigs": [ - { - "type": "speculation", - "enabled": True, - } - ], - } - ], - { - "InstanceType": "ml.g5.12xlarge", - "DeploymentArgs": { - "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124" - }, - "AccelerationConfigs": [ - { - "type": "speculation", - "enabled": True, - } - ], - }, - ), - (None, None), - ], -) -def test_extract_supported_deployment_config(deployment_configs, expected): - assert _extract_supported_deployment_config(deployment_configs, True) == expected - - def test_generate_optimized_model(): pysdk_model = Mock() pysdk_model.model_data = {"S3DataSource": {"S3Uri": "s3://foo/bar"}} @@ -199,3 +145,16 @@ def test_generate_optimized_model(): ) def test_is_speculation_enabled(deployment_config, expected): assert _is_speculation_enabled(deployment_config) is expected + + +@pytest.mark.parametrize( + "env, new_env, output_env", + [ + ({"a": "1"}, {"b": "2"}, {"a": "1", "b": "2"}), + (None, {"b": "2"}, {"b": "2"}), + ({"a": "1"}, None, {"a": "1"}), + (None, None, None), + ], +) +def test_update_environment_variables(env, new_env, output_env): + assert _update_environment_variables(env, new_env) == output_env From f3b3504f25ebc69f6f3e3e08367d8cac6cde71c8 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga <54963715+makungaj1@users.noreply.github.com> Date: Mon, 17 Jun 2024 19:18:10 -0700 Subject: [PATCH 12/45] update: Add optimize to ModelBuilder JS (#1485) * MB JS Optimize * UT * Refactore * UT * UT * refactore * refactore --------- Co-authored-by: Jonathan Makunga --- src/sagemaker/jumpstart/types.py | 2 + src/sagemaker/jumpstart/utils.py | 17 --- .../serve/builder/jumpstart_builder.py | 125 +++++++++--------- src/sagemaker/serve/utils/optimize_utils.py | 103 +++++++++------ tests/unit/sagemaker/jumpstart/test_utils.py | 15 --- .../serve/builder/test_js_builder.py | 10 +- .../serve/utils/test_optimize_utils.py | 93 +++++++------ 7 files changed, 179 insertions(+), 186 deletions(-) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 13eb9e80bb..2561dbc237 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -2568,6 +2568,7 @@ class DeploymentArgs(BaseDeploymentConfigDataHolder): "compute_resource_requirements", "model_data_download_timeout", "container_startup_health_check_timeout", + "additional_data_sources", ] def __init__( @@ -2597,6 +2598,7 @@ def __init__( self.supported_instance_types = resolved_config.get( "supported_inference_instance_types" ) + self.additional_data_sources = resolved_config.get("hosting_additional_data_sources") class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder): diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 48b95bf887..559a960588 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1364,20 +1364,3 @@ def wrapped_f(*args, **kwargs): if _func is None: return wrapper_cache return wrapper_cache(_func) - - -def _extract_image_tag_and_version(image_uri: str) -> Tuple[Optional[str], Optional[str]]: - """Extract Image tag and version from image URI. - - Args: - image_uri (str): Image URI. - - Returns: - Tuple[Optional[str], Optional[str]]: The tag and version of the image. - """ - if image_uri is None: - return None, None - - tag = image_uri.split(":")[-1] - - return tag, tag.split("-")[0] diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 3011fe6a33..bd1acec5e6 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -38,9 +38,11 @@ SkipTuningComboException, ) from sagemaker.serve.utils.optimize_utils import ( - _is_compatible_with_optimization_job, _extract_model_source, _update_environment_variables, + _extract_speculative_draft_model_provider, + _is_image_compatible_with_optimization_job, + _validate_optimization_inputs, ) from sagemaker.serve.utils.predictors import ( DjlLocalModePredictor, @@ -628,7 +630,7 @@ def _build_for_jumpstart(self): def _optimize_for_jumpstart( self, - output_path: str, + output_path: Optional[str] = None, instance_type: Optional[str] = None, role: Optional[str] = None, tags: Optional[Tags] = None, @@ -645,7 +647,7 @@ def _optimize_for_jumpstart( """Runs a model optimization job. Args: - output_path (str): Specifies where to store the compiled/quantized model. + output_path (Optional[str]): Specifies where to store the compiled/quantized model. instance_type (Optional[str]): Target deployment instance type that the model is optimized for. role (Optional[str]): Execution role. Defaults to ``None``. @@ -673,40 +675,30 @@ def _optimize_for_jumpstart( """ if self._is_gated_model() and accept_eula is not True: raise ValueError( - f"ValueError: Model '{self.model}' " - f"requires accepting end-user license agreement (EULA)." + f"Model '{self.model}' requires accepting end-user license agreement (EULA)." ) + _validate_optimization_inputs( + output_path, instance_type, quantization_config, compilation_config + ) + optimization_env_vars = None pysdk_model_env_vars = None model_source = _extract_model_source(self.pysdk_model.model_data, accept_eula) if speculative_decoding_config: self._set_additional_model_source(speculative_decoding_config) - optimization_env_vars = self.pysdk_model.deployment_config.get("DeploymentArgs").get( - "Environment" - ) + optimization_env_vars = self.pysdk_model.deployment_config.get( + "DeploymentArgs", {} + ).get("Environment") else: - image_uri = None - if quantization_config and quantization_config.get("Image"): - image_uri = quantization_config.get("Image") - elif compilation_config and compilation_config.get("Image"): - image_uri = compilation_config.get("Image") - instance_type = ( - instance_type - or self.pysdk_model.deployment_config.get("DeploymentArgs").get("InstanceType") - or _get_nb_instance() - ) - if not _is_compatible_with_optimization_job(instance_type, image_uri): - deployment_config = self._find_compatible_deployment_config(None) - if deployment_config: - optimization_env_vars = deployment_config.get("DeploymentArgs").get( - "Environment" - ) - self.pysdk_model.set_deployment_config( - config_name=deployment_config.get("DeploymentConfigName"), - instance_type=deployment_config.get("InstanceType"), - ) + deployment_config = self._find_compatible_deployment_config(None) + if deployment_config: + optimization_env_vars = deployment_config.get("DeploymentArgs").get("Environment") + self.pysdk_model.set_deployment_config( + config_name=deployment_config.get("DeploymentConfigName"), + instance_type=deployment_config.get("InstanceType"), + ) optimization_env_vars = _update_environment_variables(optimization_env_vars, env_vars) @@ -736,7 +728,7 @@ def _optimize_for_jumpstart( } if optimization_env_vars: - create_optimization_job_args["Environment"] = optimization_env_vars + create_optimization_job_args["OptimizationEnvironment"] = optimization_env_vars if max_runtime_in_sec: create_optimization_job_args["StoppingCondition"] = { "MaxRuntimeInSeconds": max_runtime_in_sec @@ -766,18 +758,26 @@ def _is_gated_model(self, model=None) -> bool: return "private" in s3_uri def _set_additional_model_source( - self, speculative_decoding_config: Optional[Dict[str, Any]] = None + self, + speculative_decoding_config: Optional[Dict[str, Any]] = None, + accept_eula: Optional[bool] = None, ) -> None: """Set Additional Model Source to ``this`` model. Args: speculative_decoding_config (Optional[Dict[str, Any]]): Speculative decoding config. + accept_eula (Optional[bool]): For models that require a Model Access Config. """ if speculative_decoding_config: - model_provider: str = speculative_decoding_config["ModelProvider"] + model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config) if model_provider.lower() == "sagemaker": - if not self._is_speculation_enabled(self.pysdk_model.deployment_config): + if ( + self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get( + "AdditionalDataSources" + ) + is None + ): deployment_config = self._find_compatible_deployment_config( speculative_decoding_config ) @@ -786,21 +786,30 @@ def _set_additional_model_source( config_name=deployment_config.get("DeploymentConfigName"), instance_type=deployment_config.get("InstanceType"), ) - self.pysdk_model.add_tags( - {"key": Tag.SPECULATIVE_DRAFT_MODL_PROVIDER, "value": "sagemaker"}, - ) else: raise ValueError( "Cannot find deployment config compatible for optimization job." ) + + self.pysdk_model.add_tags( + {"key": Tag.SPECULATIVE_DRAFT_MODL_PROVIDER, "value": "sagemaker"}, + ) else: s3_uri = speculative_decoding_config.get("ModelSource") if not s3_uri: raise ValueError("Custom S3 Uri cannot be none.") - self.pysdk_model.additional_model_data_sources["speculative_decoding"][0][ - "s3_data_source" - ]["s3_uri"] = s3_uri + # TODO: Set correct channel name. + additional_model_data_source = { + "ChannelName": "DraftModelName", + "S3DataSource": {"S3Uri": s3_uri}, + } + if accept_eula: + additional_model_data_source["S3DataSource"]["ModelAccessConfig"] = { + "ACCEPT_EULA": True + } + + self.pysdk_model.additional_model_data_sources = [additional_model_data_source] self.pysdk_model.add_tags( {"key": Tag.SPECULATIVE_DRAFT_MODL_PROVIDER, "value": "customer"}, ) @@ -816,36 +825,20 @@ def _find_compatible_deployment_config( Returns: Optional[Dict[str, Any]]: A compatible model deployment config for optimization job. """ + model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config) for deployment_config in self.pysdk_model.list_deployment_configs(): - instance_type = deployment_config.get("deployment_config").get("InstanceType") - image_uri = deployment_config.get("deployment_config").get("ImageUri") - - if _is_compatible_with_optimization_job(instance_type, image_uri): - if not speculative_decoding_config: - return deployment_config + image_uri = deployment_config.get("deployment_config", {}).get("ImageUri") - if self._is_speculation_enabled(deployment_config): + if _is_image_compatible_with_optimization_job(image_uri): + if ( + model_provider == "sagemaker" + and deployment_config.get("DeploymentArgs", {}).get("AdditionalDataSources") + ) or model_provider == "custom": return deployment_config - return None - - def _is_speculation_enabled(self, deployment_config: Optional[Dict[str, Any]]) -> bool: - """Checks whether speculative is enabled for the given deployment config. + # There's no matching config from jumpstart to add sagemaker draft model location + if model_provider == "sagemaker": + return None - Args: - deployment_config (Dict[str, Any]): A deployment config. - - Returns: - bool: Whether speculative is enabled for this deployment config. - """ - if deployment_config is None: - return False - - acceleration_configs = deployment_config.get("AccelerationConfigs") - if acceleration_configs: - for acceleration_config in acceleration_configs: - if acceleration_config.get( - "type", "default" - ).lower() == "speculative" and acceleration_config.get("enabled"): - return True - return False + # fall back to the default jumpstart model deployment config for optimization job + return self.pysdk_model.deployment_config diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index 4a9babdbb8..826699beee 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -19,7 +19,6 @@ from sagemaker import Model from sagemaker.enums import Tag -from sagemaker.fw_utils import _is_gpu_instance logger = logging.getLogger(__name__) @@ -42,30 +41,19 @@ def _is_inferentia_or_trainium(instance_type: Optional[str]) -> bool: return False -def _is_compatible_with_optimization_job( - instance_type: Optional[str], image_uri: Optional[str] -) -> bool: +def _is_image_compatible_with_optimization_job(image_uri: Optional[str]) -> bool: """Checks whether an instance is compatible with an optimization job. Args: - instance_type (str): The instance type used for the compilation job. image_uri (str): The image URI of the optimization job. Returns: bool: Whether the given instance type is compatible with an optimization job. """ - if not instance_type: - return False - - compatible_image = True - if image_uri: - compatible_image = "djl-inference:" in image_uri and ( - "-lmi" in image_uri or "-neuronx-" in image_uri - ) - - return ( - _is_gpu_instance(instance_type) or _is_inferentia_or_trainium(instance_type) - ) and compatible_image + # TODO: Use specific container type instead. + if image_uri is None: + return True + return "djl-inference:" in image_uri and ("-lmi" in image_uri or "-neuronx-" in image_uri) def _generate_optimized_model(pysdk_model: Model, optimization_response: dict) -> Model: @@ -89,28 +77,6 @@ def _generate_optimized_model(pysdk_model: Model, optimization_response: dict) - return pysdk_model -def _is_speculation_enabled(deployment_config: Optional[Dict[str, Any]]) -> bool: - """Checks whether speculation is enabled for this deployment config. - - Args: - deployment_config (Dict[str, Any]): A deployment config. - - Returns: - bool: Whether the speculation is enabled for this deployment config. - """ - if deployment_config is None: - return False - - acceleration_configs = deployment_config.get("AccelerationConfigs") - if acceleration_configs: - for acceleration_config in acceleration_configs: - if acceleration_config.get("type").lower() == "speculation" and acceleration_config.get( - "enabled" - ): - return True - return False - - def _extract_model_source( model_data: Optional[Union[Dict[str, Any], str]], accept_eula: Optional[bool] ) -> Optional[Dict[str, Any]]: @@ -129,7 +95,6 @@ def _extract_model_source( if isinstance(s3_uri, dict): s3_uri = s3_uri.get("S3DataSource").get("S3Uri") - # Todo: Inject fine-tune data source model_source = {"S3": {"S3Uri": s3_uri}} if accept_eula: model_source["S3"]["ModelAccessConfig"] = {"AcceptEula": True} @@ -154,3 +119,61 @@ def _update_environment_variables( else: env = new_env return env + + +def _extract_speculative_draft_model_provider( + speculative_decoding_config: Optional[Dict] = None, +) -> Optional[str]: + """Extracts speculative draft model provider from speculative decoding config. + + Args: + speculative_decoding_config (Optional[Dict]): A speculative decoding config. + + Returns: + Optional[str]: The speculative draft model provider. + """ + if speculative_decoding_config is None: + return None + + if speculative_decoding_config.get( + "ModelProvider" + ) == "Custom" or speculative_decoding_config.get("ModelSource"): + return "custom" + + return "sagemaker" + + +def _validate_optimization_inputs( + output_path: Optional[str] = None, + instance_type: Optional[str] = None, + quantization_config: Optional[Dict] = None, + compilation_config: Optional[Dict] = None, +) -> None: + """Validates optimization inputs. + + Args: + output_path (Optional[str]): The output path. + instance_type (Optional[str]): The instance type. + quantization_config (Optional[Dict]): The quantization config. + compilation_config (Optional[Dict]): The compilation config. + + Raises: + ValueError: If an optimization input is invalid. + """ + if quantization_config and compilation_config: + raise ValueError("Quantization config and compilation config are mutually exclusive.") + + instance_type_msg = "Please provide an instance type for %s optimization job." + output_path_msg = "Please provide an output path for %s optimization job." + + if quantization_config: + if not instance_type: + raise ValueError(instance_type_msg.format("quantization")) + if not output_path: + raise ValueError(output_path_msg.format("quantization")) + + if compilation_config: + if not instance_type: + raise ValueError(instance_type_msg.format("compilation")) + if not output_path: + raise ValueError(output_path_msg.format("compilation")) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index bb5aa93d24..6cb8fbaa14 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -48,7 +48,6 @@ JumpStartModelHeader, JumpStartVersionedModelId, ) -from sagemaker.jumpstart.utils import _extract_image_tag_and_version from tests.unit.sagemaker.jumpstart.utils import ( get_base_spec_with_prototype_configs, get_spec_from_base_spec, @@ -2063,17 +2062,3 @@ def test_deployment_config_response_data(data, expected): print(out) assert out == expected - - -@pytest.mark.parametrize( - "image_uri, expected", - [ - ( - "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124", - ("0.28.0-lmi10.0.0-cu124", "0.28.0"), - ), - (None, (None, None)), - ], -) -def test_extract_image_tag_and_version(image_uri, expected): - assert _extract_image_tag_and_version(image_uri) == expected diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index e38317067c..48d12acb3e 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -304,7 +304,7 @@ def test_tune_for_tgi_js_local_container_sharding_not_supported( ) @patch( "sagemaker.serve.builder.djl_builder._serial_benchmark", - **{"return_value.raiseError.side_effect": LocalDeepPingException("mock_exception")} + **{"return_value.raiseError.side_effect": LocalDeepPingException("mock_exception")}, ) def test_tune_for_tgi_js_local_container_deep_ping_ex( self, @@ -354,7 +354,7 @@ def test_tune_for_tgi_js_local_container_deep_ping_ex( ) @patch( "sagemaker.serve.builder.djl_builder._serial_benchmark", - **{"return_value.raiseError.side_effect": LocalModelLoadException("mock_exception")} + **{"return_value.raiseError.side_effect": LocalModelLoadException("mock_exception")}, ) def test_tune_for_tgi_js_local_container_load_ex( self, @@ -404,7 +404,7 @@ def test_tune_for_tgi_js_local_container_load_ex( ) @patch( "sagemaker.serve.builder.djl_builder._serial_benchmark", - **{"return_value.raiseError.side_effect": LocalModelOutOfMemoryException("mock_exception")} + **{"return_value.raiseError.side_effect": LocalModelOutOfMemoryException("mock_exception")}, ) def test_tune_for_tgi_js_local_container_oom_ex( self, @@ -454,7 +454,7 @@ def test_tune_for_tgi_js_local_container_oom_ex( ) @patch( "sagemaker.serve.builder.djl_builder._serial_benchmark", - **{"return_value.raiseError.side_effect": LocalModelInvocationException("mock_exception")} + **{"return_value.raiseError.side_effect": LocalModelInvocationException("mock_exception")}, ) def test_tune_for_tgi_js_local_container_invoke_ex( self, @@ -569,7 +569,7 @@ def test_tune_for_djl_js_local_container( ) @patch( "sagemaker.serve.builder.djl_builder._serial_benchmark", - **{"return_value.raiseError.side_effect": LocalModelInvocationException("mock_exception")} + **{"return_value.raiseError.side_effect": LocalModelInvocationException("mock_exception")}, ) def test_tune_for_djl_js_local_container_invoke_ex( self, diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py index 1740176d61..c0eea1ed68 100644 --- a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -19,10 +19,11 @@ from sagemaker.enums import Tag from sagemaker.serve.utils.optimize_utils import ( _generate_optimized_model, - _is_speculation_enabled, _is_inferentia_or_trainium, - _is_compatible_with_optimization_job, _update_environment_variables, + _is_image_compatible_with_optimization_job, + _extract_speculative_draft_model_provider, + _validate_optimization_inputs, ) mock_optimization_job_output = { @@ -63,34 +64,30 @@ def test_is_inferentia_or_trainium(instance, expected): @pytest.mark.parametrize( - "instance, image_uri, expected", + "image_uri, expected", [ ( - "ml.g5.12xlarge", "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124", True, ), ( - "ml.trn1.2xlarge", "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-neuronx-sdk2.18.2", True, ), ( - "ml.inf2.xlarge", None, True, ), ( - "ml.c7gd.4xlarge", "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:" "2.1.1-tgi2.0.0-gpu-py310-cu121-ubuntu22.04", False, ), - (None, None, False), + (None, True), ], ) -def test_is_compatible_with_optimization_job(instance, image_uri, expected): - assert _is_compatible_with_optimization_job(instance, image_uri) == expected +def test_is_image_compatible_with_optimization_job(image_uri, expected): + assert _is_image_compatible_with_optimization_job(image_uri) == expected def test_generate_optimized_model(): @@ -114,39 +111,6 @@ def test_generate_optimized_model(): ) -@pytest.mark.parametrize( - "deployment_config, expected", - [ - ( - { - "AccelerationConfigs": [ - { - "type": "acceleration", - "enabled": True, - "spec": {"compiler": "a", "version": "1"}, - } - ], - }, - False, - ), - ( - { - "AccelerationConfigs": [ - { - "type": "speculation", - "enabled": True, - } - ], - }, - True, - ), - (None, False), - ], -) -def test_is_speculation_enabled(deployment_config, expected): - assert _is_speculation_enabled(deployment_config) is expected - - @pytest.mark.parametrize( "env, new_env, output_env", [ @@ -158,3 +122,46 @@ def test_is_speculation_enabled(deployment_config, expected): ) def test_update_environment_variables(env, new_env, output_env): assert _update_environment_variables(env, new_env) == output_env + + +@pytest.mark.parametrize( + "speculative_decoding_config, expected_model_provider", + [ + ({"ModelProvider": "SageMaker"}, "sagemaker"), + ({"ModelProvider": "Custom"}, "custom"), + ({"ModelSource": "s3://"}, "custom"), + (None, None), + ], +) +def test_extract_speculative_draft_model_provider( + speculative_decoding_config, expected_model_provider +): + assert ( + _extract_speculative_draft_model_provider(speculative_decoding_config) + == expected_model_provider + ) + + +@pytest.mark.parametrize( + "output_path, instance, quantization_config, compilation_config", + [ + ( + None, + None, + {"OverrideEnvironment": {"TENSOR_PARALLEL_DEGREE": 4}}, + {"OverrideEnvironment": {"TENSOR_PARALLEL_DEGREE": 4}}, + ), + (None, None, {"OverrideEnvironment": {"TENSOR_PARALLEL_DEGREE": 4}}, None), + (None, None, None, {"OverrideEnvironment": {"TENSOR_PARALLEL_DEGREE": 4}}), + ("output_path", None, None, {"OverrideEnvironment": {"TENSOR_PARALLEL_DEGREE": 4}}), + (None, "instance_type", None, {"OverrideEnvironment": {"TENSOR_PARALLEL_DEGREE": 4}}), + ], +) +def test_validate_optimization_inputs( + output_path, instance, quantization_config, compilation_config +): + + with pytest.raises(ValueError): + _validate_optimization_inputs( + output_path, instance, quantization_config, compilation_config + ) From 1f6f876c30318d82d898ebc1943e097332c11bf8 Mon Sep 17 00:00:00 2001 From: Jacky Lee Date: Tue, 18 Jun 2024 12:47:28 -0700 Subject: [PATCH 13/45] feat: add quicksilver telemetry (#1482) * feat: add quicksilver telemetry fields * pylint * add UTs * pylint * Refactor * add gated and fine-tuned to telemetry * fix: typo * fix: jumpstart var * refactor model_hub * pylint * update TEI/TGI to remove jumpstart field * reorder telemetry schema * refactor --------- Co-authored-by: Jacky Lee --- src/sagemaker/enums.py | 2 +- .../serve/builder/jumpstart_builder.py | 16 ++-- src/sagemaker/serve/builder/model_builder.py | 12 ++- src/sagemaker/serve/builder/tei_builder.py | 1 - src/sagemaker/serve/builder/tgi_builder.py | 1 - src/sagemaker/serve/utils/telemetry_logger.py | 39 ++++++++- src/sagemaker/serve/utils/types.py | 22 +++++ .../serve/utils/test_telemetry_logger.py | 86 +++++++++++++++++-- 8 files changed, 156 insertions(+), 23 deletions(-) diff --git a/src/sagemaker/enums.py b/src/sagemaker/enums.py index caa0d77175..f8c618620b 100644 --- a/src/sagemaker/enums.py +++ b/src/sagemaker/enums.py @@ -46,6 +46,6 @@ class Tag(str, Enum): """Enum class for tag keys to apply to models.""" OPTIMIZATION_JOB_NAME = "sagemaker-sdk:optimization-job-name" - SPECULATIVE_DRAFT_MODL_PROVIDER = "sagemaker-sdk:speculative-draft-model-provider" + SPECULATIVE_DRAFT_MODEL_PROVIDER = "sagemaker-sdk:speculative-draft-model-provider" FINE_TUNING_MODEL_PATH = "sagemaker-sdk:fine-tuning-model-path" FINE_TUNING_JOB_NAME = "sagemaker-sdk:fine-tuning-job-name" diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index bd1acec5e6..b32266df43 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -108,8 +108,10 @@ def __init__(self): self.schema_builder = None self.nb_instance_type = None self.ram_usage_model_load = None - self.jumpstart = None + self.model_hub = None self.model_metadata = None + self.is_fine_tuned = None + self.is_gated = None @abstractmethod def _prepare_for_mode(self): @@ -580,7 +582,6 @@ def _build_for_jumpstart(self): # we do not pickle for jumpstart. set to none self.secret_key = None - self.jumpstart = True pysdk_model = self._create_pre_trained_js_model() image_uri = pysdk_model.image_uri @@ -588,6 +589,7 @@ def _build_for_jumpstart(self): logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri) if self._is_fine_tuned_model(): + self.is_fine_tuned = True pysdk_model = self._update_model_data_for_fine_tuned_model(pysdk_model) if self._is_gated_model(pysdk_model) and self.mode != Mode.SAGEMAKER_ENDPOINT: @@ -754,8 +756,10 @@ def _is_gated_model(self, model=None) -> bool: s3_uri = s3_uri.get("S3DataSource").get("S3Uri") if s3_uri is None: - return False - return "private" in s3_uri + self.is_gated = False + else: + self.is_gated = "private" in s3_uri + return self.is_gated def _set_additional_model_source( self, @@ -792,7 +796,7 @@ def _set_additional_model_source( ) self.pysdk_model.add_tags( - {"key": Tag.SPECULATIVE_DRAFT_MODL_PROVIDER, "value": "sagemaker"}, + {"key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "value": "sagemaker"}, ) else: s3_uri = speculative_decoding_config.get("ModelSource") @@ -811,7 +815,7 @@ def _set_additional_model_source( self.pysdk_model.additional_model_data_sources = [additional_model_data_source] self.pysdk_model.add_tags( - {"key": Tag.SPECULATIVE_DRAFT_MODL_PROVIDER, "value": "customer"}, + {"key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "value": "customer"}, ) def _find_compatible_deployment_config( diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 8f03f3aa87..9d37782794 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -78,7 +78,7 @@ from sagemaker.serve.model_server.torchserve.prepare import prepare_for_torchserve from sagemaker.serve.model_server.triton.triton_builder import Triton from sagemaker.serve.utils.telemetry_logger import _capture_telemetry -from sagemaker.serve.utils.types import ModelServer +from sagemaker.serve.utils.types import ModelServer, ModelHub from sagemaker.serve.validations.check_image_uri import is_1p_image_uri from sagemaker.serve.save_retrive.version_1_0_0.save.save_handler import SaveHandler from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import get_metadata @@ -400,7 +400,7 @@ def _prepare_for_mode(self): self.serve_settings.s3_model_data_url, self.sagemaker_session, self.image_uri, - self.jumpstart if hasattr(self, "jumpstart") else False, + getattr(self, "model_hub", None) == ModelHub.JUMPSTART, ) self.env_vars.update(env_vars_sagemaker) return self.s3_upload_path, env_vars_sagemaker @@ -754,10 +754,14 @@ def build( # pylint: disable=R0911 if isinstance(self.model, str): model_task = None - if self.model_metadata: - model_task = self.model_metadata.get("HF_TASK") if self._is_jumpstart_model_id(): + self.model_hub = ModelHub.JUMPSTART return self._build_for_jumpstart() + self.model_hub = ModelHub.HUGGINGFACE + + if self.model_metadata: + model_task = self.model_metadata.get("HF_TASK") + if self._is_djl(): return self._build_for_djl() else: diff --git a/src/sagemaker/serve/builder/tei_builder.py b/src/sagemaker/serve/builder/tei_builder.py index 6aba3c9da2..f4e5e67891 100644 --- a/src/sagemaker/serve/builder/tei_builder.py +++ b/src/sagemaker/serve/builder/tei_builder.py @@ -63,7 +63,6 @@ def __init__(self): self.nb_instance_type = None self.ram_usage_model_load = None self.secret_key = None - self.jumpstart = None self.role_arn = None @abstractmethod diff --git a/src/sagemaker/serve/builder/tgi_builder.py b/src/sagemaker/serve/builder/tgi_builder.py index 23cc7e2202..9f8762c27e 100644 --- a/src/sagemaker/serve/builder/tgi_builder.py +++ b/src/sagemaker/serve/builder/tgi_builder.py @@ -90,7 +90,6 @@ def __init__(self): self.nb_instance_type = None self.ram_usage_model_load = None self.secret_key = None - self.jumpstart = None self.role_arn = None @abstractmethod diff --git a/src/sagemaker/serve/utils/telemetry_logger.py b/src/sagemaker/serve/utils/telemetry_logger.py index fe99e787a0..9a74f4b828 100644 --- a/src/sagemaker/serve/utils/telemetry_logger.py +++ b/src/sagemaker/serve/utils/telemetry_logger.py @@ -29,7 +29,12 @@ MLFLOW_REGISTRY_PATH, ) from sagemaker.serve.utils.lineage_utils import _get_mlflow_model_path_type -from sagemaker.serve.utils.types import ModelServer, ImageUriOption +from sagemaker.serve.utils.types import ( + ModelServer, + ImageUriOption, + ModelHub, + SpeculativeDecodingDraftModelSource, +) from sagemaker.serve.validations.check_image_uri import is_1p_image_uri from sagemaker.user_agent import SDK_VERSION @@ -69,6 +74,16 @@ MLFLOW_REGISTRY_PATH: 5, } +MODEL_HUB_TO_CODE = { + str(ModelHub.JUMPSTART): 1, + str(ModelHub.HUGGINGFACE): 2, +} + +SD_DRAFT_MODEL_SOURCE_TO_CODE = { + str(SpeculativeDecodingDraftModelSource.SAGEMAKER): 1, + str(SpeculativeDecodingDraftModelSource.CUSTOM): 2, +} + def _capture_telemetry(func_name: str): """Placeholder docstring""" @@ -108,6 +123,28 @@ def wrapper(self, *args, **kwargs): mlflow_model_path_type = _get_mlflow_model_path_type(mlflow_model_path) extra += f"&x-mlflowModelPathType={MLFLOW_MODEL_PATH_CODE[mlflow_model_path_type]}" + if getattr(self, "model_hub", False): + extra += f"&x-modelHub={MODEL_HUB_TO_CODE[str(self.model_hub)]}" + + if getattr(self, "is_fine_tuned", False): + extra += "&x-fineTuned=1" + if getattr(self, "is_gated", False): + extra += "&x-gated=1" + + if kwargs.get("compilation_config"): + extra += "&x-compiled=1" + if kwargs.get("quantization_config"): + extra += "&x-quantized=1" + if kwargs.get("speculative_decoding_config"): + model_provider = kwargs["speculative_decoding_config"]["ModelProvider"] + model_provider_enum = ( + SpeculativeDecodingDraftModelSource.SAGEMAKER + if model_provider.lower() == "sagemaker" + else SpeculativeDecodingDraftModelSource.CUSTOM + ) + model_provider_value = SD_DRAFT_MODEL_SOURCE_TO_CODE[str(model_provider_enum)] + extra += f"&x-sdDraftModelSource={model_provider_value}" + start_timer = perf_counter() try: response = func(self, *args, **kwargs) diff --git a/src/sagemaker/serve/utils/types.py b/src/sagemaker/serve/utils/types.py index 3ac80aa7ea..2e5e4f40d7 100644 --- a/src/sagemaker/serve/utils/types.py +++ b/src/sagemaker/serve/utils/types.py @@ -57,3 +57,25 @@ def __str__(self) -> str: CUSTOM_IMAGE = 1 CUSTOM_1P_IMAGE = 2 DEFAULT_IMAGE = 3 + + +class ModelHub(Enum): + """Enum type for model hub source""" + + def __str__(self) -> str: + """Convert enum to string""" + return str(self.name) + + JUMPSTART = 1 + HUGGINGFACE = 2 + + +class SpeculativeDecodingDraftModelSource(Enum): + """Enum type for speculative decoding draft model source""" + + def __str__(self) -> str: + """Convert enum to string""" + return str(self.name) + + SAGEMAKER = 1 + CUSTOM = 2 diff --git a/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py b/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py index 33af575e8f..563e0f8f20 100644 --- a/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py +++ b/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import unittest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock from sagemaker.serve import Mode, ModelServer from sagemaker.serve.model_format.mlflow.constants import MLFLOW_MODEL_PATH from sagemaker.serve.utils.telemetry_logger import ( @@ -25,7 +25,8 @@ from sagemaker.user_agent import SDK_VERSION MOCK_SESSION = Mock() -MOCK_FUNC_NAME = "Mock.deploy" +MOCK_DEPLOY_FUNC_NAME = "Mock.deploy" +MOCK_OPTIMIZE_FUNC_NAME = "Mock.optimize" MOCK_DJL_CONTAINER = ( "763104351884.dkr.ecr.us-west-2.amazonaws.com/" "djl-inference:0.25.0-deepspeed0.11.0-cu118" ) @@ -47,11 +48,15 @@ def __init__(self): self.serve_settings = Mock() self.sagemaker_session = MOCK_SESSION - @_capture_telemetry(MOCK_FUNC_NAME) + @_capture_telemetry(MOCK_DEPLOY_FUNC_NAME) def mock_deploy(self, mock_exception_func=None): if mock_exception_func: mock_exception_func() + @_capture_telemetry(MOCK_OPTIMIZE_FUNC_NAME) + def mock_optimize(self, *args, **kwargs): + pass + class TestTelemetryLogger(unittest.TestCase): @patch("sagemaker.serve.utils.telemetry_logger._requests_helper") @@ -88,7 +93,7 @@ def test_capture_telemetry_decorator_djl_success(self, mock_send_telemetry): args = mock_send_telemetry.call_args.args latency = str(args[5]).split("latency=")[1] expected_extra_str = ( - f"{MOCK_FUNC_NAME}" + f"{MOCK_DEPLOY_FUNC_NAME}" "&x-modelServer=4" "&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118" f"&x-sdkVersion={SDK_VERSION}" @@ -118,7 +123,7 @@ def test_capture_telemetry_decorator_djl_success_with_custom_image(self, mock_se args = mock_send_telemetry.call_args.args latency = str(args[5]).split("latency=")[1] expected_extra_str = ( - f"{MOCK_FUNC_NAME}" + f"{MOCK_DEPLOY_FUNC_NAME}" "&x-modelServer=4" "&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118" f"&x-sdkVersion={SDK_VERSION}" @@ -148,7 +153,7 @@ def test_capture_telemetry_decorator_tgi_success(self, mock_send_telemetry): args = mock_send_telemetry.call_args.args latency = str(args[5]).split("latency=")[1] expected_extra_str = ( - f"{MOCK_FUNC_NAME}" + f"{MOCK_DEPLOY_FUNC_NAME}" "&x-modelServer=6" "&x-imageTag=huggingface-pytorch-inference:2.0.0-transformers4.28.1-cpu-py310-ubuntu20.04" f"&x-sdkVersion={SDK_VERSION}" @@ -196,7 +201,7 @@ def test_capture_telemetry_decorator_handle_exception_success(self, mock_send_te args = mock_send_telemetry.call_args.args latency = str(args[5]).split("latency=")[1] expected_extra_str = ( - f"{MOCK_FUNC_NAME}" + f"{MOCK_DEPLOY_FUNC_NAME}" "&x-modelServer=4" "&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118" f"&x-sdkVersion={SDK_VERSION}" @@ -243,7 +248,7 @@ def test_construct_url_with_failure_reason_and_extra_info(self): f"&x-failureType={mock_failure_type}" f"&x-extra={mock_extra_info}" ) - self.assertEquals(ret_url, expected_base_url) + self.assertEqual(ret_url, expected_base_url) @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") def test_capture_telemetry_decorator_mlflow_success(self, mock_send_telemetry): @@ -262,7 +267,7 @@ def test_capture_telemetry_decorator_mlflow_success(self, mock_send_telemetry): args = mock_send_telemetry.call_args.args latency = str(args[5]).split("latency=")[1] expected_extra_str = ( - f"{MOCK_FUNC_NAME}" + f"{MOCK_DEPLOY_FUNC_NAME}" "&x-modelServer=1" "&x-imageTag=pytorch-inference:2.0.1-cpu-py310" f"&x-sdkVersion={SDK_VERSION}" @@ -275,3 +280,66 @@ def test_capture_telemetry_decorator_mlflow_success(self, mock_send_telemetry): mock_send_telemetry.assert_called_once_with( "1", 3, MOCK_SESSION, None, None, expected_extra_str ) + + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + def test_capture_telemetry_decorator_optimize_with_default_configs(self, mock_send_telemetry): + mock_model_builder = ModelBuilderMock() + mock_model_builder.serve_settings.telemetry_opt_out = False + mock_model_builder.image_uri = None + mock_model_builder.mode = Mode.SAGEMAKER_ENDPOINT + mock_model_builder.model_server = ModelServer.TORCHSERVE + mock_model_builder.sagemaker_session.endpoint_arn = None + + mock_model_builder.mock_optimize() + + args = mock_send_telemetry.call_args.args + latency = str(args[5]).split("latency=")[1] + expected_extra_str = ( + f"{MOCK_OPTIMIZE_FUNC_NAME}" + "&x-modelServer=1" + f"&x-sdkVersion={SDK_VERSION}" + f"&x-latency={latency}" + ) + + mock_send_telemetry.assert_called_once_with( + "1", 3, MOCK_SESSION, None, None, expected_extra_str + ) + + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + def test_capture_telemetry_decorator_optimize_with_custom_configs(self, mock_send_telemetry): + mock_model_builder = ModelBuilderMock() + mock_model_builder.serve_settings.telemetry_opt_out = False + mock_model_builder.image_uri = None + mock_model_builder.mode = Mode.SAGEMAKER_ENDPOINT + mock_model_builder.model_server = ModelServer.TORCHSERVE + mock_model_builder.sagemaker_session.endpoint_arn = None + mock_model_builder.is_fine_tuned = True + mock_model_builder.is_gated = True + + mock_speculative_decoding_config = MagicMock() + mock_config = {"ModelProvider": "sagemaker"} + mock_speculative_decoding_config.__getitem__.side_effect = mock_config.__getitem__ + + mock_model_builder.mock_optimize( + quantization_config=Mock(), + compilation_config=Mock(), + speculative_decoding_config=mock_speculative_decoding_config, + ) + + args = mock_send_telemetry.call_args.args + latency = str(args[5]).split("latency=")[1] + expected_extra_str = ( + f"{MOCK_OPTIMIZE_FUNC_NAME}" + "&x-modelServer=1" + f"&x-sdkVersion={SDK_VERSION}" + f"&x-fineTuned=1" + f"&x-gated=1" + f"&x-compiled=1" + f"&x-quantized=1" + f"&x-sdDraftModelSource=1" + f"&x-latency={latency}" + ) + + mock_send_telemetry.assert_called_once_with( + "1", 3, MOCK_SESSION, None, None, expected_extra_str + ) From b07f21056ad4b6436540ea3a4adc12e5330e49a4 Mon Sep 17 00:00:00 2001 From: Jacky Lee Date: Tue, 18 Jun 2024 12:47:51 -0700 Subject: [PATCH 14/45] unit: tests for fine tuned JS model support (#1481) * UTs * flake8 --------- Co-authored-by: Jacky Lee --- .../serve/builder/test_js_builder.py | 120 +++++++++++++++++- 1 file changed, 119 insertions(+), 1 deletion(-) diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index 48d12acb3e..09f89a1571 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -11,10 +11,12 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, Mock import unittest +from sagemaker.enums import Tag +from sagemaker.serve import SchemaBuilder from sagemaker.serve.builder.model_builder import ModelBuilder from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.utils.exceptions import ( @@ -961,3 +963,119 @@ def test_display_benchmark_metrics_initial( builder.display_benchmark_metrics() mock_pre_trained_model.return_value.display_benchmark_metrics.assert_called_once() + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + def test_fine_tuned_model_with_fine_tuning_model_path( + self, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + mock_pre_trained_model.return_value.image_uri = mock_djl_image_uri + mock_fine_tuning_model_path = "s3://test" + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species of turtle native to the brackish " + "coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a species of turtle native to the " + "brackish coastal tidal marshes of the east coast." + } + ] + builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + model_metadata={ + "FINE_TUNING_MODEL_PATH": mock_fine_tuning_model_path, + }, + ) + model = builder.build() + + model.model_data["S3DataSource"].__setitem__.assert_called_with( + "S3Uri", mock_fine_tuning_model_path + ) + mock_pre_trained_model.return_value.add_tags.assert_called_with( + {"key": Tag.FINE_TUNING_MODEL_PATH, "value": mock_fine_tuning_model_path} + ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + def test_fine_tuned_model_with_fine_tuning_job_name( + self, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_serve_settings, + mock_telemetry, + ): + mock_fine_tuning_model_path = "s3://test" + mock_sagemaker_session = Mock() + mock_sagemaker_session.sagemaker_client.describe_training_job.return_value = { + "OutputDataConfig": { + "S3OutputPath": mock_fine_tuning_model_path, + "CompressionType": "None", + } + } + mock_pre_trained_model.return_value.image_uri = mock_djl_image_uri + mock_fine_tuning_job_name = "mock-job" + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species of turtle native to the brackish " + "coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a species of turtle native to the " + "brackish coastal tidal marshes of the east coast." + } + ] + builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + model_metadata={"FINE_TUNING_JOB_NAME": mock_fine_tuning_job_name}, + sagemaker_session=mock_sagemaker_session, + ) + model = builder.build(sagemaker_session=mock_sagemaker_session) + + mock_sagemaker_session.sagemaker_client.describe_training_job.assert_called_once_with( + TrainingJobName=mock_fine_tuning_job_name + ) + + model.model_data["S3DataSource"].__setitem__.assert_any_call( + "S3Uri", mock_fine_tuning_model_path + ) + mock_pre_trained_model.return_value.add_tags.assert_called_with( + [ + {"key": Tag.FINE_TUNING_JOB_NAME, "value": mock_fine_tuning_job_name}, + {"key": Tag.FINE_TUNING_MODEL_PATH, "value": mock_fine_tuning_model_path}, + ] + ) From 262a5eb95a68f89452a7ace4344af7dcbfb68963 Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Thu, 20 Jun 2024 16:01:10 -0400 Subject: [PATCH 15/45] fix: use current session and role when setting config (#1493) * fix: use current session and role when setting config * format --- src/sagemaker/jumpstart/model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index ed7dbff2f1..df139e56b3 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -355,6 +355,7 @@ def _validate_model_id_and_type(): self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model self.region = model_init_kwargs.region self.sagemaker_session = model_init_kwargs.sagemaker_session + self.role = role self.config_name = model_init_kwargs.config_name self.additional_model_data_sources = model_init_kwargs.additional_model_data_sources @@ -446,6 +447,8 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None: model_version=self.model_version, instance_type=instance_type, config_name=config_name, + sagemaker_session=self.sagemaker_session, + role=self.role, ) @property From 99345d80b06f1ddad076cdbde4ba28a61267719c Mon Sep 17 00:00:00 2001 From: Jacky Lee Date: Thu, 20 Jun 2024 17:40:30 -0700 Subject: [PATCH 16/45] fix: training arn support (#1494) * fix: training job ARN * pylint --------- Co-authored-by: Jacky Lee --- src/sagemaker/serve/builder/jumpstart_builder.py | 13 +++++++++---- .../unit/sagemaker/serve/builder/test_js_builder.py | 5 ++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index b32266df43..9749d39f4e 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -546,6 +546,10 @@ def _update_model_data_for_fine_tuned_model(self, pysdk_model: Type[Model]) -> T pysdk_model.add_tags( {"key": Tag.FINE_TUNING_MODEL_PATH, "value": fine_tuning_model_path} ) + logger.info( + "FINE_TUNING_MODEL_PATH detected. Using fine-tuned model found in %s.", + fine_tuning_model_path, + ) return pysdk_model if fine_tuning_job_name := self.model_metadata.get("FINE_TUNING_JOB_NAME"): @@ -553,17 +557,18 @@ def _update_model_data_for_fine_tuned_model(self, pysdk_model: Type[Model]) -> T response = self.sagemaker_session.sagemaker_client.describe_training_job( TrainingJobName=fine_tuning_job_name ) - fine_tuning_model_path = response["OutputDataConfig"]["S3OutputPath"] + fine_tuning_model_path = response["ModelArtifacts"]["S3ModelArtifacts"] pysdk_model.model_data["S3DataSource"]["S3Uri"] = fine_tuning_model_path - pysdk_model.model_data["S3DataSource"]["CompressionType"] = response[ - "OutputDataConfig" - ]["CompressionType"] pysdk_model.add_tags( [ {"key": Tag.FINE_TUNING_JOB_NAME, "value": fine_tuning_job_name}, {"key": Tag.FINE_TUNING_MODEL_PATH, "value": fine_tuning_model_path}, ] ) + logger.info( + "FINE_TUNING_JOB_NAME detected. Using fine-tuned model found in %s.", + fine_tuning_model_path, + ) return pysdk_model except ClientError: raise ValueError( diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index 09f89a1571..d98cdde896 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -1039,9 +1039,8 @@ def test_fine_tuned_model_with_fine_tuning_job_name( mock_fine_tuning_model_path = "s3://test" mock_sagemaker_session = Mock() mock_sagemaker_session.sagemaker_client.describe_training_job.return_value = { - "OutputDataConfig": { - "S3OutputPath": mock_fine_tuning_model_path, - "CompressionType": "None", + "ModelArtifacts": { + "S3ModelArtifacts": mock_fine_tuning_model_path, } } mock_pre_trained_model.return_value.image_uri = mock_djl_image_uri From 9a3f6ca4b9f35939a0bb0a157a05352b47cb9dc3 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga <54963715+makungaj1@users.noreply.github.com> Date: Thu, 20 Jun 2024 17:41:43 -0700 Subject: [PATCH 17/45] Bug bash fixes (#1492) * HF Optimized * Revert "HF Optimized" * MB HF Optimize support * Refactoring * HF only s3 upload if optimize * reuse role if provided in MB * Refactoring * New requirements * Draft * Refactoring * Refactoring * Bug Bash fixes * UT * UT * Fix for parsing optimization output * Tag fix * UT * UT --------- Co-authored-by: Jonathan Makunga --- src/sagemaker/model.py | 6 +- src/sagemaker/serve/builder/djl_builder.py | 5 + .../serve/builder/jumpstart_builder.py | 72 +++++++------- src/sagemaker/serve/builder/model_builder.py | 36 +++++-- src/sagemaker/serve/builder/tei_builder.py | 9 +- .../serve/builder/tf_serving_builder.py | 4 + src/sagemaker/serve/builder/tgi_builder.py | 9 +- .../serve/builder/transformers_builder.py | 9 +- .../serve/mode/sagemaker_endpoint_mode.py | 9 +- .../model_server/triton/triton_builder.py | 4 + src/sagemaker/serve/utils/optimize_utils.py | 97 +++++++++++++++++-- src/sagemaker/session.py | 8 +- src/sagemaker/utils.py | 47 +++++++++ .../serve/builder/test_js_builder.py | 2 +- .../serve/builder/test_model_builder.py | 14 +-- .../serve/utils/test_optimize_utils.py | 78 +++++++++++---- tests/unit/test_utils.py | 29 ++++++ 17 files changed, 337 insertions(+), 101 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 7e23df0c41..abe4889174 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -67,6 +67,7 @@ format_tags, Tags, _resolve_routing_config, + _validate_new_tags, ) from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.predictor_async import AsyncPredictor @@ -412,10 +413,7 @@ def add_tags(self, tags: Tags) -> None: Args: tags (Tags): Tags to add. """ - if self._tags and tags: - self._tags.update(tags) - else: - self._tags = tags + self._tags = _validate_new_tags(tags, self._tags) @runnable_by_pipeline def register( diff --git a/src/sagemaker/serve/builder/djl_builder.py b/src/sagemaker/serve/builder/djl_builder.py index e89c1b8e9c..646b9fa611 100644 --- a/src/sagemaker/serve/builder/djl_builder.py +++ b/src/sagemaker/serve/builder/djl_builder.py @@ -100,6 +100,7 @@ def __init__(self): self.env_vars = None self.nb_instance_type = None self.ram_usage_model_load = None + self.role_arn = None @abstractmethod def _prepare_for_mode(self): @@ -499,4 +500,8 @@ def _build_for_djl(self): self.pysdk_model = self._build_for_hf_djl() self.pysdk_model.tune = self._tune_for_hf_djl + if self.role_arn: + self.pysdk_model.role = self.role_arn + if self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session return self.pysdk_model diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 9749d39f4e..89ad4f21ab 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -38,11 +38,13 @@ SkipTuningComboException, ) from sagemaker.serve.utils.optimize_utils import ( - _extract_model_source, + _generate_model_source, _update_environment_variables, _extract_speculative_draft_model_provider, _is_image_compatible_with_optimization_job, - _validate_optimization_inputs, + _extracts_and_validates_speculative_model_source, + _generate_channel_name, + _generate_additional_model_data_sources, ) from sagemaker.serve.utils.predictors import ( DjlLocalModePredictor, @@ -110,6 +112,7 @@ def __init__(self): self.ram_usage_model_load = None self.model_hub = None self.model_metadata = None + self.role_arn = None self.is_fine_tuned = None self.is_gated = None @@ -544,7 +547,7 @@ def _update_model_data_for_fine_tuned_model(self, pysdk_model: Type[Model]) -> T ) pysdk_model.model_data["S3DataSource"]["S3Uri"] = fine_tuning_model_path pysdk_model.add_tags( - {"key": Tag.FINE_TUNING_MODEL_PATH, "value": fine_tuning_model_path} + {"Key": Tag.FINE_TUNING_MODEL_PATH, "Value": fine_tuning_model_path} ) logger.info( "FINE_TUNING_MODEL_PATH detected. Using fine-tuned model found in %s.", @@ -633,6 +636,10 @@ def _build_for_jumpstart(self): "with djl-inference, tgi-inference, or mms-inference container." ) + if self.role_arn: + self.pysdk_model.role = self.role_arn + if self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session return self.pysdk_model def _optimize_for_jumpstart( @@ -650,7 +657,7 @@ def _optimize_for_jumpstart( vpc_config: Optional[Dict] = None, kms_key: Optional[str] = None, max_runtime_in_sec: Optional[int] = None, - ) -> Dict[str, Any]: + ) -> Optional[Dict[str, Any]]: """Runs a model optimization job. Args: @@ -685,13 +692,9 @@ def _optimize_for_jumpstart( f"Model '{self.model}' requires accepting end-user license agreement (EULA)." ) - _validate_optimization_inputs( - output_path, instance_type, quantization_config, compilation_config - ) - optimization_env_vars = None pysdk_model_env_vars = None - model_source = _extract_model_source(self.pysdk_model.model_data, accept_eula) + model_source = _generate_model_source(self.pysdk_model.model_data, accept_eula) if speculative_decoding_config: self._set_additional_model_source(speculative_decoding_config) @@ -745,8 +748,12 @@ def _optimize_for_jumpstart( if vpc_config: create_optimization_job_args["VpcConfig"] = vpc_config - self.pysdk_model.env.update(pysdk_model_env_vars) - return create_optimization_job_args + if pysdk_model_env_vars: + self.pysdk_model.env.update(pysdk_model_env_vars) + + if quantization_config or compilation_config: + return create_optimization_job_args + return None def _is_gated_model(self, model=None) -> bool: """Determine if ``this`` Model is Gated @@ -779,14 +786,13 @@ def _set_additional_model_source( """ if speculative_decoding_config: model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config) + channel_name = _generate_channel_name(self.pysdk_model.additional_model_data_sources) if model_provider.lower() == "sagemaker": - if ( - self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get( - "AdditionalDataSources" - ) - is None - ): + additional_model_data_sources = self.pysdk_model.deployment_config.get( + "DeploymentArgs", {} + ).get("AdditionalDataSources") + if additional_model_data_sources is None: deployment_config = self._find_compatible_deployment_config( speculative_decoding_config ) @@ -801,28 +807,26 @@ def _set_additional_model_source( ) self.pysdk_model.add_tags( - {"key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "value": "sagemaker"}, + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "sagemaker"}, ) else: - s3_uri = speculative_decoding_config.get("ModelSource") - if not s3_uri: - raise ValueError("Custom S3 Uri cannot be none.") - - # TODO: Set correct channel name. - additional_model_data_source = { - "ChannelName": "DraftModelName", - "S3DataSource": {"S3Uri": s3_uri}, - } - if accept_eula: - additional_model_data_source["S3DataSource"]["ModelAccessConfig"] = { - "ACCEPT_EULA": True - } - - self.pysdk_model.additional_model_data_sources = [additional_model_data_source] + s3_uri = _extracts_and_validates_speculative_model_source( + speculative_decoding_config + ) + + self.pysdk_model.additional_model_data_sources = ( + _generate_additional_model_data_sources(s3_uri, channel_name, accept_eula) + ) self.pysdk_model.add_tags( - {"key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "value": "customer"}, + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "customer"}, ) + speculative_draft_model = f"/opt/ml/additional-model-data-sources/{channel_name}" + self.pysdk_model.env = _update_environment_variables( + self.pysdk_model.env, + {"OPTION_SPECULATIVE_DRAFT_MODEL": speculative_draft_model}, + ) + def _find_compatible_deployment_config( self, speculative_decoding_config: Optional[Dict] = None ) -> Optional[Dict[str, Any]]: diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 9d37782794..17d1418eaa 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -63,7 +63,10 @@ from sagemaker.serve.utils import task from sagemaker.serve.utils.exceptions import TaskNotFoundException from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model -from sagemaker.serve.utils.optimize_utils import _generate_optimized_model +from sagemaker.serve.utils.optimize_utils import ( + _generate_optimized_model, + _validate_optimization_inputs, +) from sagemaker.serve.utils.predictors import _get_local_mode_predictor from sagemaker.serve.utils.hardware_detector import ( _get_gpu_info, @@ -87,7 +90,9 @@ ) from sagemaker.utils import Tags from sagemaker.workflow.entities import PipelineVariable -from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata +from sagemaker.huggingface.llm_utils import ( + get_huggingface_model_metadata, +) logger = logging.getLogger(__name__) @@ -383,7 +388,7 @@ def _get_serve_setting(self): sagemaker_session=self.sagemaker_session, ) - def _prepare_for_mode(self): + def _prepare_for_mode(self, should_upload_artifacts: bool = False): """Placeholder docstring""" # TODO: move mode specific prepare steps under _model_builder_deploy_wrapper self.s3_upload_path = None @@ -401,6 +406,7 @@ def _prepare_for_mode(self): self.sagemaker_session, self.image_uri, getattr(self, "model_hub", None) == ModelHub.JUMPSTART, + should_upload=should_upload_artifacts, ) self.env_vars.update(env_vars_sagemaker) return self.s3_upload_path, env_vars_sagemaker @@ -479,6 +485,10 @@ def _create_model(self): self.pysdk_model.mode = self.mode self.pysdk_model.modes = self.modes self.pysdk_model.serve_settings = self.serve_settings + if self.role_arn: + self.pysdk_model.role = self.role_arn + if self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session # dynamically generate a method to direct model.deploy() logic based on mode # unique method to models created via ModelBuilder() @@ -935,8 +945,9 @@ def optimize(self, *args, **kwargs) -> Model: """Runs a model optimization job. Args: - instance_type (str): Target deployment instance type that the model is optimized for. - output_path (str): Specifies where to store the compiled/quantized model. + instance_type (Optional[str]): Target deployment instance type that the + model is optimized for. + output_path (Optional[str]): Specifies where to store the compiled/quantized model. role (Optional[str]): Execution role. Defaults to ``None``. tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. @@ -964,7 +975,7 @@ def optimize(self, *args, **kwargs) -> Model: @_capture_telemetry("optimize") def _model_builder_optimize_wrapper( self, - output_path: str, + output_path: Optional[str] = None, instance_type: Optional[str] = None, role: Optional[str] = None, tags: Optional[Tags] = None, @@ -1010,11 +1021,15 @@ def _model_builder_optimize_wrapper( Returns: Model: A deployable ``Model`` object. """ + _validate_optimization_inputs( + output_path, instance_type, quantization_config, compilation_config + ) + self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() self.build(mode=self.mode, sagemaker_session=self.sagemaker_session) job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}" - input_args = {} + input_args = None if self._is_jumpstart_model_id(): input_args = self._optimize_for_jumpstart( output_path=output_path, @@ -1032,8 +1047,9 @@ def _model_builder_optimize_wrapper( max_runtime_in_sec=max_runtime_in_sec, ) - self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args) - job_status = self.sagemaker_session.wait_for_optimization_job(job_name) - self.pysdk_model = _generate_optimized_model(self.pysdk_model, job_status) + if input_args: + self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args) + job_status = self.sagemaker_session.wait_for_optimization_job(job_name) + self.pysdk_model = _generate_optimized_model(self.pysdk_model, job_status) return self.pysdk_model diff --git a/src/sagemaker/serve/builder/tei_builder.py b/src/sagemaker/serve/builder/tei_builder.py index f4e5e67891..79b0d276b7 100644 --- a/src/sagemaker/serve/builder/tei_builder.py +++ b/src/sagemaker/serve/builder/tei_builder.py @@ -162,10 +162,7 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa self.pysdk_model.role = kwargs.get("role") del kwargs["role"] - # set model_data to uncompressed s3 dict - self.pysdk_model.model_data, env_vars = self._prepare_for_mode() - self.env_vars.update(env_vars) - self.pysdk_model.env.update(self.env_vars) + self._prepare_for_mode() # if the weights have been cached via local container mode -> set to offline if str(Mode.LOCAL_CONTAINER) in self.modes: @@ -220,4 +217,8 @@ def _build_for_tei(self): self._set_to_tei() self.pysdk_model = self._build_for_hf_tei() + if self.role_arn: + self.pysdk_model.role = self.role_arn + if self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session return self.pysdk_model diff --git a/src/sagemaker/serve/builder/tf_serving_builder.py b/src/sagemaker/serve/builder/tf_serving_builder.py index 42c548f4e4..9b171b1d98 100644 --- a/src/sagemaker/serve/builder/tf_serving_builder.py +++ b/src/sagemaker/serve/builder/tf_serving_builder.py @@ -102,6 +102,10 @@ def _create_tensorflow_model(self): self.pysdk_model.mode = self.mode self.pysdk_model.modes = self.modes self.pysdk_model.serve_settings = self.serve_settings + if hasattr(self, "role_arn") and self.role_arn: + self.pysdk_model.role = self.role_arn + if hasattr(self, "sagemaker_session") and self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session self._original_deploy = self.pysdk_model.deploy self.pysdk_model.deploy = self._model_builder_deploy_wrapper diff --git a/src/sagemaker/serve/builder/tgi_builder.py b/src/sagemaker/serve/builder/tgi_builder.py index 9f8762c27e..13755b1a43 100644 --- a/src/sagemaker/serve/builder/tgi_builder.py +++ b/src/sagemaker/serve/builder/tgi_builder.py @@ -201,10 +201,7 @@ def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa self.pysdk_model.role = kwargs.get("role") del kwargs["role"] - # set model_data to uncompressed s3 dict - self.pysdk_model.model_data, env_vars = self._prepare_for_mode() - self.env_vars.update(env_vars) - self.pysdk_model.env.update(self.env_vars) + self._prepare_for_mode() # if the weights have been cached via local container mode -> set to offline if str(Mode.LOCAL_CONTAINER) in self.modes: @@ -472,4 +469,8 @@ def _build_for_tgi(self): self.pysdk_model = self._build_for_hf_tgi() self.pysdk_model.tune = self._tune_for_hf_tgi + if self.role_arn: + self.pysdk_model.role = self.role_arn + if self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session return self.pysdk_model diff --git a/src/sagemaker/serve/builder/transformers_builder.py b/src/sagemaker/serve/builder/transformers_builder.py index f84d8f868d..47ea8189b2 100644 --- a/src/sagemaker/serve/builder/transformers_builder.py +++ b/src/sagemaker/serve/builder/transformers_builder.py @@ -223,10 +223,7 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr self.pysdk_model.role = kwargs.get("role") del kwargs["role"] - # set model_data to uncompressed s3 dict - self.pysdk_model.model_data, env_vars = self._prepare_for_mode() - self.env_vars.update(env_vars) - self.pysdk_model.env.update(self.env_vars) + self._prepare_for_mode() if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True @@ -303,4 +300,8 @@ def _build_for_transformers(self): self._build_transformers_env() + if self.role_arn: + self.pysdk_model.role = self.role_arn + if self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session return self.pysdk_model diff --git a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py index b8f1d0529b..35f782b685 100644 --- a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py +++ b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py @@ -59,6 +59,7 @@ def prepare( sagemaker_session: Session = None, image: str = None, jumpstart: bool = False, + should_upload: bool = False, ): """Placeholder docstring""" try: @@ -96,7 +97,7 @@ def prepare( image=image, ) - if self.model_server == ModelServer.TGI: + if self.model_server == ModelServer.TGI and should_upload: upload_artifacts = self._upload_tgi_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, @@ -105,7 +106,7 @@ def prepare( jumpstart=jumpstart, ) - if self.model_server == ModelServer.MMS: + if self.model_server == ModelServer.MMS and should_upload: upload_artifacts = self._upload_server_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, @@ -122,7 +123,7 @@ def prepare( image=image, ) - if self.model_server == ModelServer.TEI: + if self.model_server == ModelServer.TEI and should_upload: upload_artifacts = self._tei_serving._upload_tei_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, @@ -130,7 +131,7 @@ def prepare( image=image, ) - if upload_artifacts: + if upload_artifacts or isinstance(self.model_server, ModelServer): return upload_artifacts raise ValueError("%s model server is not supported" % self.model_server) diff --git a/src/sagemaker/serve/model_server/triton/triton_builder.py b/src/sagemaker/serve/model_server/triton/triton_builder.py index ed0ec49204..a19235767f 100644 --- a/src/sagemaker/serve/model_server/triton/triton_builder.py +++ b/src/sagemaker/serve/model_server/triton/triton_builder.py @@ -428,6 +428,10 @@ def _create_triton_model(self) -> Type[Model]: self.pysdk_model.mode = self.mode self.pysdk_model.modes = self.modes self.pysdk_model.serve_settings = self.serve_settings + if hasattr(self, "role_arn") and self.role_arn: + self.pysdk_model.role = self.role_arn + if hasattr(self, "sagemaker_session") and self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session # dynamically generate a method to direct model.deploy() logic based on mode # unique method to models created via ModelBuilder() diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index 826699beee..2d7125b17d 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -15,7 +15,8 @@ import re import logging -from typing import Dict, Any, Optional, Union +import uuid +from typing import Dict, Any, Optional, Union, List from sagemaker import Model from sagemaker.enums import Tag @@ -66,18 +67,27 @@ def _generate_optimized_model(pysdk_model: Model, optimization_response: dict) - Returns: Model: A deployable optimized model. """ - pysdk_model.image_uri = optimization_response["RecommendedInferenceImage"] - pysdk_model.env = optimization_response["OptimizationEnvironment"] - pysdk_model.model_data["S3DataSource"]["S3Uri"] = optimization_response["ModelSource"]["S3"] - pysdk_model.instance_type = optimization_response["DeploymentInstanceType"] + recommended_image_uri = optimization_response["OptimizationOutput"]["RecommendedInferenceImage"] + optimized_environment = optimization_response["OptimizationEnvironment"] + s3_uri = optimization_response["ModelSource"]["S3"] + deployment_instance_type = optimization_response["DeploymentInstanceType"] + + if recommended_image_uri: + pysdk_model.image_uri = recommended_image_uri + if optimized_environment: + pysdk_model.env = optimized_environment + if s3_uri: + pysdk_model.model_data["S3DataSource"]["S3Uri"] = s3_uri + if deployment_instance_type: + pysdk_model.instance_type = deployment_instance_type + pysdk_model.add_tags( - {"key": Tag.OPTIMIZATION_JOB_NAME, "value": optimization_response["OptimizationJobName"]} + {"Key": Tag.OPTIMIZATION_JOB_NAME, "Value": optimization_response["OptimizationJobName"]} ) - return pysdk_model -def _extract_model_source( +def _generate_model_source( model_data: Optional[Union[Dict[str, Any], str]], accept_eula: Optional[bool] ) -> Optional[Dict[str, Any]]: """Extracts model source from model data. @@ -143,6 +153,27 @@ def _extract_speculative_draft_model_provider( return "sagemaker" +def _extracts_and_validates_speculative_model_source( + speculative_decoding_config: Dict, +) -> str: + """Extracts model source from speculative decoding config. + + Args: + speculative_decoding_config (Optional[Dict]): A speculative decoding config. + + Returns: + str: Model source. + + Raises: + ValueError: If model source is none. + """ + s3_uri: str = speculative_decoding_config.get("ModelSource") + + if not s3_uri: + raise ValueError("ModelSource must be provided in speculative decoding config.") + return s3_uri + + def _validate_optimization_inputs( output_path: Optional[str] = None, instance_type: Optional[str] = None, @@ -177,3 +208,53 @@ def _validate_optimization_inputs( raise ValueError(instance_type_msg.format("compilation")) if not output_path: raise ValueError(output_path_msg.format("compilation")) + + +def _generate_channel_name(additional_model_data_sources: Optional[List[Dict]]) -> str: + """Generates a channel name. + + Args: + additional_model_data_sources (Optional[List[Dict]]): The additional model data sources. + + Returns: + str: The channel name. + """ + channel_name = f"model-builder-channel-{uuid.uuid4().hex}" + if additional_model_data_sources and len(additional_model_data_sources) > 0: + channel_name = additional_model_data_sources[0].get("ChannelName", channel_name) + + return channel_name + + +def _generate_additional_model_data_sources( + model_source: str, + channel_name: str, + accept_eula: bool = False, + s3_data_type: Optional[str] = "S3Prefix", + compression_type: Optional[str] = "None", +) -> List[Dict]: + """Generates additional model data sources. + + Args: + model_source (Optional[str]): The model source. + channel_name (Optional[str]): The channel name. + accept_eula (Optional[bool]): Whether to accept eula or not. + s3_data_type (Optional[str]): The S3 data type, defaults to 'S3Prefix'. + compression_type (Optional[str]): The compression type, defaults to None. + + Returns: + List[Dict]: The additional model data sources. + """ + + additional_model_data_source = { + "ChannelName": channel_name, + "S3DataSource": { + "S3Uri": model_source, + "S3DataType": s3_data_type, + "CompressionType": compression_type, + }, + } + if accept_eula: + additional_model_data_source["S3DataSource"]["ModelAccessConfig"] = {"ACCEPT_EULA": True} + + return [additional_model_data_source] diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 596608be8a..6593751b58 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -7630,12 +7630,12 @@ def _auto_ml_job_status(sagemaker_client, job_name): def _optimization_job_status(sagemaker_client, job_name): """Placeholder docstring""" optimization_job_status_codes = { - "INPROGRESS": "!", - "COMPLETED": ".", + "INPROGRESS": ".", + "COMPLETED": "!", "FAILED": "*", - "STARTING": "s", + "STARTING": ".", "STOPPING": "_", - "STOPPED": ",", + "STOPPED": "s", } in_progress_statuses = ["INPROGRESS", "STARTING", "STOPPING"] diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 045a214759..d20e72fc1f 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1826,3 +1826,50 @@ def convert_value(value): result[convert_key(key)] = convert_value(value) return result + + +def tag_exists(tag: TagsDict, curr_tags: Optional[Tags]) -> bool: + """Returns True if ``tag`` already exists. + + Args: + tag (TagsDict): The tag dictionary. + curr_tags (Optional[Tags]): The current tags. + + Returns: + bool: True if the tag exists. + """ + if curr_tags is None: + return False + + for curr_tag in curr_tags: + if tag["Key"] == curr_tag["Key"]: + return True + + return False + + +def _validate_new_tags(new_tags: Optional[Tags], curr_tags: Optional[Tags]) -> Optional[Tags]: + """Validates new tags against existing tags. + + Args: + new_tags (Optional[Tags]): The new tags. + curr_tags (Optional[Tags]): The current tags. + + Returns: + Optional[Tags]: The updated tags. + """ + if curr_tags is None: + return new_tags + + if curr_tags and isinstance(curr_tags, dict): + curr_tags = [curr_tags] + + if isinstance(new_tags, dict): + if not tag_exists(new_tags, curr_tags): + curr_tags.append(new_tags) + elif isinstance(new_tags, list): + for new_tag in new_tags: + if not tag_exists(new_tag, curr_tags): + curr_tags.append(new_tag) + + return curr_tags diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index d98cdde896..6c2e03b683 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -1011,7 +1011,7 @@ def test_fine_tuned_model_with_fine_tuning_model_path( "S3Uri", mock_fine_tuning_model_path ) mock_pre_trained_model.return_value.add_tags.assert_called_with( - {"key": Tag.FINE_TUNING_MODEL_PATH, "value": mock_fine_tuning_model_path} + {"Key": Tag.FINE_TUNING_MODEL_PATH, "Value": mock_fine_tuning_model_path} ) @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index b70f855486..3671d2382e 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -181,7 +181,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc( mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -284,7 +284,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc( mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -391,7 +391,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec( if inference_spec == mock_inference_spec and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -487,7 +487,7 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model( mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -591,7 +591,7 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model( mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -817,7 +817,7 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo if inference_spec == mock_inference_spec and model_server == ModelServer.TORCHSERVE else None ) - mock_sagemaker_endpoint_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_sagemaker_endpoint_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -939,7 +939,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py index c0eea1ed68..c6d753c4fc 100644 --- a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -24,30 +24,56 @@ _is_image_compatible_with_optimization_job, _extract_speculative_draft_model_provider, _validate_optimization_inputs, + _extracts_and_validates_speculative_model_source, ) mock_optimization_job_output = { - "OptimizationJobName": "optimization_job_name", - "RecommendedInferenceImage": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-tgi-inference:2.1.1-tgi2.0.0-gpu-py310-cu121-ubuntu22.04", + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:312206380606:" + "optimization-job/modelbuilderjob-6b09ffebeb0741b8a28b85623fd9c968", "OptimizationJobStatus": "COMPLETED", + "OptimizationJobName": "modelbuilderjob-6b09ffebeb0741b8a28b85623fd9c968", + "ModelSource": { + "S3": { + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/" + "meta-textgeneration-llama-3-8b/artifacts/inference-prepack/v1.0.1/" + } + }, "OptimizationEnvironment": { - "SAGEMAKER_PROGRAM": "inference.py", "ENDPOINT_SERVER_TIMEOUT": "3600", + "HF_MODEL_ID": "/opt/ml/model", "MODEL_CACHE_ROOT": "/opt/ml/model", "SAGEMAKER_ENV": "1", - "HF_MODEL_ID": "/opt/ml/model", - "MAX_INPUT_LENGTH": "4095", - "MAX_TOTAL_TOKENS": "4096", - "MAX_BATCH_PREFILL_TOKENS": "8192", - "MAX_CONCURRENT_REQUESTS": "512", "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "SAGEMAKER_PROGRAM": "inference.py", }, - "ModelSource": { - "S3": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/" - "meta-textgeneration-llama-3-8b/artifacts/inference-prepack/v2.0.0/" + "DeploymentInstanceType": "ml.g5.48xlarge", + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124", + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + } + } + ], + "OutputConfig": { + "S3OutputLocation": "s3://dont-delete-ss-jarvis-integ-test-312206380606-us-west-2/" + }, + "OptimizationOutput": { + "RecommendedInferenceImage": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124" + }, + "RoleArn": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20230707T131628", + "StoppingCondition": {"MaxRuntimeInSeconds": 36000}, + "ResponseMetadata": { + "RequestId": "17ae151f-b51d-4194-8ba9-edbba068c90b", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": "17ae151f-b51d-4194-8ba9-edbba068c90b", + "content-type": "application/x-amz-json-1.1", + "content-length": "1380", + "date": "Thu, 20 Jun 2024 19:25:53 GMT", + }, + "RetryAttempts": 0, }, - "DeploymentInstanceType": "ml.m5.xlarge", } @@ -92,11 +118,19 @@ def test_is_image_compatible_with_optimization_job(image_uri, expected): def test_generate_optimized_model(): pysdk_model = Mock() - pysdk_model.model_data = {"S3DataSource": {"S3Uri": "s3://foo/bar"}} + pysdk_model.model_data = { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/" + "meta-textgeneration-llama-3-8b/artifacts/inference-prepack/v1.0.1/" + } + } optimized_model = _generate_optimized_model(pysdk_model, mock_optimization_job_output) - assert optimized_model.image_uri == mock_optimization_job_output["RecommendedInferenceImage"] + assert ( + optimized_model.image_uri + == mock_optimization_job_output["OptimizationOutput"]["RecommendedInferenceImage"] + ) assert optimized_model.env == mock_optimization_job_output["OptimizationEnvironment"] assert ( optimized_model.model_data["S3DataSource"]["S3Uri"] @@ -105,8 +139,8 @@ def test_generate_optimized_model(): assert optimized_model.instance_type == mock_optimization_job_output["DeploymentInstanceType"] pysdk_model.add_tags.assert_called_once_with( { - "key": Tag.OPTIMIZATION_JOB_NAME, - "value": mock_optimization_job_output["OptimizationJobName"], + "Key": Tag.OPTIMIZATION_JOB_NAME, + "Value": mock_optimization_job_output["OptimizationJobName"], } ) @@ -165,3 +199,13 @@ def test_validate_optimization_inputs( _validate_optimization_inputs( output_path, instance, quantization_config, compilation_config ) + + +def test_extract_speculative_draft_model_s3_uri(): + res = _extracts_and_validates_speculative_model_source({"ModelSource": "s3://"}) + assert res == "s3://" + + +def test_extract_speculative_draft_model_s3_uri_ex(): + with pytest.raises(ValueError): + _extracts_and_validates_speculative_model_source({"ModelSource": None}) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 731333d8ba..63263a7920 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -55,6 +55,8 @@ get_instance_rate_per_hour, extract_instance_rate_per_hour, _resolve_routing_config, + tag_exists, + _validate_new_tags, ) from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger @@ -2095,3 +2097,30 @@ def camel_case_to_pascal_case_nested(self): def test_empty_input(self): self.assertEqual(camel_case_to_pascal_case({}), {}) + + +class TestTags(TestCase): + def test_tag_exists(self): + curr_tags = [{"Key": "project", "Value": "my-project"}] + self.assertTrue(tag_exists({"Key": "project", "Value": "my-project"}, curr_tags=curr_tags)) + + def test_does_not_tag_exists(self): + curr_tags = [{"Key": "project", "Value": "my-project"}] + self.assertFalse( + tag_exists({"Key": "project-2", "Value": "my-project-2"}, curr_tags=curr_tags) + ) + + def test_add_tags(self): + curr_tags = [{"Key": "project", "Value": "my-project"}] + new_tag = {"Key": "project-2", "Value": "my-project-2"} + expected = [ + {"Key": "project", "Value": "my-project"}, + {"Key": "project-2", "Value": "my-project-2"}, + ] + + self.assertEqual(_validate_new_tags(new_tag, curr_tags), expected) + + def test_new_add_tags(self): + new_tag = {"Key": "project-2", "Value": "my-project-2"} + + self.assertEqual(_validate_new_tags(new_tag, None), new_tag) From 114a716f77d876a9bd47c0e0528182b0a83306b6 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga <54963715+makungaj1@users.noreply.github.com> Date: Mon, 24 Jun 2024 14:26:50 -0700 Subject: [PATCH 18/45] Bug fixes (#1496) * Bug fixes * refcatore * ENV update * Remove code duplication * Fix Integ tests * Fix MB EULA bug --------- Co-authored-by: Jonathan Makunga --- .../serve/builder/jumpstart_builder.py | 34 +++--- src/sagemaker/serve/builder/model_builder.py | 9 +- .../serve/mode/sagemaker_endpoint_mode.py | 2 +- src/sagemaker/serve/utils/optimize_utils.py | 39 +++---- .../serve/utils/test_optimize_utils.py | 104 +++++++++++++----- 5 files changed, 127 insertions(+), 61 deletions(-) diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 89ad4f21ab..105d55c4ed 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -45,6 +45,7 @@ _extracts_and_validates_speculative_model_source, _generate_channel_name, _generate_additional_model_data_sources, + _is_s3_uri, ) from sagemaker.serve.utils.predictors import ( DjlLocalModePredictor, @@ -750,6 +751,8 @@ def _optimize_for_jumpstart( if pysdk_model_env_vars: self.pysdk_model.env.update(pysdk_model_env_vars) + if accept_eula: + self.pysdk_model.accept_eula = accept_eula if quantization_config or compilation_config: return create_optimization_job_args @@ -787,8 +790,9 @@ def _set_additional_model_source( if speculative_decoding_config: model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config) channel_name = _generate_channel_name(self.pysdk_model.additional_model_data_sources) + speculative_draft_model = f"/opt/ml/additional-model-data-sources/{channel_name}" - if model_provider.lower() == "sagemaker": + if model_provider == "sagemaker": additional_model_data_sources = self.pysdk_model.deployment_config.get( "DeploymentArgs", {} ).get("AdditionalDataSources") @@ -805,27 +809,31 @@ def _set_additional_model_source( raise ValueError( "Cannot find deployment config compatible for optimization job." ) - - self.pysdk_model.add_tags( - {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "sagemaker"}, - ) else: - s3_uri = _extracts_and_validates_speculative_model_source( + model_source = _extracts_and_validates_speculative_model_source( speculative_decoding_config ) - self.pysdk_model.additional_model_data_sources = ( - _generate_additional_model_data_sources(s3_uri, channel_name, accept_eula) - ) - self.pysdk_model.add_tags( - {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "customer"}, - ) + if _is_s3_uri(model_source): + self.pysdk_model.additional_model_data_sources = ( + _generate_additional_model_data_sources( + model_source, channel_name, accept_eula + ) + ) + else: + speculative_draft_model = model_source - speculative_draft_model = f"/opt/ml/additional-model-data-sources/{channel_name}" self.pysdk_model.env = _update_environment_variables( self.pysdk_model.env, {"OPTION_SPECULATIVE_DRAFT_MODEL": speculative_draft_model}, ) + self.pysdk_model.add_tags( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": model_provider}, + ) + if accept_eula and isinstance(self.pysdk_model.model_data, dict): + self.pysdk_model.model_data["S3DataSource"]["ModelAccessConfig"] = { + "AcceptEula": True + } def _find_compatible_deployment_config( self, speculative_decoding_config: Optional[Dict] = None diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 17d1418eaa..036310561f 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -408,7 +408,8 @@ def _prepare_for_mode(self, should_upload_artifacts: bool = False): getattr(self, "model_hub", None) == ModelHub.JUMPSTART, should_upload=should_upload_artifacts, ) - self.env_vars.update(env_vars_sagemaker) + if env_vars_sagemaker: + self.env_vars.update(env_vars_sagemaker) return self.s3_upload_path, env_vars_sagemaker if self.mode == Mode.LOCAL_CONTAINER: # init the LocalContainerMode object @@ -1026,6 +1027,12 @@ def _model_builder_optimize_wrapper( ) self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() + + if instance_type: + self.instance_type = instance_type + if role: + self.role = role + self.build(mode=self.mode, sagemaker_session=self.sagemaker_session) job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}" diff --git a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py index 35f782b685..d0022ae74c 100644 --- a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py +++ b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py @@ -70,7 +70,7 @@ def prepare( + "session to be created or supply `sagemaker_session` into @serve.invoke." ) from e - upload_artifacts = None + upload_artifacts = None, None if self.model_server == ModelServer.TORCHSERVE: upload_artifacts = self._upload_torchserve_artifacts( model_path=model_path, diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index 2d7125b17d..ea7d6d3cb4 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -25,23 +25,6 @@ logger = logging.getLogger(__name__) -def _is_inferentia_or_trainium(instance_type: Optional[str]) -> bool: - """Checks whether an instance is compatible with Inferentia. - - Args: - instance_type (str): The instance type used for the compilation job. - - Returns: - bool: Whether the given instance type is Inferentia or Trainium. - """ - if isinstance(instance_type, str): - match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) - if match: - if match[1].startswith("inf") or match[1].startswith("trn"): - return True - return False - - def _is_image_compatible_with_optimization_job(image_uri: Optional[str]) -> bool: """Checks whether an instance is compatible with an optimization job. @@ -69,13 +52,16 @@ def _generate_optimized_model(pysdk_model: Model, optimization_response: dict) - """ recommended_image_uri = optimization_response["OptimizationOutput"]["RecommendedInferenceImage"] optimized_environment = optimization_response["OptimizationEnvironment"] - s3_uri = optimization_response["ModelSource"]["S3"] + s3_uri = optimization_response["OutputConfig"]["S3OutputLocation"] deployment_instance_type = optimization_response["DeploymentInstanceType"] if recommended_image_uri: pysdk_model.image_uri = recommended_image_uri if optimized_environment: - pysdk_model.env = optimized_environment + if pysdk_model.env: + pysdk_model.env.update(optimized_environment) + else: + pysdk_model.env = optimized_environment if s3_uri: pysdk_model.model_data["S3DataSource"]["S3Uri"] = s3_uri if deployment_instance_type: @@ -258,3 +244,18 @@ def _generate_additional_model_data_sources( additional_model_data_source["S3DataSource"]["ModelAccessConfig"] = {"ACCEPT_EULA": True} return [additional_model_data_source] + + +def _is_s3_uri(s3_uri: Optional[str]) -> bool: + """Checks whether an S3 URI is valid. + + Args: + s3_uri (Optional[str]): The S3 URI. + + Returns: + bool: Whether the S3 URI is valid. + """ + if s3_uri is None: + return False + + return re.match("^s3://([^/]+)/?(.*)$", s3_uri) is not None diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py index c6d753c4fc..2e0a2914d8 100644 --- a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -19,19 +19,21 @@ from sagemaker.enums import Tag from sagemaker.serve.utils.optimize_utils import ( _generate_optimized_model, - _is_inferentia_or_trainium, _update_environment_variables, _is_image_compatible_with_optimization_job, _extract_speculative_draft_model_provider, _validate_optimization_inputs, _extracts_and_validates_speculative_model_source, + _is_s3_uri, + _generate_additional_model_data_sources, + _generate_channel_name, ) mock_optimization_job_output = { - "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:312206380606:" - "optimization-job/modelbuilderjob-6b09ffebeb0741b8a28b85623fd9c968", + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:312206380606:optimization-job/" + "modelbuilderjob-3cbf9c40b63c455d85b60033f9a01691", "OptimizationJobStatus": "COMPLETED", - "OptimizationJobName": "modelbuilderjob-6b09ffebeb0741b8a28b85623fd9c968", + "OptimizationJobName": "modelbuilderjob-3cbf9c40b63c455d85b60033f9a01691", "ModelSource": { "S3": { "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/" @@ -46,7 +48,7 @@ "SAGEMAKER_MODEL_SERVER_WORKERS": "1", "SAGEMAKER_PROGRAM": "inference.py", }, - "DeploymentInstanceType": "ml.g5.48xlarge", + "DeploymentInstanceType": "ml.g5.2xlarge", "OptimizationConfigs": [ { "ModelQuantizationConfig": { @@ -55,40 +57,26 @@ } } ], - "OutputConfig": { - "S3OutputLocation": "s3://dont-delete-ss-jarvis-integ-test-312206380606-us-west-2/" - }, + "OutputConfig": {"S3OutputLocation": "s3://quicksilver-model-data/llama-3-8b/quantized-1/"}, "OptimizationOutput": { "RecommendedInferenceImage": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124" }, - "RoleArn": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20230707T131628", + "RoleArn": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20240116T151132", "StoppingCondition": {"MaxRuntimeInSeconds": 36000}, "ResponseMetadata": { - "RequestId": "17ae151f-b51d-4194-8ba9-edbba068c90b", + "RequestId": "a95253d5-c045-4708-8aac-9f0d327515f7", "HTTPStatusCode": 200, "HTTPHeaders": { - "x-amzn-requestid": "17ae151f-b51d-4194-8ba9-edbba068c90b", + "x-amzn-requestid": "a95253d5-c045-4708-8aac-9f0d327515f7", "content-type": "application/x-amz-json-1.1", - "content-length": "1380", - "date": "Thu, 20 Jun 2024 19:25:53 GMT", + "content-length": "1371", + "date": "Fri, 21 Jun 2024 04:27:42 GMT", }, "RetryAttempts": 0, }, } -@pytest.mark.parametrize( - "instance, expected", - [ - ("ml.trn1.2xlarge", True), - ("ml.inf2.xlarge", True), - ("ml.c7gd.4xlarge", False), - ], -) -def test_is_inferentia_or_trainium(instance, expected): - assert _is_inferentia_or_trainium(instance) == expected - - @pytest.mark.parametrize( "image_uri, expected", [ @@ -124,6 +112,7 @@ def test_generate_optimized_model(): "meta-textgeneration-llama-3-8b/artifacts/inference-prepack/v1.0.1/" } } + pysdk_model.env = {"OPTION_QUANTIZE": "awq"} optimized_model = _generate_optimized_model(pysdk_model, mock_optimization_job_output) @@ -131,10 +120,13 @@ def test_generate_optimized_model(): optimized_model.image_uri == mock_optimization_job_output["OptimizationOutput"]["RecommendedInferenceImage"] ) - assert optimized_model.env == mock_optimization_job_output["OptimizationEnvironment"] + assert optimized_model.env == { + "OPTION_QUANTIZE": "awq", + **mock_optimization_job_output["OptimizationEnvironment"], + } assert ( optimized_model.model_data["S3DataSource"]["S3Uri"] - == mock_optimization_job_output["ModelSource"]["S3"] + == mock_optimization_job_output["OutputConfig"]["S3OutputLocation"] ) assert optimized_model.instance_type == mock_optimization_job_output["DeploymentInstanceType"] pysdk_model.add_tags.assert_called_once_with( @@ -209,3 +201,61 @@ def test_extract_speculative_draft_model_s3_uri(): def test_extract_speculative_draft_model_s3_uri_ex(): with pytest.raises(ValueError): _extracts_and_validates_speculative_model_source({"ModelSource": None}) + + +def test_generate_channel_name(): + assert _generate_channel_name(None) is not None + + additional_model_data_sources = _generate_additional_model_data_sources( + "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", "channel_name", True + ) + + assert _generate_channel_name(additional_model_data_sources) == "channel_name" + + +def test_generate_additional_model_data_sources(): + model_source = _generate_additional_model_data_sources( + "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", "channel_name", True + ) + + assert model_source == [ + { + "ChannelName": "channel_name", + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + "ModelAccessConfig": {"ACCEPT_EULA": True}, + }, + } + ] + + model_source = _generate_additional_model_data_sources( + "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", "channel_name", False + ) + + assert model_source == [ + { + "ChannelName": "channel_name", + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + }, + } + ] + + +@pytest.mark.parametrize( + "s3_uri, expected", + [ + ( + "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/" + "meta-textgeneration-llama-3-8b/artifacts/inference-prepack/v1.0.1/", + True, + ), + ("invalid://", False), + ], +) +def test_is_s3_uri(s3_uri, expected): + assert _is_s3_uri(s3_uri) == expected From 0ac601425fad07b76b7f7d4c11c7b86e60506cd4 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Mon, 24 Jun 2024 14:45:37 -0700 Subject: [PATCH 19/45] JS Optimize api ref --- .../serve/builder/jumpstart_builder.py | 18 ++++++++++-------- src/sagemaker/serve/builder/model_builder.py | 12 ++++++------ 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 105d55c4ed..71b1abb217 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -109,7 +109,7 @@ def __init__(self): self.prepared_for_djl = None self.prepared_for_mms = None self.schema_builder = None - self.nb_instance_type = None + self.instance_type = None self.ram_usage_model_load = None self.model_hub = None self.model_metadata = None @@ -138,7 +138,9 @@ def _is_jumpstart_model_id(self) -> bool: def _create_pre_trained_js_model(self) -> Type[Model]: """Placeholder docstring""" - pysdk_model = JumpStartModel(self.model, vpc_config=self.vpc_config) + pysdk_model = JumpStartModel( + self.model, vpc_config=self.vpc_config, instance_type=self.instance_type + ) pysdk_model.sagemaker_session = self.sagemaker_session self._original_deploy = pysdk_model.deploy @@ -234,8 +236,8 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True - if hasattr(self, "nb_instance_type"): - kwargs.update({"instance_type": self.nb_instance_type}) + if self.instance_type: + kwargs.update({"instance_type": self.instance_type}) if "mode" in kwargs: del kwargs["mode"] @@ -268,7 +270,7 @@ def _build_for_djl_jumpstart(self): ) self._prepare_for_mode() elif self.mode == Mode.SAGEMAKER_ENDPOINT and hasattr(self, "prepared_for_djl"): - self.nb_instance_type = _get_nb_instance() + self.instance_type = self.instance_type or _get_nb_instance() self.pysdk_model.model_data, env = self._prepare_for_mode() self.pysdk_model.env.update(env) @@ -647,7 +649,7 @@ def _optimize_for_jumpstart( self, output_path: Optional[str] = None, instance_type: Optional[str] = None, - role: Optional[str] = None, + role_arn: Optional[str] = None, tags: Optional[Tags] = None, job_name: Optional[str] = None, accept_eula: Optional[bool] = None, @@ -665,7 +667,7 @@ def _optimize_for_jumpstart( output_path (Optional[str]): Specifies where to store the compiled/quantized model. instance_type (Optional[str]): Target deployment instance type that the model is optimized for. - role (Optional[str]): Execution role. Defaults to ``None``. + role_arn (Optional[str]): Execution role. Defaults to ``None``. tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. accept_eula (bool): For models that require a Model Access Config, specify True or @@ -735,7 +737,7 @@ def _optimize_for_jumpstart( "DeploymentInstanceType": instance_type, "OptimizationConfigs": [optimization_config], "OutputConfig": output_config, - "RoleArn": role, + "RoleArn": role_arn, } if optimization_env_vars: diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 036310561f..829ffff717 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -949,7 +949,7 @@ def optimize(self, *args, **kwargs) -> Model: instance_type (Optional[str]): Target deployment instance type that the model is optimized for. output_path (Optional[str]): Specifies where to store the compiled/quantized model. - role (Optional[str]): Execution role. Defaults to ``None``. + role_arn (Optional[str]): Execution role. Defaults to ``None``. tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``. @@ -978,7 +978,7 @@ def _model_builder_optimize_wrapper( self, output_path: Optional[str] = None, instance_type: Optional[str] = None, - role: Optional[str] = None, + role_arn: Optional[str] = None, tags: Optional[Tags] = None, job_name: Optional[str] = None, accept_eula: Optional[bool] = None, @@ -996,7 +996,7 @@ def _model_builder_optimize_wrapper( Args: output_path (str): Specifies where to store the compiled/quantized model. instance_type (str): Target deployment instance type that the model is optimized for. - role (Optional[str]): Execution role. Defaults to ``None``. + role_arn (Optional[str]): Execution role arn. Defaults to ``None``. tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. accept_eula (bool): For models that require a Model Access Config, specify True or @@ -1030,8 +1030,8 @@ def _model_builder_optimize_wrapper( if instance_type: self.instance_type = instance_type - if role: - self.role = role + if role_arn: + self.role_arn = role_arn self.build(mode=self.mode, sagemaker_session=self.sagemaker_session) job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}" @@ -1041,7 +1041,7 @@ def _model_builder_optimize_wrapper( input_args = self._optimize_for_jumpstart( output_path=output_path, instance_type=instance_type, - role=role if role else self.role_arn, + role_arn=self.role_arn, tags=tags, job_name=job_name, accept_eula=accept_eula, From 80fb96ac43f3c642410e13eed21714ed737c2c4c Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Tue, 25 Jun 2024 09:27:07 -0700 Subject: [PATCH 20/45] Refactoring --- src/sagemaker/jumpstart/types.py | 4 +++ .../serve/builder/jumpstart_builder.py | 29 +++++++++------ src/sagemaker/serve/builder/model_builder.py | 8 ++--- src/sagemaker/serve/utils/optimize_utils.py | 36 ------------------- .../serve/utils/test_optimize_utils.py | 26 -------------- 5 files changed, 26 insertions(+), 77 deletions(-) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 2561dbc237..7e8ce1a239 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -2569,6 +2569,8 @@ class DeploymentArgs(BaseDeploymentConfigDataHolder): "model_data_download_timeout", "container_startup_health_check_timeout", "additional_data_sources", + "neuron_model_id", + "neuron_model_version", ] def __init__( @@ -2599,6 +2601,8 @@ def __init__( "supported_inference_instance_types" ) self.additional_data_sources = resolved_config.get("hosting_additional_data_sources") + self.neuron_model_id = resolved_config.get("hosting_neuron_model_id") + self.neuron_model_version = resolved_config.get("hosting_neuron_model_version") class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder): diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 71b1abb217..7472256901 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -110,6 +110,7 @@ def __init__(self): self.prepared_for_mms = None self.schema_builder = None self.instance_type = None + self.nb_instance_type = None self.ram_usage_model_load = None self.model_hub = None self.model_metadata = None @@ -236,8 +237,8 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True - if self.instance_type: - kwargs.update({"instance_type": self.instance_type}) + if hasattr(self, "nb_instance_type"): + kwargs.update({"instance_type": self.nb_instance_type}) if "mode" in kwargs: del kwargs["mode"] @@ -270,7 +271,7 @@ def _build_for_djl_jumpstart(self): ) self._prepare_for_mode() elif self.mode == Mode.SAGEMAKER_ENDPOINT and hasattr(self, "prepared_for_djl"): - self.instance_type = self.instance_type or _get_nb_instance() + self.nb_instance_type = self.instance_type or _get_nb_instance() self.pysdk_model.model_data, env = self._prepare_for_mode() self.pysdk_model.env.update(env) @@ -695,25 +696,29 @@ def _optimize_for_jumpstart( f"Model '{self.model}' requires accepting end-user license agreement (EULA)." ) - optimization_env_vars = None - pysdk_model_env_vars = None - model_source = _generate_model_source(self.pysdk_model.model_data, accept_eula) + if compilation_config: + neuro_model_id = self.pysdk_model.deployment_config.get("DeploymentArgs").get( + "NeuronModelId" + ) + self.model = neuro_model_id + self.pysdk_model = self._create_pre_trained_js_model() if speculative_decoding_config: self._set_additional_model_source(speculative_decoding_config) - optimization_env_vars = self.pysdk_model.deployment_config.get( - "DeploymentArgs", {} - ).get("Environment") else: deployment_config = self._find_compatible_deployment_config(None) if deployment_config: - optimization_env_vars = deployment_config.get("DeploymentArgs").get("Environment") self.pysdk_model.set_deployment_config( config_name=deployment_config.get("DeploymentConfigName"), instance_type=deployment_config.get("InstanceType"), ) + model_source = _generate_model_source(self.pysdk_model.model_data, accept_eula) + optimization_env_vars = self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get( + "Environment" + ) optimization_env_vars = _update_environment_variables(optimization_env_vars, env_vars) + pysdk_model_env_vars = env_vars optimization_config = {} if quantization_config: @@ -730,6 +735,10 @@ def _optimize_for_jumpstart( output_config = {"S3OutputLocation": output_path} if kms_key: output_config["KmsKeyId"] = kms_key + if not instance_type: + instance_type = self.pysdk_model.deployment_config.get("DeploymentArgs").get( + "InstanceType" + ) create_optimization_job_args = { "OptimizationJobName": job_name, diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 829ffff717..d9fc6909d8 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -65,7 +65,6 @@ from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model from sagemaker.serve.utils.optimize_utils import ( _generate_optimized_model, - _validate_optimization_inputs, ) from sagemaker.serve.utils.predictors import _get_local_mode_predictor from sagemaker.serve.utils.hardware_detector import ( @@ -238,7 +237,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, metadata={"help": "Define the s3 location where you want to upload the model package"}, ) instance_type: Optional[str] = field( - default="ml.c5.xlarge", + default=None, metadata={"help": "Define the instance_type of the endpoint"}, ) schema_builder: Optional[SchemaBuilder] = field( @@ -1022,9 +1021,8 @@ def _model_builder_optimize_wrapper( Returns: Model: A deployable ``Model`` object. """ - _validate_optimization_inputs( - output_path, instance_type, quantization_config, compilation_config - ) + if quantization_config and compilation_config: + raise ValueError("Quantization config and compilation config are mutually exclusive.") self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index ea7d6d3cb4..13438e467f 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -160,42 +160,6 @@ def _extracts_and_validates_speculative_model_source( return s3_uri -def _validate_optimization_inputs( - output_path: Optional[str] = None, - instance_type: Optional[str] = None, - quantization_config: Optional[Dict] = None, - compilation_config: Optional[Dict] = None, -) -> None: - """Validates optimization inputs. - - Args: - output_path (Optional[str]): The output path. - instance_type (Optional[str]): The instance type. - quantization_config (Optional[Dict]): The quantization config. - compilation_config (Optional[Dict]): The compilation config. - - Raises: - ValueError: If an optimization input is invalid. - """ - if quantization_config and compilation_config: - raise ValueError("Quantization config and compilation config are mutually exclusive.") - - instance_type_msg = "Please provide an instance type for %s optimization job." - output_path_msg = "Please provide an output path for %s optimization job." - - if quantization_config: - if not instance_type: - raise ValueError(instance_type_msg.format("quantization")) - if not output_path: - raise ValueError(output_path_msg.format("quantization")) - - if compilation_config: - if not instance_type: - raise ValueError(instance_type_msg.format("compilation")) - if not output_path: - raise ValueError(output_path_msg.format("compilation")) - - def _generate_channel_name(additional_model_data_sources: Optional[List[Dict]]) -> str: """Generates a channel name. diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py index 2e0a2914d8..f0e18186b7 100644 --- a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -22,7 +22,6 @@ _update_environment_variables, _is_image_compatible_with_optimization_job, _extract_speculative_draft_model_provider, - _validate_optimization_inputs, _extracts_and_validates_speculative_model_source, _is_s3_uri, _generate_additional_model_data_sources, @@ -168,31 +167,6 @@ def test_extract_speculative_draft_model_provider( ) -@pytest.mark.parametrize( - "output_path, instance, quantization_config, compilation_config", - [ - ( - None, - None, - {"OverrideEnvironment": {"TENSOR_PARALLEL_DEGREE": 4}}, - {"OverrideEnvironment": {"TENSOR_PARALLEL_DEGREE": 4}}, - ), - (None, None, {"OverrideEnvironment": {"TENSOR_PARALLEL_DEGREE": 4}}, None), - (None, None, None, {"OverrideEnvironment": {"TENSOR_PARALLEL_DEGREE": 4}}), - ("output_path", None, None, {"OverrideEnvironment": {"TENSOR_PARALLEL_DEGREE": 4}}), - (None, "instance_type", None, {"OverrideEnvironment": {"TENSOR_PARALLEL_DEGREE": 4}}), - ], -) -def test_validate_optimization_inputs( - output_path, instance, quantization_config, compilation_config -): - - with pytest.raises(ValueError): - _validate_optimization_inputs( - output_path, instance, quantization_config, compilation_config - ) - - def test_extract_speculative_draft_model_s3_uri(): res = _extracts_and_validates_speculative_model_source({"ModelSource": "s3://"}) assert res == "s3://" From c41a7cad8f4132c8a7feca0e9be274b7d7d7c47f Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Tue, 25 Jun 2024 12:51:02 -0700 Subject: [PATCH 21/45] Refactoring --- src/sagemaker/jumpstart/types.py | 10 +++-- .../serve/builder/jumpstart_builder.py | 42 +++++++++++++------ src/sagemaker/serve/builder/model_builder.py | 4 +- 3 files changed, 37 insertions(+), 19 deletions(-) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 7e8ce1a239..65cd4d274c 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1098,6 +1098,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "gated_bucket", "model_subscription_link", "hosting_additional_data_sources", + "hosting_neuron_model_id", + "hosting_neuron_model_version", ] def __init__(self, fields: Dict[str, Any]): @@ -1208,6 +1210,10 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if json_obj.get("hosting_additional_data_sources") else None ) + self.hosting_neuron_model_id: Optional[str] = json_obj.get("hosting_neuron_model_id") + self.hosting_neuron_model_version: Optional[str] = json_obj.get( + "hosting_neuron_model_version" + ) if self.training_supported: self.training_ecr_specs: Optional[JumpStartECRSpecs] = ( @@ -2569,8 +2575,6 @@ class DeploymentArgs(BaseDeploymentConfigDataHolder): "model_data_download_timeout", "container_startup_health_check_timeout", "additional_data_sources", - "neuron_model_id", - "neuron_model_version", ] def __init__( @@ -2601,8 +2605,6 @@ def __init__( "supported_inference_instance_types" ) self.additional_data_sources = resolved_config.get("hosting_additional_data_sources") - self.neuron_model_id = resolved_config.get("hosting_neuron_model_id") - self.neuron_model_version = resolved_config.get("hosting_neuron_model_version") class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder): diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 7472256901..bd27570964 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -139,9 +139,7 @@ def _is_jumpstart_model_id(self) -> bool: def _create_pre_trained_js_model(self) -> Type[Model]: """Placeholder docstring""" - pysdk_model = JumpStartModel( - self.model, vpc_config=self.vpc_config, instance_type=self.instance_type - ) + pysdk_model = JumpStartModel(self.model, vpc_config=self.vpc_config) pysdk_model.sagemaker_session = self.sagemaker_session self._original_deploy = pysdk_model.deploy @@ -696,12 +694,12 @@ def _optimize_for_jumpstart( f"Model '{self.model}' requires accepting end-user license agreement (EULA)." ) + optimization_env_vars = env_vars + pysdk_model_env_vars = env_vars + if compilation_config: - neuro_model_id = self.pysdk_model.deployment_config.get("DeploymentArgs").get( - "NeuronModelId" - ) - self.model = neuro_model_id - self.pysdk_model = self._create_pre_trained_js_model() + neuron_env = self._get_neuron_model_env_vars(instance_type) + optimization_env_vars = _update_environment_variables(neuron_env, optimization_env_vars) if speculative_decoding_config: self._set_additional_model_source(speculative_decoding_config) @@ -714,11 +712,6 @@ def _optimize_for_jumpstart( ) model_source = _generate_model_source(self.pysdk_model.model_data, accept_eula) - optimization_env_vars = self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get( - "Environment" - ) - optimization_env_vars = _update_environment_variables(optimization_env_vars, env_vars) - pysdk_model_env_vars = env_vars optimization_config = {} if quantization_config: @@ -874,3 +867,26 @@ def _find_compatible_deployment_config( # fall back to the default jumpstart model deployment config for optimization job return self.pysdk_model.deployment_config + + def _get_neuron_model_env_vars( + self, instance_type: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + """Gets Neuron model env vars. + + Args: + instance_type (Optional[str]): Instance type. + + Returns: + Optional[Dict[str, Any]]: Neuron Model environment variables. + """ + metadata_config = self.pysdk_model._metadata_configs.get(self.pysdk_model.config_name) + resolve_config = metadata_config.resolved_config + if instance_type not in resolve_config.get("supported_inference_instance_types", []): + neuro_model_id = resolve_config.get("hosting_neuron_model_id") + neuro_model_version = resolve_config.get("hosting_neuron_model_version") + if neuro_model_id: + job_model = JumpStartModel( + neuro_model_id, model_version=neuro_model_version, vpc_config=self.vpc_config + ) + return job_model.env + return None diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index d9fc6909d8..164bfe894c 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -237,7 +237,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, metadata={"help": "Define the s3 location where you want to upload the model package"}, ) instance_type: Optional[str] = field( - default=None, + default="ml.c5.xlarge", metadata={"help": "Define the instance_type of the endpoint"}, ) schema_builder: Optional[SchemaBuilder] = field( @@ -1055,6 +1055,6 @@ def _model_builder_optimize_wrapper( if input_args: self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args) job_status = self.sagemaker_session.wait_for_optimization_job(job_name) - self.pysdk_model = _generate_optimized_model(self.pysdk_model, job_status) + return _generate_optimized_model(self.pysdk_model, job_status) return self.pysdk_model From f2062a7ee7d9a932780fcd8cb4522846ed789ac4 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Tue, 25 Jun 2024 15:37:01 -0700 Subject: [PATCH 22/45] Fix issues --- src/sagemaker/serve/utils/optimize_utils.py | 3 +-- tests/unit/sagemaker/jumpstart/constants.py | 2 ++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index 13438e467f..959f54d19f 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -15,7 +15,6 @@ import re import logging -import uuid from typing import Dict, Any, Optional, Union, List from sagemaker import Model @@ -169,7 +168,7 @@ def _generate_channel_name(additional_model_data_sources: Optional[List[Dict]]) Returns: str: The channel name. """ - channel_name = f"model-builder-channel-{uuid.uuid4().hex}" + channel_name = "model-builder-channel" if additional_model_data_sources and len(additional_model_data_sources) > 0: channel_name = additional_model_data_sources[0].get("ChannelName", channel_name) diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 734857945a..eb2598b357 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -7516,6 +7516,8 @@ "inference_config_rankings": None, "training_config_rankings": None, "hosting_additional_data_sources": None, + "hosting_neuron_model_id": None, + "hosting_neuron_model_version": None, } BASE_HOSTING_ADDITIONAL_DATA_SOURCES = { From d3f4274e7c50fcf2827634ae2cd201fd8496ae4a Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Tue, 25 Jun 2024 15:48:39 -0700 Subject: [PATCH 23/45] Channel name --- src/sagemaker/serve/utils/optimize_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index 959f54d19f..c66ec5a991 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -168,7 +168,7 @@ def _generate_channel_name(additional_model_data_sources: Optional[List[Dict]]) Returns: str: The channel name. """ - channel_name = "model-builder-channel" + channel_name = "draft-model" if additional_model_data_sources and len(additional_model_data_sources) > 0: channel_name = additional_model_data_sources[0].get("ChannelName", channel_name) From 3e93e95aa1259b3c5dd1c37f678148c3bbaf8bfc Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Tue, 25 Jun 2024 17:21:03 -0700 Subject: [PATCH 24/45] Channel name --- src/sagemaker/serve/utils/optimize_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index c66ec5a991..ccf6756f3c 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -168,7 +168,7 @@ def _generate_channel_name(additional_model_data_sources: Optional[List[Dict]]) Returns: str: The channel name. """ - channel_name = "draft-model" + channel_name = "draft_model" if additional_model_data_sources and len(additional_model_data_sources) > 0: channel_name = additional_model_data_sources[0].get("ChannelName", channel_name) From 271a8621e492a68b9d09e905e1ca9085b6ac63e8 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Tue, 25 Jun 2024 18:36:04 -0700 Subject: [PATCH 25/45] Optimization output --- src/sagemaker/serve/utils/optimize_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index ccf6756f3c..e4313a4321 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -49,10 +49,12 @@ def _generate_optimized_model(pysdk_model: Model, optimization_response: dict) - Returns: Model: A deployable optimized model. """ - recommended_image_uri = optimization_response["OptimizationOutput"]["RecommendedInferenceImage"] - optimized_environment = optimization_response["OptimizationEnvironment"] - s3_uri = optimization_response["OutputConfig"]["S3OutputLocation"] - deployment_instance_type = optimization_response["DeploymentInstanceType"] + recommended_image_uri = optimization_response.get("OptimizationOutput", {}).get( + "RecommendedInferenceImage" + ) + optimized_environment = optimization_response.get("OptimizationEnvironment") + s3_uri = optimization_response.get("OutputConfig", {}).get("S3OutputLocation") + deployment_instance_type = optimization_response.get("DeploymentInstanceType") if recommended_image_uri: pysdk_model.image_uri = recommended_image_uri From e3995b0783937bac0a5d9d08f6264634cf3632f8 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Wed, 26 Jun 2024 09:20:28 -0700 Subject: [PATCH 26/45] neuron model env --- .../serve/builder/jumpstart_builder.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index bd27570964..e051e4340d 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -879,14 +879,20 @@ def _get_neuron_model_env_vars( Returns: Optional[Dict[str, Any]]: Neuron Model environment variables. """ - metadata_config = self.pysdk_model._metadata_configs.get(self.pysdk_model.config_name) - resolve_config = metadata_config.resolved_config - if instance_type not in resolve_config.get("supported_inference_instance_types", []): - neuro_model_id = resolve_config.get("hosting_neuron_model_id") - neuro_model_version = resolve_config.get("hosting_neuron_model_version") - if neuro_model_id: - job_model = JumpStartModel( - neuro_model_id, model_version=neuro_model_version, vpc_config=self.vpc_config - ) - return job_model.env + metadata_configs = self.pysdk_model._metadata_configs + if metadata_configs: + metadata_config = metadata_configs.get(self.pysdk_model.config_name) + resolve_config = metadata_config.resolved_config if metadata_config else None + if resolve_config and instance_type not in resolve_config.get( + "supported_inference_instance_types", [] + ): + neuro_model_id = resolve_config.get("hosting_neuron_model_id") + neuro_model_version = resolve_config.get("hosting_neuron_model_version", "*") + if neuro_model_id: + job_model = JumpStartModel( + neuro_model_id, + model_version=neuro_model_version, + vpc_config=self.vpc_config, + ) + return job_model.env return None From 31d70e6a7f940e576ab6092e68f735350be21b8a Mon Sep 17 00:00:00 2001 From: Adam Kozdrowicz Date: Mon, 1 Jul 2024 13:21:43 -0400 Subject: [PATCH 27/45] Merge master into master-benchmark-feature (#1502) * prepare release v2.222.0 * update development version to v2.222.1.dev0 * fix: estimator.deploy not respecting instance type (#4724) * fix: estimator.deploy not respecting instance type * chore: add inline comment about using user supplied instance type * First changes (#4723) Co-authored-by: Bryannah Hernandez * prepare release v2.222.1 * update development version to v2.222.2.dev0 * change: update image_uri_configs 06-12-2024 07:17:03 PST * fix: Fix ci unit-tests (#4728) * Implement custom telemetry logging in SDK (#4721) * Fix Sniping bug fix (#4730) * Python SDK bucket sniping fix bug * Python SDK bucket sniping fix bug * Minor fixes to default bucket function and fixing unit tests * fix - Fixes from Pylint failures * fix - Fixes from Flake8 failures * fix - More Flake8 fixes * fix - Remove Whitespace from blankline * fix - Fix black recommendations * fix - Adjust tabbing --------- Co-authored-by: Jiao Liu Co-authored-by: liujiaor <128006184+liujiaorr@users.noreply.github.com> * feature: add 'ModelCard' property to Register step (#4726) * feature: add 'ModelCard' property to RegisterModel step * Updated ModelCard content type * fix: ModelCard Object integ Test fix --------- Co-authored-by: Gokul A <166456257+nargokul@users.noreply.github.com> * prepare release v2.223.0 * update development version to v2.223.1.dev0 * Fix Dependabot Issues - MLFlow Version (#4731) * fix - Address Dependapot issues * fix -Update MLFLOW Version * Fix: AttributeError: 'NoneType' object has no attribute 'len' error in session.py (#4735) * fix - Address Dependapot issues * fix -Update MLFLOW Version * fix - Update fetching Length for NoneType Error * change: Enable telemetry logging for Remote function (#4729) * change: Enhance telemetry logging module and feature coverage * Fix default session issue * fix unit-tests * chore: use ml.g5.2xlarge for integ test (#4741) * feat: JumpStartModel attach (#4680) * feat: JumpStartModel attach * fix: unit tests * chore: change order of kwargs to pass unit tests * chore: update docstrings, add tests * fix: docstring * fix: integ tests * chore: address PR comments --------- Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> * Upgrading to PT 2.3 for release (#4732) * upgrading to PT 2.3 for release * reverting mistake in modifying dataparallel --------- Co-authored-by: Andrew Tian Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> * feat(sagemaker-mlflow): New features for SageMaker MLflow (#4744) * feat: add support for mlflow inputs (#1441) * feat: add support for mlflow inputs * fix: typo * fix: doc * fix: S3 regex * fix: refactor * fix: refactor typo * fix: pylint * fix: pylint * fix: black and pylint --------- Co-authored-by: Jacky Lee * fix: lineage tracking bug (#1447) * fix: lineage bug * fix: lineage * fix: add validation for tracking ARN input with MLflow input type * fix: bug * fix: unit tests * fix: mock * fix: args --------- Co-authored-by: Jacky Lee * [Fix] regex for RunId to handle empty artifact path and change mlflow plugin name (#1455) * [Fix] run id regex pattern such that empty artifact path is handled * Change mlflow plugin name as per legal team requirement * Update describe_mlflow_tracking_server call to align with api changes (#1466) * feat: (sagemaker-mlflow) Adding Presigned Url function to SDK (#1462) (#1477) * mlflow presigned url changes * addressing design feedback * test changes * change: mlflow plugin name (#1489) Co-authored-by: Jacky Lee --------- Co-authored-by: Jacky Lee Co-authored-by: Jacky Lee Co-authored-by: jiapinw <95885824+jiapinw@users.noreply.github.com> Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> * prepare release v2.224.0 * update development version to v2.224.1.dev0 * fix: Model server override logic (#4733) * fix: Model server override logic * Fix formatting --------- Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> * chore(deps): bump apache-airflow from 2.9.1 to 2.9.2 in /requirements/extras (#4740) Bumps [apache-airflow](https://github.com/apache/airflow) from 2.9.1 to 2.9.2. - [Release notes](https://github.com/apache/airflow/releases) - [Changelog](https://github.com/apache/airflow/blob/main/RELEASE_NOTES.rst) - [Commits](https://github.com/apache/airflow/compare/2.9.1...2.9.2) --- updated-dependencies: - dependency-name: apache-airflow dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> * fix: Update tox.ini (#4747) * change: Update README.rst to show conda-forge version of SageMaker SDK (#4749) * JumpStart CuratedHub Launch (#4748) * Implement CuratedHub APIs (#1449) * Implement CuratedHub Admin APIs * making some parameters optional in create_hub_content_reference as per the API design * add describe_hub and list_hubs APIs * implement delete_hub API * Implement list_hub_contents API * create CuratedHub class and supported utils * implement list_models and address comments * Add unit tests * add describe_model function * cache retrieval for describeHubContent changes * fix curated hub class unit tests * add utils needed for curatedHub * Cache retrieval * implement get_hub_model_reference() * cleanup HUB type datatype * cleanup constants * rename list_public_models to list_jumpstart_service_hub_models * implement describe_model_reference * Rename CuratedHub to Hub * address nit * address nits and fix failing tests --------- Co-authored-by: Malav Shastri * feat: implement list_jumpstart_service_hub_models function to fetch JumpStart public hub models (#1456) * Implement CuratedHub Admin APIs * making some parameters optional in create_hub_content_reference as per the API design * add describe_hub and list_hubs APIs * implement delete_hub API * Implement list_hub_contents API * create CuratedHub class and supported utils * implement list_models and address comments * Add unit tests * add describe_model function * cache retrieval for describeHubContent changes * fix curated hub class unit tests * add utils needed for curatedHub * Cache retrieval * implement get_hub_model_reference() * cleanup HUB type datatype * cleanup constants * rename list_public_models to list_jumpstart_service_hub_models * implement describe_model_reference * Rename CuratedHub to Hub * address nit * address nits and fix failing tests * implement list_jumpstart_service_hub_models function --------- Co-authored-by: Malav Shastri * Feat/Curated Hub hub_arn and hub_content_type support (#1453) * get_model_spec() changes to support hub_arn and hub_content_type * implement get_hub_model_reference() * support hub_arn and hub_content_type for specs retrieval * add support for hub_arn and hub_content_type for serializers, deserializers, estimators, models, predictors and various spec retrieval functionalities * address nits and test failures * remove hub_content_type support --------- Co-authored-by: Malav Shastri * feat: implement curated hub parser and bug bash fixes (#1457) * implement HubContentDocument parser * modify the parser to remove aliases for hubcontent documents * bug fix * update boto3 * Bug Fix in the parser * Improve Hub Class and related functionalities * Bug Fix and parser updates * add missing hub_arn support * Add model reference deployment support and other minor bug fixes * fix: retrieve correct image_uri (parser update) * fix: retrieve correct model URI and model data path from HubContentDocument (parser update) * Add model reference deployment support * Model accessor and cache retrival bug fixes * fix: curated hub model training workflow * fix: pass sagemaker sessions object to retrieve model specs from describe_hub_content call * fix: fix payload retrieval for curated hub models * modify constants, enums * fix: update parser * Address nits in the parser * Add unit tests for parser * implement pagination for list_models utility * feat: support wildcard chars for model versions * Address nits and comments * Add Hub Content Arn Tag to training and hosting * Add Hub Content Arn Tag to training and hosting * fix: HubContentDocument schema version * fix broken unit tests * fix prepare_container_def unit tests to include ModelReferenceArn * fix unit tests for test_session.py * revert boto version changes * Fix unit tests * support wildcard model versions for training workflow * Add test cases for get_model_versions * Add/fix unit tests --------- Co-authored-by: Malav Shastri * address unit tests failures in codebuild * change list_jumpstart_service_hub_models to list_sagemaker_public_hub_models() * fix: Changing list input output shapes * fix: gated model training bug * run black -l 100 * flake 8 * address formatting issues * black -l * DocStyle issues * address flake8, pylint * blake -l * pass model type down * disabling pylint for release * disable pylint --------- Co-authored-by: Malav Shastri Co-authored-by: chrstfu Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> * prepare release v2.224.1 * update development version to v2.224.2.dev0 * fix: list_models() for python3.8 (#4756) * fix: list_models() for python3.8 * fix linting --------- Co-authored-by: Malav Shastri * Update DJLModel class for latest container releases (#4754) * simplify and refactor djl model for latest container releases * update model builder for new DJLModel implementation * fix formatting/linting suggestions * update DJLModel documentation on docs site * address reviewer feedback * Feature: Update model card on model package request (#4739) * Feature: Update model card on model package request * Feature: Update model card on model package request * fix: update_model_card input types * Feature: register proprietary models from jumpstart (#4753) * Feature: register proprietary models from jumpstart Feature: register proprietary models from jumpstart * fix: register jumpstart models on model registry * fixed get_model_id_version_from_endpoint naming issue * fixed issues with model builder * cleanup types.py file * fixed jumpstart unit tests * fixed issue in model_builder --------- Signed-off-by: dependabot[bot] Co-authored-by: ci Co-authored-by: evakravi <69981223+evakravi@users.noreply.github.com> Co-authored-by: bryannahm1 <110491182+bryannahm1@users.noreply.github.com> Co-authored-by: Bryannah Hernandez Co-authored-by: sagemaker-bot Co-authored-by: Kalyani Nikure <110067132+knikure@users.noreply.github.com> Co-authored-by: Gokul A <166456257+nargokul@users.noreply.github.com> Co-authored-by: Jiao Liu Co-authored-by: liujiaor <128006184+liujiaorr@users.noreply.github.com> Co-authored-by: selvask-aws Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> Co-authored-by: adtian2 <55163384+adtian2@users.noreply.github.com> Co-authored-by: Andrew Tian Co-authored-by: ananth102 Co-authored-by: Jacky Lee Co-authored-by: Jacky Lee Co-authored-by: jiapinw <95885824+jiapinw@users.noreply.github.com> Co-authored-by: Samrudhi Sharma <154457034+samruds@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Malav Shastri <57682969+malav-shastri@users.noreply.github.com> Co-authored-by: Malav Shastri Co-authored-by: chrstfu Co-authored-by: Siddharth Venkatesan --- CHANGELOG.md | 66 + README.rst | 11 + VERSION | 2 +- .../djl/sagemaker.djl_inference.rst | 28 +- doc/frameworks/djl/using_djl.rst | 234 +- requirements/extras/test_requirements.txt | 3 +- src/sagemaker/accept_types.py | 28 +- src/sagemaker/chainer/model.py | 10 + src/sagemaker/content_types.py | 28 +- src/sagemaker/deserializers.py | 28 +- src/sagemaker/djl_inference/__init__.py | 5 +- src/sagemaker/djl_inference/defaults.py | 59 - src/sagemaker/djl_inference/djl_predictor.py | 58 + src/sagemaker/djl_inference/model.py | 1241 +--- src/sagemaker/environment_variables.py | 16 +- src/sagemaker/estimator.py | 4 + src/sagemaker/huggingface/model.py | 10 + src/sagemaker/hyperparameters.py | 8 + .../image_uri_config/pytorch-smp.json | 28 +- src/sagemaker/image_uri_config/pytorch.json | 44 +- src/sagemaker/image_uris.py | 5 + src/sagemaker/instance_types.py | 20 +- src/sagemaker/jumpstart/accessors.py | 35 +- .../artifacts/environment_variables.py | 12 + .../jumpstart/artifacts/hyperparameters.py | 4 + .../jumpstart/artifacts/image_uris.py | 13 + .../artifacts/incremental_training.py | 4 + .../jumpstart/artifacts/instance_types.py | 10 +- src/sagemaker/jumpstart/artifacts/kwargs.py | 16 + .../jumpstart/artifacts/metric_definitions.py | 4 + .../jumpstart/artifacts/model_packages.py | 8 + .../jumpstart/artifacts/model_uris.py | 11 + src/sagemaker/jumpstart/artifacts/payloads.py | 4 + .../jumpstart/artifacts/predictors.py | 32 + .../jumpstart/artifacts/resource_names.py | 4 + .../artifacts/resource_requirements.py | 4 + .../jumpstart/artifacts/script_uris.py | 10 +- src/sagemaker/jumpstart/cache.py | 139 +- src/sagemaker/jumpstart/constants.py | 5 + src/sagemaker/jumpstart/enums.py | 25 + src/sagemaker/jumpstart/estimator.py | 19 +- src/sagemaker/jumpstart/factory/estimator.py | 50 +- src/sagemaker/jumpstart/factory/model.py | 93 +- src/sagemaker/jumpstart/hub/__init__.py | 0 src/sagemaker/jumpstart/hub/constants.py | 16 + src/sagemaker/jumpstart/hub/hub.py | 307 + src/sagemaker/jumpstart/hub/interfaces.py | 831 +++ src/sagemaker/jumpstart/hub/parser_utils.py | 56 + src/sagemaker/jumpstart/hub/parsers.py | 262 + src/sagemaker/jumpstart/hub/types.py | 35 + src/sagemaker/jumpstart/hub/utils.py | 219 + src/sagemaker/jumpstart/model.py | 92 +- src/sagemaker/jumpstart/types.py | 531 +- src/sagemaker/jumpstart/utils.py | 25 + src/sagemaker/jumpstart/validators.py | 2 + src/sagemaker/local/local_session.py | 9 + src/sagemaker/metric_definitions.py | 4 + src/sagemaker/mlflow/__init__.py | 12 + src/sagemaker/mlflow/tracking_server.py | 50 + src/sagemaker/model.py | 95 +- src/sagemaker/model_card/__init__.py | 1 + src/sagemaker/model_card/model_card.py | 28 +- src/sagemaker/model_uris.py | 4 + src/sagemaker/multidatamodel.py | 2 + src/sagemaker/mxnet/model.py | 10 + src/sagemaker/payloads.py | 14 +- src/sagemaker/pipeline.py | 8 + src/sagemaker/predictor.py | 4 + src/sagemaker/pytorch/model.py | 10 + src/sagemaker/remote_function/client.py | 3 + src/sagemaker/resource_requirements.py | 16 +- src/sagemaker/script_uris.py | 16 +- src/sagemaker/serializers.py | 28 +- src/sagemaker/serve/builder/djl_builder.py | 142 +- src/sagemaker/serve/builder/model_builder.py | 219 +- src/sagemaker/serve/builder/tei_builder.py | 3 +- src/sagemaker/serve/builder/tgi_builder.py | 3 +- .../serve/builder/transformers_builder.py | 2 +- .../serve/model_format/mlflow/constants.py | 5 +- .../serve/model_format/mlflow/utils.py | 22 - .../model_server/djl_serving/inference.py | 63 - .../serve/model_server/djl_serving/prepare.py | 157 +- .../serve/model_server/djl_serving/server.py | 1 + .../serve/model_server/djl_serving/utils.py | 69 +- .../serve/model_server/tei/server.py | 1 + .../serve/model_server/tgi/server.py | 1 + src/sagemaker/serve/utils/hf_utils.py | 53 + .../serve/utils/lineage_constants.py | 2 + src/sagemaker/serve/utils/lineage_utils.py | 71 +- src/sagemaker/serve/utils/tuning.py | 4 +- src/sagemaker/serve/utils/types.py | 12 - src/sagemaker/session.py | 489 +- src/sagemaker/sklearn/model.py | 10 + src/sagemaker/telemetry/__init__.py | 16 + src/sagemaker/telemetry/constants.py | 42 + src/sagemaker/telemetry/telemetry_logging.py | 256 + src/sagemaker/tensorflow/model.py | 10 + src/sagemaker/workflow/_utils.py | 5 + src/sagemaker/workflow/step_collections.py | 5 +- src/sagemaker/xgboost/model.py | 10 + tests/conftest.py | 4 +- .../mlflow/pytorch/requirements.txt | 2 +- .../mlflow/tensorflow/requirements.txt | 2 +- .../mlflow/xgboost/requirements.txt | 2 +- tests/integ/sagemaker/conftest.py | 4 +- .../estimator/test_jumpstart_estimator.py | 1 + .../jumpstart/model/test_jumpstart_model.py | 63 + .../test_model_create_and_registration.py | 431 ++ tests/integ/test_model_package.py | 221 +- .../jumpstart/test_accept_types.py | 10 +- .../jumpstart/test_content_types.py | 9 +- .../jumpstart/test_deserializers.py | 4 + .../jumpstart/test_default.py | 8 + .../hyperparameters/jumpstart/test_default.py | 6 + .../jumpstart/test_validate.py | 6 + .../image_uris/jumpstart/test_common.py | 8 + .../unit/sagemaker/image_uris/test_smp_v2.py | 2 +- .../jumpstart/test_instance_types.py | 8 + tests/unit/sagemaker/jumpstart/constants.py | 5324 +++++++++++------ .../jumpstart/estimator/test_estimator.py | 12 + .../unit/sagemaker/jumpstart/hub/__init__.py | 0 .../unit/sagemaker/jumpstart/hub/test_hub.py | 235 + .../jumpstart/hub/test_interfaces.py | 981 +++ .../sagemaker/jumpstart/hub/test_utils.py | 256 + .../sagemaker/jumpstart/model/test_model.py | 75 +- tests/unit/sagemaker/jumpstart/test_cache.py | 8 +- .../jumpstart/test_notebook_utils.py | 13 +- .../sagemaker/jumpstart/test_predictor.py | 1 + tests/unit/sagemaker/jumpstart/test_types.py | 6 + tests/unit/sagemaker/jumpstart/utils.py | 54 +- .../sagemaker/local/test_local_session.py | 53 +- .../jumpstart/test_default.py | 4 + tests/unit/sagemaker/mlflow/__init__.py | 0 .../sagemaker/mlflow/test_tracking_server.py | 42 + tests/unit/sagemaker/model/test_deploy.py | 12 +- tests/unit/sagemaker/model/test_model.py | 15 +- .../sagemaker/model/test_model_package.py | 52 +- .../model_uris/jumpstart/test_common.py | 8 + .../sagemaker/remote_function/test_client.py | 7 +- .../jumpstart/test_resource_requirements.py | 7 + .../script_uris/jumpstart/test_common.py | 8 + .../serializers/jumpstart/test_serializers.py | 4 + .../serve/builder/test_djl_builder.py | 237 +- .../serve/builder/test_model_builder.py | 448 +- .../model_format/mlflow/test_mlflow_utils.py | 12 - .../djl_serving/test_djl_prepare.py | 130 +- .../serve/model_server/tei/test_server.py | 1 + .../serve/utils/test_lineage_utils.py | 82 +- .../telemetry/test_telemetry_logging.py | 302 + tests/unit/test_default_bucket.py | 13 +- tests/unit/test_djl_inference.py | 845 +-- tests/unit/test_estimator.py | 17 +- tests/unit/test_session.py | 204 + tox.ini | 8 +- 154 files changed, 11921 insertions(+), 5147 deletions(-) delete mode 100644 src/sagemaker/djl_inference/defaults.py create mode 100644 src/sagemaker/djl_inference/djl_predictor.py create mode 100644 src/sagemaker/jumpstart/hub/__init__.py create mode 100644 src/sagemaker/jumpstart/hub/constants.py create mode 100644 src/sagemaker/jumpstart/hub/hub.py create mode 100644 src/sagemaker/jumpstart/hub/interfaces.py create mode 100644 src/sagemaker/jumpstart/hub/parser_utils.py create mode 100644 src/sagemaker/jumpstart/hub/parsers.py create mode 100644 src/sagemaker/jumpstart/hub/types.py create mode 100644 src/sagemaker/jumpstart/hub/utils.py create mode 100644 src/sagemaker/mlflow/__init__.py create mode 100644 src/sagemaker/mlflow/tracking_server.py delete mode 100644 src/sagemaker/serve/model_server/djl_serving/inference.py create mode 100644 src/sagemaker/serve/utils/hf_utils.py create mode 100644 src/sagemaker/telemetry/__init__.py create mode 100644 src/sagemaker/telemetry/constants.py create mode 100644 src/sagemaker/telemetry/telemetry_logging.py create mode 100644 tests/unit/sagemaker/jumpstart/hub/__init__.py create mode 100644 tests/unit/sagemaker/jumpstart/hub/test_hub.py create mode 100644 tests/unit/sagemaker/jumpstart/hub/test_interfaces.py create mode 100644 tests/unit/sagemaker/jumpstart/hub/test_utils.py create mode 100644 tests/unit/sagemaker/mlflow/__init__.py create mode 100644 tests/unit/sagemaker/mlflow/test_tracking_server.py create mode 100644 tests/unit/sagemaker/telemetry/test_telemetry_logging.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 63e5114f10..b844c5e357 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,71 @@ # Changelog +## v2.224.1 (2024-06-21) + +### Bug Fixes and Other Changes + + * JumpStart CuratedHub Launch + * Update README.rst to show conda-forge version of SageMaker SDK + * Update tox.ini + * chore(deps): bump apache-airflow from 2.9.1 to 2.9.2 in /requirements/extras + * Model server override logic + +## v2.224.0 (2024-06-19) + +### Features + + * JumpStartModel attach + +### Bug Fixes and Other Changes + + * feat(sagemaker-mlflow): New features for SageMaker MLflow + * Upgrading to PT 2.3 for release + * chore: use ml.g5.2xlarge for integ test + * Enable telemetry logging for Remote function + * Fix Dependabot Issues - MLFlow Version + +## v2.223.0 (2024-06-13) + +### Features + + * add 'ModelCard' property to Register step + +### Bug Fixes and Other Changes + + * Fix Sniping bug fix + * Implement custom telemetry logging in SDK + * Fix ci unit-tests + * update image_uri_configs 06-12-2024 07:17:03 PST + +## v2.222.1 (2024-06-12) + +### Bug Fixes and Other Changes + + * First changes + * estimator.deploy not respecting instance type + +## v2.222.0 (2024-06-07) + +### Features + + * jumpstart telemetry + +### Bug Fixes and Other Changes + + * update image_uri_configs 06-06-2024 07:17:31 PST + * bump requests from 2.31.0 to 2.32.2 in /requirements/extras + * chore: add HF LLM neuronx 0.0.23 image + * Updates for DJL 0.28.0 release + * chore(deps): bump mlflow from 2.11.1 to 2.12.1 in /tests/data/serve_resources/mlflow/tensorflow + * chore(deps): bump mlflow from 2.11.1 to 2.12.1 in /tests/data/serve_resources/mlflow/xgboost + * chore(deps): bump mlflow from 2.10.2 to 2.12.1 in /tests/data/serve_resources/mlflow/pytorch + * chore(deps): bump apache-airflow from 2.9.0 to 2.9.1 in /requirements/extras + * chore(deps): bump requests from 2.31.0 to 2.32.2 in /tests/data/serve_resources/mlflow/pytorch + * Fix ci unit-tests + * Making project name in workflow files dynamic + * update image_uri_configs 05-29-2024 07:17:35 PST + * Update: SM Endpoint Routing Strategy Support. + ## v2.221.1 (2024-05-22) ### Bug Fixes and Other Changes diff --git a/README.rst b/README.rst index e59b2da9c5..68cf79c55b 100644 --- a/README.rst +++ b/README.rst @@ -10,6 +10,10 @@ SageMaker Python SDK :target: https://pypi.python.org/pypi/sagemaker :alt: Latest Version +.. image:: https://img.shields.io/conda/vn/conda-forge/sagemaker-python-sdk.svg + :target: https://anaconda.org/conda-forge/sagemaker-python-sdk + :alt: Conda-Forge Version + .. image:: https://img.shields.io/pypi/pyversions/sagemaker.svg :target: https://pypi.python.org/pypi/sagemaker :alt: Supported Python Versions @@ -95,6 +99,13 @@ SageMaker Python SDK is tested on: - Python 3.10 - Python 3.11 +Telemetry +~~~~~~~~~~~~~~~ + +The ``sagemaker`` library has telemetry enabled to help us better understand user needs, diagnose issues, and deliver new features. This telemetry tracks the usage of various SageMaker functions. + +If you prefer to opt out of telemetry, you can easily do so by setting the ``TelemetryOptOut`` parameter to ``true`` in the SDK defaults configuration. For detailed instructions, please visit `Configuring and using defaults with the SageMaker Python SDK `__. + AWS Permissions ~~~~~~~~~~~~~~~ diff --git a/VERSION b/VERSION index e55266069e..83a9987ea1 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.221.2.dev0 +2.224.2.dev0 diff --git a/doc/frameworks/djl/sagemaker.djl_inference.rst b/doc/frameworks/djl/sagemaker.djl_inference.rst index fd34ae1a23..5b4d138776 100644 --- a/doc/frameworks/djl/sagemaker.djl_inference.rst +++ b/doc/frameworks/djl/sagemaker.djl_inference.rst @@ -5,31 +5,7 @@ DJL Classes DJLModel --------------------------- -.. autoclass:: sagemaker.djl_inference.model.DJLModel - :members: - :undoc-members: - :show-inheritance: - -DeepSpeedModel ---------------------------- - -.. autoclass:: sagemaker.djl_inference.model.DeepSpeedModel - :members: - :undoc-members: - :show-inheritance: - -HuggingFaceAccelerateModel ---------------------------- - -.. autoclass:: sagemaker.djl_inference.model.HuggingFaceAccelerateModel - :members: - :undoc-members: - :show-inheritance: - -FasterTransformerModel ---------------------------- - -.. autoclass:: sagemaker.djl_inference.model.FasterTransformerModel +.. autoclass:: sagemaker.djl_inference.DJLModel :members: :undoc-members: :show-inheritance: @@ -37,7 +13,7 @@ FasterTransformerModel DJLPredictor --------------------------- -.. autoclass:: sagemaker.djl_inference.model.DJLPredictor +.. autoclass:: sagemaker.djl_inference.DJLPredictor :members: :undoc-members: :show-inheritance: diff --git a/doc/frameworks/djl/using_djl.rst b/doc/frameworks/djl/using_djl.rst index 217f5ed7dd..63b8acd684 100644 --- a/doc/frameworks/djl/using_djl.rst +++ b/doc/frameworks/djl/using_djl.rst @@ -2,14 +2,11 @@ Use DJL with the SageMaker Python SDK ####################################### -With the SageMaker Python SDK, you can use Deep Java Library to host models on Amazon SageMaker. - `Deep Java Library (DJL) Serving `_ is a high performance universal stand-alone model serving solution powered by `DJL `_. DJL Serving supports loading models trained with a variety of different frameworks. With the SageMaker Python SDK you can -use DJL Serving to host large models using backends like DeepSpeed and HuggingFace Accelerate. +use DJL Serving to host large language models for text-generation and text-embedding use-cases. -For information about supported versions of DJL Serving, see the `AWS documentation `_. -We recommend that you use the latest supported version because that's where we focus our development efforts. +You can learn more about Large Model Inference using DJLServing on the `docs site `_. For general information about using the SageMaker Python SDK, see :ref:`overview:Using the SageMaker Python SDK`. @@ -19,238 +16,57 @@ For general information about using the SageMaker Python SDK, see :ref:`overview Deploy DJL models ******************* -With the SageMaker Python SDK, you can use DJL Serving to host models that have been saved in the HuggingFace pretrained format. +With the SageMaker Python SDK, you can use DJL Serving to host text-generation and text-embedding models that have been saved in the HuggingFace pretrained format. These can either be models you have trained/fine-tuned yourself, or models available publicly from the HuggingFace Hub. -DJL Serving in the SageMaker Python SDK supports hosting models for the popular HuggingFace NLP tasks, as well as Stable Diffusion. - -You can either deploy your model using DeepSpeed, FasterTransformer, or HuggingFace Accelerate, or let DJL Serving determine the best backend based on your model architecture and configuration. .. code:: python - # Create a DJL Model, backend is chosen automatically + # DJLModel will infer which container to use, and apply some starter configuration djl_model = DJLModel( - "s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id - "my_sagemaker_role", - dtype="fp16", + model_id="", + role="my_sagemaker_role", task="text-generation", - number_of_partitions=2 # number of gpus to partition the model across ) # Deploy the model to an Amazon SageMaker Endpoint and get a Predictor predictor = djl_model.deploy("ml.g5.12xlarge", initial_instance_count=1) -If you want to use a specific backend, then you can create an instance of the corresponding model directly. +Alternatively, you can provide full specifications to the DJLModel to have full control over the model configuration: .. code:: python - # Create a model using the DeepSpeed backend - deepspeed_model = DeepSpeedModel( - "s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id - "my_sagemaker_role", - dtype="bf16", - task="text-generation", - tensor_parallel_degree=2, # number of gpus to partition the model across using tensor parallelism - ) - - # Create a model using the HuggingFace Accelerate backend - - hf_accelerate_model = HuggingFaceAccelerateModel( - "s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id - "my_sagemaker_role", - dtype="fp16", - task="text-generation", - number_of_partitions=2, # number of gpus to partition the model across - ) - - # Create a model using the FasterTransformer backend - - fastertransformer_model = FasterTransformerModel( - "s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id - "my_sagemaker_role", - data_type="fp16", + djl_model = DJLModel( + model_id="", + role="my_sagemaker_role", task="text-generation", - tensor_parallel_degree=2, # number of gpus to partition the model across + engine="Python", + env={ + "OPTION_ROLLING_BATCH": "lmi-dist", + "TENSOR_PARALLEL_DEGREE": "2", + "OPTION_DTYPE": "bf16", + "OPTION_MAX_ROLLING_BATCH_SIZE": "64", + }, + image_uri=, ) - # Deploy the model to an Amazon SageMaker Endpoint and get a Predictor - deepspeed_predictor = deepspeed_model.deploy("ml.g5.12xlarge", - initial_instance_count=1) - hf_accelerate_predictor = hf_accelerate_model.deploy("ml.g5.12xlarge", - initial_instance_count=1) - fastertransformer_predictor = fastertransformer_model.deploy("ml.g5.12xlarge", - initial_instance_count=1) - -Regardless of which way you choose to create your model, a ``Predictor`` object is returned. You can use this ``Predictor`` -to do inference on the endpoint hosting your DJLModel. + predictor = djl_model.deploy("ml.g5.12xlarge", + initial_instance_count=1) +Regardless of how you create your model, a ``Predictor`` object is returned. Each ``Predictor`` provides a ``predict`` method, which can do inference with json data, numpy arrays, or Python lists. Inference data are serialized and sent to the DJL Serving model server by an ``InvokeEndpoint`` SageMaker operation. The ``predict`` method returns the result of inference against your model. By default, the inference data is serialized to a json string, and the inference result is a Python dictionary. -Model Directory Structure -========================= - -There are two components that are needed to deploy DJL Serving Models on Sagemaker. -1. Model Artifacts (required) -2. Inference code and Model Server Properties (optional) - -These are stored and handled separately. Model artifacts should not be stored with the custom inference code and -model server configuration. - -Model Artifacts ---------------- - -DJL Serving supports two ways to load models for inference. -1. A HuggingFace Hub model id. -2. Uncompressed model artifacts stored in a S3 bucket. - -HuggingFace Hub model id -^^^^^^^^^^^^^^^^^^^^^^^^ - -Using a HuggingFace Hub model id is the easiest way to get started with deploying Large Models via DJL Serving on SageMaker. -DJL Serving will use this model id to download the model at runtime via the HuggingFace Transformers ``from_pretrained`` API. -This method makes it easy to deploy models quickly, but for very large models the download time can become unreasonable. - -For example, you can deploy the EleutherAI gpt-j-6B model like this: - -.. code:: - - model = DJLModel( - "EleutherAI/gpt-j-6B", - "my_sagemaker_role", - dtype="fp16", - number_of_partitions=2 - ) - - predictor = model.deploy("ml.g5.12xlarge") - -Uncompressed Model Artifacts stored in a S3 bucket -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -For models that are larger than 20GB (total checkpoint size), we recommend that you store the model in S3. -Download times will be much faster compared to downloading from the HuggingFace Hub at runtime. -DJL Serving Models expect a different model structure than most of the other frameworks in the SageMaker Python SDK. -Specifically, DJLModels do not support loading models stored in tar.gz format. -This is because DJL Serving is optimized for large models, and it implements a fast downloading mechanism for large models that require the artifacts be uncompressed. - -For example, lets say you want to deploy the EleutherAI/gpt-j-6B model available on the HuggingFace Hub. -You can download the model and upload to S3 like this: - -.. code:: - - # Requires Git LFS - git clone https://huggingface.co/EleutherAI/gpt-j-6B - - # Upload to S3 - aws s3 sync gpt-j-6B s3://my_bucket/gpt-j-6B - -You would then pass "s3://my_bucket/gpt-j-6B" as ``model_id`` to the ``DJLModel`` like this: - -.. code:: - - model = DJLModel( - "s3://my_bucket/gpt-j-6B", - "my_sagemaker_role", - dtype="fp16", - number_of_partitions=2 - ) - - predictor = model.deploy("ml.g5.12xlarge") - -For language models we expect that the model weights, model config, and tokenizer config are provided in S3. The model -should be loadable from the HuggingFace Transformers AutoModelFor.from_pretrained API, where task -is the NLP task you want to host the model for. The weights must be stored as PyTorch compatible checkpoints. - -Example: - -.. code:: - - my_bucket/my_model/ - |- config.json - |- added_tokens.json - |- config.json - |- pytorch_model-*-of-*.bin # model weights can be partitioned into multiple checkpoints - |- tokenizer.json - |- tokenizer_config.json - |- vocab.json - -For Stable Diffusion models, the model should be loadable from the HuggingFace Diffusers DiffusionPipeline.from_pretrained API. - -Inference code and Model Server Properties ------------------------------------------- - -You can provide custom inference code and model server configuration by specifying the ``source_dir`` and -``entry_point`` arguments of the ``DJLModel``. These are not required. The model server configuration can be generated -based on the arguments passed to the constructor, and we provide default inference handler code for DeepSpeed, -HuggingFaceAccelerate, and Stable Diffusion. You can find these handler implementations in the `DJL Serving Github repository. `_ - -You can find documentation for the model server configurations on the `DJL Serving Docs website `_. - -The code and configuration you want to deploy can either be stored locally or in S3. These files will be bundled into -a tar.gz file that will be uploaded to SageMaker. - -For example: - -.. code:: - - sourcedir/ - |- script.py # Inference handler code - |- serving.properties # Model Server configuration file - |- requirements.txt # Additional Python requirements that will be installed at runtime via PyPi - -In the above example, sourcedir will be bundled and compressed into a tar.gz file and uploaded as part of creating the Inference Endpoint. - -The DJL Serving Model Server -============================ - -The endpoint you create with ``deploy`` runs the DJL Serving model server. -The model server loads the model from S3 and performs inference on the model in response to SageMaker ``InvokeEndpoint`` API calls. - -DJL Serving is highly customizable. You can control aspects of both model loading and model serving. Most of the model server -configuration are exposed through the ``DJLModel`` API. The SageMaker Python SDK will use the values it is passed to -create the proper configuration file used when creating the inference endpoint. You can optionally provide your own -``serving.properties`` file via the ``source_dir`` argument. You can find documentation about serving.properties in the -`DJL Serving Documentation for model specific settings. `_ - -Within the SageMaker Python SDK, DJL Serving is used in Python mode. This allows users to provide their inference script, -and data processing scripts in python. For details on how to write custom inference and data processing code, please -see the `DJL Serving Documentation on Python Mode. `_ - -For more information about DJL Serving, see the `DJL Serving documentation. `_ - -************************** -Ahead of time partitioning -************************** - -To optimize the deployment of large models that do not fit in a single GPU, the model’s tensor weights are partitioned at -runtime and each partition is loaded in individual GPU. But runtime partitioning takes significant amount of time and -memory on model loading. So, DJLModel offers an ahead of time partitioning capability for DeepSpeed and FasterTransformer -engines, which lets you partition your model weights and save them before deployment. HuggingFace does not support -tensor parallelism, so ahead of time partitioning cannot be done for it. In our experiment with GPT-J model, loading -this model with partitioned checkpoints increased the model loading time by 40%. - -`partition` method invokes an Amazon SageMaker Training job to partition the model and upload those partitioned -checkpoints to S3 bucket. You can either provide your desired S3 bucket to upload the partitioned checkpoints or it will be -uploaded to the default SageMaker S3 bucket. Please note that this S3 bucket will be remembered for deployment. When you -call `deploy` method after partition, DJLServing downloads the partitioned model checkpoints directly from the uploaded -s3 url, if available. - -.. code:: - - # partitions the model using Amazon Sagemaker Training Job. - djl_model.partition("ml.g5.12xlarge") +************************************** +DJL Serving for Large Model Inference +************************************** - predictor = deepspeed_model.deploy("ml.g5.12xlarge", - initial_instance_count=1) +You can learn more about using DJL Serving for Large Model Inference use-cases on our `documentation site `_. -*********************** -SageMaker DJL Classes -*********************** -For information about the different DJL Serving related classes in the SageMaker Python SDK, see https://sagemaker.readthedocs.io/en/stable/frameworks/djl/sagemaker.djl_inference.html. ******************************** SageMaker DJL Serving Containers diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index 7dae26fcac..60904c51b0 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -12,7 +12,7 @@ awslogs==0.14.0 black==24.3.0 stopit==1.1.2 # Update tox.ini to have correct version of airflow constraints file -apache-airflow==2.9.1 +apache-airflow==2.9.2 apache-airflow-providers-amazon==7.2.1 attrs>=23.1.0,<24 fabric==2.6.0 @@ -37,3 +37,4 @@ nbformat>=5.9,<6 accelerate>=0.24.1,<=0.27.0 schema==0.7.5 tensorflow>=2.1,<=2.16 +mlflow>=2.12.2,<2.13 diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index 7541425868..b48adda44c 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -24,6 +24,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -37,6 +38,8 @@ def retrieve_options( retrieve the supported accept types. (Default: None). model_version (str): The version of the model for which to retrieve the supported accept types. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -60,11 +63,12 @@ def retrieve_options( ) return artifacts._retrieve_supported_accept_types( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) @@ -73,6 +77,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -88,6 +93,8 @@ def retrieve_default( retrieve the default accept type. (Default: None). model_version (str): The version of the model for which to retrieve the default accept type. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -112,11 +119,12 @@ def retrieve_default( ) return artifacts._retrieve_default_accept_type( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, config_name=config_name, diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index 9fce051454..963eaaa474 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -28,6 +28,10 @@ from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.chainer import defaults from sagemaker.deserializers import NumpyDeserializer +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.predictor import Predictor from sagemaker.serializers import NumpySerializer from sagemaker.utils import to_string @@ -175,6 +179,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -226,6 +231,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: str: A string of SageMaker Model Package ARN. @@ -266,6 +273,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) def prepare_container_def( @@ -274,6 +282,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """Return a container definition with framework configuration set in model environment. @@ -325,6 +334,7 @@ def prepare_container_def( self.model_data, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def serving_image_uri( diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index 627feca0d6..16c81d6d77 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -24,6 +24,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -37,6 +38,8 @@ def retrieve_options( retrieve the supported content types. (Default: None). model_version (str): The version of the model for which to retrieve the supported content types. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -60,11 +63,12 @@ def retrieve_options( ) return artifacts._retrieve_supported_content_types( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) @@ -73,6 +77,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -88,6 +93,8 @@ def retrieve_default( retrieve the default content type. (Default: None). model_version (str): The version of the model for which to retrieve the default content type. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -112,11 +119,12 @@ def retrieve_default( ) return artifacts._retrieve_default_content_type( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, config_name=config_name, diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 02e61149ec..957a9dfb0c 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -43,6 +43,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -56,6 +57,8 @@ def retrieve_options( retrieve the supported deserializers. (Default: None). model_version (str): The version of the model for which to retrieve the supported deserializers. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -80,11 +83,12 @@ def retrieve_options( ) return artifacts._retrieve_deserializer_options( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) @@ -93,6 +97,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -108,6 +113,8 @@ def retrieve_default( retrieve the default deserializer. (Default: None). model_version (str): The version of the model for which to retrieve the default deserializer. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -133,11 +140,12 @@ def retrieve_default( ) return artifacts._retrieve_default_deserializer( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, config_name=config_name, diff --git a/src/sagemaker/djl_inference/__init__.py b/src/sagemaker/djl_inference/__init__.py index 0f6b867318..dd8b005d1e 100644 --- a/src/sagemaker/djl_inference/__init__.py +++ b/src/sagemaker/djl_inference/__init__.py @@ -13,8 +13,5 @@ """Placeholder docstring""" from __future__ import absolute_import -from sagemaker.djl_inference.model import DJLPredictor # noqa: F401 +from sagemaker.djl_inference.djl_predictor import DJLPredictor # noqa: F401 from sagemaker.djl_inference.model import DJLModel # noqa: F401 -from sagemaker.djl_inference.model import DeepSpeedModel # noqa: F401 -from sagemaker.djl_inference.model import HuggingFaceAccelerateModel # noqa: F401 -from sagemaker.djl_inference.model import FasterTransformerModel # noqa: F401 diff --git a/src/sagemaker/djl_inference/defaults.py b/src/sagemaker/djl_inference/defaults.py deleted file mode 100644 index 64699de8f9..0000000000 --- a/src/sagemaker/djl_inference/defaults.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Placeholder docstring""" -from __future__ import absolute_import - -STABLE_DIFFUSION_MODEL_TYPE = "stable-diffusion" - -VALID_MODEL_CONFIG_FILES = ["config.json", "model_index.json"] - -DEEPSPEED_RECOMMENDED_ARCHITECTURES = { - "bloom", - "opt", - "gpt_neox", - "gptj", - "gpt_neo", - "gpt2", - "xlm-roberta", - "roberta", - "bert", - STABLE_DIFFUSION_MODEL_TYPE, -} - -FASTER_TRANSFORMER_RECOMMENDED_ARCHITECTURES = { - "t5", -} - -FASTER_TRANSFORMER_SUPPORTED_ARCHITECTURES = { - "bert", - "gpt2", - "bloom", - "opt", - "gptj", - "gpt_neox", - "gpt_neo", - "t5", -} - -ALLOWED_INSTANCE_FAMILIES = { - "ml.g4dn", - "ml.g5", - "ml.p3", - "ml.p3dn", - "ml.p4", - "ml.p4d", - "ml.p4de", - "local_gpu", -} - -REVISION_MAPPING = {"fp16": "float16", "fp32": "float32"} diff --git a/src/sagemaker/djl_inference/djl_predictor.py b/src/sagemaker/djl_inference/djl_predictor.py new file mode 100644 index 0000000000..e6ab10f676 --- /dev/null +++ b/src/sagemaker/djl_inference/djl_predictor.py @@ -0,0 +1,58 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Default Predictor for JSON inputs/outputs used with DJL LMI containers""" +from __future__ import absolute_import +from sagemaker.predictor import Predictor +from sagemaker import Session +from sagemaker.serializers import BaseSerializer, JSONSerializer +from sagemaker.deserializers import BaseDeserializer, JSONDeserializer + + +class DJLPredictor(Predictor): + """A Predictor for inference against DJL Model Endpoints. + + This is able to serialize Python lists, dictionaries, and numpy arrays to + multidimensional tensors for DJL inference. + """ + + def __init__( + self, + endpoint_name: str, + sagemaker_session: Session = None, + serializer: BaseSerializer = JSONSerializer(), + deserializer: BaseDeserializer = JSONDeserializer(), + component_name=None, + ): + """Initialize a ``DJLPredictor`` + + Args: + endpoint_name (str): The name of the endpoint to perform inference + on. + sagemaker_session (sagemaker.session.Session): Session object that + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, the estimator creates one + using the default AWS configuration chain. + serializer (sagemaker.serializers.BaseSerializer): Optional. Default + serializes input data to json format. + deserializer (sagemaker.deserializers.BaseDeserializer): Optional. + Default parses the response from json format to dictionary. + component_name (str): Optional. Name of the Amazon SageMaker inference + component corresponding the predictor. + """ + super(DJLPredictor, self).__init__( + endpoint_name, + sagemaker_session, + serializer=serializer, + deserializer=deserializer, + component_name=component_name, + ) diff --git a/src/sagemaker/djl_inference/model.py b/src/sagemaker/djl_inference/model.py index efbb44460c..3fa523c605 100644 --- a/src/sagemaker/djl_inference/model.py +++ b/src/sagemaker/djl_inference/model.py @@ -13,308 +13,52 @@ """Placeholder docstring""" from __future__ import absolute_import -import json import logging -import os.path -import urllib.request -from json import JSONDecodeError -from urllib.error import HTTPError, URLError -from enum import Enum -from typing import Optional, Union, Dict, Any, List +from typing import Optional, Dict, Any -import sagemaker -from sagemaker import s3, Predictor, image_uris, fw_utils -from sagemaker.deserializers import JSONDeserializer, BaseDeserializer -from sagemaker.djl_inference import defaults -from sagemaker.model import FrameworkModel -from sagemaker.s3_utils import s3_path_join -from sagemaker.serializers import JSONSerializer, BaseSerializer +from sagemaker import image_uris +from sagemaker.model import Model from sagemaker.session import Session -from sagemaker.utils import _tmpdir, _create_or_update_code_dir, format_tags -from sagemaker.workflow.entities import PipelineVariable -from sagemaker.estimator import Estimator -from sagemaker.s3 import S3Uploader -logger = logging.getLogger("sagemaker") +from sagemaker.djl_inference.djl_predictor import DJLPredictor -# DJL Serving uses log4j, so we convert python logging level to log4j equivalent -_LOG_LEVEL_MAP = { - logging.INFO: "info", - logging.DEBUG: "debug", - logging.WARNING: "warn", - logging.ERROR: "error", - logging.FATAL: "fatal", - logging.CRITICAL: "fatal", - logging.NOTSET: "off", -} +logger = logging.getLogger(__name__) -class DJLServingEngineEntryPointDefaults(Enum): - """Enum describing supported engines and corresponding default inference handler modules.""" +def _set_env_var_from_property( + property_value: Optional[Any], env_key: str, env: dict, override_env_var=False +) -> dict: + """Utility method to set an environment variable configuration""" + if not property_value: + return env + if override_env_var or env_key not in env: + env[env_key] = str(property_value) + return env - DEEPSPEED = ("DeepSpeed", "djl_python.deepspeed") - HUGGINGFACE_ACCELERATE = ("Python", "djl_python.huggingface") - STABLE_DIFFUSION = ("DeepSpeed", "djl_python.stable-diffusion") - FASTER_TRANSFORMER = ("FasterTransformer", "djl_python.fastertransformer") - -class DJLPredictor(Predictor): - """A Predictor for inference against DJL Model Endpoints. - - This is able to serialize Python lists, dictionaries, and numpy arrays to - multidimensional tensors for DJL inference. - """ - - def __init__( - self, - endpoint_name: str, - sagemaker_session: Session = None, - serializer: BaseSerializer = JSONSerializer(), - deserializer: BaseDeserializer = JSONDeserializer(), - component_name=None, - ): - """Initialize a ``DJLPredictor`` - - Args: - endpoint_name (str): The name of the endpoint to perform inference - on. - sagemaker_session (sagemaker.session.Session): Session object that - manages interactions with Amazon SageMaker APIs and any other - AWS services needed. If not specified, the estimator creates one - using the default AWS configuration chain. - serializer (sagemaker.serializers.BaseSerializer): Optional. Default - serializes input data to json format. - deserializer (sagemaker.deserializers.BaseDeserializer): Optional. - Default parses the response from json format to dictionary. - component_name (str): Optional. Name of the Amazon SageMaker inference - component corresponding the predictor. - """ - super(DJLPredictor, self).__init__( - endpoint_name, - sagemaker_session, - serializer=serializer, - deserializer=deserializer, - component_name=component_name, - ) - - -def _determine_engine_for_model(model_type: str, num_partitions: int, num_heads: int): - """Placeholder docstring""" - - # Tensor Parallelism is only possible if attention heads can be split evenly - # across devices - if num_heads is not None and num_partitions is not None and num_heads % num_partitions: - return HuggingFaceAccelerateModel - if model_type in defaults.DEEPSPEED_RECOMMENDED_ARCHITECTURES: - return DeepSpeedModel - if model_type in defaults.FASTER_TRANSFORMER_RECOMMENDED_ARCHITECTURES: - return FasterTransformerModel - return HuggingFaceAccelerateModel - - -def _validate_engine_for_model_type(cls, model_type: str, num_partitions: int, num_heads: int): - """Placeholder docstring""" - - if cls == DeepSpeedModel: - if num_heads is not None and num_partitions is not None and num_heads % num_partitions: - raise ValueError( - "The number of attention heads is not evenly divisible by the number of partitions." - "Please set the number of partitions such that the number of attention heads can be" - "evenly split across the partitions." - ) - if cls == FasterTransformerModel: - if model_type not in defaults.FASTER_TRANSFORMER_SUPPORTED_ARCHITECTURES: - raise ValueError( - f"The model architecture {model_type} is currently not supported by " - f"FasterTransformer. Please use a different engine, or use the DJLModel" - f"to let SageMaker pick a recommended engine for this model." - ) - return cls - - -def _read_existing_serving_properties(directory: str): - """Placeholder docstring""" - - serving_properties_path = os.path.join(directory, "serving.properties") - properties = {} - if os.path.exists(serving_properties_path): - with open(serving_properties_path, "r") as f: - for line in f: - if line.startswith("#") or len(line.strip()) == 0: - continue - key, val = line.split("=", 1) - properties[key] = val - return properties - - -def _get_model_config_properties_from_s3(model_s3_uri: str, sagemaker_session: Session): - """Placeholder docstring""" - - s3_files = s3.S3Downloader.list(model_s3_uri, sagemaker_session=sagemaker_session) - model_config = None - for config in defaults.VALID_MODEL_CONFIG_FILES: - config_file = os.path.join(model_s3_uri, config) - if config_file in s3_files: - model_config = json.loads( - s3.S3Downloader.read_file(config_file, sagemaker_session=sagemaker_session) - ) - break - if not model_config: - raise ValueError( - f"Did not find a config.json or model_index.json file in {model_s3_uri}. Please make " - f"sure a config.json exists (or model_index.json for Stable Diffusion Models) in" - f"the provided s3 location" - ) - return model_config - - -def _get_model_config_properties_from_hf(model_id: str, hf_hub_token: str = None): - """Placeholder docstring""" - - config_url_prefix = f"https://huggingface.co/{model_id}/raw/main/" - model_config = None - for config in defaults.VALID_MODEL_CONFIG_FILES: - config_file_url = config_url_prefix + config - try: - if hf_hub_token: - config_file_url = urllib.request.Request( - config_file_url, None, {"Authorization": "Bearer " + hf_hub_token} - ) - with urllib.request.urlopen(config_file_url) as response: - model_config = json.load(response) - break - except (HTTPError, URLError, TimeoutError, JSONDecodeError) as e: - if "HTTP Error 401: Unauthorized" in str(e): - raise ValueError( - "Trying to access a gated/private HuggingFace model without valid credentials. " - "Please provide a HUGGING_FACE_HUB_TOKEN in env_vars" - ) - logger.warning( - "Exception encountered while trying to read config file %s. " "Details: %s", - config_file_url, - e, - ) - if not model_config: - raise ValueError( - f"Did not find a config.json or model_index.json file in huggingface hub for " - f"{model_id}. Please make sure a config.json exists (or model_index.json for Stable " - f"Diffusion Models) for this model in the huggingface hub" - ) - return model_config - - -def _create_estimator( - instance_type: str, - s3_output_uri: str, - image_uri: str, - role: str, - sagemaker_session: Optional[Session], - volume_size: int, - vpc_config: Optional[ - Dict[ - str, - List[str], - ] - ] = None, - volume_kms_key=None, - output_kms_key=None, - use_spot_instances: bool = False, - max_wait: int = None, - enable_network_isolation: bool = False, -): - """Placeholder docstring""" - - subnets = None - security_group_ids = None - if vpc_config: - subnets = vpc_config.get("Subnets") - security_group_ids = vpc_config.get("SecurityGroupIds") - - return Estimator( - image_uri=image_uri, - role=role, - instance_count=1, - instance_type=instance_type, - volume_size=volume_size, - volume_kms_key=volume_kms_key, - output_path=s3_output_uri, - output_kms_key=output_kms_key, - sagemaker_session=sagemaker_session, - subnets=subnets, - security_group_ids=security_group_ids, - use_spot_instances=use_spot_instances, - max_wait=max_wait, - enable_network_isolation=enable_network_isolation, - ) - - -class DJLModel(FrameworkModel): +class DJLModel(Model): """A DJL SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``.""" - def __new__( - cls, - model_id: str, - *args, - **kwargs, - ): # pylint: disable=W0613 - """Create a specific subclass of DJLModel for a given engine""" - if model_id.endswith("tar.gz"): - raise ValueError( - "DJLModel does not support model artifacts in tar.gz format." - "Please store the model in uncompressed format and provide the s3 uri of the " - "containing folder" - ) - if model_id.startswith("s3://"): - sagemaker_session = kwargs.get("sagemaker_session") - model_config = _get_model_config_properties_from_s3(model_id, sagemaker_session) - else: - hf_hub_token = kwargs.get("hf_hub_token") - model_config = _get_model_config_properties_from_hf(model_id, hf_hub_token) - if model_config.get("_class_name") == "StableDiffusionPipeline": - model_type = defaults.STABLE_DIFFUSION_MODEL_TYPE - num_heads = 0 - else: - model_type = model_config.get("model_type") - num_heads = model_config.get("n_head") or model_config.get("num_attention_heads") - number_of_partitions = kwargs.get("number_of_partitions") or kwargs.get( - "tensor_parallel_degree" - ) - cls_to_create = ( - _validate_engine_for_model_type(cls, model_type, number_of_partitions, num_heads) - if cls is not DJLModel - else _determine_engine_for_model(model_type, number_of_partitions, num_heads) - ) - instance = super().__new__(cls_to_create) - if model_type == defaults.STABLE_DIFFUSION_MODEL_TYPE: - instance.engine = DJLServingEngineEntryPointDefaults.STABLE_DIFFUSION - elif isinstance(instance, DeepSpeedModel): - instance.engine = DJLServingEngineEntryPointDefaults.DEEPSPEED - elif isinstance(instance, FasterTransformerModel): - instance.engine = DJLServingEngineEntryPointDefaults.FASTER_TRANSFORMER - else: - instance.engine = DJLServingEngineEntryPointDefaults.HUGGINGFACE_ACCELERATE - return instance - def __init__( self, - model_id: str, - role: str, - djl_version: Optional[str] = None, + model_id: Optional[str] = None, + engine: Optional[str] = None, + djl_version: str = "0.28.0", + djl_framework: Optional[str] = None, task: Optional[str] = None, - dtype: str = "fp32", - number_of_partitions: Optional[int] = None, + dtype: Optional[str] = None, + tensor_parallel_degree: Optional[int] = None, min_workers: Optional[int] = None, max_workers: Optional[int] = None, job_queue_size: Optional[int] = None, parallel_loading: bool = False, model_loading_timeout: Optional[int] = None, prediction_timeout: Optional[int] = None, - entry_point: Optional[str] = None, - image_uri: Optional[Union[str, PipelineVariable]] = None, predictor_cls: callable = DJLPredictor, + huggingface_hub_token: Optional[str] = None, **kwargs, ): - """Initialize a DJLModel. + """Initialize a SageMaker model using one of the DJL Model Serving Containers. Args: model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location @@ -322,24 +66,23 @@ def __init__( The model artifacts are expected to be in HuggingFace pre-trained model format (i.e. model should be loadable from the huggingface transformers from_pretrained api, and should also include tokenizer configs if applicable). - role (str): An AWS IAM role specified with either the name or full ARN. The Amazon - SageMaker training jobs and APIs that create Amazon SageMaker - endpoints use this role to access model artifacts. After the endpoint is created, - the inference code might use the IAM role, if it needs to access an AWS resource. + model artifact location must be specified using either the model_id parameter, + model_data parameter, or HF_MODEL_ID environment variable in the env parameter + engine (str): The DJL inference engine to use for your model. Defaults to None. + If not provided, the engine is inferred based on the task. If no task is provided, + the Python engine is used. djl_version (str): DJL Serving version you want to use for serving your model for inference. Defaults to None. If not provided, the latest available version of DJL Serving is used. This is not used if ``image_uri`` is provided. + djl_framework (str): The DJL container to use. This is used along with djl_version + to fetch the image_uri of the djl inference container. This is not used if + ``image_uri`` is provided. task (str): The HuggingFace/NLP task you want to launch this model for. Defaults to None. If not provided, the task will be inferred from the model architecture by DJL. - dtype (str): The data type to use for loading your model. Accepted values are - "fp32", "fp16", "bf16", "int8". Defaults to "fp32". - number_of_partitions (int): The number of GPUs to partition the model across. The - partitioning strategy is determined by the selected backend. If DeepSpeed is - selected, this is tensor parallelism. - If HuggingFace Accelerate is selected, this is a naive sharding strategy - that splits the model layers across the available resources. Defaults to None. If - not provided, no model partitioning is done. + tensor_parallel_degree (int): The number of accelerators to partition the model across + using tensor parallelism. Defaults to None. If not provided, the maximum number + of available accelerators will be used. min_workers (int): The minimum number of worker processes. Defaults to None. If not provided, dJL Serving will automatically detect the minimum workers. max_workers (int): The maximum number of worker processes. Defaults to None. If not @@ -354,58 +97,26 @@ def __init__( None. If not provided, the default is 240 seconds. prediction_timeout (int): The worker predict call (handler) timeout in seconds. Defaults to None. If not provided, the default is 120 seconds. - entry_point (str): This can either be the absolute or relative path to the Python source - file that should be executed as the entry point to model - hosting, or a python module that is installed in the container. If ``source_dir`` - is specified, then ``entry_point`` - must point to a file located at the root of ``source_dir``. Defaults to None. - image_uri (str): A docker image URI. Defaults to None. If not specified, a default - image for DJL Serving will be used based on ``djl_version``. If ``djl_version`` - is not specified, the latest available container version will be used. predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. + huggingface_hub_token (str): The HuggingFace Hub token to use for downloading the model + artifacts for a model stored on the huggingface hub. + Defaults to None. If not provided, the token must be specified in the + HF_TOKEN environment variable in the env parameter. **kwargs: Keyword arguments passed to the superclass :class:`~sagemaker.model.FrameworkModel` and, subsequently, its superclass :class:`~sagemaker.model.Model`. - - .. tip:: - - Instantiating a DJLModel will return an instance of either - :class:`~sagemaker.djl_inference.DeepSpeedModel` or - :class:`~sagemaker.djl_inference.HuggingFaceAccelerateModel` based on our framework - recommendation for the model type. - - If you want to use a specific framework to deploy your model with, we recommend - instantiating that specific - model class directly. The available framework specific classes are - :class:`~sagemaker.djl_inference.DeepSpeedModel` or - :class:`~sagemaker.djl_inference.HuggingFaceAccelerateModel` """ - if "hf_hub_token" in kwargs: - kwargs.pop("hf_hub_token") - if kwargs.get("model_data"): - logger.warning( - "DJLModels do not use model_data parameter. model_data parameter will be ignored." - "You only need to set model_id and ensure it points to uncompressed model " - "artifacts in s3, or a valid HuggingFace Hub model_id." - ) - data_type = kwargs.pop("data_type", None) - if data_type: - logger.warning( - "data_type is being deprecated in favor of dtype. Please migrate use of data_type" - " to dtype. Support for data_type will be removed in a future release" - ) - dtype = dtype or data_type - super(DJLModel, self).__init__( - None, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs - ) + super(DJLModel, self).__init__(predictor_cls=predictor_cls, **kwargs) self.model_id = model_id self.djl_version = djl_version + self.djl_framework = djl_framework + self.engine = engine self.task = task self.dtype = dtype - self.number_of_partitions = number_of_partitions + self.tensor_parallel_degree = tensor_parallel_degree self.min_workers = min_workers self.max_workers = max_workers self.job_queue_size = job_queue_size @@ -413,7 +124,85 @@ def __init__( self.model_loading_timeout = model_loading_timeout self.prediction_timeout = prediction_timeout self.sagemaker_session = self.sagemaker_session or Session() - self.save_mp_checkpoint_path = None + self.hub_token = huggingface_hub_token + self._initialize_model() + + def _initialize_model(self): + """Placeholder docstring""" + self._validate_model_artifacts() + self.engine = self._infer_engine() + self.env = self._configure_environment_variables() + self.image_uri = self._infer_image_uri() + + def _validate_model_artifacts(self): + """Placeholder docstring""" + if self.model_id is not None and self.model_data is not None: + raise ValueError( + "both model_id and model_data are provided. Please only provide one of them" + ) + + def _infer_engine(self) -> Optional[str]: + """Placeholder docstring""" + if self.engine is not None: + logger.info("Using provided engine %s", self.engine) + return self.engine + + if self.task == "text-embedding": + return "OnnxRuntime" + return "Python" + + def _infer_image_uri(self): + """Placeholder docstring""" + if self.image_uri is not None: + return self.image_uri + if self.djl_framework is None: + self.djl_framework = "djl-lmi" + return image_uris.retrieve( + framework=self.djl_framework, + region=self.sagemaker_session.boto_region_name, + version=self.djl_version, + ) + + def _configure_environment_variables(self) -> Dict[str, str]: + """Placeholder docstring""" + env = self.env.copy() if self.env else {} + env = _set_env_var_from_property(self.model_id, "HF_MODEL_ID", env) + env = _set_env_var_from_property(self.task, "HF_TASK", env) + env = _set_env_var_from_property(self.dtype, "OPTION_DTYPE", env) + env = _set_env_var_from_property(self.min_workers, "SERVING_MIN_WORKERS", env) + env = _set_env_var_from_property(self.max_workers, "SERVING_MAX_WORKERS", env) + env = _set_env_var_from_property(self.job_queue_size, "SERVING_JOB_QUEUE_SIZE", env) + env = _set_env_var_from_property(self.parallel_loading, "OPTION_PARALLEL_LOADING", env) + env = _set_env_var_from_property( + self.model_loading_timeout, "OPTION_MODEL_LOADING_TIMEOUT", env + ) + env = _set_env_var_from_property(self.prediction_timeout, "OPTION_PREDICT_TIMEOUT", env) + env = _set_env_var_from_property(self.hub_token, "HF_TOKEN", env) + env = _set_env_var_from_property(self.engine, "OPTION_ENGINE", env) + if "TENSOR_PARALLEL_DEGREE" not in env or "OPTION_TENSOR_PARALLEL_DEGREE" not in env: + if self.tensor_parallel_degree is not None: + env["TENSOR_PARALLEL_DEGREE"] = str(self.tensor_parallel_degree) + return env + + def serving_image_uri( + self, + region_name, + instance_type=None, + accelerator_type=None, + serverless_inference_config=None, + ): + """Placeholder docstring""" + if self.image_uri: + return self.image_uri + return image_uris.retrieve( + framework=self.djl_framework, + region=region_name, + version=self.djl_version, + instance_type=instance_type, + accelerator_type=accelerator_type, + image_scope="inference", + serverless_inference_config=serverless_inference_config, + ) def package_for_edge(self, **_): """Not implemented. @@ -460,791 +249,3 @@ def right_size(self, **_): raise NotImplementedError( "DJLModels do not currently support Inference Recommendation Jobs" ) - - def partition( - self, - instance_type: str, - s3_output_uri: str = None, - s3_output_prefix: str = "aot-partitioned-checkpoints", - job_name: Optional[str] = None, - volume_size: int = 30, - volume_kms_key: Optional[str] = None, - output_kms_key: Optional[str] = None, - use_spot_instances: bool = False, - max_wait: int = None, - enable_network_isolation: bool = False, - ): - """Partitions the model using SageMaker Training Job. This is a synchronous API call. - - Args: - instance_type (str): The EC2 instance type to partition this Model. - For example, 'ml.p4d.24xlarge'. - s3_output_uri (str): S3 location for saving the training result (model - artifacts and output files). If not specified, results are - stored to a default bucket. If the bucket with the specific name - does not exist, it will be created. - s3_output_prefix (str): Name of the prefix where all the partitioned - checkpoints to be uploaded. If not provided, the default value is - aot-partitioned-checkpoints. - job_name (str): Training job name. If not specified, a unique training job - name will be created. - volume_size (int): Size in GB of the storage volume to use for - storing input and output data during training (default: 30). - volume_kms_key (str): Optional. KMS key ID for encrypting EBS - volume attached to the training instance (default: None). - output_kms_key (str): Optional. KMS key ID for encrypting the - training output (default: None). - use_spot_instances (bool): Specifies whether to use SageMaker - Managed Spot instances for training. If enabled then the - ``max_wait`` arg should also be set. - - More information: - https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html - (default: ``False``). - max_wait (int): Timeout in seconds waiting for spot training - job (default: None). After this amount of time Amazon - SageMaker will stop waiting for managed spot training job to - complete (default: None). - enable_network_isolation (bool): Specifies whether container will - run in network isolation mode (default: ``False``). Network - isolation mode restricts the container access to outside networks - (such as the Internet). The container does not make any inbound or - outbound network calls. Also known as Internet-free mode. - Returns: - None - """ - - if not self.image_uri: - region_name = self.sagemaker_session.boto_session.region_name - self.image_uri = self.serving_image_uri(region_name) - - if s3_output_uri is None: - deploy_key_prefix = fw_utils.model_code_key_prefix( - self.key_prefix, self.name, self.image_uri - ) - - bucket, deploy_key_prefix = s3.determine_bucket_and_prefix( - bucket=self.bucket, - key_prefix=deploy_key_prefix, - sagemaker_session=self.sagemaker_session, - ) - s3_output_uri = s3_path_join("s3://", bucket, deploy_key_prefix) - - self.save_mp_checkpoint_path = s3_path_join(s3_output_uri, s3_output_prefix) - - container_def = self._upload_model_to_s3(upload_as_tar=False) - estimator = _create_estimator( - instance_type=instance_type, - s3_output_uri=s3_output_uri, - image_uri=self.image_uri, - role=self.role, - sagemaker_session=self.sagemaker_session, - volume_size=volume_size, - vpc_config=self.vpc_config, - volume_kms_key=volume_kms_key, - output_kms_key=output_kms_key, - use_spot_instances=use_spot_instances, - max_wait=max_wait, - enable_network_isolation=enable_network_isolation, - ) - - # creates a training job to do partitions - estimator.fit( - inputs=container_def["ModelDataUrl"], - wait=True, - logs="All", - job_name=job_name, - experiment_config=None, - ) - - self.model_id = self.save_mp_checkpoint_path - # reset save_mp_checkpoint_path since partition is completed. - self.save_mp_checkpoint_path = None - - def deploy( - self, - instance_type, - initial_instance_count=1, - serializer=None, - deserializer=None, - endpoint_name=None, - tags=None, - kms_key=None, - wait=True, - data_capture_config=None, - volume_size=None, - model_data_download_timeout=None, - container_startup_health_check_timeout=None, - **kwargs, - ): - """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. - - Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an - ``Endpoint`` from this ``Model``. If ``self.predictor_cls`` is not None, - this method returns the result of invoking ``self.predictor_cls`` on - the created endpoint name. - - The name of the created model is accessible in the ``name`` field of - this ``Model`` after deploy returns - - The name of the created endpoint is accessible in the - ``endpoint_name`` field of this ``Model`` after deploy returns. - - Args: - instance_type (str): The EC2 instance type to deploy this Model to. - For example, 'ml.p4d.24xlarge'. - initial_instance_count (int): The initial number of instances to run - in the ``Endpoint`` created from this ``Model``. It needs to be at least 1 ( - default: 1) - serializer (:class:`~sagemaker.serializers.BaseSerializer`): A - serializer object, used to encode data for an inference endpoint - (default: None). If ``serializer`` is not None, then - ``serializer`` will override the default serializer. The - default serializer is set by the ``predictor_cls``. - deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A - deserializer object, used to decode data from an inference - endpoint (default: None). If ``deserializer`` is not None, then - ``deserializer`` will override the default deserializer. The - default deserializer is set by the ``predictor_cls``. - endpoint_name (str): The name of the endpoint to create (default: - None). If not specified, a unique endpoint name will be created. - tags (Optional[Tags]): The list of tags to attach to this - specific endpoint. - kms_key (str): The ARN of the KMS key that is used to encrypt the - data on the storage volume attached to the instance hosting the - endpoint. - wait (bool): Whether the call should wait until the deployment of - this model completes (default: True). - data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies - configuration related to Endpoint data capture for use with - Amazon SageMaker Model Monitoring. Default: None. - volume_size (int): The size, in GB, of the ML storage volume attached to individual - inference instance associated with the production variant. Currenly only Amazon EBS - gp2 storage volumes are supported. - model_data_download_timeout (int): The timeout value, in seconds, to download and - extract model data from Amazon S3 to the individual inference instance associated - with this production variant. - container_startup_health_check_timeout (int): The timeout value, in seconds, for your - inference container to pass health check by SageMaker Hosting. For more information - about health check see: - https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code - .html#your-algorithms-inference-algo-ping-requests - - Returns: - callable[string, sagemaker.session.Session] or None: Invocation of - ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` - is not None. Otherwise, return None. - """ - - instance_family = instance_type.rsplit(".", 1)[0] - if instance_family not in defaults.ALLOWED_INSTANCE_FAMILIES: - raise ValueError( - f"Invalid instance type. DJLModels only support deployment to instances" - f"with GPUs. Supported instance families are {defaults.ALLOWED_INSTANCE_FAMILIES}" - ) - - return super(DJLModel, self).deploy( - initial_instance_count=initial_instance_count, - instance_type=instance_type, - serializer=serializer, - deserializer=deserializer, - endpoint_name=endpoint_name, - tags=format_tags(tags), - kms_key=kms_key, - wait=wait, - data_capture_config=data_capture_config, - volume_size=volume_size, - model_data_download_timeout=model_data_download_timeout, - container_startup_health_check_timeout=container_startup_health_check_timeout, - **kwargs, - ) - - def _upload_model_to_s3(self, upload_as_tar: bool = True): - """Placeholder docstring""" - - if not self.image_uri: - region_name = self.sagemaker_session.boto_session.region_name - self.image_uri = self.serving_image_uri(region_name) - - environment = self._get_container_env() - - local_download_dir = ( - None - if self.sagemaker_session.settings is None - or self.sagemaker_session.settings.local_download_dir is None - else self.sagemaker_session.settings.local_download_dir - ) - with _tmpdir(directory=local_download_dir) as tmp: - if self.source_dir or self.entry_point: - # Below method downloads from s3, or moves local files to tmp/code - _create_or_update_code_dir( - tmp, - self.entry_point, - self.source_dir, - self.dependencies, - self.sagemaker_session, - tmp, - ) - tmp_code_dir = os.path.join(tmp, "code") - existing_serving_properties = _read_existing_serving_properties(tmp_code_dir) - kwargs_serving_properties = self.generate_serving_properties() - existing_serving_properties.update(kwargs_serving_properties) - - if not os.path.exists(tmp_code_dir): - os.mkdir(tmp_code_dir) - with open(os.path.join(tmp_code_dir, "serving.properties"), "w+") as f: - for key, val in existing_serving_properties.items(): - f.write(f"{key}={val}\n") - - deploy_key_prefix = fw_utils.model_code_key_prefix( - self.key_prefix, self.name, self.image_uri - ) - bucket, deploy_key_prefix = s3.determine_bucket_and_prefix( - bucket=self.bucket, - key_prefix=deploy_key_prefix, - sagemaker_session=self.sagemaker_session, - ) - if upload_as_tar: - uploaded_code = fw_utils.tar_and_upload_dir( - self.sagemaker_session.boto_session, - bucket, - deploy_key_prefix, - self.entry_point, - directory=tmp_code_dir, - dependencies=self.dependencies, - kms_key=self.model_kms_key, - ) - model_data_url = uploaded_code.s3_prefix - else: - model_data_url = S3Uploader.upload( - tmp_code_dir, - s3_path_join("s3://", bucket, deploy_key_prefix, "aot-model"), - self.model_kms_key, - self.sagemaker_session, - ) - return sagemaker.container_def( - self.image_uri, model_data_url=model_data_url, env=environment - ) - - def prepare_container_def( - self, - instance_type=None, - accelerator_type=None, - serverless_inference_config=None, - accept_eula=None, - ): # pylint: disable=unused-argument - """A container definition with framework configuration set in model environment variables. - - Returns: - dict[str, str]: A container definition object usable with the - CreateModel API. - """ - - if not self.model_data and not isinstance(self.model_data, dict): - return self._upload_model_to_s3(upload_as_tar=True) - return super().prepare_container_def( - instance_type, accelerator_type, serverless_inference_config - ) - - def generate_serving_properties(self, serving_properties=None) -> Dict[str, str]: - """Generates the DJL Serving configuration to use for the model. - - The configuration is generated using the arguments passed to the Model during - initialization. If a serving.properties file is found in ``self.source_dir``, - those configuration as merged with the Model parameters, with Model parameters taking - priority. - - Args: - serving_properties: Dictionary containing existing model server configuration - obtained from ``self.source_dir``. Defaults to None. - - Returns: - dict: The model server configuration to use when deploying this model to SageMaker. - """ - if not serving_properties: - serving_properties = {} - serving_properties["engine"] = self.engine.value[0] # pylint: disable=E1101 - serving_properties["option.entryPoint"] = self.engine.value[1] # pylint: disable=E1101 - serving_properties["option.model_id"] = self.model_id - if self.number_of_partitions: - serving_properties["option.tensor_parallel_degree"] = self.number_of_partitions - if self.entry_point: - serving_properties["option.entryPoint"] = self.entry_point - if self.task: - serving_properties["option.task"] = self.task - if self.dtype: - serving_properties["option.dtype"] = self.dtype - if self.min_workers: - serving_properties["minWorkers"] = self.min_workers - if self.max_workers: - serving_properties["maxWorkers"] = self.max_workers - if self.job_queue_size: - serving_properties["job_queue_size"] = self.job_queue_size - if self.parallel_loading: - serving_properties["option.parallel_loading"] = self.parallel_loading - if self.model_loading_timeout: - serving_properties["option.model_loading_timeout"] = self.model_loading_timeout - if self.prediction_timeout: - serving_properties["option.prediction_timeout"] = self.prediction_timeout - if self.save_mp_checkpoint_path: - serving_properties["option.save_mp_checkpoint_path"] = self.save_mp_checkpoint_path - return serving_properties - - def serving_image_uri(self, region_name): - """Create a URI for the serving image. - - Args: - region_name (str): AWS region where the image is uploaded. - - Returns: - str: The appropriate image URI based on the given parameters. - """ - if not self.djl_version: - self.djl_version = "0.24.0" - - return image_uris.retrieve( - self._framework(), - region_name, - version=self.djl_version, - ) - - def _get_container_env(self): - """Placeholder docstring""" - - if not self.container_log_level: - return self.env - - if self.container_log_level not in _LOG_LEVEL_MAP: - logger.warning("Ignoring invalid container log level: %s", self.container_log_level) - return self.env - - self.env["SERVING_OPTS"] = ( - f'"-Dai.djl.logging.level={_LOG_LEVEL_MAP[self.container_log_level]}"' - ) - return self.env - - -class DeepSpeedModel(DJLModel): - """A DJL DeepSpeed SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``""" - - _framework_name = "djl-deepspeed" - - def __init__( - self, - model_id: str, - role: str, - tensor_parallel_degree: Optional[int] = None, - max_tokens: Optional[int] = None, - low_cpu_mem_usage: bool = False, - enable_cuda_graph: bool = False, - triangular_masking: bool = True, - return_tuple: bool = True, - **kwargs, - ): - """Initialize a DeepSpeedModel - - Args: - model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location - containing the uncompressed model artifacts (i.e. not a tar.gz file). - The model artifacts are expected to be in HuggingFace pre-trained model - format (i.e. model should be loadable from the huggingface transformers - from_pretrained api, and should also include tokenizer configs if applicable). - role (str): An AWS IAM role specified with either the name or full ARN. The Amazon - SageMaker training jobs and APIs that create Amazon SageMaker - endpoints use this role to access model artifacts. After the endpoint is created, - the inference code - might use the IAM role, if it needs to access an AWS resource. - tensor_parallel_degree (int): The number of gpus to shard a single instance of the - model across via tensor_parallelism. This should be set to greater than 1 if the - size of the model is larger than the memory available on a single GPU on the - instance. Defaults to None. If not set, no tensor parallel sharding is done. - max_tokens (int): The maximum number of tokens (input + output tokens) the DeepSpeed - engine is configured for. Defaults to None. If not set, the DeepSpeed default of - 1024 is used. - low_cpu_mem_usage (bool): Whether to limit CPU memory usage to 1x model size during - model loading. This is an experimental feature in HuggingFace. This is useful when - loading multiple instances of your model in parallel. Defaults to False. - enable_cuda_graph (bool): Whether to enable CUDA graph replay to accelerate inference - passes. This cannot be used with tensor parallelism greater than 1. - Defaults to False. - triangular_masking (bool): Whether to use triangular attention mask. This is - application specific. Defaults to True. - return_tuple (bool): Whether the transformer layers need to return a tuple or a - Tensor. Defaults to True. - **kwargs: Keyword arguments passed to the superclasses - :class:`~sagemaker.djl_inference.DJLModel`, - :class:`~sagemaker.model.FrameworkModel`, and - :class:`~sagemaker.model.Model` - - .. tip:: - - You can find additional parameters for initializing this class at - :class:`~sagemaker.djl_inference.DJLModel`, - :class:`~sagemaker.model.FrameworkModel`, and - :class:`~sagemaker.model.Model`. - """ - if "hf_hub_token" in kwargs: - kwargs.pop("hf_hub_token") - super(DeepSpeedModel, self).__init__( - model_id, - role, - **kwargs, - ) - if self.number_of_partitions and tensor_parallel_degree: - logger.warning( - "Both number_of_partitions and tensor_parallel_degree have been set for " - "DeepSpeedModel." - "These mean the same thing for DeepSpeedModel. Please only set " - "tensor_parallel_degree." - "number_of_partitions will be ignored" - ) - self.number_of_partitions = tensor_parallel_degree or self.number_of_partitions - self.max_tokens = max_tokens - self.low_cpu_mem_usage = low_cpu_mem_usage - self.enable_cuda_graph = enable_cuda_graph - self.triangular_masking = triangular_masking - self.return_tuple = return_tuple - self.save_mp_checkpoint_path = None - self.checkpoint = None - - def generate_serving_properties(self, serving_properties=None) -> Dict[str, Any]: - """Generates the DJL Serving configuration to use for the model. - - The configuration is generated using the arguments passed to the Model during - initialization. If a serving.properties file is found in ``self.source_dir``, - those configuration as merged with the Model parameters, with Model parameters taking - priority. - - Args: - serving_properties: Dictionary containing existing model server configuration - obtained from ``self.source_dir``. Defaults to None. - - Returns: - dict: The model server configuration to use when deploying this model to SageMaker. - """ - - serving_properties = super(DeepSpeedModel, self).generate_serving_properties( - serving_properties=serving_properties - ) - if self.max_tokens: - serving_properties["option.max_tokens"] = self.max_tokens - if self.low_cpu_mem_usage: - serving_properties["option.low_cpu_mem_usage"] = self.low_cpu_mem_usage - if self.enable_cuda_graph: - if self.number_of_partitions > 1: - raise ValueError( - "enable_cuda_graph is not supported when tensor_parallel_degree > 1" - ) - serving_properties["option.enable_cuda_graph"] = self.enable_cuda_graph - if self.triangular_masking: - serving_properties["option.triangular_masking"] = self.triangular_masking - if self.return_tuple: - serving_properties["option.return_tuple"] = self.return_tuple - if self.save_mp_checkpoint_path: - serving_properties["option.save_mp_checkpoint_path"] = self.save_mp_checkpoint_path - if self.checkpoint: - serving_properties["option.checkpoint"] = self.checkpoint - - return serving_properties - - def partition( - self, - instance_type: str, - s3_output_uri: str = None, - s3_output_prefix: str = "aot-partitioned-checkpoints", - job_name: Optional[str] = None, - volume_size: int = 30, - volume_kms_key: Optional[str] = None, - output_kms_key: Optional[str] = None, - use_spot_instances: bool = False, - max_wait: int = None, - enable_network_isolation: bool = False, - ): - """Partitions the model using SageMaker Training Job. This is a synchronous API call. - - Args: - instance_type (str): The EC2 instance type to partition this Model. - For example, 'ml.p4d.24xlarge'. - s3_output_uri (str): S3 location for saving the training result (model - artifacts and output files). If not specified, results are - stored to a default bucket. If the bucket with the specific name - does not exist, it will be created. - s3_output_prefix (str): Name of the prefix where all the partitioned - checkpoints to be uploaded. If not provided, the default value is - aot-partitioned-checkpoints. - job_name (str): Training job name. If not specified, a unique training job - name will be created. - volume_size (int): Size in GB of the storage volume to use for - storing input and output data during training (default: 30). - volume_kms_key (str): Optional. KMS key ID for encrypting EBS - volume attached to the training instance (default: None). - output_kms_key (str): Optional. KMS key ID for encrypting the - training output (default: None). - use_spot_instances (bool): Specifies whether to use SageMaker - Managed Spot instances for training. If enabled then the - ``max_wait`` arg should also be set. - - More information: - https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html - (default: ``False``). - max_wait (int): Timeout in seconds waiting for spot training - job (default: None). After this amount of time Amazon - SageMaker will stop waiting for managed spot training job to - complete (default: None). - enable_network_isolation (bool): Specifies whether container will - run in network isolation mode (default: ``False``). Network - isolation mode restricts the container access to outside networks - (such as the Internet). The container does not make any inbound or - outbound network calls. Also known as Internet-free mode. - Returns: - None - """ - - super(DeepSpeedModel, self).partition( - instance_type, - s3_output_uri, - s3_output_prefix=s3_output_prefix, - job_name=job_name, - volume_size=volume_size, - volume_kms_key=volume_kms_key, - output_kms_key=output_kms_key, - use_spot_instances=use_spot_instances, - max_wait=max_wait, - enable_network_isolation=enable_network_isolation, - ) - - self.checkpoint = "ds_inference_config.json" - - -class HuggingFaceAccelerateModel(DJLModel): - """A DJL Hugging Face SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``.""" - - _framework_name = "djl-deepspeed" - - def __init__( - self, - model_id: str, - role: str, - number_of_partitions: Optional[int] = None, - device_id: Optional[int] = None, - device_map: Optional[Union[str, Dict[str, str]]] = None, - load_in_8bit: bool = False, - low_cpu_mem_usage: bool = False, - **kwargs, - ): - """Initialize a HuggingFaceAccelerateModel. - - Args: - model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location - containing the uncompressed model artifacts (i.e. not a tar.gz file). - The model artifacts are expected to be in HuggingFace pre-trained model - format (i.e. model should be loadable from the huggingface transformers - from_pretrained api, and should also include tokenizer configs if applicable). - role (str): An AWS IAM role specified with either the name or full ARN. The Amazon - SageMaker training jobs and APIs that create Amazon SageMaker - endpoints use this role to access model artifacts. After the endpoint is created, - the inference code - might use the IAM role, if it needs to access an AWS resource. - number_of_partitions (int): The number of GPUs to partition the model across. The - partitioning strategy is determined by the device_map setting. If device_map is - not specified, the default HuggingFace strategy will be used. - device_id (int): The device_id to use for instantiating the model. If provided, - the model will only be instantiated once on the indicated device. Do not set this - if you have also specified data_parallel_degree. Defaults to None. - device_map (str or dict): The HuggingFace accelerate device_map to use. Defaults to - None. - load_in_8bit (bool): Whether to load the model in int8 precision using bits and bytes - quantization. This is only supported for select model architectures. - Defaults to False. If ``dtype`` is int8, then this is set to True. - low_cpu_mem_usage (bool): Whether to limit CPU memory usage to 1x model size during - model loading. This is an experimental feature in HuggingFace. This is useful when - loading multiple instances of your model in parallel. Defaults to False. - **kwargs: Keyword arguments passed to the superclasses - :class:`~sagemaker.djl_inference.DJLModel`, - :class:`~sagemaker.model.FrameworkModel`, and - :class:`~sagemaker.model.Model` - - .. tip:: - - You can find additional parameters for initializing this class at - :class:`~sagemaker.djl_inference.DJLModel`, - :class:`~sagemaker.model.FrameworkModel`, and - :class:`~sagemaker.model.Model`. - """ - if "hf_hub_token" in kwargs: - kwargs.pop("hf_hub_token") - super(HuggingFaceAccelerateModel, self).__init__( - model_id, - role, - number_of_partitions=number_of_partitions, - **kwargs, - ) - self.device_id = device_id - self.device_map = device_map - self.load_in_8bit = load_in_8bit - self.low_cpu_mem_usage = low_cpu_mem_usage - - def generate_serving_properties(self, serving_properties=None) -> Dict[str, str]: - """Generates the DJL Serving configuration to use for the model. - - The configuration is generated using the arguments passed to the Model during - initialization. If a serving.properties file is found in ``self.source_dir``, - those configuration as merged with the Model parameters, with Model parameters taking - priority. - - Args: - serving_properties: Dictionary containing existing model server configuration - obtained from ``self.source_dir``. Defaults to None. - - Returns: - dict: The model server configuration to use when deploying this model to SageMaker. - """ - serving_properties = super(HuggingFaceAccelerateModel, self).generate_serving_properties( - serving_properties=serving_properties - ) - if self.device_id: - if self.number_of_partitions > 1: - raise ValueError("device_id cannot be set when number_of_partitions is > 1") - serving_properties["option.device_id"] = self.device_id - if self.device_map: - serving_properties["option.device_map"] = self.device_map - if self.load_in_8bit: - if self.dtype != "int8": - raise ValueError("Set dtype='int8' to use load_in_8bit") - serving_properties["option.load_in_8bit"] = self.load_in_8bit - if self.dtype == "int8": - serving_properties["option.load_in_8bit"] = True - if self.low_cpu_mem_usage: - serving_properties["option.low_cpu_mem_usage"] = self.low_cpu_mem_usage - # This is a workaround due to a bug in our built in handler for huggingface - # TODO: Remove this logic whenever 0.20.0 image is out of service - if ( - serving_properties["option.entryPoint"] == "djl_python.huggingface" - and self.dtype - and self.dtype != "auto" - and self.djl_version - and int(self.djl_version.split(".")[1]) < 21 - ): - serving_properties["option.dtype"] = "auto" - serving_properties.pop("option.load_in_8bit", None) - return serving_properties - - def partition( - self, - instance_type: str, - s3_output_uri: str = None, - s3_output_prefix: str = "aot-partitioned-checkpoints", - job_name: Optional[str] = None, - volume_size: int = 30, - volume_kms_key: Optional[str] = None, - output_kms_key: Optional[str] = None, - use_spot_instances: bool = False, - max_wait: int = None, - enable_network_isolation: bool = False, - ): - """Partitions the model using SageMaker Training Job. This is a synchronous API call. - - Args: - instance_type (str): The EC2 instance type to partition this Model. - For example, 'ml.p4d.24xlarge'. - s3_output_uri (str): S3 location for saving the training result (model - artifacts and output files). If not specified, results are - stored to a default bucket. If the bucket with the specific name - does not exist, it will be created. - s3_output_prefix (str): Name of the prefix where all the partitioned - checkpoints to be uploaded. If not provided, the default value is - aot-partitioned-checkpoints. - job_name (str): Training job name. If not specified, a unique training job - name will be created. - volume_size (int): Size in GB of the storage volume to use for - storing input and output data during training (default: 30). - volume_kms_key (str): Optional. KMS key ID for encrypting EBS - volume attached to the training instance (default: None). - output_kms_key (str): Optional. KMS key ID for encrypting the - training output (default: None). - use_spot_instances (bool): Specifies whether to use SageMaker - Managed Spot instances for training. If enabled then the - ``max_wait`` arg should also be set. - - More information: - https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html - (default: ``False``). - max_wait (int): Timeout in seconds waiting for spot training - job (default: None). After this amount of time Amazon - SageMaker will stop waiting for managed spot training job to - complete (default: None). - enable_network_isolation (bool): Specifies whether container will - run in network isolation mode (default: ``False``). Network - isolation mode restricts the container access to outside networks - (such as the Internet). The container does not make any inbound or - outbound network calls. Also known as Internet-free mode. - Returns: - None - """ - - logger.warning( - "HuggingFace engine does not currently support tensor parallelism. " - "Hence ahead of time partitioning is skipped" - ) - - -class FasterTransformerModel(DJLModel): - """A DJL FasterTransformer SageMaker ``Model`` - - This can be deployed to a SageMaker ``Endpoint``. - """ - - _framework_name = "djl-fastertransformer" - - def __init__( - self, - model_id: str, - role: str, - tensor_parallel_degree: Optional[int] = None, - **kwargs, - ): - """Initialize a FasterTransformerModel. - - Args: - model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location - containing the uncompressed model artifacts (i.e. not a tar.gz file). - The model artifacts are expected to be in HuggingFace pre-trained model - format (i.e. model should be loadable from the huggingface transformers - from_pretrained api, and should also include tokenizer configs if applicable). - role (str): An AWS IAM role specified with either the name or full ARN. The Amazon - SageMaker training jobs and APIs that create Amazon SageMaker - endpoints use this role to access model artifacts. After the endpoint is created, - the inference code - might use the IAM role, if it needs to access an AWS resource. - tensor_parllel_degree (int): The number of gpus to shard a single instance of the - model across via tensor_parallelism. This should be set to greater than 1 if the - size of the model is larger than the memory available on a single GPU on the - instance. Defaults to None. If not set, no tensor parallel sharding is done. - **kwargs: Keyword arguments passed to the superclasses - :class:`~sagemaker.djl_inference.DJLModel`, - :class:`~sagemaker.model.FrameworkModel`, and - :class:`~sagemaker.model.Model` - - .. tip:: - - You can find additional parameters for initializing this class at - :class:`~sagemaker.djl_inference.DJLModel`, - :class:`~sagemaker.model.FrameworkModel`, and - :class:`~sagemaker.model.Model`. - """ - if "hf_hub_token" in kwargs: - kwargs.pop("hf_hub_token") - super(FasterTransformerModel, self).__init__( - model_id, - role, - **kwargs, - ) - if self.number_of_partitions and tensor_parallel_degree: - logger.warning( - "Both number_of_partitions and tensor_parallel_degree have been set for " - "FasterTransformerModel." - "These mean the same thing for FasterTransformerModel. Please only set " - "tensor_parallel_degree." - "number_of_partitions will be ignored" - ) - self.number_of_partitions = tensor_parallel_degree or self.number_of_partitions diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index 8fa52c3ec8..173266e1a1 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -30,6 +30,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, include_aws_sdk_env_vars: bool = True, @@ -47,6 +48,8 @@ def retrieve_default( retrieve the default environment variables. (Default: None). model_version (str): Optional. The version of the model for which to retrieve the default environment variables. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to + retrieve model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -80,12 +83,13 @@ def retrieve_default( ) return artifacts._retrieve_default_environment_variables( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, - include_aws_sdk_env_vars, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, + include_aws_sdk_env_vars=include_aws_sdk_env_vars, sagemaker_session=sagemaker_session, instance_type=instance_type, script=script, diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 58a5fabc2f..b6af6cf5de 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1724,6 +1724,7 @@ def register( data_input_configuration=None, skip_model_validation=None, source_uri=None, + model_card=None, **kwargs, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -1772,6 +1773,8 @@ def register( skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). source_uri (str): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). **kwargs: Passed to invocation of ``create_model()``. Implementations may customize ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. For more, see the implementation docs. @@ -1817,6 +1820,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) @property diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index 662baecae6..533a427747 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -26,6 +26,10 @@ ) from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer from sagemaker.session import Session @@ -362,6 +366,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -414,6 +419,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -462,6 +469,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) def prepare_container_def( @@ -471,6 +479,7 @@ def prepare_container_def( serverless_inference_config=None, inference_tool=None, accept_eula=None, + model_reference_arn=None, ): """A container definition with framework configuration set in model environment variables. @@ -525,6 +534,7 @@ def prepare_container_def( self.repacked_model_data or self.model_data, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def serving_image_uri( diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 5c22409c50..f1353cc8ff 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -31,6 +31,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, instance_type: Optional[str] = None, include_container_hyperparameters: bool = False, tolerate_vulnerable_model: bool = False, @@ -47,6 +48,8 @@ def retrieve_default( retrieve the default hyperparameters. (Default: None). model_version (str): The version of the model for which to retrieve the default hyperparameters. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). instance_type (str): An instance type to optionally supply in order to get hyperparameters specific for the instance type. include_container_hyperparameters (bool): ``True`` if the container hyperparameters @@ -82,6 +85,7 @@ def retrieve_default( return artifacts._retrieve_default_hyperparameters( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, instance_type=instance_type, region=region, include_container_hyperparameters=include_container_hyperparameters, @@ -95,6 +99,7 @@ def retrieve_default( def validate( region: Optional[str] = None, model_id: Optional[str] = None, + hub_arn: Optional[str] = None, model_version: Optional[str] = None, hyperparameters: Optional[dict] = None, validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, @@ -110,6 +115,8 @@ def validate( (Default: None). model_version (str): The version of the model for which to validate hyperparameters. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). hyperparameters (dict): Hyperparameters to validate. (Default: None). validation_mode (HyperparameterValidationMode): Method of validation to use with @@ -151,6 +158,7 @@ def validate( return validate_hyperparameters( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, hyperparameters=hyperparameters, validation_mode=validation_mode, region=region, diff --git a/src/sagemaker/image_uri_config/pytorch-smp.json b/src/sagemaker/image_uri_config/pytorch-smp.json index 518da5f15d..61971e5128 100644 --- a/src/sagemaker/image_uri_config/pytorch-smp.json +++ b/src/sagemaker/image_uri_config/pytorch-smp.json @@ -7,7 +7,8 @@ "2.0": "2.0.1", "2.1": "2.1.2", "2.2": "2.3.1", - "2.2.0": "2.3.1" + "2.2.0": "2.3.1", + "2.3": "2.4.0" }, "versions": { "2.0.1": { @@ -134,6 +135,31 @@ "us-west-2": "658645717510" }, "repository": "smdistributed-modelparallel" + }, + "2.4.0": { + "py_versions": [ + "py311" + ], + "registries": { + "ap-northeast-1": "658645717510", + "ap-northeast-2": "658645717510", + "ap-northeast-3": "658645717510", + "ap-south-1": "658645717510", + "ap-southeast-1": "658645717510", + "ap-southeast-2": "658645717510", + "ca-central-1": "658645717510", + "eu-central-1": "658645717510", + "eu-north-1": "658645717510", + "eu-west-1": "658645717510", + "eu-west-2": "658645717510", + "eu-west-3": "658645717510", + "sa-east-1": "658645717510", + "us-east-1": "658645717510", + "us-east-2": "658645717510", + "us-west-1": "658645717510", + "us-west-2": "658645717510" + }, + "repository": "smdistributed-modelparallel" } } } diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index e3a2944340..e73eed051b 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -82,7 +82,8 @@ "1.13": "1.13.1", "2.0": "2.0.1", "2.1": "2.1.0", - "2.2": "2.2.0" + "2.2": "2.2.0", + "2.3": "2.3.0" }, "versions": { "0.4.0": { @@ -1054,6 +1055,47 @@ "us-west-2": "763104351884" }, "repository": "pytorch-inference" + }, + "2.3.0": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-inference" } } }, diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 743f6b1f99..65497927e9 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -64,6 +64,7 @@ def retrieve( training_compiler_config=None, model_id=None, model_version=None, + hub_arn=None, tolerate_vulnerable_model=False, tolerate_deprecated_model=False, sdk_version=None, @@ -105,6 +106,8 @@ def retrieve( (default: None). model_version (str): The version of the JumpStart model for which to retrieve the image URI (default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications should be tolerated without an exception raised. If ``False``, raises an exception if the script used by this version of the model has dependencies with known security @@ -151,6 +154,7 @@ def retrieve( model_id, model_version, image_scope, + hub_arn, framework, region, version, @@ -689,6 +693,7 @@ def get_training_image_uri( "p5" in instance_type or "2.1" in framework_version or "2.2" in framework_version + or "2.3" in framework_version ): container_version = "cu121" else: diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index c4af4b2036..1b664fc9ae 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -30,6 +30,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -47,6 +48,8 @@ def retrieve_default( retrieve the default instance type. (Default: None). model_version (str): The version of the model for which to retrieve the default instance type. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -84,6 +87,7 @@ def retrieve_default( model_id, model_version, scope, + hub_arn, region, tolerate_vulnerable_model, tolerate_deprecated_model, @@ -98,6 +102,7 @@ def retrieve( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -113,6 +118,8 @@ def retrieve( retrieve the supported instance types. (Default: None). model_version (str): The version of the model for which to retrieve the supported instance types. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -145,12 +152,13 @@ def retrieve( raise ValueError("Must specify scope for instance types.") return artifacts._retrieve_instance_types( - model_id, - model_version, - scope, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + scope=scope, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, training_instance_type=training_instance_type, ) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 35df030ddc..66003c9f03 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -10,6 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +# pylint: skip-file """This module contains accessors related to SageMaker JumpStart.""" from __future__ import absolute_import import functools @@ -17,10 +18,16 @@ import boto3 from sagemaker.deprecations import deprecated -from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs +from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs, HubContentType from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart import cache +from sagemaker.jumpstart.hub.utils import ( + construct_hub_model_arn_from_inputs, + construct_hub_model_reference_arn_from_inputs, +) from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.session import Session +from sagemaker.jumpstart import constants class SageMakerSettings(object): @@ -253,8 +260,10 @@ def get_model_specs( region: str, model_id: str, version: str, + hub_arn: Optional[str] = None, s3_client: Optional[boto3.client] = None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> JumpStartModelSpecs: """Returns model specs from JumpStart models cache. @@ -270,10 +279,34 @@ def get_model_specs( if s3_client is not None: additional_kwargs.update({"s3_client": s3_client}) + if hub_arn: + additional_kwargs.update({"sagemaker_session": sagemaker_session}) + cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs} ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) + + if hub_arn: + try: + hub_model_arn = construct_hub_model_arn_from_inputs( + hub_arn=hub_arn, model_name=model_id, version=version + ) + model_specs = JumpStartModelsAccessor._cache.get_hub_model( + hub_model_arn=hub_model_arn + ) + model_specs.set_hub_content_type(HubContentType.MODEL) + return model_specs + except: # noqa: E722 + hub_model_arn = construct_hub_model_reference_arn_from_inputs( + hub_arn=hub_arn, model_name=model_id, version=version + ) + model_specs = JumpStartModelsAccessor._cache.get_hub_model_reference( + hub_model_reference_arn=hub_model_arn + ) + model_specs.set_hub_content_type(HubContentType.MODEL_REFERENCE) + return model_specs + return JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, version_str=version, model_type=model_type ) diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index fcb3ce3bf2..f10bfe4a5d 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -32,6 +32,7 @@ def _retrieve_default_environment_variables( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -48,6 +49,8 @@ def _retrieve_default_environment_variables( retrieve the default environment variables. model_version (str): Version of the JumpStart model for which to retrieve the default environment variables. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default environment variables. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -81,6 +84,7 @@ def _retrieve_default_environment_variables( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=script, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -119,6 +123,7 @@ def _retrieve_default_environment_variables( lambda instance_type: _retrieve_gated_model_uri_env_var_value( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -166,6 +171,7 @@ def _retrieve_default_environment_variables( def _retrieve_gated_model_uri_env_var_value( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -180,6 +186,8 @@ def _retrieve_gated_model_uri_env_var_value( retrieve the gated model env var URI. model_version (str): Version of the JumpStart model for which to retrieve the gated model env var URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve the gated model env var URI. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -212,6 +220,7 @@ def _retrieve_gated_model_uri_env_var_value( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -228,4 +237,7 @@ def _retrieve_gated_model_uri_env_var_value( if s3_key is None: return None + if hub_arn: + return s3_key + return f"s3://{get_jumpstart_gated_content_bucket(region)}/{s3_key}" diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index 67db7d260f..4383a17bf9 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -30,6 +30,7 @@ def _retrieve_default_hyperparameters( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, include_container_hyperparameters: bool = False, tolerate_vulnerable_model: bool = False, @@ -45,6 +46,8 @@ def _retrieve_default_hyperparameters( retrieve the default hyperparameters. model_version (str): Version of the JumpStart model for which to retrieve the default hyperparameters. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (str): Region for which to retrieve default hyperparameters. (Default: None). include_container_hyperparameters (bool): True if container hyperparameters @@ -79,6 +82,7 @@ def _retrieve_default_hyperparameters( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 72633320f5..0d4a61d112 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -33,6 +33,7 @@ def _retrieve_image_uri( model_id: str, model_version: str, image_scope: str, + hub_arn: Optional[str] = None, framework: Optional[str] = None, region: Optional[str] = None, version: Optional[str] = None, @@ -58,6 +59,8 @@ def _retrieve_image_uri( model_id (str): JumpStart model ID for which to retrieve image URI. model_version (str): Version of the JumpStart model for which to retrieve the image URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). image_scope (str): The image type, i.e. what it is used for. Valid values: "training", "inference", "eia". If ``accelerator_type`` is set, ``image_scope`` is ignored. @@ -113,6 +116,7 @@ def _retrieve_image_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=image_scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -129,6 +133,10 @@ def _retrieve_image_uri( ) if image_uri is not None: return image_uri + if hub_arn: + ecr_uri = model_specs.hosting_ecr_uri + return ecr_uri + ecr_specs = model_specs.hosting_ecr_specs if ecr_specs is None: raise ValueError( @@ -144,6 +152,10 @@ def _retrieve_image_uri( ) if image_uri is not None: return image_uri + if hub_arn: + ecr_uri = model_specs.training_ecr_uri + return ecr_uri + ecr_specs = model_specs.training_ecr_specs if ecr_specs is None: raise ValueError( @@ -197,6 +209,7 @@ def _retrieve_image_uri( version=version_override or ecr_specs.framework_version, py_version=ecr_specs.py_version, instance_type=instance_type, + hub_arn=hub_arn, accelerator_type=accelerator_type, image_scope=image_scope, container_version=container_version, diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index 8bbe089354..80b5aa8ef5 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -30,6 +30,7 @@ def _model_supports_incremental_training( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -44,6 +45,8 @@ def _model_supports_incremental_training( support status for incremental training. region (Optional[str]): Region for which to retrieve the support status for incremental training. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -67,6 +70,7 @@ def _model_supports_incremental_training( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index f4bf212c1c..25119266cf 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -34,6 +34,7 @@ def _retrieve_default_instance_type( model_id: str, model_version: str, scope: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -51,6 +52,8 @@ def _retrieve_default_instance_type( default instance type. scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default instance type. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -85,6 +88,7 @@ def _retrieve_default_instance_type( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -126,6 +130,7 @@ def _retrieve_instance_types( model_id: str, model_version: str, scope: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -142,6 +147,8 @@ def _retrieve_instance_types( supported instance types. scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve supported instance types. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -176,6 +183,7 @@ def _retrieve_instance_types( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -202,7 +210,7 @@ def _retrieve_instance_types( elif scope == JumpStartScriptScope.TRAINING: if training_instance_type is not None: - raise ValueError("Cannot use `training_instance_type` argument " "with training scope.") + raise ValueError("Cannot use `training_instance_type` argument with training scope.") instance_types = model_specs.supported_training_instance_types else: raise NotImplementedError( diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index ceb88d9b26..eb7980b88f 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -32,6 +32,7 @@ def _retrieve_model_init_kwargs( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -46,6 +47,8 @@ def _retrieve_model_init_kwargs( retrieve the kwargs. model_version (str): Version of the JumpStart model for which to retrieve the kwargs. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -71,6 +74,7 @@ def _retrieve_model_init_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -92,6 +96,7 @@ def _retrieve_model_deploy_kwargs( model_id: str, model_version: str, instance_type: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -108,6 +113,8 @@ def _retrieve_model_deploy_kwargs( kwargs. instance_type (str): Instance type of the hosting endpoint, to determine if volume size is supported. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -134,6 +141,7 @@ def _retrieve_model_deploy_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -153,6 +161,7 @@ def _retrieve_estimator_init_kwargs( model_id: str, model_version: str, instance_type: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -168,6 +177,8 @@ def _retrieve_estimator_init_kwargs( kwargs. instance_type (str): Instance type of the training job, to determine if volume size is supported. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -193,6 +204,7 @@ def _retrieve_estimator_init_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -215,6 +227,7 @@ def _retrieve_estimator_init_kwargs( def _retrieve_estimator_fit_kwargs( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -228,6 +241,8 @@ def _retrieve_estimator_fit_kwargs( retrieve the kwargs. model_version (str): Version of the JumpStart model for which to retrieve the kwargs. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -254,6 +269,7 @@ def _retrieve_estimator_fit_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index f23b66aed4..5e5c0d79a0 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -31,6 +31,7 @@ def _retrieve_default_training_metric_definitions( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -46,6 +47,8 @@ def _retrieve_default_training_metric_definitions( default training metric definitions. region (Optional[str]): Region for which to retrieve default training metric definitions. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -71,6 +74,7 @@ def _retrieve_default_training_metric_definitions( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index 67459519f3..7aa5be7507 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -32,6 +32,7 @@ def _retrieve_model_package_arn( model_version: str, instance_type: Optional[str], region: Optional[str], + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -49,6 +50,8 @@ def _retrieve_model_package_arn( instance_type (Optional[str]): An instance type to optionally supply in order to get an arn specific for the instance type. region (Optional[str]): Region for which to retrieve the model package arn. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). scope (Optional[str]): Scope for which to retrieve the model package arn. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -74,6 +77,7 @@ def _retrieve_model_package_arn( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -120,6 +124,7 @@ def _retrieve_model_package_model_artifact_s3_uri( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -135,6 +140,8 @@ def _retrieve_model_package_model_artifact_s3_uri( model package artifact. region (Optional[str]): Region for which to retrieve the model package artifact. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). scope (Optional[str]): Scope for which to retrieve the model package artifact. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -165,6 +172,7 @@ def _retrieve_model_package_model_artifact_s3_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 00c6d8b9aa..5fac979b14 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -89,6 +89,7 @@ def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_t def _retrieve_model_uri( model_id: str, model_version: str, + hub_arn: Optional[str] = None, model_scope: Optional[str] = None, instance_type: Optional[str] = None, region: Optional[str] = None, @@ -106,6 +107,8 @@ def _retrieve_model_uri( the model artifact S3 URI. model_version (str): Version of the JumpStart model for which to retrieve the model artifact S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). model_scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". instance_type (str): The ML compute instance type for the specified scope. (Default: None). @@ -139,6 +142,7 @@ def _retrieve_model_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=model_scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -153,6 +157,9 @@ def _retrieve_model_uri( is_prepacked = not model_specs.use_inference_script_uri() + if hub_arn: + model_artifact_uri = model_specs.hosting_artifact_uri + return model_artifact_uri model_artifact_key = ( _retrieve_hosting_prepacked_artifact_key(model_specs, instance_type) if is_prepacked @@ -183,6 +190,7 @@ def _model_supports_training_model_uri( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -197,6 +205,8 @@ def _model_supports_training_model_uri( support status for model uri with training. region (Optional[str]): Region for which to retrieve the support status for model uri with training. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -220,6 +230,7 @@ def _model_supports_training_model_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 2f4a8bb0ac..c217495ede 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -33,6 +33,7 @@ def _retrieve_example_payloads( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -48,6 +49,8 @@ def _retrieve_example_payloads( example payloads. region (Optional[str]): Region for which to retrieve the example payloads. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -72,6 +75,7 @@ def _retrieve_example_payloads( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 635f063e05..352a4384f8 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -73,6 +73,7 @@ def _retrieve_deserializer_from_accept_type( def _retrieve_default_deserializer( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -87,6 +88,8 @@ def _retrieve_default_deserializer( retrieve the default deserializer. model_version (str): Version of the JumpStart model for which to retrieve the default deserializer. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default deserializer. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -108,6 +111,7 @@ def _retrieve_default_deserializer( default_accept_type = _retrieve_default_accept_type( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -122,6 +126,7 @@ def _retrieve_default_deserializer( def _retrieve_default_serializer( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -136,6 +141,8 @@ def _retrieve_default_serializer( retrieve the default serializer. model_version (str): Version of the JumpStart model for which to retrieve the default serializer. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default serializer. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -156,6 +163,7 @@ def _retrieve_default_serializer( default_content_type = _retrieve_default_content_type( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -170,6 +178,7 @@ def _retrieve_default_serializer( def _retrieve_deserializer_options( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -184,6 +193,8 @@ def _retrieve_deserializer_options( retrieve the supported deserializers. model_version (str): Version of the JumpStart model for which to retrieve the supported deserializers. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve deserializer options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -204,6 +215,7 @@ def _retrieve_deserializer_options( supported_accept_types = _retrieve_supported_accept_types( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -232,6 +244,7 @@ def _retrieve_deserializer_options( def _retrieve_serializer_options( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -245,6 +258,8 @@ def _retrieve_serializer_options( retrieve the supported serializers. model_version (str): Version of the JumpStart model for which to retrieve the supported serializers. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve serializer options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -265,6 +280,7 @@ def _retrieve_serializer_options( supported_content_types = _retrieve_supported_content_types( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -292,6 +308,7 @@ def _retrieve_serializer_options( def _retrieve_default_content_type( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -306,6 +323,8 @@ def _retrieve_default_content_type( retrieve the default content type. model_version (str): Version of the JumpStart model for which to retrieve the default content type. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default content type. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -330,6 +349,7 @@ def _retrieve_default_content_type( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -346,6 +366,7 @@ def _retrieve_default_content_type( def _retrieve_default_accept_type( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -360,6 +381,8 @@ def _retrieve_default_accept_type( retrieve the default accept type. model_version (str): Version of the JumpStart model for which to retrieve the default accept type. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default accept type. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -384,6 +407,7 @@ def _retrieve_default_accept_type( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -401,6 +425,7 @@ def _retrieve_default_accept_type( def _retrieve_supported_accept_types( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -415,6 +440,8 @@ def _retrieve_supported_accept_types( retrieve the supported accept types. model_version (str): Version of the JumpStart model for which to retrieve the supported accept types. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve accept type options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -439,6 +466,7 @@ def _retrieve_supported_accept_types( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -456,6 +484,7 @@ def _retrieve_supported_accept_types( def _retrieve_supported_content_types( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -470,6 +499,8 @@ def _retrieve_supported_content_types( retrieve the supported content types. model_version (str): Version of the JumpStart model for which to retrieve the supported content types. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve content type options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -494,6 +525,7 @@ def _retrieve_supported_content_types( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index b4fdac770b..8c47750061 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -31,6 +31,7 @@ def _retrieve_resource_name_base( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, @@ -47,6 +48,8 @@ def _retrieve_resource_name_base( default resource name. region (Optional[str]): Region for which to retrieve the default resource name. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -70,6 +73,7 @@ def _retrieve_resource_name_base( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 49126da336..74523be1de 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -48,6 +48,7 @@ def _retrieve_default_resources( model_id: str, model_version: str, scope: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -65,6 +66,8 @@ def _retrieve_default_resources( default resource requirements. scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default resource requirements. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -98,6 +101,7 @@ def _retrieve_default_resources( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index 97313ec626..5029f53cfb 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -32,6 +32,7 @@ def _retrieve_script_uri( model_id: str, model_version: str, + hub_arn: Optional[str] = None, script_scope: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, @@ -48,6 +49,8 @@ def _retrieve_script_uri( retrieve the script S3 URI. model_version (str): Version of the JumpStart model for which to retrieve the model script S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). script_scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". region (str): Region for which to retrieve model script S3 URI. @@ -80,6 +83,7 @@ def _retrieve_script_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=script_scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -107,7 +111,8 @@ def _retrieve_script_uri( def _model_supports_inference_script_uri( model_id: str, model_version: str, - region: Optional[str], + hub_arn: Optional[str] = None, + region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -120,6 +125,8 @@ def _model_supports_inference_script_uri( retrieve the support status for script uri with inference. model_version (str): Version of the JumpStart model for which to retrieve the support status for script uri with inference. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve the support status for script uri with inference. tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -144,6 +151,7 @@ def _model_supports_inference_script_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index e9a34a21a8..257a9e71af 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -15,13 +15,14 @@ import datetime from difflib import get_close_matches import os -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import json import boto3 import botocore from packaging.version import Version from packaging.specifiers import SpecifierSet, InvalidSpecifier from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, @@ -42,16 +43,23 @@ JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, ) from sagemaker.jumpstart.types import ( - JumpStartCachedS3ContentKey, - JumpStartCachedS3ContentValue, + JumpStartCachedContentKey, + JumpStartCachedContentValue, JumpStartModelHeader, JumpStartModelSpecs, JumpStartS3FileType, JumpStartVersionedModelId, + HubContentType, +) +from sagemaker.jumpstart.hub import utils as hub_utils +from sagemaker.jumpstart.hub.interfaces import DescribeHubContentResponse +from sagemaker.jumpstart.hub.parsers import ( + make_model_specs_from_describe_hub_content_response, ) from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart import utils from sagemaker.utilities.cache import LRUCache +from sagemaker.session import Session class JumpStartModelsCache: @@ -77,6 +85,7 @@ def __init__( s3_bucket_name: Optional[str] = None, s3_client_config: Optional[botocore.config.Config] = None, s3_client: Optional[boto3.client] = None, + sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> None: """Initialize a ``JumpStartModelsCache`` instance. @@ -98,13 +107,15 @@ def __init__( s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache. Default: None (no config). s3_client (Optional[boto3.client]): s3 client to use. Default: None. + sagemaker_session: sagemaker session object to use. + Default: session object from default region us-west-2. """ self._region = region or utils.get_region_fallback( s3_bucket_name=s3_bucket_name, s3_client=s3_client ) - self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue]( + self._content_cache = LRUCache[JumpStartCachedContentKey, JumpStartCachedContentValue]( max_cache_items=max_s3_cache_items, expiration_horizon=s3_cache_expiration_horizon, retrieval_function=self._retrieval_function, @@ -139,6 +150,7 @@ def __init__( if s3_client_config else boto3.client("s3", region_name=self._region) ) + self._sagemaker_session = sagemaker_session def set_region(self, region: str) -> None: """Set region for cache. Clears cache after new region is set.""" @@ -230,8 +242,8 @@ def _model_id_retrieval_function( model_id, version = key.model_id, key.version sm_version = utils.get_sagemaker_version() - manifest = self._s3_cache.get( - JumpStartCachedS3ContentKey( + manifest = self._content_cache.get( + JumpStartCachedContentKey( MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type] ) )[0].formatted_content @@ -392,53 +404,96 @@ def _get_json_file_from_local_override( def _retrieval_function( self, - key: JumpStartCachedS3ContentKey, - value: Optional[JumpStartCachedS3ContentValue], - ) -> JumpStartCachedS3ContentValue: - """Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``. + key: JumpStartCachedContentKey, + value: Optional[JumpStartCachedContentValue], + ) -> JumpStartCachedContentValue: + """Return s3 content given a file type and s3_key in ``JumpStartCachedContentKey``. If a manifest file is being fetched, we only download the object if the md5 hash in ``head_object`` does not match the current md5 hash for the stored value. This prevents unnecessarily downloading the full manifest when it hasn't changed. Args: - key (JumpStartCachedS3ContentKey): key for which to fetch s3 content. + key (JumpStartCachedContentKey): key for which to fetch s3 content. value (Optional[JumpStartVersionedModelId]): Current value of old cached s3 content. This is used for the manifest file, so that it is only downloaded when its content changes. """ - file_type, s3_key = key.file_type, key.s3_key - if file_type in { + data_type, id_info = key.data_type, key.id_info + + if data_type in { JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, JumpStartS3FileType.PROPRIETARY_MANIFEST, }: if value is not None and not self._is_local_metadata_mode(): - etag = self._get_json_md5_hash(s3_key) + etag = self._get_json_md5_hash(id_info) if etag == value.md5_hash: return value - formatted_body, etag = self._get_json_file(s3_key, file_type) - return JumpStartCachedS3ContentValue( + formatted_body, etag = self._get_json_file(id_info, data_type) + return JumpStartCachedContentValue( formatted_content=utils.get_formatted_manifest(formatted_body), md5_hash=etag, ) - if file_type in { + if data_type in { JumpStartS3FileType.OPEN_WEIGHT_SPECS, JumpStartS3FileType.PROPRIETARY_SPECS, }: - formatted_body, _ = self._get_json_file(s3_key, file_type) + formatted_body, _ = self._get_json_file(id_info, data_type) model_specs = JumpStartModelSpecs(formatted_body) utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) - return JumpStartCachedS3ContentValue(formatted_content=model_specs) - raise ValueError(self._file_type_error_msg(file_type)) + return JumpStartCachedContentValue(formatted_content=model_specs) + + if data_type == HubContentType.NOTEBOOK: + hub_name, _, notebook_name, notebook_version = hub_utils.get_info_from_hub_resource_arn( + id_info + ) + response: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=hub_name, + hub_content_name=notebook_name, + hub_content_version=notebook_version, + hub_content_type=data_type, + ) + hub_notebook_description = DescribeHubContentResponse(response) + return JumpStartCachedContentValue(formatted_content=hub_notebook_description) + + if data_type in { + HubContentType.MODEL, + HubContentType.MODEL_REFERENCE, + }: + + hub_arn_extracted_info = hub_utils.get_info_from_hub_resource_arn(id_info) + + model_version: str = hub_utils.get_hub_model_version( + hub_model_name=hub_arn_extracted_info.hub_content_name, + hub_model_type=data_type.value, + hub_name=hub_arn_extracted_info.hub_name, + sagemaker_session=self._sagemaker_session, + hub_model_version=hub_arn_extracted_info.hub_content_version, + ) + + hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=hub_arn_extracted_info.hub_name, + hub_content_name=hub_arn_extracted_info.hub_content_name, + hub_content_version=model_version, + hub_content_type=data_type.value, + ) + + model_specs = make_model_specs_from_describe_hub_content_response( + DescribeHubContentResponse(hub_model_description), + ) + + return JumpStartCachedContentValue(formatted_content=model_specs) + + raise ValueError(self._file_type_error_msg(data_type)) def get_manifest( self, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest.""" - manifest_dict = self._s3_cache.get( - JumpStartCachedS3ContentKey( + manifest_dict = self._content_cache.get( + JumpStartCachedContentKey( MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type] ) )[0].formatted_content @@ -525,8 +580,8 @@ def _get_header_impl( JumpStartVersionedModelId(model_id, semantic_version_str) )[0] - manifest = self._s3_cache.get( - JumpStartCachedS3ContentKey( + manifest = self._content_cache.get( + JumpStartCachedContentKey( MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type] ) )[0].formatted_content @@ -556,8 +611,8 @@ def get_specs( """ header = self.get_header(model_id, version_str, model_type) spec_key = header.spec_key - specs, cache_hit = self._s3_cache.get( - JumpStartCachedS3ContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key) + specs, cache_hit = self._content_cache.get( + JumpStartCachedContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key) ) if not cache_hit and "*" in version_str: @@ -566,8 +621,38 @@ def get_specs( ) return specs.formatted_content + def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: + """Return JumpStart-compatible specs for a given Hub model + + Args: + hub_model_arn (str): Arn for the Hub model to get specs for + """ + + details, _ = self._content_cache.get( + JumpStartCachedContentKey( + HubContentType.MODEL, + hub_model_arn, + ) + ) + return details.formatted_content + + def get_hub_model_reference(self, hub_model_reference_arn: str) -> JumpStartModelSpecs: + """Return JumpStart-compatible specs for a given Hub model reference + + Args: + hub_model_arn (str): Arn for the Hub model to get specs for + """ + + details, _ = self._content_cache.get( + JumpStartCachedContentKey( + HubContentType.MODEL_REFERENCE, + hub_model_reference_arn, + ) + ) + return details.formatted_content + def clear(self) -> None: """Clears the model ID/version and s3 cache.""" - self._s3_cache.clear() + self._content_cache.clear() self._open_weight_model_id_manifest_key_cache.clear() self._proprietary_model_id_manifest_key_cache.clear() diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index b94fb2982c..795585204d 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -204,9 +204,14 @@ JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2" NEO_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2" +JUMPSTART_MODEL_HUB_NAME = "SageMakerPublicHub" + JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json" +HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$" +HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" + INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py" TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py" diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index 9666ce828f..a83964e394 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -15,6 +15,7 @@ from __future__ import absolute_import from enum import Enum +from typing import List class ModelFramework(str, Enum): @@ -96,6 +97,8 @@ class JumpStartTag(str, Enum): INFERENCE_CONFIG_NAME = "sagemaker-sdk:jumpstart-inference-config-name" TRAINING_CONFIG_NAME = "sagemaker-sdk:jumpstart-training-config-name" + HUB_CONTENT_ARN = "sagemaker-sdk:hub-content-arn" + class SerializerType(str, Enum): """Enum class for serializers associated with JumpStart models.""" @@ -129,6 +132,28 @@ def from_suffixed_type(mime_type_with_suffix: str) -> "MIMEType": return MIMEType(base_type) +class NamingConventionType(str, Enum): + """Enum class for naming conventions.""" + + SNAKE_CASE = "snake_case" + UPPER_CAMEL_CASE = "upper_camel_case" + DEFAULT = UPPER_CAMEL_CASE + + +class ModelSpecKwargType(str, Enum): + """Enum class for types of kwargs for model hub content document and model specs.""" + + FIT = "fit_kwargs" + MODEL = "model_kwargs" + ESTIMATOR = "estimator_kwargs" + DEPLOY = "deploy_kwargs" + + @classmethod + def arg_keys(cls) -> List[str]: + """Returns a list of kwargs keys that each type can have""" + return [member.value for member in cls] + + class JumpStartConfigRankingName(str, Enum): """Enum class for ranking of JumpStart config.""" diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 5f7e0ed82c..8b30317a52 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -28,6 +28,7 @@ from sagemaker.instance_group import InstanceGroup from sagemaker.jumpstart.accessors import JumpStartModelsAccessor from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.hub.utils import generate_hub_arn_for_init_kwargs from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG @@ -60,6 +61,7 @@ def __init__( self, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_name: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -127,6 +129,7 @@ def __init__( https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html for list of model IDs. model_version (Optional[str]): Version for JumpStart model to use (Default: None). + hub_name (Optional[str]): Hub name or arn where the model is stored (Default: None). tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies @@ -513,6 +516,12 @@ def __init__( ValueError: If the model ID is not recognized by JumpStart. """ + hub_arn = None + if hub_name: + hub_arn = generate_hub_arn_for_init_kwargs( + hub_name=hub_name, region=region, session=sagemaker_session + ) + def _validate_model_id_and_get_type_hook(): return validate_model_id_and_get_type( model_id=model_id, @@ -520,18 +529,20 @@ def _validate_model_id_and_get_type_hook(): region=region or getattr(sagemaker_session, "boto_region_name", None), script=JumpStartScriptScope.TRAINING, sagemaker_session=sagemaker_session, + hub_arn=hub_arn, ) self.model_type = _validate_model_id_and_get_type_hook() if not self.model_type: JumpStartModelsAccessor.reset_cache() self.model_type = _validate_model_id_and_get_type_hook() - if not self.model_type: + if not self.model_type and not hub_arn: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) estimator_init_kwargs = get_init_kwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=self.model_type, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -590,6 +601,7 @@ def _validate_model_id_and_get_type_hook(): enable_session_tag_chaining=enable_session_tag_chaining, ) + self.hub_arn = estimator_init_kwargs.hub_arn self.model_id = estimator_init_kwargs.model_id self.model_version = estimator_init_kwargs.model_version self.instance_type = estimator_init_kwargs.instance_type @@ -668,6 +680,7 @@ def fit( estimator_fit_kwargs = get_fit_kwargs( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, inputs=inputs, wait=wait, @@ -688,6 +701,7 @@ def attach( training_job_name: str, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_channel_name: str = "model", config_name: Optional[str] = None, @@ -758,6 +772,7 @@ def attach( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, region=sagemaker_session.boto_region_name, scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=True, # model is already trained, so tolerate if deprecated @@ -1065,6 +1080,7 @@ def deploy( estimator_deploy_kwargs = get_deploy_kwargs( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_vulnerable_model=self.tolerate_vulnerable_model, tolerate_deprecated_model=self.tolerate_deprecated_model, @@ -1117,6 +1133,7 @@ def deploy( predictor=predictor, model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index c936b2f5eb..8540f53ca4 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -60,6 +60,7 @@ JumpStartModelInitKwargs, ) from sagemaker.jumpstart.utils import ( + add_hub_content_arn_tags, add_jumpstart_model_info_tags, get_eula_message, get_default_jumpstart_session_with_user_agent_suffix, @@ -78,6 +79,7 @@ def get_init_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, @@ -138,6 +140,7 @@ def get_init_kwargs( estimator_init_kwargs: JumpStartEstimatorInitKwargs = JumpStartEstimatorInitKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=model_type, role=role, region=region, @@ -217,6 +220,7 @@ def get_init_kwargs( def get_fit_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None, wait: Optional[bool] = None, @@ -233,6 +237,7 @@ def get_fit_kwargs( estimator_fit_kwargs: JumpStartEstimatorFitKwargs = JumpStartEstimatorFitKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, inputs=inputs, wait=wait, @@ -256,6 +261,7 @@ def get_fit_kwargs( def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -263,6 +269,7 @@ def get_deploy_kwargs( deserializer: Optional[BaseDeserializer] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, + inference_component_name: Optional[str] = None, tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = None, @@ -302,6 +309,7 @@ def get_deploy_kwargs( model_deploy_kwargs: JumpStartModelDeployKwargs = model.get_deploy_kwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, initial_instance_count=initial_instance_count, instance_type=instance_type, @@ -309,6 +317,7 @@ def get_deploy_kwargs( deserializer=deserializer, accelerator_type=accelerator_type, endpoint_name=endpoint_name, + inference_component_name=inference_component_name, tags=format_tags(tags), kms_key=kms_key, wait=wait, @@ -331,7 +340,13 @@ def get_deploy_kwargs( model_id=model_id, model_from_estimator=True, model_version=model_version, - instance_type=model_deploy_kwargs.instance_type if training_instance_type is None else None, + hub_arn=hub_arn, + instance_type=( + model_deploy_kwargs.instance_type + if training_instance_type is None + or instance_type is not None # always use supplied inference instance type + else None + ), region=region, image_uri=image_uri, source_dir=source_dir, @@ -359,6 +374,7 @@ def get_deploy_kwargs( estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs( model_id=model_init_kwargs.model_id, model_version=model_init_kwargs.model_version, + hub_arn=hub_arn, instance_type=model_init_kwargs.instance_type, initial_instance_count=model_deploy_kwargs.initial_instance_count, region=model_init_kwargs.region, @@ -428,6 +444,20 @@ def _add_model_version_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: kwargs.model_version = kwargs.model_version or "*" + if kwargs.hub_arn: + hub_content_version = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + hub_arn=kwargs.hub_arn, + scope=JumpStartScriptScope.TRAINING, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + ).version + kwargs.model_version = hub_content_version + return kwargs @@ -455,6 +485,7 @@ def _add_instance_type_and_count_to_kwargs( region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -478,6 +509,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima full_model_version = verify_model_region_and_return_specs( model_id=kwargs.model_id, version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.TRAINING, region=kwargs.region, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -494,6 +526,10 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima config_name=kwargs.config_name, scope=JumpStartScriptScope.TRAINING, ) + + if kwargs.hub_arn: + kwargs.tags = add_hub_content_arn_tags(kwargs.tags, kwargs.hub_arn) + return kwargs @@ -506,6 +542,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE image_scope=JumpStartScriptScope.TRAINING, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -522,6 +559,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE if _model_supports_training_model_uri( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -545,6 +583,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE and not _model_supports_incremental_training( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -581,6 +620,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart script_scope=JumpStartScriptScope.TRAINING, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, region=kwargs.region, @@ -599,6 +639,7 @@ def _add_env_to_kwargs( extra_env_vars = environment_variables.retrieve_default( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, include_aws_sdk_env_vars=False, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -612,6 +653,7 @@ def _add_env_to_kwargs( model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -640,6 +682,7 @@ def _add_env_to_kwargs( model_specs = verify_model_region_and_return_specs( model_id=kwargs.model_id, version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -674,6 +717,7 @@ def _add_training_job_name_to_kwargs( default_training_job_name = _retrieve_resource_name_base( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -702,6 +746,7 @@ def _add_hyperparameters_to_kwargs( region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, @@ -736,6 +781,7 @@ def _add_metric_definitions_to_kwargs( region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, @@ -765,6 +811,7 @@ def _add_estimator_extra_kwargs( estimator_kwargs_to_add = _retrieve_estimator_init_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -791,6 +838,7 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim fit_kwargs_to_add = _retrieve_estimator_fit_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 7de6407e47..e759adec5e 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -33,17 +33,21 @@ JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, ) +from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard +from sagemaker.jumpstart.hub.utils import construct_hub_model_reference_arn_from_inputs from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.types import ( + HubContentType, JumpStartModelDeployKwargs, JumpStartModelInitKwargs, JumpStartModelRegisterKwargs, JumpStartModelSpecs, ) from sagemaker.jumpstart.utils import ( + add_hub_content_arn_tags, add_jumpstart_model_info_tags, get_default_jumpstart_session_with_user_agent_suffix, get_neo_content_bucket, @@ -69,6 +73,7 @@ def get_default_predictor( predictor: Predictor, model_id: str, model_version: str, + hub_arn: Optional[str], region: str, tolerate_vulnerable_model: bool, tolerate_deprecated_model: bool, @@ -92,6 +97,7 @@ def get_default_predictor( predictor.serializer = serializers.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -102,6 +108,7 @@ def get_default_predictor( predictor.deserializer = deserializers.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -112,6 +119,7 @@ def get_default_predictor( predictor.accept = accept_types.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -122,6 +130,7 @@ def get_default_predictor( predictor.content_type = content_types.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -176,6 +185,20 @@ def _add_model_version_to_kwargs( kwargs.model_version = kwargs.model_version or "*" + if kwargs.hub_arn: + hub_content_version = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + hub_arn=kwargs.hub_arn, + scope=JumpStartScriptScope.INFERENCE, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + ).version + kwargs.model_version = hub_content_version + return kwargs @@ -200,6 +223,7 @@ def _add_instance_type_to_kwargs( region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.INFERENCE, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -234,6 +258,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel image_scope=JumpStartScriptScope.INFERENCE, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -244,6 +269,32 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel return kwargs +def _add_model_reference_arn_to_kwargs( + kwargs: JumpStartModelInitKwargs, +) -> JumpStartModelInitKwargs: + """Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs.""" + hub_content_type = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + hub_arn=kwargs.hub_arn, + scope=JumpStartScriptScope.INFERENCE, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + ).hub_content_type + kwargs.hub_content_type = hub_content_type if kwargs.hub_arn else None + + if hub_content_type == HubContentType.MODEL_REFERENCE: + kwargs.model_reference_arn = construct_hub_model_reference_arn_from_inputs( + hub_arn=kwargs.hub_arn, model_name=kwargs.model_id, version=kwargs.model_version + ) + else: + kwargs.model_reference_arn = None + return kwargs + + def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets model data based on default or override, returns full kwargs.""" @@ -255,6 +306,7 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode model_scope=JumpStartScriptScope.INFERENCE, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -297,6 +349,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode if _model_supports_inference_script_uri( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -307,6 +360,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode script_scope=JumpStartScriptScope.INFERENCE, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -331,6 +385,7 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod if _model_supports_inference_script_uri( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -360,6 +415,7 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw extra_env_vars = environment_variables.retrieve_default( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, include_aws_sdk_env_vars=False, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -391,6 +447,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt model_package_arn = kwargs.model_package_arn or _retrieve_model_package_arn( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, scope=JumpStartScriptScope.INFERENCE, region=kwargs.region, @@ -411,6 +468,7 @@ def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelI model_kwargs_to_add = _retrieve_model_init_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -448,6 +506,7 @@ def _add_endpoint_name_to_kwargs( default_endpoint_name = _retrieve_resource_name_base( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -471,6 +530,7 @@ def _add_model_name_to_kwargs( default_model_name = _retrieve_resource_name_base( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -492,6 +552,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: full_model_version = verify_model_region_and_return_specs( model_id=kwargs.model_id, version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.INFERENCE, region=kwargs.region, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -511,6 +572,9 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: scope=JumpStartScriptScope.INFERENCE, ) + if kwargs.hub_arn: + kwargs.tags = add_hub_content_arn_tags(kwargs.tags, kwargs.hub_arn) + return kwargs @@ -520,6 +584,7 @@ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any] deploy_kwargs_to_add = _retrieve_model_deploy_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -543,6 +608,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.INFERENCE, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -692,6 +758,7 @@ def _add_config_name_to_deploy_kwargs( def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, initial_instance_count: Optional[int] = None, @@ -700,6 +767,7 @@ def get_deploy_kwargs( deserializer: Optional[BaseDeserializer] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, + inference_component_name: Optional[str] = None, tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = None, @@ -715,6 +783,7 @@ def get_deploy_kwargs( tolerate_deprecated_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, accept_eula: Optional[bool] = None, + model_reference_arn: Optional[str] = None, endpoint_logging: Optional[bool] = None, resources: Optional[ResourceRequirements] = None, managed_instance_scaling: Optional[str] = None, @@ -728,6 +797,7 @@ def get_deploy_kwargs( deploy_kwargs: JumpStartModelDeployKwargs = JumpStartModelDeployKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=model_type, region=region, initial_instance_count=initial_instance_count, @@ -736,6 +806,7 @@ def get_deploy_kwargs( deserializer=deserializer, accelerator_type=accelerator_type, endpoint_name=endpoint_name, + inference_component_name=inference_component_name, tags=format_tags(tags), kms_key=kms_key, wait=wait, @@ -751,6 +822,7 @@ def get_deploy_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, endpoint_logging=endpoint_logging, resources=resources, config_name=config_name, @@ -786,6 +858,8 @@ def get_deploy_kwargs( def get_register_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, @@ -813,12 +887,16 @@ def get_register_kwargs( skip_model_validation: Optional[str] = None, source_uri: Optional[str] = None, config_name: Optional[str] = None, + model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None, + accept_eula: Optional[bool] = None, ) -> JumpStartModelRegisterKwargs: """Returns kwargs required to call `register` on `sagemaker.estimator.Model` object.""" register_kwargs = JumpStartModelRegisterKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, + model_type=model_type, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -845,11 +923,15 @@ def get_register_kwargs( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, + accept_eula=accept_eula, ) model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, + model_type=model_type, region=region, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=sagemaker_session, @@ -872,6 +954,7 @@ def get_init_kwargs( model_id: str, model_from_estimator: bool = False, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, @@ -906,6 +989,7 @@ def get_init_kwargs( model_init_kwargs: JumpStartModelInitKwargs = JumpStartModelInitKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=model_type, instance_type=instance_type, region=region, @@ -935,13 +1019,12 @@ def get_init_kwargs( additional_model_data_sources=additional_model_data_sources, ) - model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_sagemaker_session_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_instance_type_to_kwargs( @@ -950,6 +1033,12 @@ def get_init_kwargs( model_init_kwargs = _add_image_uri_to_kwargs(kwargs=model_init_kwargs) + if hub_arn: + model_init_kwargs = _add_model_reference_arn_to_kwargs(kwargs=model_init_kwargs) + else: + model_init_kwargs.model_reference_arn = None + model_init_kwargs.hub_content_type = None + # we use the model artifact from the training job output if not model_from_estimator: model_init_kwargs = _add_model_data_to_kwargs(kwargs=model_init_kwargs) diff --git a/src/sagemaker/jumpstart/hub/__init__.py b/src/sagemaker/jumpstart/hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/jumpstart/hub/constants.py b/src/sagemaker/jumpstart/hub/constants.py new file mode 100644 index 0000000000..e3a6b7752a --- /dev/null +++ b/src/sagemaker/jumpstart/hub/constants.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores constants related to SageMaker JumpStart Hub.""" +from __future__ import absolute_import + +LATEST_VERSION_WILDCARD = "*" diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py new file mode 100644 index 0000000000..d208220965 --- /dev/null +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -0,0 +1,307 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# pylint: skip-file +"""This module provides the JumpStart Hub class.""" +from __future__ import absolute_import +from datetime import datetime +import logging +from typing import Optional, Dict, List, Any, Union +from botocore import exceptions + +from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME +from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.session import Session + +from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + JUMPSTART_LOGGER, +) +from sagemaker.jumpstart.types import ( + HubContentType, +) +from sagemaker.jumpstart.filters import Constant, Operator, BooleanValues +from sagemaker.jumpstart.hub.utils import ( + get_hub_model_version, + get_info_from_hub_resource_arn, + create_hub_bucket_if_it_does_not_exist, + generate_default_hub_bucket_name, + create_s3_object_reference_from_uri, + construct_hub_arn_from_name, +) + +from sagemaker.jumpstart.notebook_utils import ( + list_jumpstart_models, +) + +from sagemaker.jumpstart.hub.types import ( + S3ObjectLocation, +) +from sagemaker.jumpstart.hub.interfaces import ( + DescribeHubResponse, + DescribeHubContentResponse, +) +from sagemaker.jumpstart.hub.constants import ( + LATEST_VERSION_WILDCARD, +) +from sagemaker.jumpstart import utils + + +class Hub: + """Class for creating and managing a curated JumpStart hub""" + + # Setting LOGGER for backward compatibility, in case users import it... + logger = LOGGER = logging.getLogger("sagemaker") + + _list_hubs_cache: List[Dict[str, Any]] = [] + + def __init__( + self, + hub_name: str, + bucket_name: Optional[str] = None, + sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ) -> None: + """Instantiates a SageMaker ``Hub``. + + Args: + hub_name (str): The name of the Hub to create. + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. + """ + self.hub_name = hub_name + self.region = sagemaker_session.boto_region_name + self._sagemaker_session = sagemaker_session + self.hub_storage_location = self._generate_hub_storage_location(bucket_name) + + def _fetch_hub_bucket_name(self) -> str: + """Retrieves hub bucket name from Hub config if exists""" + try: + hub_response = self._sagemaker_session.describe_hub(hub_name=self.hub_name) + hub_output_location = hub_response["S3StorageConfig"].get("S3OutputPath") + if hub_output_location: + location = create_s3_object_reference_from_uri(hub_output_location) + return location.bucket + default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) + JUMPSTART_LOGGER.warning( + "There is not a Hub bucket associated with %s. Using %s", + self.hub_name, + default_bucket_name, + ) + return default_bucket_name + except exceptions.ClientError: + hub_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) + JUMPSTART_LOGGER.warning( + "There is not a Hub bucket associated with %s. Using %s", + self.hub_name, + hub_bucket_name, + ) + return hub_bucket_name + + def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> None: + """Generates an ``S3ObjectLocation`` given a Hub name.""" + hub_bucket_name = bucket_name or self._fetch_hub_bucket_name() + curr_timestamp = datetime.now().timestamp() + return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}") + + def _get_latest_model_version(self, model_id: str) -> str: + """Populates the lastest version of a model from specs no matter what is passed. + + Returns model ({ model_id: str, version: str }) + """ + model_specs = utils.verify_model_region_and_return_specs( + model_id, LATEST_VERSION_WILDCARD, JumpStartScriptScope.INFERENCE, self.region + ) + return model_specs.version + + def create( + self, + description: str, + display_name: Optional[str] = None, + search_keywords: Optional[str] = None, + tags: Optional[str] = None, + ) -> Dict[str, str]: + """Creates a hub with the given description""" + + create_hub_bucket_if_it_does_not_exist( + self.hub_storage_location.bucket, self._sagemaker_session + ) + + return self._sagemaker_session.create_hub( + hub_name=self.hub_name, + hub_description=description, + hub_display_name=display_name, + hub_search_keywords=search_keywords, + s3_storage_config={"S3OutputPath": self.hub_storage_location.get_uri()}, + tags=tags, + ) + + def describe(self, hub_name: Optional[str] = None) -> DescribeHubResponse: + """Returns descriptive information about the Hub""" + + hub_description: DescribeHubResponse = self._sagemaker_session.describe_hub( + hub_name=self.hub_name if not hub_name else hub_name + ) + + return hub_description + + def _list_and_paginate_models(self, **kwargs) -> List[Dict[str, Any]]: + """List and paginate models from Hub.""" + next_token: Optional[str] = None + first_iteration: bool = True + hub_model_summaries: List[Dict[str, Any]] = [] + + while first_iteration or next_token: + first_iteration = False + list_hub_content_response = self._sagemaker_session.list_hub_contents(**kwargs) + hub_model_summaries.extend(list_hub_content_response.get("HubContentSummaries", [])) + next_token = list_hub_content_response.get("NextToken") + + return hub_model_summaries + + def list_models(self, clear_cache: bool = True, **kwargs) -> Dict[str, Any]: + """Lists the models and model references in this SageMaker Hub. + + This function caches the models in local memory + + **kwargs: Passed to invocation of ``Session:list_hub_contents``. + """ + response = {} + + if clear_cache: + self._list_hubs_cache = None + if self._list_hubs_cache is None: + + hub_model_reference_summaries = self._list_and_paginate_models( + **{ + "hub_name": self.hub_name, + "hub_content_type": HubContentType.MODEL_REFERENCE.value, + **kwargs, + } + ) + + hub_model_summaries = self._list_and_paginate_models( + **{ + "hub_name": self.hub_name, + "hub_content_type": HubContentType.MODEL.value, + **kwargs, + } + ) + response["hub_content_summaries"] = hub_model_reference_summaries + hub_model_summaries + response["next_token"] = None # Temporary until pagination is implemented + return response + + def list_sagemaker_public_hub_models( + self, + filter: Union[Operator, str] = Constant(BooleanValues.TRUE), + next_token: Optional[str] = None, + ) -> Dict[str, Any]: + """Lists the models and model arns from AmazonSageMakerJumpStart Public Hub. + + Args: + filter (Union[Operator, str]): Optional. The filter to apply to list models. This can be + either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``), + or simply a string filter which will get serialized into an Identity filter. + (e.g. ``"task == ic"``). If this argument is not supplied, all models will be listed. + (Default: Constant(BooleanValues.TRUE)). + next_token (str): Optional. A token to resume pagination of list_inference_components. + This is currently not implemented. + """ + + response = {} + + jumpstart_public_hub_arn = construct_hub_arn_from_name( + JUMPSTART_MODEL_HUB_NAME, self.region, self._sagemaker_session + ) + + hub_content_summaries = [] + models = list_jumpstart_models(filter=filter, list_versions=True) + for model in models: + if len(model) <= 63: + info = get_info_from_hub_resource_arn(jumpstart_public_hub_arn) + hub_model_arn = ( + f"arn:{info.partition}:" + f"sagemaker:{info.region}:" + f"aws:hub-content/{info.hub_name}/" + f"{HubContentType.MODEL}/{model[0]}" + ) + hub_content_summary = { + "hub_content_name": model[0], + "hub_content_arn": hub_model_arn, + } + hub_content_summaries.append(hub_content_summary) + response["hub_content_summaries"] = hub_content_summaries + + response["next_token"] = None # Temporary until pagination is implemented for this function + + return response + + def delete(self) -> None: + """Deletes this SageMaker Hub.""" + return self._sagemaker_session.delete_hub(self.hub_name) + + def create_model_reference( + self, model_arn: str, model_name: Optional[str] = None, min_version: Optional[str] = None + ): + """Adds model reference to this SageMaker Hub.""" + return self._sagemaker_session.create_hub_content_reference( + hub_name=self.hub_name, + source_hub_content_arn=model_arn, + hub_content_name=model_name, + min_version=min_version, + ) + + def delete_model_reference(self, model_name: str) -> None: + """Deletes model reference from this SageMaker Hub.""" + return self._sagemaker_session.delete_hub_content_reference( + hub_name=self.hub_name, + hub_content_type=HubContentType.MODEL_REFERENCE.value, + hub_content_name=model_name, + ) + + def describe_model( + self, model_name: str, hub_name: Optional[str] = None, model_version: Optional[str] = None + ) -> DescribeHubContentResponse: + """Describe model in the SageMaker Hub.""" + try: + model_version = get_hub_model_version( + hub_model_name=model_name, + hub_model_type=HubContentType.MODEL.value, + hub_name=self.hub_name, + sagemaker_session=self._sagemaker_session, + hub_model_version=model_version, + ) + + hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=self.hub_name if not hub_name else hub_name, + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=HubContentType.MODEL.value, + ) + + except Exception as ex: + logging.info("Recieved expection while calling APIs for ContentType Model: " + str(ex)) + model_version = get_hub_model_version( + hub_model_name=model_name, + hub_model_type=HubContentType.MODEL_REFERENCE.value, + hub_name=self.hub_name, + sagemaker_session=self._sagemaker_session, + hub_model_version=model_version, + ) + + hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=self.hub_name, + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=HubContentType.MODEL_REFERENCE.value, + ) + + return DescribeHubContentResponse(hub_content_description) diff --git a/src/sagemaker/jumpstart/hub/interfaces.py b/src/sagemaker/jumpstart/hub/interfaces.py new file mode 100644 index 0000000000..2748409927 --- /dev/null +++ b/src/sagemaker/jumpstart/hub/interfaces.py @@ -0,0 +1,831 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores types related to SageMaker JumpStart HubAPI requests and responses.""" +from __future__ import absolute_import + +import re +import json +import datetime + +from typing import Any, Dict, List, Union, Optional +from sagemaker.jumpstart.types import ( + HubContentType, + HubArnExtractedInfo, + JumpStartPredictorSpecs, + JumpStartHyperparameter, + JumpStartDataHolderType, + JumpStartEnvironmentVariable, + JumpStartSerializablePayload, + JumpStartInstanceTypeVariants, +) +from sagemaker.jumpstart.hub.parser_utils import ( + snake_to_upper_camel, + walk_and_apply_json, +) + + +class HubDataHolderType(JumpStartDataHolderType): + """Base class for many Hub API interfaces.""" + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of object.""" + json_obj = {} + for att in self.__slots__: + if att in self._non_serializable_slots: + continue + if hasattr(self, att): + cur_val = getattr(self, att) + # Do not serialize null values. + if cur_val is None: + continue + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + elif isinstance(cur_val, list): + json_obj[att] = [] + for obj in cur_val: + if issubclass(type(obj), JumpStartDataHolderType): + json_obj[att].append(obj.to_json()) + else: + json_obj[att].append(obj) + elif isinstance(cur_val, datetime.datetime): + json_obj[att] = str(cur_val) + else: + json_obj[att] = cur_val + return json_obj + + def __str__(self) -> str: + """Returns string representation of object. + + Example: "{'content_bucket': 'bucket', 'region_name': 'us-west-2'}" + """ + + att_dict = walk_and_apply_json(self.to_json(), snake_to_upper_camel) + return f"{json.dumps(att_dict, default=lambda o: o.to_json())}" + + +class CreateHubResponse(HubDataHolderType): + """Data class for the Hub from session.create_hub()""" + + __slots__ = [ + "hub_arn", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates CreateHubResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of session.create_hub() response. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.hub_arn: str = json_obj["HubArn"] + + +class HubContentDependency(HubDataHolderType): + """Data class for any dependencies related to hub content. + + Content can be scripts, model artifacts, datasets, or notebooks. + """ + + __slots__ = ["dependency_copy_path", "dependency_origin_path", "dependency_type"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubContentDependency object + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + + self.dependency_copy_path: Optional[str] = json_obj.get("DependencyCopyPath", "") + self.dependency_origin_path: Optional[str] = json_obj.get("DependencyOriginPath", "") + self.dependency_type: Optional[str] = json_obj.get("DependencyType", "") + + +class DescribeHubContentResponse(HubDataHolderType): + """Data class for the Hub Content from session.describe_hub_contents()""" + + __slots__ = [ + "creation_time", + "document_schema_version", + "failure_reason", + "hub_arn", + "hub_content_arn", + "hub_content_dependencies", + "hub_content_description", + "hub_content_display_name", + "hub_content_document", + "hub_content_markdown", + "hub_content_name", + "hub_content_search_keywords", + "hub_content_status", + "hub_content_type", + "hub_content_version", + "reference_min_version", + "hub_name", + "_region", + ] + + _non_serializable_slots = ["_region"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates DescribeHubContentResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.creation_time: datetime.datetime = json_obj["CreationTime"] + self.document_schema_version: str = json_obj["DocumentSchemaVersion"] + self.failure_reason: Optional[str] = json_obj.get("FailureReason") + self.hub_arn: str = json_obj["HubArn"] + self.hub_content_arn: str = json_obj["HubContentArn"] + self.hub_content_dependencies = [] + if "Dependencies" in json_obj: + self.hub_content_dependencies: Optional[List[HubContentDependency]] = [ + HubContentDependency(dep) for dep in json_obj.get(["Dependencies"]) + ] + self.hub_content_description: str = json_obj.get("HubContentDescription") + self.hub_content_display_name: str = json_obj.get("HubContentDisplayName") + hub_region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn(self.hub_arn) + self._region = hub_region + self.hub_content_type: str = json_obj.get("HubContentType") + hub_content_document = json.loads(json_obj["HubContentDocument"]) + if self.hub_content_type == HubContentType.MODEL: + self.hub_content_document: HubContentDocument = HubModelDocument( + json_obj=hub_content_document, + region=self._region, + dependencies=self.hub_content_dependencies, + ) + elif self.hub_content_type == HubContentType.MODEL_REFERENCE: + self.hub_content_document: HubContentDocument = HubModelDocument( + json_obj=hub_content_document, + region=self._region, + dependencies=self.hub_content_dependencies, + ) + elif self.hub_content_type == HubContentType.NOTEBOOK: + self.hub_content_document: HubContentDocument = HubNotebookDocument( + json_obj=hub_content_document, region=self._region + ) + else: + raise ValueError( + f"[{self.hub_content_type}] is not a valid HubContentType." + f"Should be one of: {[item.name for item in HubContentType]}." + ) + + self.hub_content_markdown: str = json_obj.get("HubContentMarkdown") + self.hub_content_name: str = json_obj["HubContentName"] + self.hub_content_search_keywords: List[str] = json_obj.get("HubContentSearchKeywords") + self.hub_content_status: str = json_obj["HubContentStatus"] + self.hub_content_version: str = json_obj["HubContentVersion"] + self.hub_name: str = json_obj["HubName"] + + def get_hub_region(self) -> Optional[str]: + """Returns the region hub is in.""" + return self._region + + +class HubS3StorageConfig(HubDataHolderType): + """Data class for any dependencies related to hub content. + + Includes scripts, model artifacts, datasets, or notebooks. + """ + + __slots__ = ["s3_output_path"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubS3StorageConfig object + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + + self.s3_output_path: Optional[str] = json_obj.get("S3OutputPath", "") + + +class DescribeHubResponse(HubDataHolderType): + """Data class for the Hub from session.describe_hub()""" + + __slots__ = [ + "creation_time", + "failure_reason", + "hub_arn", + "hub_description", + "hub_display_name", + "hub_name", + "hub_search_keywords", + "hub_status", + "last_modified_time", + "s3_storage_config", + "_region", + ] + + _non_serializable_slots = ["_region"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates DescribeHubResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + + self.creation_time: datetime.datetime = datetime.datetime(json_obj["CreationTime"]) + self.failure_reason: str = json_obj["FailureReason"] + self.hub_arn: str = json_obj["HubArn"] + hub_region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn(self.hub_arn) + self._region = hub_region + self.hub_description: str = json_obj["HubDescription"] + self.hub_display_name: str = json_obj["HubDisplayName"] + self.hub_name: str = json_obj["HubName"] + self.hub_search_keywords: List[str] = json_obj["HubSearchKeywords"] + self.hub_status: str = json_obj["HubStatus"] + self.last_modified_time: datetime.datetime = datetime.datetime(json_obj["LastModifiedTime"]) + self.s3_storage_config: HubS3StorageConfig = HubS3StorageConfig(json_obj["S3StorageConfig"]) + + def get_hub_region(self) -> Optional[str]: + """Returns the region hub is in.""" + return self._region + + +class ImportHubResponse(HubDataHolderType): + """Data class for the Hub from session.import_hub()""" + + __slots__ = [ + "hub_arn", + "hub_content_arn", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates ImportHubResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.hub_arn: str = json_obj["HubArn"] + self.hub_content_arn: str = json_obj["HubContentArn"] + + +class HubSummary(HubDataHolderType): + """Data class for the HubSummary from session.list_hubs()""" + + __slots__ = [ + "creation_time", + "hub_arn", + "hub_description", + "hub_display_name", + "hub_name", + "hub_search_keywords", + "hub_status", + "last_modified_time", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubSummary object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.creation_time: datetime.datetime = datetime.datetime(json_obj["CreationTime"]) + self.hub_arn: str = json_obj["HubArn"] + self.hub_description: str = json_obj["HubDescription"] + self.hub_display_name: str = json_obj["HubDisplayName"] + self.hub_name: str = json_obj["HubName"] + self.hub_search_keywords: List[str] = json_obj["HubSearchKeywords"] + self.hub_status: str = json_obj["HubStatus"] + self.last_modified_time: datetime.datetime = datetime.datetime(json_obj["LastModifiedTime"]) + + +class ListHubsResponse(HubDataHolderType): + """Data class for the Hub from session.list_hubs()""" + + __slots__ = [ + "hub_summaries", + "next_token", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates ListHubsResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of session.list_hubs() response. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of session.list_hubs() response. + """ + self.hub_summaries: List[HubSummary] = [ + HubSummary(item) for item in json_obj["HubSummaries"] + ] + self.next_token: str = json_obj["NextToken"] + + +class EcrUri(HubDataHolderType): + """Data class for ECR image uri.""" + + __slots__ = ["account", "region_name", "repository", "tag"] + + def __init__(self, uri: str): + """Instantiates EcrUri object.""" + self.from_ecr_uri(uri) + + def from_ecr_uri(self, uri: str) -> None: + """Parse a given aws ecr image uri into its various components.""" + uri_regex = ( + r"^(?:(?P[a-zA-Z0-9][\w-]*)\.dkr\.ecr\.(?P[a-zA-Z0-9][\w-]*)" + r"\.(?P[a-zA-Z0-9\.-]+))\/(?P([a-z0-9]+" + r"(?:[._-][a-z0-9]+)*\/)*[a-z0-9]+(?:[._-][a-z0-9]+)*)(:*)(?P.*)?" + ) + + parsed_image_uri = re.compile(uri_regex).match(uri) + + account = parsed_image_uri.group("account_id") + region = parsed_image_uri.group("region") + repository = parsed_image_uri.group("repository_name") + tag = parsed_image_uri.group("image_tag") + + self.account = account + self.region_name = region + self.repository = repository + self.tag = tag + + +class NotebookLocationUris(HubDataHolderType): + """Data class for Notebook Location uri.""" + + __slots__ = ["demo_notebook", "model_fit", "model_deploy"] + + def __init__(self, json_obj: Dict[str, Any]): + """Instantiates EcrUri object.""" + self.from_json(json_obj) + + def from_json(self, json_obj: str) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.demo_notebook = json_obj.get("demo_notebook") + self.model_fit = json_obj.get("model_fit") + self.model_deploy = json_obj.get("model_deploy") + + +class HubModelDocument(HubDataHolderType): + """Data class for model type HubContentDocument from session.describe_hub_content().""" + + SCHEMA_VERSION = "2.2.0" + + __slots__ = [ + "url", + "min_sdk_version", + "training_supported", + "incremental_training_supported", + "dynamic_container_deployment_supported", + "hosting_ecr_uri", + "hosting_artifact_s3_data_type", + "hosting_artifact_compression_type", + "hosting_artifact_uri", + "hosting_prepacked_artifact_uri", + "hosting_prepacked_artifact_version", + "hosting_script_uri", + "hosting_use_script_uri", + "hosting_eula_uri", + "hosting_model_package_arn", + "training_artifact_s3_data_type", + "training_artifact_compression_type", + "training_model_package_artifact_uri", + "hyperparameters", + "inference_environment_variables", + "training_script_uri", + "training_prepacked_script_uri", + "training_prepacked_script_version", + "training_ecr_uri", + "training_metrics", + "training_artifact_uri", + "inference_dependencies", + "training_dependencies", + "default_inference_instance_type", + "supported_inference_instance_types", + "default_training_instance_type", + "supported_training_instance_types", + "sage_maker_sdk_predictor_specifications", + "inference_volume_size", + "training_volume_size", + "inference_enable_network_isolation", + "training_enable_network_isolation", + "fine_tuning_supported", + "validation_supported", + "default_training_dataset_uri", + "resource_name_base", + "gated_bucket", + "default_payloads", + "hosting_resource_requirements", + "hosting_instance_type_variants", + "training_instance_type_variants", + "notebook_location_uris", + "model_provider_icon_uri", + "task", + "framework", + "datatype", + "license", + "contextual_help", + "model_data_download_timeout", + "container_startup_health_check_timeout", + "encrypt_inter_container_traffic", + "max_runtime_in_seconds", + "disable_output_compression", + "model_dir", + "dependencies", + "_region", + ] + + _non_serializable_slots = ["_region"] + + def __init__( + self, + json_obj: Dict[str, Any], + region: str, + dependencies: List[HubContentDependency] = None, + ) -> None: + """Instantiates HubModelDocument object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content document. + + Raises: + ValueError: When one of (json_obj) or (model_specs and studio_specs) is not provided. + """ + self._region = region + self.dependencies = dependencies or [] + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub model document. + """ + self.url: str = json_obj["Url"] + self.min_sdk_version: str = json_obj["MinSdkVersion"] + self.hosting_ecr_uri: Optional[str] = json_obj["HostingEcrUri"] + self.hosting_artifact_uri = json_obj["HostingArtifactUri"] + self.hosting_script_uri = json_obj["HostingScriptUri"] + self.inference_dependencies: List[str] = json_obj["InferenceDependencies"] + self.inference_environment_variables: List[JumpStartEnvironmentVariable] = [ + JumpStartEnvironmentVariable(env_variable, is_hub_content=True) + for env_variable in json_obj["InferenceEnvironmentVariables"] + ] + self.training_supported: bool = bool(json_obj["TrainingSupported"]) + self.incremental_training_supported: bool = bool(json_obj["IncrementalTrainingSupported"]) + self.dynamic_container_deployment_supported: Optional[bool] = ( + bool(json_obj.get("DynamicContainerDeploymentSupported")) + if json_obj.get("DynamicContainerDeploymentSupported") + else None + ) + self.hosting_artifact_s3_data_type: Optional[str] = json_obj.get( + "HostingArtifactS3DataType" + ) + self.hosting_artifact_compression_type: Optional[str] = json_obj.get( + "HostingArtifactCompressionType" + ) + self.hosting_prepacked_artifact_uri: Optional[str] = json_obj.get( + "HostingPrepackedArtifactUri" + ) + self.hosting_prepacked_artifact_version: Optional[str] = json_obj.get( + "HostingPrepackedArtifactVersion" + ) + self.hosting_use_script_uri: Optional[bool] = ( + bool(json_obj.get("HostingUseScriptUri")) + if json_obj.get("HostingUseScriptUri") is not None + else None + ) + self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri") + self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn") + self.default_inference_instance_type: Optional[str] = json_obj.get( + "DefaultInferenceInstanceType" + ) + self.supported_inference_instance_types: Optional[str] = json_obj.get( + "SupportedInferenceInstanceTypes" + ) + self.sage_maker_sdk_predictor_specifications: Optional[JumpStartPredictorSpecs] = ( + JumpStartPredictorSpecs( + json_obj.get("SageMakerSdkPredictorSpecifications"), + is_hub_content=True, + ) + if json_obj.get("SageMakerSdkPredictorSpecifications") + else None + ) + self.inference_volume_size: Optional[int] = json_obj.get("InferenceVolumeSize") + self.inference_enable_network_isolation: Optional[str] = json_obj.get( + "InferenceEnableNetworkIsolation", False + ) + self.fine_tuning_supported: Optional[bool] = ( + bool(json_obj.get("FineTuningSupported")) + if json_obj.get("FineTuningSupported") + else None + ) + self.validation_supported: Optional[bool] = ( + bool(json_obj.get("ValidationSupported")) + if json_obj.get("ValidationSupported") + else None + ) + self.default_training_dataset_uri: Optional[str] = json_obj.get("DefaultTrainingDatasetUri") + self.resource_name_base: Optional[str] = json_obj.get("ResourceNameBase") + self.gated_bucket: bool = bool(json_obj.get("GatedBucket", False)) + self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = ( + { + alias: JumpStartSerializablePayload(payload, is_hub_content=True) + for alias, payload in json_obj.get("DefaultPayloads").items() + } + if json_obj.get("DefaultPayloads") + else None + ) + self.hosting_resource_requirements: Optional[Dict[str, int]] = json_obj.get( + "HostingResourceRequirements", None + ) + self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( + JumpStartInstanceTypeVariants( + json_obj.get("HostingInstanceTypeVariants"), + is_hub_content=True, + ) + if json_obj.get("HostingInstanceTypeVariants") + else None + ) + self.notebook_location_uris: Optional[NotebookLocationUris] = ( + NotebookLocationUris(json_obj.get("NotebookLocationUris")) + if json_obj.get("NotebookLocationUris") + else None + ) + self.model_provider_icon_uri: Optional[str] = None # Not needed for private beta + self.task: Optional[str] = json_obj.get("Task") + self.framework: Optional[str] = json_obj.get("Framework") + self.datatype: Optional[str] = json_obj.get("Datatype") + self.license: Optional[str] = json_obj.get("License") + self.contextual_help: Optional[str] = json_obj.get("ContextualHelp") + self.model_dir: Optional[str] = json_obj.get("ModelDir") + # Deploy kwargs + self.model_data_download_timeout: Optional[str] = json_obj.get("ModelDataDownloadTimeout") + self.container_startup_health_check_timeout: Optional[str] = json_obj.get( + "ContainerStartupHealthCheckTimeout" + ) + + if self.training_supported: + self.training_model_package_artifact_uri: Optional[str] = json_obj.get( + "TrainingModelPackageArtifactUri" + ) + self.training_artifact_compression_type: Optional[str] = json_obj.get( + "TrainingArtifactCompressionType" + ) + self.training_artifact_s3_data_type: Optional[str] = json_obj.get( + "TrainingArtifactS3DataType" + ) + self.hyperparameters: List[JumpStartHyperparameter] = [] + hyperparameters: Any = json_obj.get("Hyperparameters") + if hyperparameters is not None: + self.hyperparameters.extend( + [ + JumpStartHyperparameter(hyperparameter, is_hub_content=True) + for hyperparameter in hyperparameters + ] + ) + + self.training_script_uri: Optional[str] = json_obj.get("TrainingScriptUri") + self.training_prepacked_script_uri: Optional[str] = json_obj.get( + "TrainingPrepackedScriptUri" + ) + self.training_prepacked_script_version: Optional[str] = json_obj.get( + "TrainingPrepackedScriptVersion" + ) + self.training_ecr_uri: Optional[str] = json_obj.get("TrainingEcrUri") + self._non_serializable_slots.append("training_ecr_specs") + self.training_metrics: Optional[List[Dict[str, str]]] = json_obj.get( + "TrainingMetrics", None + ) + self.training_artifact_uri: Optional[str] = json_obj.get("TrainingArtifactUri") + self.training_dependencies: Optional[str] = json_obj.get("TrainingDependencies") + self.default_training_instance_type: Optional[str] = json_obj.get( + "DefaultTrainingInstanceType" + ) + self.supported_training_instance_types: Optional[str] = json_obj.get( + "SupportedTrainingInstanceTypes" + ) + self.training_volume_size: Optional[int] = json_obj.get("TrainingVolumeSize") + self.training_enable_network_isolation: Optional[str] = json_obj.get( + "TrainingEnableNetworkIsolation", False + ) + self.training_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( + JumpStartInstanceTypeVariants( + json_obj.get("TrainingInstanceTypeVariants"), + is_hub_content=True, + ) + if json_obj.get("TrainingInstanceTypeVariants") + else None + ) + # Estimator kwargs + self.encrypt_inter_container_traffic: Optional[bool] = ( + bool(json_obj.get("EncryptInterContainerTraffic")) + if json_obj.get("EncryptInterContainerTraffic") + else None + ) + self.max_runtime_in_seconds: Optional[str] = json_obj.get("MaxRuntimeInSeconds") + self.disable_output_compression: Optional[bool] = ( + bool(json_obj.get("DisableOutputCompression")) + if json_obj.get("DisableOutputCompression") + else None + ) + + def get_schema_version(self) -> str: + """Returns schema version.""" + return self.SCHEMA_VERSION + + def get_region(self) -> str: + """Returns hub region.""" + return self._region + + +class HubNotebookDocument(HubDataHolderType): + """Data class for notebook type HubContentDocument from session.describe_hub_content().""" + + SCHEMA_VERSION = "1.0.0" + + __slots__ = ["notebook_location", "dependencies", "_region"] + + _non_serializable_slots = ["_region"] + + def __init__(self, json_obj: Dict[str, Any], region: str) -> None: + """Instantiates HubNotebookDocument object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content document. + """ + self._region = region + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.notebook_location = json_obj["NotebookLocation"] + self.dependencies: List[HubContentDependency] = [ + HubContentDependency(dep) for dep in json_obj["Dependencies"] + ] + + def get_schema_version(self) -> str: + """Returns schema version.""" + return self.SCHEMA_VERSION + + def get_region(self) -> str: + """Returns hub region.""" + return self._region + + +HubContentDocument = Union[HubModelDocument, HubNotebookDocument] + + +class HubContentInfo(HubDataHolderType): + """Data class for the HubContentInfo from session.list_hub_contents().""" + + __slots__ = [ + "creation_time", + "document_schema_version", + "hub_content_arn", + "hub_content_name", + "hub_content_status", + "hub_content_type", + "hub_content_version", + "hub_content_description", + "hub_content_display_name", + "hub_content_search_keywords", + "_region", + ] + + _non_serializable_slots = ["_region"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubContentInfo object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.creation_time: str = json_obj["CreationTime"] + self.document_schema_version: str = json_obj["DocumentSchemaVersion"] + self.hub_content_arn: str = json_obj["HubContentArn"] + self.hub_content_name: str = json_obj["HubContentName"] + self.hub_content_status: str = json_obj["HubContentStatus"] + self.hub_content_type: HubContentType = HubContentType(json_obj["HubContentType"]) + self.hub_content_version: str = json_obj["HubContentVersion"] + self.hub_content_description: Optional[str] = json_obj.get("HubContentDescription") + self.hub_content_display_name: Optional[str] = json_obj.get("HubContentDisplayName") + self._region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn( + self.hub_content_arn + ) + self.hub_content_search_keywords: Optional[List[str]] = json_obj.get( + "HubContentSearchKeywords" + ) + + def get_hub_region(self) -> Optional[str]: + """Returns the region hub is in.""" + return self._region + + +class ListHubContentsResponse(HubDataHolderType): + """Data class for the Hub from session.list_hub_contents()""" + + __slots__ = [ + "hub_content_summaries", + "next_token", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates ImportHubResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.hub_content_summaries: List[HubContentInfo] = [ + HubContentInfo(item) for item in json_obj["HubContentSummaries"] + ] + self.next_token: str = json_obj["NextToken"] diff --git a/src/sagemaker/jumpstart/hub/parser_utils.py b/src/sagemaker/jumpstart/hub/parser_utils.py new file mode 100644 index 0000000000..140c089b11 --- /dev/null +++ b/src/sagemaker/jumpstart/hub/parser_utils.py @@ -0,0 +1,56 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains utilities related to SageMaker JumpStart Hub.""" +from __future__ import absolute_import + +import re +from typing import Any, Dict + + +def camel_to_snake(camel_case_string: str) -> str: + """Converts camelCaseString or UpperCamelCaseString to snake_case_string.""" + snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower() + + +def snake_to_upper_camel(snake_case_string: str) -> str: + """Converts snake_case_string to UpperCamelCaseString.""" + upper_camel_case_string = "".join(word.title() for word in snake_case_string.split("_")) + return upper_camel_case_string + + +def walk_and_apply_json(json_obj: Dict[Any, Any], apply) -> Dict[Any, Any]: + """Recursively walks a json object and applies a given function to the keys.""" + + def _walk_and_apply_json(json_obj, new): + if isinstance(json_obj, dict) and isinstance(new, dict): + for key, value in json_obj.items(): + new_key = apply(key) + if isinstance(value, dict): + new[new_key] = {} + _walk_and_apply_json(value, new=new[new_key]) + elif isinstance(value, list): + new[new_key] = [] + for item in value: + _walk_and_apply_json(item, new=new[new_key]) + else: + new[new_key] = value + elif isinstance(json_obj, dict) and isinstance(new, list): + new.append(_walk_and_apply_json(json_obj, new={})) + elif isinstance(json_obj, list) and isinstance(new, dict): + new.update(json_obj) + elif isinstance(json_obj, list) and isinstance(new, list): + new.append(json_obj) + return new + + return _walk_and_apply_json(json_obj, new={}) diff --git a/src/sagemaker/jumpstart/hub/parsers.py b/src/sagemaker/jumpstart/hub/parsers.py new file mode 100644 index 0000000000..8226a380fd --- /dev/null +++ b/src/sagemaker/jumpstart/hub/parsers.py @@ -0,0 +1,262 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# pylint: skip-file +"""This module stores Hub converter utilities for JumpStart.""" +from __future__ import absolute_import + +from typing import Any, Dict, List +from sagemaker.jumpstart.enums import ModelSpecKwargType, NamingConventionType +from sagemaker.s3 import parse_s3_url +from sagemaker.jumpstart.types import ( + JumpStartModelSpecs, + HubContentType, + JumpStartDataHolderType, +) +from sagemaker.jumpstart.hub.interfaces import ( + DescribeHubContentResponse, + HubModelDocument, +) +from sagemaker.jumpstart.hub.parser_utils import ( + camel_to_snake, + snake_to_upper_camel, + walk_and_apply_json, +) + + +def _to_json(dictionary: Dict[Any, Any]) -> Dict[Any, Any]: + """Convert a nested dictionary of JumpStartDataHolderType into json with UpperCamelCase keys""" + for key, value in dictionary.items(): + if issubclass(type(value), JumpStartDataHolderType): + dictionary[key] = walk_and_apply_json(value.to_json(), snake_to_upper_camel) + elif isinstance(value, list): + new_value = [] + for value_in_list in value: + new_value_in_list = value_in_list + if issubclass(type(value_in_list), JumpStartDataHolderType): + new_value_in_list = walk_and_apply_json( + value_in_list.to_json(), snake_to_upper_camel + ) + new_value.append(new_value_in_list) + dictionary[key] = new_value + elif isinstance(value, dict): + for key_in_dict, value_in_dict in value.items(): + if issubclass(type(value_in_dict), JumpStartDataHolderType): + value[key_in_dict] = walk_and_apply_json( + value_in_dict.to_json(), snake_to_upper_camel + ) + return dictionary + + +def get_model_spec_arg_keys( + arg_type: ModelSpecKwargType, + naming_convention: NamingConventionType = NamingConventionType.DEFAULT, +) -> List[str]: + """Returns a list of arg keys for a specific model spec arg type. + + Args: + arg_type (ModelSpecKwargType): Type of the model spec's kwarg. + naming_convention (NamingConventionType): Type of naming convention to return. + + Raises: + ValueError: If the naming convention is not valid. + """ + arg_keys: List[str] = [] + if arg_type == ModelSpecKwargType.DEPLOY: + arg_keys = ["ModelDataDownloadTimeout", "ContainerStartupHealthCheckTimeout"] + elif arg_type == ModelSpecKwargType.ESTIMATOR: + arg_keys = [ + "EncryptInterContainerTraffic", + "MaxRuntimeInSeconds", + "DisableOutputCompression", + "ModelDir", + ] + elif arg_type == ModelSpecKwargType.MODEL: + arg_keys = [] + elif arg_type == ModelSpecKwargType.FIT: + arg_keys = [] + + if naming_convention == NamingConventionType.SNAKE_CASE: + arg_keys = [camel_to_snake(key) for key in arg_keys] + elif naming_convention == NamingConventionType.UPPER_CAMEL_CASE: + return arg_keys + else: + raise ValueError("Please provide a valid naming convention.") + return arg_keys + + +def get_model_spec_kwargs_from_hub_model_document( + arg_type: ModelSpecKwargType, + hub_content_document: Dict[str, Any], + naming_convention: NamingConventionType = NamingConventionType.UPPER_CAMEL_CASE, +) -> Dict[str, Any]: + """Returns a map of arg type to arg keys for a given hub content document. + + Args: + arg_type (ModelSpecKwargType): Type of the model spec's kwarg. + hub_content_document: A dictionary representation of hub content document. + naming_convention (NamingConventionType): Type of naming convention to return. + + """ + kwargs = dict() + keys = get_model_spec_arg_keys(arg_type, naming_convention=naming_convention) + for k in keys: + kwarg_value = hub_content_document.get(k) + if kwarg_value is not None: + kwargs[k] = kwarg_value + return kwargs + + +def make_model_specs_from_describe_hub_content_response( + response: DescribeHubContentResponse, +) -> JumpStartModelSpecs: + """Sets fields in JumpStartModelSpecs based on values in DescribeHubContentResponse + + Args: + response (Dict[str, any]): parsed DescribeHubContentResponse returned + from SageMaker:DescribeHubContent + """ + if response.hub_content_type not in {HubContentType.MODEL, HubContentType.MODEL_REFERENCE}: + raise AttributeError( + "Invalid content type, use either HubContentType.MODEL or HubContentType.MODEL_REFERENCE." + ) + region = response.get_hub_region() + specs = {} + model_id = response.hub_content_name + specs["model_id"] = model_id + specs["version"] = response.hub_content_version + hub_model_document: HubModelDocument = response.hub_content_document + specs["url"] = hub_model_document.url + specs["min_sdk_version"] = hub_model_document.min_sdk_version + specs["training_supported"] = bool(hub_model_document.training_supported) + specs["incremental_training_supported"] = bool( + hub_model_document.incremental_training_supported + ) + specs["hosting_ecr_uri"] = hub_model_document.hosting_ecr_uri + + hosting_artifact_bucket, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.hosting_artifact_uri + ) + specs["hosting_artifact_key"] = hosting_artifact_key + specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri + hosting_script_bucket, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.hosting_script_uri + ) + specs["hosting_script_key"] = hosting_script_key + specs["inference_environment_variables"] = hub_model_document.inference_environment_variables + specs["inference_vulnerable"] = False + specs["inference_dependencies"] = hub_model_document.inference_dependencies + specs["inference_vulnerabilities"] = [] + specs["training_vulnerable"] = False + specs["training_vulnerabilities"] = [] + specs["deprecated"] = False + specs["deprecated_message"] = None + specs["deprecate_warn_message"] = None + specs["usage_info_message"] = None + specs["default_inference_instance_type"] = hub_model_document.default_inference_instance_type + specs["supported_inference_instance_types"] = ( + hub_model_document.supported_inference_instance_types + ) + specs["dynamic_container_deployment_supported"] = ( + hub_model_document.dynamic_container_deployment_supported + ) + specs["hosting_resource_requirements"] = hub_model_document.hosting_resource_requirements + + specs["hosting_prepacked_artifact_key"] = None + if hub_model_document.hosting_prepacked_artifact_uri is not None: + ( + hosting_prepacked_artifact_bucket, # pylint: disable=unused-variable + hosting_prepacked_artifact_key, + ) = parse_s3_url(hub_model_document.hosting_prepacked_artifact_uri) + specs["hosting_prepacked_artifact_key"] = hosting_prepacked_artifact_key + + hub_content_document_dict: Dict[str, Any] = hub_model_document.to_json() + + specs["fit_kwargs"] = get_model_spec_kwargs_from_hub_model_document( + ModelSpecKwargType.FIT, hub_content_document_dict + ) + specs["model_kwargs"] = get_model_spec_kwargs_from_hub_model_document( + ModelSpecKwargType.MODEL, hub_content_document_dict + ) + specs["deploy_kwargs"] = get_model_spec_kwargs_from_hub_model_document( + ModelSpecKwargType.DEPLOY, hub_content_document_dict + ) + specs["estimator_kwargs"] = get_model_spec_kwargs_from_hub_model_document( + ModelSpecKwargType.ESTIMATOR, hub_content_document_dict + ) + + specs["predictor_specs"] = hub_model_document.sage_maker_sdk_predictor_specifications + default_payloads: Dict[str, Any] = {} + if hub_model_document.default_payloads is not None: + for alias, payload in hub_model_document.default_payloads.items(): + default_payloads[alias] = walk_and_apply_json(payload.to_json(), camel_to_snake) + specs["default_payloads"] = default_payloads + specs["gated_bucket"] = hub_model_document.gated_bucket + specs["inference_volume_size"] = hub_model_document.inference_volume_size + specs["inference_enable_network_isolation"] = ( + hub_model_document.inference_enable_network_isolation + ) + specs["resource_name_base"] = hub_model_document.resource_name_base + + specs["hosting_eula_key"] = None + if hub_model_document.hosting_eula_uri is not None: + hosting_eula_bucket, hosting_eula_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.hosting_eula_uri + ) + specs["hosting_eula_key"] = hosting_eula_key + + if hub_model_document.hosting_model_package_arn: + specs["hosting_model_package_arns"] = {region: hub_model_document.hosting_model_package_arn} + + specs["hosting_use_script_uri"] = hub_model_document.hosting_use_script_uri + + specs["hosting_instance_type_variants"] = hub_model_document.hosting_instance_type_variants + + if specs["training_supported"]: + specs["training_ecr_uri"] = hub_model_document.training_ecr_uri + ( + training_artifact_bucket, # pylint: disable=unused-variable + training_artifact_key, + ) = parse_s3_url(hub_model_document.training_artifact_uri) + specs["training_artifact_key"] = training_artifact_key + ( + training_script_bucket, # pylint: disable=unused-variable + training_script_key, + ) = parse_s3_url(hub_model_document.training_script_uri) + specs["training_script_key"] = training_script_key + specs["training_dependencies"] = hub_model_document.training_dependencies + specs["default_training_instance_type"] = hub_model_document.default_training_instance_type + specs["supported_training_instance_types"] = ( + hub_model_document.supported_training_instance_types + ) + specs["metrics"] = hub_model_document.training_metrics + specs["training_prepacked_script_key"] = None + if hub_model_document.training_prepacked_script_uri is not None: + ( + training_prepacked_script_bucket, # pylint: disable=unused-variable + training_prepacked_script_key, + ) = parse_s3_url(hub_model_document.training_prepacked_script_uri) + specs["training_prepacked_script_key"] = training_prepacked_script_key + + specs["hyperparameters"] = hub_model_document.hyperparameters + specs["training_volume_size"] = hub_model_document.training_volume_size + specs["training_enable_network_isolation"] = ( + hub_model_document.training_enable_network_isolation + ) + if hub_model_document.training_model_package_artifact_uri: + specs["training_model_package_artifact_uris"] = { + region: hub_model_document.training_model_package_artifact_uri + } + specs["training_instance_type_variants"] = ( + hub_model_document.training_instance_type_variants + ) + return JumpStartModelSpecs(_to_json(specs), is_hub_content=True) diff --git a/src/sagemaker/jumpstart/hub/types.py b/src/sagemaker/jumpstart/hub/types.py new file mode 100644 index 0000000000..1a68f84bbc --- /dev/null +++ b/src/sagemaker/jumpstart/hub/types.py @@ -0,0 +1,35 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores types related to SageMaker JumpStart Hub.""" +from __future__ import absolute_import +from typing import Dict +from dataclasses import dataclass + + +@dataclass +class S3ObjectLocation: + """Helper class for S3 object references.""" + + bucket: str + key: str + + def format_for_s3_copy(self) -> Dict[str, str]: + """Returns a dict formatted for S3 copy calls""" + return { + "Bucket": self.bucket, + "Key": self.key, + } + + def get_uri(self) -> str: + """Returns the s3 URI""" + return f"s3://{self.bucket}/{self.key}" diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py new file mode 100644 index 0000000000..3dfe99a8c4 --- /dev/null +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -0,0 +1,219 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# pylint: skip-file +"""This module contains utilities related to SageMaker JumpStart Hub.""" +from __future__ import absolute_import +import re +from typing import Optional +from sagemaker.jumpstart.hub.types import S3ObjectLocation +from sagemaker.s3_utils import parse_s3_url +from sagemaker.session import Session +from sagemaker.utils import aws_partition +from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo +from sagemaker.jumpstart import constants +from packaging.specifiers import SpecifierSet, InvalidSpecifier + + +def get_info_from_hub_resource_arn( + arn: str, +) -> HubArnExtractedInfo: + """Extracts descriptive information from a Hub or HubContent Arn.""" + + match = re.match(constants.HUB_CONTENT_ARN_REGEX, arn) + if match: + partition = match.group(1) + hub_region = match.group(2) + account_id = match.group(3) + hub_name = match.group(4) + hub_content_type = match.group(5) + hub_content_name = match.group(6) + hub_content_version = match.group(7) + + return HubArnExtractedInfo( + partition=partition, + region=hub_region, + account_id=account_id, + hub_name=hub_name, + hub_content_type=hub_content_type, + hub_content_name=hub_content_name, + hub_content_version=hub_content_version, + ) + + match = re.match(constants.HUB_ARN_REGEX, arn) + if match: + partition = match.group(1) + hub_region = match.group(2) + account_id = match.group(3) + hub_name = match.group(4) + return HubArnExtractedInfo( + partition=partition, + region=hub_region, + account_id=account_id, + hub_name=hub_name, + ) + + +def construct_hub_arn_from_name( + hub_name: str, + region: Optional[str] = None, + session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Constructs a Hub arn from the Hub name using default Session values.""" + + account_id = session.account_id() + region = region or session.boto_region_name + partition = aws_partition(region) + + return f"arn:{partition}:sagemaker:{region}:{account_id}:hub/{hub_name}" + + +def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: str) -> str: + """Constructs a HubContent model arn from the Hub name, model name, and model version.""" + + info = get_info_from_hub_resource_arn(hub_arn) + arn = ( + f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/" + f"{info.hub_name}/{HubContentType.MODEL.value}/{model_name}/{version}" + ) + + return arn + + +def construct_hub_model_reference_arn_from_inputs( + hub_arn: str, model_name: str, version: str +) -> str: + """Constructs a HubContent model arn from the Hub name, model name, and model version.""" + + info = get_info_from_hub_resource_arn(hub_arn) + arn = ( + f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/" + f"{info.hub_name}/{HubContentType.MODEL_REFERENCE}/{model_name}/{version}" + ) + + return arn + + +def generate_hub_arn_for_init_kwargs( + hub_name: str, region: Optional[str] = None, session: Optional[Session] = None +): + """Generates the Hub Arn for JumpStart class args from a HubName or Arn. + + Args: + hub_name (str): HubName or HubArn from JumpStart class args + region (str): Region from JumpStart class args + session (Session): Custom SageMaker Session from JumpStart class args + """ + + hub_arn = None + if hub_name: + if hub_name == constants.JUMPSTART_MODEL_HUB_NAME: + return None + match = re.match(constants.HUB_ARN_REGEX, hub_name) + if match: + hub_arn = hub_name + else: + hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, session=session) + return hub_arn + + +def generate_default_hub_bucket_name( + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions. + + Returns: + str: The name of the default bucket. If the name was not explicitly specified through + the Session or sagemaker_config, the bucket will take the form: + ``sagemaker-hubs-{region}-{AWS account ID}``. + """ + + region: str = sagemaker_session.boto_region_name + account_id: str = sagemaker_session.account_id() + + # TODO: Validate and fast fail + + return f"sagemaker-hubs-{region}-{account_id}" + + +def create_s3_object_reference_from_uri(s3_uri: Optional[str]) -> Optional[S3ObjectLocation]: + """Utiity to help generate an S3 object reference""" + if not s3_uri: + return None + + bucket, key = parse_s3_url(s3_uri) + + return S3ObjectLocation( + bucket=bucket, + key=key, + ) + + +def create_hub_bucket_if_it_does_not_exist( + bucket_name: Optional[str] = None, + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Creates the default SageMaker Hub bucket if it does not exist. + + Returns: + str: The name of the default bucket. Takes the form: + ``sagemaker-hubs-{region}-{AWS account ID}``. + """ + + region: str = sagemaker_session.boto_region_name + if bucket_name is None: + bucket_name: str = generate_default_hub_bucket_name(sagemaker_session) + + sagemaker_session._create_s3_bucket_if_it_does_not_exist( + bucket_name=bucket_name, + region=region, + ) + + return bucket_name + + +def is_gated_bucket(bucket_name: str) -> bool: + """Returns true if the bucket name is the JumpStart gated bucket.""" + return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET + + +def get_hub_model_version( + hub_name: str, + hub_model_name: str, + hub_model_type: str, + hub_model_version: Optional[str] = None, + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Returns available Jumpstart hub model version""" + + try: + hub_content_summaries = sagemaker_session.list_hub_content_versions( + hub_name=hub_name, hub_content_name=hub_model_name, hub_content_type=hub_model_type + ).get("HubContentSummaries") + except Exception as ex: + raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}") + + available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries] + + if hub_model_version == "*" or hub_model_version is None: + return str(max(available_model_versions)) + + try: + spec = SpecifierSet(f"=={hub_model_version}") + except InvalidSpecifier: + raise KeyError(f"Bad semantic version: {hub_model_version}") + available_versions_filtered = list(spec.filter(available_model_versions)) + if not available_versions_filtered: + raise KeyError("Model version not available in the Hub") + hub_model_version = str(max(available_versions_filtered)) + + return hub_model_version diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index df139e56b3..15cfea5c86 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -25,6 +25,7 @@ from sagemaker.enums import EndpointType from sagemaker.explainer.explainer_config import ExplainerConfig from sagemaker.jumpstart.accessors import JumpStartModelsAccessor +from sagemaker.jumpstart.hub.utils import generate_hub_arn_for_init_kwargs from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.exceptions import ( INVALID_MODEL_ID_ERROR_MSG, @@ -37,6 +38,7 @@ get_init_kwargs, get_register_kwargs, ) +from sagemaker.jumpstart.session_utils import get_model_info_from_endpoint from sagemaker.jumpstart.types import ( JumpStartSerializablePayload, DeploymentConfigMetadata, @@ -50,8 +52,12 @@ deployment_config_response_data, _deployment_config_lru_cache, ) -from sagemaker.jumpstart.constants import JUMPSTART_LOGGER +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER from sagemaker.jumpstart.enums import JumpStartModelType +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.utils import stringify_object, format_tags, Tags from sagemaker.model import ( Model, @@ -78,6 +84,7 @@ def __init__( self, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_name: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -117,6 +124,7 @@ def __init__( https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html for list of model IDs. model_version (Optional[str]): Version for JumpStart model to use (Default: None). + hub_name (Optional[str]): Hub name or arn where the model is stored (Default: None). tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with @@ -296,6 +304,12 @@ def __init__( ValueError: If the model ID is not recognized by JumpStart. """ + hub_arn = None + if hub_name: + hub_arn = generate_hub_arn_for_init_kwargs( + hub_name=hub_name, region=region, session=sagemaker_session + ) + def _validate_model_id_and_type(): return validate_model_id_and_get_type( model_id=model_id, @@ -303,13 +317,14 @@ def _validate_model_id_and_type(): region=region or getattr(sagemaker_session, "boto_region_name", None), script=JumpStartScriptScope.INFERENCE, sagemaker_session=sagemaker_session, + hub_arn=hub_arn, ) self.model_type = _validate_model_id_and_type() if not self.model_type: JumpStartModelsAccessor.reset_cache() self.model_type = _validate_model_id_and_type() - if not self.model_type: + if not self.model_type and not hub_arn: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) self._model_data_is_set = model_data is not None @@ -318,6 +333,7 @@ def _validate_model_id_and_type(): model_from_estimator=False, model_type=self.model_type, model_version=model_version, + hub_arn=hub_arn, instance_type=instance_type, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -349,6 +365,7 @@ def _validate_model_id_and_type(): self.model_id = model_init_kwargs.model_id self.model_version = model_init_kwargs.model_version + self.hub_arn = model_init_kwargs.hub_arn self.instance_type = model_init_kwargs.instance_type self.resources = model_init_kwargs.resources self.tolerate_vulnerable_model = model_init_kwargs.tolerate_vulnerable_model @@ -358,11 +375,14 @@ def _validate_model_id_and_type(): self.role = role self.config_name = model_init_kwargs.config_name self.additional_model_data_sources = model_init_kwargs.additional_model_data_sources + self.model_reference_arn = model_init_kwargs.model_reference_arn if self.model_type == JumpStartModelType.PROPRIETARY: self.log_subscription_warning() - super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) + model_init_kwargs_dict = model_init_kwargs.to_kwargs_dict() + + super(JumpStartModel, self).__init__(**model_init_kwargs_dict) self.model_package_arn = model_init_kwargs.model_package_arn self.init_kwargs = model_init_kwargs.to_kwargs_dict(False) @@ -381,6 +401,7 @@ def log_subscription_warning(self) -> None: region=self.region, model_id=self.model_id, version=self.model_version, + hub_arn=self.hub_arn, model_type=self.model_type, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=self.sagemaker_session, @@ -402,6 +423,7 @@ def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]: return payloads.retrieve_all_examples( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, @@ -497,6 +519,45 @@ def list_deployment_configs(self) -> List[Dict[str, Any]]: self._get_deployment_configs(self.config_name, self.instance_type) ) + @classmethod + def attach( + cls, + endpoint_name: str, + inference_component_name: Optional[str] = None, + model_id: Optional[str] = None, + model_version: Optional[str] = None, + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ) -> "JumpStartModel": + """Attaches a JumpStartModel object to an existing SageMaker Endpoint. + + The model id, version (and inference component name) can be inferred from the tags. + """ + + inferred_model_id = inferred_model_version = inferred_inference_component_name = None + + if inference_component_name is None or model_id is None or model_version is None: + inferred_model_id, inferred_model_version, inferred_inference_component_name, _, _ = ( + get_model_info_from_endpoint( + endpoint_name=endpoint_name, + inference_component_name=inference_component_name, + sagemaker_session=sagemaker_session, + ) + ) + + model_id = model_id or inferred_model_id + model_version = model_version or inferred_model_version or "*" + inference_component_name = inference_component_name or inferred_inference_component_name + + model = JumpStartModel( + model_id=model_id, + model_version=model_version, + sagemaker_session=sagemaker_session, + ) + model.endpoint_name = endpoint_name + model.inference_component_name = inference_component_name + + return model + def _create_sagemaker_model( self, instance_type=None, @@ -575,6 +636,7 @@ def deploy( deserializer: Optional[BaseDeserializer] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, + inference_component_name: Optional[str] = None, tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = True, @@ -697,6 +759,7 @@ def deploy( model_id=self.model_id, model_version=self.model_version, region=self.region, + hub_arn=self.hub_arn, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, initial_instance_count=initial_instance_count, @@ -705,6 +768,7 @@ def deploy( deserializer=deserializer, accelerator_type=accelerator_type, endpoint_name=endpoint_name, + inference_component_name=inference_component_name, tags=format_tags(tags), kms_key=kms_key, wait=wait, @@ -718,6 +782,7 @@ def deploy( explainer_config=explainer_config, sagemaker_session=self.sagemaker_session, accept_eula=accept_eula, + model_reference_arn=self.model_reference_arn, endpoint_logging=endpoint_logging, resources=resources, managed_instance_scaling=managed_instance_scaling, @@ -745,6 +810,7 @@ def deploy( scope=JumpStartScriptScope.INFERENCE, sagemaker_session=self.sagemaker_session, config_name=self.config_name, + hub_arn=self.hub_arn, ).model_subscription_link get_proprietary_model_subscription_error(e, subscription_link) raise @@ -755,6 +821,7 @@ def deploy( predictor=predictor, model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, @@ -790,6 +857,8 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + accept_eula: Optional[bool] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -837,14 +906,27 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). - + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). Returns: A `sagemaker.model.ModelPackage` instance. """ + if model_package_group_name is None: + model_package_group_name = self.model_id + if self.model_type is JumpStartModelType.PROPRIETARY: + source_uri = self.model_package_arn + register_kwargs = get_register_kwargs( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, + model_type=self.model_type, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, @@ -872,6 +954,8 @@ def register( skip_model_validation=skip_model_validation, source_uri=source_uri, config_name=self.config_name, + model_card=model_card, + accept_eula=accept_eula, ) model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict()) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 65cd4d274c..6ed2c4fdb9 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -10,11 +10,14 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +# pylint: skip-file """This module stores types related to SageMaker JumpStart.""" from __future__ import absolute_import +import re from copy import deepcopy from enum import Enum from typing import Any, Dict, List, Optional, Set, Union +from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard from sagemaker.utils import ( S3_PREFIX, get_instance_type_family, @@ -35,6 +38,10 @@ from sagemaker.workflow.entities import PipelineVariable from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.enums import EndpointType +from sagemaker.jumpstart.hub.parser_utils import ( + camel_to_snake, + walk_and_apply_json, +) class JumpStartDataHolderType: @@ -119,6 +126,23 @@ class JumpStartS3FileType(str, Enum): PROPRIETARY_SPECS = "proprietary_specs" +class HubType(str, Enum): + """Enum for Hub objects.""" + + HUB = "Hub" + + +class HubContentType(str, Enum): + """Enum for Hub content objects.""" + + MODEL = "Model" + NOTEBOOK = "Notebook" + MODEL_REFERENCE = "ModelReference" + + +JumpStartContentDataType = Union[JumpStartS3FileType, HubType, HubContentType] + + class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): """Data class for launched region info.""" @@ -190,14 +214,18 @@ class JumpStartECRSpecs(JumpStartDataHolderType): "framework_version", "py_version", "huggingface_transformers_version", + "_is_hub_content", ] - def __init__(self, spec: Dict[str, Any]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartECRSpecs object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of spec. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -210,6 +238,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if not json_obj: return + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) + self.framework = json_obj.get("framework") self.framework_version = json_obj.get("framework_version") self.py_version = json_obj.get("py_version") @@ -219,7 +250,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartECRSpecs object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } return json_obj @@ -236,14 +271,18 @@ class JumpStartHyperparameter(JumpStartDataHolderType): "max", "exclusive_min", "exclusive_max", + "_is_hub_content", ] - def __init__(self, spec: Dict[str, Any]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartHyperparameter object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of hyperparameter. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -253,6 +292,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj (Dict[str, Any]): Dictionary representation of hyperparameter. """ + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.name = json_obj["name"] self.type = json_obj["type"] self.default = json_obj["default"] @@ -270,17 +311,24 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if max_val is not None: self.max = max_val + # HubContentDocument model schema does not allow exclusive min/max. + if self._is_hub_content: + return + exclusive_min_val = json_obj.get("exclusive_min") + exclusive_max_val = json_obj.get("exclusive_max") if exclusive_min_val is not None: self.exclusive_min = exclusive_min_val - - exclusive_max_val = json_obj.get("exclusive_max") if exclusive_max_val is not None: self.exclusive_max = exclusive_max_val def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartHyperparameter object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } return json_obj @@ -293,14 +341,18 @@ class JumpStartEnvironmentVariable(JumpStartDataHolderType): "default", "scope", "required_for_model_class", + "_is_hub_content", ] - def __init__(self, spec: Dict[str, Any]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartEnvironmentVariable object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of environment variable. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -309,7 +361,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Args: json_obj (Dict[str, Any]): Dictionary representation of environment variable. """ - + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.name = json_obj["name"] self.type = json_obj["type"] self.default = json_obj["default"] @@ -318,7 +370,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartEnvironmentVariable object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } return json_obj @@ -330,14 +386,18 @@ class JumpStartPredictorSpecs(JumpStartDataHolderType): "supported_content_types", "default_accept_type", "supported_accept_types", + "_is_hub_content", ] - def __init__(self, spec: Optional[Dict[str, Any]]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: Optional[bool] = False): """Initializes a JumpStartPredictorSpecs object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of predictor specs. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: @@ -350,6 +410,8 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: if json_obj is None: return + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.default_content_type = json_obj["default_content_type"] self.supported_content_types = json_obj["supported_content_types"] self.default_accept_type = json_obj["default_accept_type"] @@ -357,7 +419,11 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartPredictorSpecs object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } return json_obj @@ -370,16 +436,18 @@ class JumpStartSerializablePayload(JumpStartDataHolderType): "accept", "body", "prompt_key", + "_is_hub_content", ] - _non_serializable_slots = ["raw_payload", "prompt_key"] + _non_serializable_slots = ["raw_payload", "prompt_key", "_is_hub_content"] - def __init__(self, spec: Optional[Dict[str, Any]]): + def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: Optional[bool] = False): """Initializes a JumpStartSerializablePayload object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of payload specs. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: @@ -396,9 +464,11 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: if json_obj is None: return + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.raw_payload = json_obj self.content_type = json_obj["content_type"] - self.body = json_obj["body"] + self.body = json_obj.get("body") accept = json_obj.get("accept") self.prompt_key = json_obj.get("prompt_key") if accept: @@ -414,16 +484,26 @@ class JumpStartInstanceTypeVariants(JumpStartDataHolderType): __slots__ = [ "regional_aliases", + "aliases", "variants", + "_is_hub_content", ] - def __init__(self, spec: Optional[Dict[str, Any]]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: Optional[bool] = False): """Initializes a JumpStartInstanceTypeVariants object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of instance type variants. """ - self.from_json(spec) + + self._is_hub_content = is_hub_content + + if self._is_hub_content: + self.from_describe_hub_content_response(spec) + else: + self.from_json(spec) def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: """Sets fields in object based on json. @@ -435,14 +515,50 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: if json_obj is None: return + self.aliases = None self.regional_aliases: Optional[dict] = json_obj.get("regional_aliases") self.variants: Optional[dict] = json_obj.get("variants") def to_json(self) -> Dict[str, Any]: - """Returns json representation of JumpStartInstanceTypeVariants object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + """Returns json representation of JumpStartInstance object.""" + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } return json_obj + def from_describe_hub_content_response(self, response: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on DescribeHubContent response. + + Args: + response (Dict[str, Any]): Dictionary representation of instance type variants. + """ + + if response is None: + return + + response = walk_and_apply_json(response, camel_to_snake) + self.aliases: Optional[dict] = response.get("aliases") + self.regional_aliases = None + self.variants: Optional[dict] = response.get("variants") + + def regionalize( # pylint: disable=inconsistent-return-statements + self, region: str + ) -> Optional[Dict[str, Any]]: + """Returns regionalized instance type variants.""" + + if self.regional_aliases is None or self.aliases is not None: + return + aliases = self.regional_aliases.get(region, {}) + variants = {} + for instance_name, properties in self.variants.items(): + if properties.get("regional_properties") is not None: + variants.update({instance_name: properties.get("regional_properties")}) + if properties.get("properties") is not None: + variants.update({instance_name: properties.get("properties")}) + return {"Aliases": aliases, "Variants": variants} + def get_instance_specific_metric_definitions( self, instance_type: str ) -> List[JumpStartHyperparameter]: @@ -640,7 +756,12 @@ def get_instance_specific_gated_model_key_env_var_value( Returns None if a model, instance type tuple does not have instance specific property. """ - return self._get_instance_specific_property(instance_type, "gated_model_key_env_var_value") + + gated_model_key_env_var_value = ( + "gated_model_env_var_uri" if self._is_hub_content else "gated_model_key_env_var_value" + ) + + return self._get_instance_specific_property(instance_type, gated_model_key_env_var_value) def get_instance_specific_default_inference_instance_type( self, instance_type: str @@ -692,7 +813,7 @@ def get_instance_specific_supported_inference_instance_types( ) ) - def get_image_uri(self, instance_type: str, region: str) -> Optional[str]: + def get_image_uri(self, instance_type: str, region: Optional[str] = None) -> Optional[str]: """Returns image uri from instance type and region. Returns None if no instance type is available or found. @@ -713,36 +834,63 @@ def get_model_package_arn(self, instance_type: str, region: str) -> Optional[str ) def _get_regional_property( - self, instance_type: str, region: str, property_name: str + self, instance_type: str, region: Optional[str], property_name: str ) -> Optional[str]: """Returns regional property from instance type and region. Returns None if no instance type is available or found. None is also returned if the metadata is improperly formatted. """ + # pylint: disable=too-many-return-statements + # if self.variants is None or (self.aliases is None and self.regional_aliases is None): + # return None - if None in [self.regional_aliases, self.variants]: + if self.variants is None: return None - regional_property_alias: Optional[str] = ( - self.variants.get(instance_type, {}).get("regional_properties", {}).get(property_name) - ) - if regional_property_alias is None: - instance_type_family = get_instance_type_family(instance_type) + if region is None and self.regional_aliases is not None: + return None - if instance_type_family in {"", None}: - return None + regional_property_alias: Optional[str] = None + regional_property_value: Optional[str] = None + if self.regional_aliases: regional_property_alias = ( - self.variants.get(instance_type_family, {}) + self.variants.get(instance_type, {}) .get("regional_properties", {}) .get(property_name) ) + else: + regional_property_value = ( + self.variants.get(instance_type, {}).get("properties", {}).get(property_name) + ) + + if regional_property_alias is None and regional_property_value is None: + instance_type_family = get_instance_type_family(instance_type) - if regional_property_alias is None or len(regional_property_alias) == 0: + if instance_type_family in {"", None}: + return None + + if self.regional_aliases: + regional_property_alias = ( + self.variants.get(instance_type_family, {}) + .get("regional_properties", {}) + .get(property_name) + ) + else: + # if reading from HubContent, aliases are already regionalized + regional_property_value = ( + self.variants.get(instance_type_family, {}) + .get("properties", {}) + .get(property_name) + ) + + if (regional_property_alias is None or len(regional_property_alias) == 0) and ( + regional_property_value is None or len(regional_property_value) == 0 + ): return None - if not regional_property_alias.startswith("$"): + if regional_property_alias and not regional_property_alias.startswith("$"): # No leading '$' indicates bad metadata. # There are tests to ensure this never happens. # However, to allow for fallback options in the unlikely event @@ -750,10 +898,64 @@ def _get_regional_property( # We return None, indicating the field does not exist. return None - if region not in self.regional_aliases: + if self.regional_aliases and region not in self.regional_aliases: return None - alias_value = self.regional_aliases[region].get(regional_property_alias[1:], None) - return alias_value + + if self.regional_aliases: + alias_value = self.regional_aliases[region].get(regional_property_alias[1:], None) + return alias_value + return regional_property_value + + +class JumpStartAdditionalDataSources(JumpStartDataHolderType): + """Data class of additional data sources.""" + + __slots__ = ["speculative_decoding", "scripts"] + + def __init__(self, spec: Dict[str, Any]): + """Initializes a AdditionalDataSources object. + + Args: + spec (Dict[str, Any]): Dictionary representation of data source. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of data source. + """ + self.speculative_decoding: Optional[List[JumpStartModelDataSource]] = ( + [ + JumpStartModelDataSource(data_source) + for data_source in json_obj["speculative_decoding"] + ] + if json_obj.get("speculative_decoding") + else None + ) + self.scripts: Optional[List[JumpStartModelDataSource]] = ( + [JumpStartModelDataSource(data_source) for data_source in json_obj["scripts"]] + if json_obj.get("scripts") + else None + ) + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of AdditionalDataSources object.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + if isinstance(cur_val, list): + json_obj[att] = [] + for obj in cur_val: + if issubclass(type(obj), JumpStartDataHolderType): + json_obj[att].append(obj.to_json()) + else: + json_obj[att].append(obj) + else: + json_obj[att] = cur_val + return json_obj class ModelAccessConfig(JumpStartDataHolderType): @@ -932,57 +1134,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.artifact_version: str = json_obj["artifact_version"] -class JumpStartAdditionalDataSources(JumpStartDataHolderType): - """Data class of additional data sources.""" - - __slots__ = ["speculative_decoding", "scripts"] - - def __init__(self, spec: Dict[str, Any]): - """Initializes a AdditionalDataSources object. - - Args: - spec (Dict[str, Any]): Dictionary representation of data source. - """ - self.from_json(spec) - - def from_json(self, json_obj: Dict[str, Any]) -> None: - """Sets fields in object based on json. - - Args: - json_obj (Dict[str, Any]): Dictionary representation of data source. - """ - self.speculative_decoding: Optional[List[JumpStartModelDataSource]] = ( - [ - JumpStartModelDataSource(data_source) - for data_source in json_obj["speculative_decoding"] - ] - if json_obj.get("speculative_decoding") - else None - ) - self.scripts: Optional[List[JumpStartModelDataSource]] = ( - [JumpStartModelDataSource(data_source) for data_source in json_obj["scripts"]] - if json_obj.get("scripts") - else None - ) - - def to_json(self) -> Dict[str, Any]: - """Returns json representation of AdditionalDataSources object.""" - json_obj = {} - for att in self.__slots__: - if hasattr(self, att): - cur_val = getattr(self, att) - if isinstance(cur_val, list): - json_obj[att] = [] - for obj in cur_val: - if issubclass(type(obj), JumpStartDataHolderType): - json_obj[att].append(obj.to_json()) - else: - json_obj[att].append(obj) - else: - json_obj[att] = cur_val - return json_obj - - class JumpStartBenchmarkStat(JumpStartDataHolderType): """Data class JumpStart benchmark stat.""" @@ -1051,10 +1202,13 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "min_sdk_version", "incremental_training_supported", "hosting_ecr_specs", + "hosting_ecr_uri", + "hosting_artifact_uri", "hosting_artifact_key", "hosting_script_key", "training_supported", "training_ecr_specs", + "training_ecr_uri", "training_artifact_key", "training_script_key", "hyperparameters", @@ -1077,7 +1231,9 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "supported_training_instance_types", "metrics", "training_prepacked_script_key", + "training_prepacked_script_version", "hosting_prepacked_artifact_key", + "hosting_prepacked_artifact_version", "model_kwargs", "deploy_kwargs", "estimator_kwargs", @@ -1100,14 +1256,19 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "hosting_additional_data_sources", "hosting_neuron_model_id", "hosting_neuron_model_version", + "hub_content_type", + "_is_hub_content", ] - def __init__(self, fields: Dict[str, Any]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, fields: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartMetadataFields object. Args: fields (Dict[str, Any]): Dictionary representation of metadata fields. """ + self._is_hub_content = is_hub_content self.from_json(fields) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -1123,16 +1284,24 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.incremental_training_supported: bool = bool( json_obj.get("incremental_training_supported", False) ) - self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = ( - JumpStartECRSpecs(json_obj["hosting_ecr_specs"]) - if "hosting_ecr_specs" in json_obj - else None - ) + if self._is_hub_content: + self.hosting_ecr_uri: Optional[str] = json_obj["hosting_ecr_uri"] + self._non_serializable_slots.append("hosting_ecr_specs") + else: + self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = ( + JumpStartECRSpecs( + json_obj["hosting_ecr_specs"], is_hub_content=self._is_hub_content + ) + if "hosting_ecr_specs" in json_obj + else None + ) + self._non_serializable_slots.append("hosting_ecr_uri") self.hosting_artifact_key: Optional[str] = json_obj.get("hosting_artifact_key") + self.hosting_artifact_uri: Optional[str] = json_obj.get("hosting_artifact_uri") self.hosting_script_key: Optional[str] = json_obj.get("hosting_script_key") self.training_supported: Optional[bool] = bool(json_obj.get("training_supported", False)) self.inference_environment_variables = [ - JumpStartEnvironmentVariable(env_variable) + JumpStartEnvironmentVariable(env_variable, is_hub_content=self._is_hub_content) for env_variable in json_obj.get("inference_environment_variables", []) ] self.inference_vulnerable: bool = bool(json_obj.get("inference_vulnerable", False)) @@ -1170,16 +1339,26 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hosting_prepacked_artifact_key: Optional[str] = json_obj.get( "hosting_prepacked_artifact_key", None ) + # New fields required for Hub model. + if self._is_hub_content: + self.training_prepacked_script_version: Optional[str] = json_obj.get( + "training_prepacked_script_version" + ) + self.hosting_prepacked_artifact_version: Optional[str] = json_obj.get( + "hosting_prepacked_artifact_version" + ) self.model_kwargs = deepcopy(json_obj.get("model_kwargs", {})) self.deploy_kwargs = deepcopy(json_obj.get("deploy_kwargs", {})) self.predictor_specs: Optional[JumpStartPredictorSpecs] = ( - JumpStartPredictorSpecs(json_obj["predictor_specs"]) + JumpStartPredictorSpecs( + json_obj["predictor_specs"], is_hub_content=self._is_hub_content + ) if "predictor_specs" in json_obj else None ) self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = ( { - alias: JumpStartSerializablePayload(payload) + alias: JumpStartSerializablePayload(payload, is_hub_content=self._is_hub_content) for alias, payload in json_obj["default_payloads"].items() } if json_obj.get("default_payloads") @@ -1201,7 +1380,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hosting_use_script_uri: bool = json_obj.get("hosting_use_script_uri", True) self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( - JumpStartInstanceTypeVariants(json_obj["hosting_instance_type_variants"]) + JumpStartInstanceTypeVariants( + json_obj["hosting_instance_type_variants"], self._is_hub_content + ) if json_obj.get("hosting_instance_type_variants") else None ) @@ -1216,18 +1397,26 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: ) if self.training_supported: - self.training_ecr_specs: Optional[JumpStartECRSpecs] = ( - JumpStartECRSpecs(json_obj["training_ecr_specs"]) - if "training_ecr_specs" in json_obj - else None - ) + if self._is_hub_content: + self.training_ecr_uri: Optional[str] = json_obj["training_ecr_uri"] + self._non_serializable_slots.append("training_ecr_specs") + else: + self.training_ecr_specs: Optional[JumpStartECRSpecs] = ( + JumpStartECRSpecs(json_obj["training_ecr_specs"]) + if "training_ecr_specs" in json_obj + else None + ) + self._non_serializable_slots.append("training_ecr_uri") self.training_artifact_key: str = json_obj["training_artifact_key"] self.training_script_key: str = json_obj["training_script_key"] hyperparameters: Any = json_obj.get("hyperparameters") self.hyperparameters: List[JumpStartHyperparameter] = [] if hyperparameters is not None: self.hyperparameters.extend( - [JumpStartHyperparameter(hyperparameter) for hyperparameter in hyperparameters] + [ + JumpStartHyperparameter(hyperparameter, is_hub_content=self._is_hub_content) + for hyperparameter in hyperparameters + ] ) self.estimator_kwargs = deepcopy(json_obj.get("estimator_kwargs", {})) self.fit_kwargs = deepcopy(json_obj.get("fit_kwargs", {})) @@ -1239,7 +1428,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: "training_model_package_artifact_uris" ) self.training_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( - JumpStartInstanceTypeVariants(json_obj["training_instance_type_variants"]) + JumpStartInstanceTypeVariants( + json_obj["training_instance_type_variants"], is_hub_content=self._is_hub_content + ) if json_obj.get("training_instance_type_variants") else None ) @@ -1249,7 +1440,7 @@ def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartMetadataBaseFields object.""" json_obj = {} for att in self.__slots__: - if hasattr(self, att): + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []): cur_val = getattr(self, att) if issubclass(type(cur_val), JumpStartDataHolderType): json_obj[att] = cur_val.to_json() @@ -1271,6 +1462,11 @@ def to_json(self) -> Dict[str, Any]: json_obj[att] = cur_val return json_obj + def set_hub_content_type(self, hub_content_type: HubContentType) -> None: + """Sets the hub content type.""" + if self._is_hub_content: + self.hub_content_type = hub_content_type + class JumpStartConfigComponent(JumpStartMetadataBaseFields): """Data class of JumpStart config component.""" @@ -1503,13 +1699,13 @@ class JumpStartModelSpecs(JumpStartMetadataBaseFields): __slots__ = JumpStartMetadataBaseFields.__slots__ + slots - def __init__(self, spec: Dict[str, Any]): + def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartModelSpecs object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of spec. """ - super().__init__(spec) + super().__init__(spec, is_hub_content) self.from_json(spec) if self.inference_configs and self.inference_configs.get_top_config_from_ranking(): super().from_json(self.inference_configs.get_top_config_from_ranking().resolved_config) @@ -1708,27 +1904,84 @@ def __init__( self.version = version -class JumpStartCachedS3ContentKey(JumpStartDataHolderType): - """Data class for the s3 cached content keys.""" +class JumpStartCachedContentKey(JumpStartDataHolderType): + """Data class for the cached content keys.""" - __slots__ = ["file_type", "s3_key"] + __slots__ = ["data_type", "id_info"] def __init__( self, - file_type: JumpStartS3FileType, - s3_key: str, + data_type: JumpStartContentDataType, + id_info: str, ) -> None: - """Instantiates JumpStartCachedS3ContentKey object. + """Instantiates JumpStartCachedContentKey object. Args: - file_type (JumpStartS3FileType): JumpStart file type. - s3_key (str): object key in s3. + data_type (JumpStartContentDataType): JumpStart content data type. + id_info (str): if S3Content, object key in s3. if HubContent, hub content arn. """ - self.file_type = file_type - self.s3_key = s3_key + self.data_type = data_type + self.id_info = id_info + + +class HubArnExtractedInfo(JumpStartDataHolderType): + """Data class for info extracted from Hub arn.""" + + __slots__ = [ + "partition", + "region", + "account_id", + "hub_name", + "hub_content_type", + "hub_content_name", + "hub_content_version", + ] + + def __init__( + self, + partition: str, + region: str, + account_id: str, + hub_name: str, + hub_content_type: Optional[str] = None, + hub_content_name: Optional[str] = None, + hub_content_version: Optional[str] = None, + ) -> None: + """Instantiates HubArnExtractedInfo object.""" + + self.partition = partition + self.region = region + self.account_id = account_id + self.hub_name = hub_name + self.hub_content_name = hub_content_name + self.hub_content_type = hub_content_type + self.hub_content_version = hub_content_version + + @staticmethod + def extract_region_from_arn(arn: str) -> Optional[str]: + """Extracts hub_name, content_name, and content_version from a HubContentArn""" + + HUB_CONTENT_ARN_REGEX = ( + r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$" + ) + HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" + + match = re.match(HUB_CONTENT_ARN_REGEX, arn) + hub_region = None + if match: + hub_region = match.group(2) + + return hub_region + + match = re.match(HUB_ARN_REGEX, arn) + if match: + hub_region = match.group(2) + return hub_region + return hub_region -class JumpStartCachedS3ContentValue(JumpStartDataHolderType): + +class JumpStartCachedContentValue(JumpStartDataHolderType): """Data class for the s3 cached content values.""" __slots__ = ["formatted_content", "md5_hash"] @@ -1741,7 +1994,7 @@ def __init__( ], md5_hash: Optional[str] = None, ) -> None: - """Instantiates JumpStartCachedS3ContentValue object. + """Instantiates JumpStartCachedContentValue object. Args: formatted_content (Union[Dict[JumpStartVersionedModelId, JumpStartModelHeader], @@ -1776,6 +2029,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "model_type", "instance_type", "tolerate_vulnerable_model", @@ -1803,12 +2057,15 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "resources", "config_name", "additional_model_data_sources", + "hub_content_type", + "model_reference_arn", ] SERIALIZATION_EXCLUSION_SET = { "instance_type", "model_id", "model_version", + "hub_arn", "model_type", "tolerate_vulnerable_model", "tolerate_deprecated_model", @@ -1816,12 +2073,14 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "model_package_arn", "training_instance_type", "config_name", + "hub_content_type", } def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, instance_type: Optional[str] = None, @@ -1854,6 +2113,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.model_type = model_type self.instance_type = instance_type self.region = region @@ -1889,6 +2149,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "model_type", "initial_instance_count", "instance_type", @@ -1897,6 +2158,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "deserializer", "accelerator_type", "endpoint_name", + "inference_component_name", "tags", "kms_key", "wait", @@ -1913,6 +2175,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "sagemaker_session", "training_instance_type", "accept_eula", + "model_reference_arn", "endpoint_logging", "resources", "endpoint_type", @@ -1924,6 +2187,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "model_id", "model_version", "model_type", + "hub_arn", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1936,6 +2200,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, initial_instance_count: Optional[int] = None, @@ -1944,6 +2209,7 @@ def __init__( deserializer: Optional[Any] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, + inference_component_name: Optional[str] = None, tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = None, @@ -1960,6 +2226,7 @@ def __init__( sagemaker_session: Optional[Session] = None, training_instance_type: Optional[str] = None, accept_eula: Optional[bool] = None, + model_reference_arn: Optional[str] = None, endpoint_logging: Optional[bool] = None, resources: Optional[ResourceRequirements] = None, endpoint_type: Optional[EndpointType] = None, @@ -1970,6 +2237,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.model_type = model_type self.initial_instance_count = initial_instance_count self.instance_type = instance_type @@ -1978,6 +2246,7 @@ def __init__( self.deserializer = deserializer self.accelerator_type = accelerator_type self.endpoint_name = endpoint_name + self.inference_component_name = inference_component_name self.tags = format_tags(tags) self.kms_key = kms_key self.wait = wait @@ -1994,6 +2263,7 @@ def __init__( self.sagemaker_session = sagemaker_session self.training_instance_type = training_instance_type self.accept_eula = accept_eula + self.model_reference_arn = model_reference_arn self.endpoint_logging = endpoint_logging self.resources = resources self.endpoint_type = endpoint_type @@ -2007,6 +2277,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "model_type", "instance_type", "instance_count", @@ -2069,6 +2340,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "model_id", "model_version", + "hub_arn", "model_type", "config_name", } @@ -2077,6 +2349,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, @@ -2136,6 +2409,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.model_type = (model_type,) self.instance_type = instance_type self.instance_count = instance_count @@ -2200,6 +2474,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "model_type", "region", "inputs", @@ -2216,6 +2491,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): SERIALIZATION_EXCLUSION_SET = { "model_id", "model_version", + "hub_arn", "model_type", "region", "tolerate_deprecated_model", @@ -2228,6 +2504,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, inputs: Optional[Union[str, Dict, Any, Any]] = None, @@ -2244,6 +2521,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.model_type = model_type self.region = region self.inputs = inputs @@ -2263,6 +2541,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "instance_type", "initial_instance_count", "region", @@ -2309,6 +2588,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): "region", "model_id", "model_version", + "hub_arn", "sagemaker_session", "config_name", } @@ -2317,6 +2597,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -2360,6 +2641,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.instance_type = instance_type self.initial_instance_count = initial_instance_count self.region = region @@ -2408,7 +2690,9 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "tolerate_deprecated_model", "region", "model_id", + "model_type", "model_version", + "hub_arn", "sagemaker_session", "content_types", "response_types", @@ -2433,6 +2717,8 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "skip_model_validation", "source_uri", "config_name", + "model_card", + "accept_eula", ] SERIALIZATION_EXCLUSION_SET = { @@ -2441,6 +2727,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "region", "model_id", "model_version", + "hub_arn", "sagemaker_session", "config_name", } @@ -2449,7 +2736,9 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, sagemaker_session: Optional[Any] = None, @@ -2476,11 +2765,15 @@ def __init__( skip_model_validation: Optional[str] = None, source_uri: Optional[str] = None, config_name: Optional[str] = None, + model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None, + accept_eula: Optional[bool] = None, ) -> None: """Instantiates JumpStartModelRegisterKwargs object.""" self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn + self.model_type = model_type self.region = region self.image_uri = image_uri self.sagemaker_session = sagemaker_session @@ -2509,6 +2802,8 @@ def __init__( self.skip_model_validation = skip_model_validation self.source_uri = source_uri self.config_name = config_name + self.model_card = model_card + self.accept_eula = accept_eula class BaseDeploymentConfigDataHolder(JumpStartDataHolderType): diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 559a960588..7a00efa8e1 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -431,6 +431,21 @@ def add_jumpstart_model_info_tags( return tags +def add_hub_content_arn_tags( + tags: Optional[List[TagsDict]], + hub_arn: str, +) -> Optional[List[TagsDict]]: + """Adds custom Hub arn tag to JumpStart related resources.""" + + tags = add_single_jumpstart_tag( + hub_arn, + enums.JumpStartTag.HUB_CONTENT_ARN, + tags, + is_uri=False, + ) + return tags + + def add_jumpstart_uri_tags( tags: Optional[List[TagsDict]] = None, inference_model_uri: Optional[Union[str, dict]] = None, @@ -595,6 +610,7 @@ def verify_model_region_and_return_specs( version: Optional[str], scope: Optional[str], region: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -611,6 +627,8 @@ def verify_model_region_and_return_specs( scope (Optional[str]): scope of the JumpStart model to verify. region (Optional[str]): region of the JumpStart model to verify and obtains specs. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -651,9 +669,11 @@ def verify_model_region_and_return_specs( model_specs = accessors.JumpStartModelsAccessor.get_model_specs( # type: ignore region=region, model_id=model_id, + hub_arn=hub_arn, version=version, s3_client=sagemaker_session.s3_client, model_type=model_type, + sagemaker_session=sagemaker_session, ) if ( @@ -817,6 +837,7 @@ def validate_model_id_and_get_type( model_version: Optional[str] = None, script: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + hub_arn: Optional[str] = None, ) -> Optional[enums.JumpStartModelType]: """Returns model type if the model ID is supported for the given script. @@ -828,6 +849,8 @@ def validate_model_id_and_get_type( return None if not isinstance(model_id, str): return None + if hub_arn: + return None s3_client = sagemaker_session.s3_client if sagemaker_session else None region = region or constants.JUMPSTART_DEFAULT_REGION_NAME @@ -1004,6 +1027,7 @@ def get_benchmark_stats( model_id: str, model_version: str, config_names: Optional[List[str]] = None, + hub_arn: Optional[str] = None, sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, @@ -1017,6 +1041,7 @@ def get_benchmark_stats( region=region, model_id=model_id, version=model_version, + hub_arn=hub_arn, sagemaker_session=sagemaker_session, scope=scope, model_type=model_type, diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index bcb0365f7b..ea8041d1ee 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -167,6 +167,7 @@ def validate_hyperparameters( model_version: str, hyperparameters: Dict[str, Any], validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, + hub_arn: Optional[str] = None, region: Optional[str] = None, sagemaker_session: Optional[session.Session] = None, tolerate_vulnerable_model: bool = False, @@ -215,6 +216,7 @@ def validate_hyperparameters( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, region=region, scope=JumpStartScriptScope.TRAINING, sagemaker_session=sagemaker_session, diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index 36a848aa52..89a2df2135 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -42,6 +42,8 @@ _LocalPipeline, ) from sagemaker.session import Session +from sagemaker.telemetry.telemetry_logging import _telemetry_emitter +from sagemaker.telemetry.constants import Feature from sagemaker.utils import ( get_config_value, _module_import_error, @@ -83,6 +85,7 @@ def __init__(self, sagemaker_session=None): """ self.sagemaker_session = sagemaker_session or LocalSession() + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_processing_job") def create_processing_job( self, ProcessingJobName, @@ -165,6 +168,7 @@ def describe_processing_job(self, ProcessingJobName): raise ClientError(error_response, "describe_processing_job") return LocalSagemakerClient._processing_jobs[ProcessingJobName].describe() + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_training_job") def create_training_job( self, TrainingJobName, @@ -235,6 +239,7 @@ def describe_training_job(self, TrainingJobName): raise ClientError(error_response, "describe_training_job") return LocalSagemakerClient._training_jobs[TrainingJobName].describe() + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_transform_job") def create_transform_job( self, TransformJobName, @@ -280,6 +285,7 @@ def describe_transform_job(self, TransformJobName): raise ClientError(error_response, "describe_transform_job") return LocalSagemakerClient._transform_jobs[TransformJobName].describe() + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_model") def create_model( self, ModelName, PrimaryContainer, *args, **kwargs ): # pylint: disable=unused-argument @@ -329,6 +335,7 @@ def describe_endpoint_config(self, EndpointConfigName): raise ClientError(error_response, "describe_endpoint_config") return LocalSagemakerClient._endpoint_configs[EndpointConfigName].describe() + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_endpoint_config") def create_endpoint_config(self, EndpointConfigName, ProductionVariants, Tags=None): """Create the endpoint configuration. @@ -360,6 +367,7 @@ def describe_endpoint(self, EndpointName): raise ClientError(error_response, "describe_endpoint") return LocalSagemakerClient._endpoints[EndpointName].describe() + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_endpoint") def create_endpoint(self, EndpointName, EndpointConfigName, Tags=None): """Create the endpoint. @@ -428,6 +436,7 @@ def delete_model(self, ModelName): if ModelName in LocalSagemakerClient._models: del LocalSagemakerClient._models[ModelName] + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_pipeline") def create_pipeline( self, pipeline, pipeline_description, **kwargs # pylint: disable=unused-argument ): diff --git a/src/sagemaker/metric_definitions.py b/src/sagemaker/metric_definitions.py index 0c066ff801..dbf7ef7650 100644 --- a/src/sagemaker/metric_definitions.py +++ b/src/sagemaker/metric_definitions.py @@ -29,6 +29,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, instance_type: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -44,6 +45,8 @@ def retrieve_default( retrieve the default training metric definitions. (Default: None). model_version (str): The version of the model for which to retrieve the default training metric definitions. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). instance_type (str): An instance type to optionally supply in order to get metric definitions specific for the instance type. tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -73,6 +76,7 @@ def retrieve_default( return artifacts._retrieve_default_training_metric_definitions( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, instance_type=instance_type, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/mlflow/__init__.py b/src/sagemaker/mlflow/__init__.py new file mode 100644 index 0000000000..6549052177 --- /dev/null +++ b/src/sagemaker/mlflow/__init__.py @@ -0,0 +1,12 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. diff --git a/src/sagemaker/mlflow/tracking_server.py b/src/sagemaker/mlflow/tracking_server.py new file mode 100644 index 0000000000..0baa0f457b --- /dev/null +++ b/src/sagemaker/mlflow/tracking_server.py @@ -0,0 +1,50 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + + +"""This module contains code related to the Mlflow Tracking Server.""" + +from __future__ import absolute_import +from typing import Optional, TYPE_CHECKING +from sagemaker.apiutils import _utils + +if TYPE_CHECKING: + from sagemaker import Session + + +def generate_mlflow_presigned_url( + name: str, + expires_in_seconds: Optional[int] = None, + session_expiration_duration_in_seconds: Optional[int] = None, + sagemaker_session: Optional["Session"] = None, +) -> str: + """Generate a presigned url to acess the Mlflow UI. + + Args: + name (str): Name of the Mlflow Tracking Server + expires_in_seconds (int): Expiration time of the first usage + of the presigned url in seconds. + session_expiration_duration_in_seconds (int): Session duration of the presigned url in + seconds after the first use. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + Returns: + (str): Authorized Url to acess the Mlflow UI. + """ + session = sagemaker_session or _utils.default_session() + api_response = session.create_presigned_mlflow_tracking_server_url( + name, expires_in_seconds, session_expiration_duration_in_seconds + ) + return api_response["AuthorizedUrl"] diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index abe4889174..ce8142e43d 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -44,6 +44,12 @@ ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH, load_sagemaker_config, ) +from sagemaker.jumpstart.enums import JumpStartModelType +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) +from sagemaker.model_card.helpers import _hash_content_str from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum from sagemaker.session import Session from sagemaker.model_metrics import ModelMetrics @@ -162,6 +168,7 @@ def __init__( git_config: Optional[Dict[str, str]] = None, resources: Optional[ResourceRequirements] = None, additional_model_data_sources: Optional[Dict[str, Any]] = None, + model_reference_arn: Optional[str] = None, ): """Initialize an SageMaker ``Model``. @@ -327,6 +334,8 @@ def __init__( (Default: None). additional_model_data_sources (Optional[Dict[str, Any]]): Additional location of SageMaker model data (default: None). + model_reference_arn (Optional [str]): Hub Content Arn of a Model Reference type + content (default: None). """ self.model_data = model_data @@ -359,6 +368,7 @@ def __init__( sagemaker_config=self._sagemaker_config, ) self.endpoint_name = None + self.inference_component_name = None self._is_compiled_model = False self._compilation_job_name = None self._is_edge_packaged_model = False @@ -405,6 +415,7 @@ def __init__( self.content_types = None self.response_types = None self.accept_eula = None + self.model_reference_arn = model_reference_arn self._tags: Optional[Tags] = None def add_tags(self, tags: Tags) -> None: @@ -415,6 +426,16 @@ def add_tags(self, tags: Tags) -> None: """ self._tags = _validate_new_tags(tags, self._tags) + @classmethod + def attach( + cls, + endpoint_name: str, + inference_component_name: Optional[str] = None, + sagemaker_session=None, + ) -> "Model": + """Attaches a Model object to an existing SageMaker Endpoint.""" + raise NotImplementedError + @runnable_by_pipeline def register( self, @@ -442,6 +463,9 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + accept_eula: Optional[bool] = None, + model_type: Optional[JumpStartModelType] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -493,6 +517,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: A `sagemaker.model.ModelPackage` instance or pipeline step arguments @@ -514,9 +540,8 @@ def register( model_package_group_name = utils.base_name_from_image( self.image_uri, default_base_name=ModelPackage.__name__ ) - if model_package_group_name is not None: - container_def = self.prepare_container_def() + container_def = self.prepare_container_def(accept_eula=accept_eula) container_def = update_container_with_inference_params( framework=framework, framework_version=framework_version, @@ -559,6 +584,7 @@ def register( task=task, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) model_package = self.sagemaker_session.create_model_package_from_containers( **model_pkg_args @@ -581,6 +607,7 @@ def create( serverless_inference_config: Optional[ServerlessInferenceConfig] = None, tags: Optional[Tags] = None, accept_eula: Optional[bool] = None, + model_reference_arn: Optional[str] = None, ): """Create a SageMaker Model Entity @@ -622,6 +649,7 @@ def create( tags=format_tags(tags), serverless_inference_config=serverless_inference_config, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def _init_sagemaker_session_if_does_not_exist(self, instance_type=None): @@ -643,6 +671,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): # pylint: disable=unused-argument """Return a dict created by ``sagemaker.container_def()``. @@ -686,6 +715,11 @@ def prepare_container_def( accept_eula if accept_eula is not None else getattr(self, "accept_eula", None) ), additional_model_data_sources=self.additional_model_data_sources, + model_reference_arn=( + model_reference_arn + if model_reference_arn is not None + else getattr(self, "model_reference_arn", None) + ), ) def is_repack(self) -> bool: @@ -828,6 +862,7 @@ def _create_sagemaker_model( tags: Optional[Tags] = None, serverless_inference_config=None, accept_eula=None, + model_reference_arn: Optional[str] = None, ): """Create a SageMaker Model Entity @@ -852,6 +887,8 @@ def _create_sagemaker_model( The `accept_eula` value must be explicitly defined as `True` in order to accept the end-user license agreement (EULA) that some models require. (Default: None). + model_reference_arn (Optional [str]): Hub Content Arn of a Model Reference type + content (default: None). """ if self.model_package_arn is not None or self.algorithm_arn is not None: model_package = ModelPackage( @@ -883,6 +920,7 @@ def _create_sagemaker_model( accelerator_type=accelerator_type, serverless_inference_config=serverless_inference_config, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) if not isinstance(self.sagemaker_session, PipelineSession): @@ -1325,6 +1363,7 @@ def deploy( resources: Optional[ResourceRequirements] = None, endpoint_type: EndpointType = EndpointType.MODEL_BASED, managed_instance_scaling: Optional[str] = None, + inference_component_name=None, routing_config: Optional[Dict[str, Any]] = None, **kwargs, ): @@ -1610,11 +1649,15 @@ def deploy( "ComputeResourceRequirements": resources.get_compute_resource_requirements(), } runtime_config = {"CopyCount": resources.copy_count} - inference_component_name = unique_name_from_base(self.name) + self.inference_component_name = ( + inference_component_name + or self.inference_component_name + or unique_name_from_base(self.name) + ) # [TODO]: Add endpoint_logging support self.sagemaker_session.create_inference_component( - inference_component_name=inference_component_name, + inference_component_name=self.inference_component_name, endpoint_name=self.endpoint_name, variant_name="AllTraffic", # default variant name specification=inference_component_spec, @@ -1627,7 +1670,7 @@ def deploy( predictor = self.predictor_cls( self.endpoint_name, self.sagemaker_session, - component_name=inference_component_name, + component_name=self.inference_component_name, ) if serializer: predictor.serializer = serializer @@ -1642,6 +1685,7 @@ def deploy( accelerator_type=accelerator_type, tags=tags, serverless_inference_config=serverless_inference_config, + **kwargs, ) serverless_inference_config_dict = ( serverless_inference_config._to_request_dict() if is_serverless else None @@ -2401,3 +2445,44 @@ def add_inference_specification( ) sagemaker_session.sagemaker_client.update_model_package(**model_package_update_args) + + def update_model_card(self, model_card: Union[ModelCard, ModelPackageModelCard]): + """Updates Created model card content which created with model package + + Args: + model_card (ModelCard | ModelPackageModelCard): Updated Model Card content + """ + + sagemaker_session = self.sagemaker_session or sagemaker.Session() + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=self.model_package_arn + ) + update_model_card_req = model_card._create_request_args() + if update_model_card_req["ModelCardStatus"] is not None: + if ( + desc_model_package["ModelCard"]["ModelCardStatus"] + == update_model_card_req["ModelCardStatus"] + ): + del update_model_card_req["ModelCardStatus"] + + if update_model_card_req.get("ModelCardName") is not None: + del update_model_card_req["ModelCardName"] + if update_model_card_req.get("Content") is not None: + previous_content_hash = _hash_content_str( + desc_model_package["ModelCard"]["ModelCardContent"] + ) + current_content_hash = _hash_content_str(update_model_card_req["Content"]) + if ( + previous_content_hash == current_content_hash + or update_model_card_req.get("Content") == "{}" + or update_model_card_req.get("Content") == "null" + ): + del update_model_card_req["Content"] + else: + update_model_card_req["ModelCardContent"] = update_model_card_req["Content"] + del update_model_card_req["Content"] + update_model_package_args = { + "ModelPackageArn": self.model_package_arn, + "ModelCard": update_model_card_req, + } + sagemaker_session.sagemaker_client.update_model_package(**update_model_package_args) diff --git a/src/sagemaker/model_card/__init__.py b/src/sagemaker/model_card/__init__.py index 679da42a3f..b7a7d24dc7 100644 --- a/src/sagemaker/model_card/__init__.py +++ b/src/sagemaker/model_card/__init__.py @@ -29,6 +29,7 @@ AdditionalInformation, ModelCard, ModelPackage, + ModelPackageModelCard, ) from sagemaker.model_card.schema_constraints import ( # noqa: F401 # pylint: disable=unused-import diff --git a/src/sagemaker/model_card/model_card.py b/src/sagemaker/model_card/model_card.py index 33af98723f..c13e979efc 100644 --- a/src/sagemaker/model_card/model_card.py +++ b/src/sagemaker/model_card/model_card.py @@ -16,7 +16,7 @@ import json import logging from datetime import datetime -from typing import Optional, Union, List, Any +from typing import Optional, Union, List, Any, Dict from botocore.exceptions import ClientError from boto3.session import Session as boto3_Session from six.moves.urllib.parse import urlparse @@ -1883,3 +1883,29 @@ def list_export_jobs( return sagemaker_session.sagemaker_client.list_model_card_export_jobs( ModelCardName=model_card_name, **kwargs ) + + +class ModelPackageModelCard(object): + """Use an Amazon SageMaker Model Card to document qualitative and quantitative information about a model.""" # noqa E501 # pylint: disable=c0301 + + def __init__( + self, + model_card_content: Optional[Dict[str, Any]] = None, + model_card_status: Optional[str] = None, + ): + + self.model_card_content = model_card_content + self.model_card_status = model_card_status + + def _create_request_args(self): + """Generate the request body for create model card call. + + Args: + model_card_content dict[str]: Content of the model card. + model_card_status (str): Status of the model card you want to export. + + """ # noqa E501 # pylint: disable=line-too-long + request_args = {} + request_args["ModelCardStatus"] = self.model_card_status + request_args["Content"] = json.dumps(self.model_card_content, cls=_JSONEncoder) + return request_args diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 122647e536..2949fbaf5f 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -29,6 +29,7 @@ def retrieve( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_scope: Optional[str] = None, instance_type: Optional[str] = None, tolerate_vulnerable_model: bool = False, @@ -44,6 +45,8 @@ def retrieve( the model artifact S3 URI. model_version (str): The version of the JumpStart model for which to retrieve the model artifact S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). model_scope (str): The model type. Valid values: "training" and "inference". instance_type (str): The ML compute instance type for the specified scope. (Default: None). @@ -78,6 +81,7 @@ def retrieve( return artifacts._retrieve_model_uri( model_id=model_id, model_version=model_version, # type: ignore + hub_arn=hub_arn, model_scope=model_scope, instance_type=instance_type, region=region, diff --git a/src/sagemaker/multidatamodel.py b/src/sagemaker/multidatamodel.py index 9c1e6ac4f4..9ed348c927 100644 --- a/src/sagemaker/multidatamodel.py +++ b/src/sagemaker/multidatamodel.py @@ -126,6 +126,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """Return a container definition set. @@ -154,6 +155,7 @@ def prepare_container_def( model_data_url=self.model_data_prefix, container_mode=self.container_mode, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def deploy( diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 714b0db945..487d336497 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -29,6 +29,10 @@ ) from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.mxnet import defaults from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer @@ -177,6 +181,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -228,6 +233,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -268,6 +275,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) def prepare_container_def( @@ -276,6 +284,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """Return a container definition with framework configuration. @@ -329,6 +338,7 @@ def prepare_container_def( self.repacked_model_data or self.model_data, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def serving_image_uri( diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py index 06d2ecfcde..403445525b 100644 --- a/src/sagemaker/payloads.py +++ b/src/sagemaker/payloads.py @@ -32,6 +32,7 @@ def retrieve_all_examples( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, serialize: bool = False, tolerate_vulnerable_model: bool = False, @@ -78,11 +79,12 @@ def retrieve_all_examples( unserialized_payload_dict: Optional[Dict[str, JumpStartSerializablePayload]] = ( artifacts._retrieve_example_payloads( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + region=region, + hub_arn=hub_arn, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, ) @@ -123,6 +125,7 @@ def retrieve_example( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, serialize: bool = False, tolerate_vulnerable_model: bool = False, @@ -168,6 +171,7 @@ def retrieve_example( region=region, model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=model_type, serialize=serialize, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 3bfdb1a594..b5a3cd4357 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -26,6 +26,10 @@ ) from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.metadata_properties import MetadataProperties +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.session import Session from sagemaker.utils import ( name_from_image, @@ -361,6 +365,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -412,6 +417,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: If ``sagemaker_session`` is a ``PipelineSession`` instance, returns pipeline step @@ -460,6 +467,7 @@ def register( task=task, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) self.sagemaker_session.create_model_package_from_containers(**model_pkg_args) diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 780a1a56c8..df8554f7e8 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -40,6 +40,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, @@ -59,6 +60,8 @@ def retrieve_default( retrieve the default predictor. (Default: None). model_version (str): The version of the model for which to retrieve the default predictor. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -109,6 +112,7 @@ def retrieve_default( predictor=predictor, model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index f490e49375..92b96bd8c8 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -29,6 +29,10 @@ ) from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.pytorch import defaults from sagemaker.predictor import Predictor from sagemaker.serializers import NumpySerializer @@ -179,6 +183,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -230,6 +235,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -270,6 +277,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) def prepare_container_def( @@ -278,6 +286,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """A container definition with framework configuration set in model environment variables. @@ -329,6 +338,7 @@ def prepare_container_def( self.repacked_model_data or self.model_data, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def serving_image_uri( diff --git a/src/sagemaker/remote_function/client.py b/src/sagemaker/remote_function/client.py index 0dc69d8647..53a116e4ef 100644 --- a/src/sagemaker/remote_function/client.py +++ b/src/sagemaker/remote_function/client.py @@ -40,6 +40,8 @@ from sagemaker.utils import name_from_base, base_from_name from sagemaker.remote_function.spark_config import SparkConfig from sagemaker.remote_function.custom_file_filter import CustomFileFilter +from sagemaker.telemetry.telemetry_logging import _telemetry_emitter +from sagemaker.telemetry.constants import Feature _API_CALL_LIMIT = { "SubmittingIntervalInSecs": 1, @@ -57,6 +59,7 @@ logger = logging_config.get_logger() +@_telemetry_emitter(feature=Feature.REMOTE_FUNCTION, func_name="remote_function.remote") def remote( _func=None, *, diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py index 7808d0172a..d0ddea4432 100644 --- a/src/sagemaker/resource_requirements.py +++ b/src/sagemaker/resource_requirements.py @@ -31,6 +31,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -48,6 +49,8 @@ def retrieve_default( retrieve the default resource requirements. (Default: None). model_version (str): The version of the model for which to retrieve the default resource requirements. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -80,12 +83,13 @@ def retrieve_default( raise ValueError("Must specify scope for resource requirements.") return artifacts._retrieve_default_resources( - model_id, - model_version, - scope, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + scope=scope, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, model_type=model_type, sagemaker_session=sagemaker_session, instance_type=instance_type, diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index 6e10785498..d60095b521 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -29,6 +29,7 @@ def retrieve( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, script_scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -43,6 +44,8 @@ def retrieve( retrieve the script S3 URI. model_version (str): The version of the JumpStart model for which to retrieve the model script S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). script_scope (str): The script type. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model @@ -73,12 +76,13 @@ def retrieve( ) return artifacts._retrieve_script_uri( - model_id, - model_version, - script_scope, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + script_scope=script_scope, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, ) diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index d197df731c..ef502dc6f3 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -42,6 +42,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -56,6 +57,8 @@ def retrieve_options( retrieve the supported serializers. (Default: None). model_version (str): The version of the model for which to retrieve the supported serializers. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -81,11 +84,12 @@ def retrieve_options( ) return artifacts._retrieve_serializer_options( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, config_name=config_name, ) @@ -95,6 +99,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, @@ -110,6 +115,8 @@ def retrieve_default( retrieve the default serializer. (Default: None). model_version (str): The version of the model for which to retrieve the default serializer. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -135,11 +142,12 @@ def retrieve_default( ) return artifacts._retrieve_default_serializer( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, config_name=config_name, diff --git a/src/sagemaker/serve/builder/djl_builder.py b/src/sagemaker/serve/builder/djl_builder.py index 646b9fa611..72437c0fbb 100644 --- a/src/sagemaker/serve/builder/djl_builder.py +++ b/src/sagemaker/serve/builder/djl_builder.py @@ -15,7 +15,6 @@ import logging from typing import Type from abc import ABC, abstractmethod -from pathlib import Path from datetime import datetime, timedelta from sagemaker.model import Model @@ -31,12 +30,12 @@ _more_performant, _pretty_print_results, ) +from sagemaker.serve.utils.hf_utils import _get_model_config_properties_from_hf from sagemaker.serve.model_server.djl_serving.utils import ( - _auto_detect_engine, - _set_serve_properties, _get_admissible_tensor_parallel_degrees, _get_admissible_dtypes, _get_default_tensor_parallel_degree, + _get_default_djl_configurations, ) from sagemaker.serve.utils.local_hardware import ( _get_nb_instance, @@ -45,24 +44,18 @@ _get_gpu_info_fallback, ) from sagemaker.serve.model_server.djl_serving.prepare import ( - prepare_for_djl_serving, _create_dir_structure, ) from sagemaker.serve.utils.predictors import DjlLocalModePredictor -from sagemaker.serve.utils.types import ModelServer, _DjlEngine +from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.utils.telemetry_logger import _capture_telemetry -from sagemaker.djl_inference.model import ( - DeepSpeedModel, - FasterTransformerModel, - HuggingFaceAccelerateModel, -) +from sagemaker.djl_inference.model import DJLModel from sagemaker.base_predictor import PredictorBase logger = logging.getLogger(__name__) # Match JumpStart DJL entrypoint format -_DJL_MODEL_BUILDER_ENTRY_POINT = "inference.py" _CODE_FOLDER = "code" _INVALID_SAMPLE_DATA_EX = ( 'For djl-serving, sample input must be of {"inputs": str, "parameters": dict}, ' @@ -88,14 +81,11 @@ def __init__(self): self.vpc_config = None self._original_deploy = None self.secret_key = None - self.engine = None self.hf_model_config = None self._default_tensor_parallel_degree = None self._default_data_type = None self._default_max_tokens = None - self._default_max_new_tokens = None self.pysdk_model = None - self.overwrite_props_from_file = None self.schema_builder = None self.env_vars = None self.nb_instance_type = None @@ -131,37 +121,15 @@ def _validate_djl_serving_sample_data(self): def _create_djl_model(self) -> Type[Model]: """Placeholder docstring""" - code_dir = str(Path(self.model_path).joinpath(_CODE_FOLDER)) - - kwargs = { - "model_id": self.model, - "role": self.serve_settings.role_arn, - "entry_point": _DJL_MODEL_BUILDER_ENTRY_POINT, - "dtype": self._default_data_type, - "sagemaker_session": self.sagemaker_session, - "source_dir": code_dir, - "env": self.env_vars, - "hf_hub_token": self.env_vars.get("HUGGING_FACE_HUB_TOKEN"), - "image_config": self.image_config, - "vpc_config": self.vpc_config, - } - - if self.engine == _DjlEngine.DEEPSPEED: - pysdk_model = DeepSpeedModel( - tensor_parallel_degree=self._default_tensor_parallel_degree, - max_tokens=self._default_max_tokens, - **kwargs, - ) - elif self.engine == _DjlEngine.FASTER_TRANSFORMER: - pysdk_model = FasterTransformerModel( - tensor_parallel_degree=self._default_tensor_parallel_degree, - **kwargs, - ) - else: - pysdk_model = HuggingFaceAccelerateModel( - number_of_partitions=self._default_tensor_parallel_degree, - **kwargs, - ) + pysdk_model = DJLModel( + model_id=self.model, + role=self.serve_settings.role_arn, + sagemaker_session=self.sagemaker_session, + env=self.env_vars, + huggingface_hub_token=self.env_vars.get("HF_TOKEN"), + image_config=self.image_config, + vpc_config=self.vpc_config, + ) if not self.image_uri: self.image_uri = pysdk_model.serving_image_uri(self.sagemaker_session.boto_region_name) @@ -197,7 +165,6 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa else: raise ValueError("Mode %s is not supported!" % overwrite_mode) - manual_set_props = None if self.mode == Mode.SAGEMAKER_ENDPOINT: if self.nb_instance_type and "instance_type" not in kwargs: kwargs.update({"instance_type": self.nb_instance_type}) @@ -213,17 +180,9 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa default_tensor_parallel_degree = _get_default_tensor_parallel_degree( self.hf_model_config, tot_gpus ) - manual_set_props = { - "option.tensor_parallel_degree": str(default_tensor_parallel_degree) + "\n" - } - - prepare_for_djl_serving( - model_path=self.model_path, - model=self.pysdk_model, - dependencies=self.dependencies, - overwrite_props_from_file=self.overwrite_props_from_file, - manual_set_props=manual_set_props, - ) + self.pysdk_model.env.update( + {"TENSOR_PARALLEL_DEGREE": str(default_tensor_parallel_degree)} + ) serializer = self.schema_builder.input_serializer deserializer = self.schema_builder._output_deserializer @@ -240,7 +199,7 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa timeout if timeout else 1800, self.secret_key, predictor, - self.env_vars, + self.pysdk_model.env, ) ram_usage_after = _get_ram_usage_mb() @@ -266,6 +225,7 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa # if has not been built for local container we must use cache # that hosting has write access to. self.pysdk_model.env["TRANSFORMERS_CACHE"] = "/tmp" + self.pysdk_model.env["HF_HOME"] = "/tmp" self.pysdk_model.env["HUGGINGFACE_HUB_CACHE"] = "/tmp" if "endpoint_logging" not in kwargs: @@ -281,25 +241,21 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa def _build_for_hf_djl(self): """Placeholder docstring""" - self.overwrite_props_from_file = True self.nb_instance_type = _get_nb_instance() _create_dir_structure(self.model_path) - self.engine, self.hf_model_config = _auto_detect_engine( - self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") - ) - if not hasattr(self, "pysdk_model"): - ( - self._default_tensor_parallel_degree, - self._default_data_type, - _, - self._default_max_tokens, - self._default_max_new_tokens, - ) = _set_serve_properties(self.hf_model_config, self.schema_builder) + self.env_vars.update({"HF_MODEL_ID": self.model}) + self.hf_model_config = _get_model_config_properties_from_hf( + self.model, self.env_vars.get("HF_TOKEN") + ) + default_djl_configurations, _default_max_new_tokens = _get_default_djl_configurations( + self.model, self.hf_model_config, self.schema_builder + ) + self.env_vars.update(default_djl_configurations) self.schema_builder.sample_input["parameters"][ "max_new_tokens" - ] = self._default_max_new_tokens + ] = _default_max_new_tokens self.pysdk_model = self._create_djl_model() if self.mode == Mode.LOCAL_CONTAINER: @@ -316,8 +272,6 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800): ) return self.pysdk_model - self.overwrite_props_from_file = False - admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees( self.hf_model_config ) @@ -337,8 +291,9 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800): "Trying tensor parallel degree: %s, dtype: %s...", tensor_parallel_degree, dtype ) - self._default_tensor_parallel_degree = tensor_parallel_degree - self._default_data_type = dtype + self.env_vars.update( + {"TENSOR_PARALLEL_DEGREE": str(tensor_parallel_degree), "OPTION_DTYPE": dtype} + ) self.pysdk_model = self._create_djl_model() try: @@ -353,15 +308,15 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800): predictor, self.schema_builder.sample_input ) - serving_properties = self.pysdk_model.generate_serving_properties() + tested_env = self.pysdk_model.env.copy() logger.info( "Average latency: %s, throughput/s: %s for configuration: %s", avg_latency, throughput_per_second, - serving_properties, + tested_env, ) benchmark_results[avg_latency] = [ - serving_properties, + tested_env, p90, avg_tokens_per_second, throughput_per_second, @@ -449,6 +404,12 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800): if best_tuned_combination: self._default_tensor_parallel_degree = best_tuned_combination[1] self._default_data_type = best_tuned_combination[2] + self.env_vars.update( + { + "TENSOR_PARALLEL_DEGREE": str(self._default_tensor_parallel_degree), + "OPTION_DTYPE": self._default_data_type, + } + ) self.pysdk_model = self._create_djl_model() _pretty_print_results(benchmark_results) @@ -456,7 +417,7 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800): "Model Configuration: %s was most performant with avg latency: %s, " "p90 latency: %s, average tokens per second: %s, throughput/s: %s, " "standard deviation of request %s", - self.pysdk_model.generate_serving_properties(), + self.pysdk_model.env, best_tuned_combination[0], best_tuned_combination[3], best_tuned_combination[4], @@ -464,33 +425,22 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800): best_tuned_combination[6], ) else: - ( - self._default_tensor_parallel_degree, - self._default_data_type, - _, - self._default_max_tokens, - self._default_max_new_tokens, - ) = _set_serve_properties(self.hf_model_config, self.schema_builder) + default_djl_configurations, _default_max_new_tokens = _get_default_djl_configurations( + self.model, self.hf_model_config, self.schema_builder + ) + self.env_vars.update(default_djl_configurations) self.schema_builder.sample_input["parameters"][ "max_new_tokens" - ] = self._default_max_new_tokens + ] = _default_max_new_tokens self.pysdk_model = self._create_djl_model() logger.debug( "Failed to gather any tuning results. " "Please inspect the stack trace emitted from live logging for more details. " "Falling back to default serving.properties: %s", - self.pysdk_model.generate_serving_properties(), + self.pysdk_model.env, ) - prepare_for_djl_serving( - model_path=self.model_path, - model=self.pysdk_model, - dependencies=self.dependencies, - overwrite_props_from_file=self.overwrite_props_from_file, - ) - self.overwrite_props_from_file = True - return self.pysdk_model def _build_for_djl(self): diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 164bfe894c..d58b0618b7 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -13,11 +13,13 @@ """Holds the ModelBuilder class and the ModelServer enum.""" from __future__ import absolute_import +import importlib.util import uuid from typing import Any, Type, List, Dict, Optional, Union from dataclasses import dataclass, field import logging import os +import re from pathlib import Path @@ -26,7 +28,6 @@ from sagemaker import Session from sagemaker.model import Model from sagemaker.base_predictor import PredictorBase -from sagemaker.djl_inference import defaults from sagemaker.serializers import NumpySerializer, TorchTensorSerializer from sagemaker.deserializers import JSONDeserializer, TorchTensorDeserializer from sagemaker.serve.builder.schema_builder import SchemaBuilder @@ -44,12 +45,15 @@ from sagemaker.predictor import Predictor from sagemaker.serve.model_format.mlflow.constants import ( MLFLOW_MODEL_PATH, + MLFLOW_TRACKING_ARN, + MLFLOW_RUN_ID_REGEX, + MLFLOW_REGISTRY_PATH_REGEX, + MODEL_PACKAGE_ARN_REGEX, MLFLOW_METADATA_FILE, MLFLOW_PIP_DEPENDENCY_FILE, ) from sagemaker.serve.model_format.mlflow.utils import ( _get_default_model_server_for_mlflow, - _mlflow_input_is_local_path, _download_s3_artifacts, _select_container_for_mlflow_model, _generate_mlflow_artifact_path, @@ -95,11 +99,15 @@ logger = logging.getLogger(__name__) -supported_model_server = { +# Any new server type should be added here +supported_model_servers = { ModelServer.TORCHSERVE, ModelServer.TRITON, ModelServer.DJL_SERVING, ModelServer.TENSORFLOW_SERVING, + ModelServer.MMS, + ModelServer.TGI, + ModelServer.TEI, } @@ -290,28 +298,17 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, ) def _build_validations(self): - """Placeholder docstring""" - # TODO: Beta validations - remove after the launch + """Validations needed for model server overrides, or auto-detection or fallback""" if self.mode == Mode.IN_PROCESS: raise ValueError("IN_PROCESS mode is not supported yet!") if self.inference_spec and self.model: - raise ValueError("Cannot have both the Model and Inference spec in the builder") + raise ValueError("Can only set one of the following: model, inference_spec.") if self.image_uri and not is_1p_image_uri(self.image_uri) and self.model_server is None: raise ValueError( "Model_server must be set when non-first-party image_uri is set. " - + "Supported model servers: %s" % supported_model_server - ) - - # Set TorchServe as default model server - if not self.model_server: - self.model_server = ModelServer.TORCHSERVE - - if self.model_server not in supported_model_server: - raise ValueError( - "%s is not supported yet! Supported model servers: %s" - % (self.model_server, supported_model_server) + + "Supported model servers: %s" % supported_model_servers ) def _save_model_inference_spec(self): @@ -516,6 +513,7 @@ def _model_builder_register_wrapper(self, *args, **kwargs): mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH], s3_upload_path=self.s3_upload_path, sagemaker_session=self.sagemaker_session, + tracking_server_arn=self.model_metadata.get(MLFLOW_TRACKING_ARN), ) return new_model_package @@ -586,6 +584,7 @@ def _model_builder_deploy_wrapper( mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH], s3_upload_path=self.s3_upload_path, sagemaker_session=self.sagemaker_session, + tracking_server_arn=self.model_metadata.get(MLFLOW_TRACKING_ARN), ) return predictor @@ -639,11 +638,30 @@ def wrapper(*args, **kwargs): return wrapper - def _check_if_input_is_mlflow_model(self) -> bool: - """Checks whether an MLmodel file exists in the given directory. + def _handle_mlflow_input(self): + """Check whether an MLflow model is present and handle accordingly""" + self._is_mlflow_model = self._has_mlflow_arguments() + if not self._is_mlflow_model: + return + + mlflow_model_path = self.model_metadata.get(MLFLOW_MODEL_PATH) + artifact_path = self._get_artifact_path(mlflow_model_path) + if not self._mlflow_metadata_exists(artifact_path): + logger.info( + "MLflow model metadata not detected in %s. ModelBuilder is not " + "handling MLflow model input", + mlflow_model_path, + ) + return + + self._initialize_for_mlflow(artifact_path) + _validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR")) + + def _has_mlflow_arguments(self) -> bool: + """Check whether MLflow model arguments are present Returns: - bool: True if the MLmodel file exists, False otherwise. + bool: True if MLflow arguments are present, False otherwise. """ if self.inference_spec or self.model: logger.info( @@ -658,8 +676,8 @@ def _check_if_input_is_mlflow_model(self) -> bool: ) return False - path = self.model_metadata.get(MLFLOW_MODEL_PATH) - if not path: + mlflow_model_path = self.model_metadata.get(MLFLOW_MODEL_PATH) + if not mlflow_model_path: logger.info( "%s is not provided in ModelMetadata. ModelBuilder is not handling MLflow model " "input", @@ -667,7 +685,73 @@ def _check_if_input_is_mlflow_model(self) -> bool: ) return False - # Check for S3 path + return True + + def _get_artifact_path(self, mlflow_model_path: str) -> str: + """Retrieves the model artifact location given the Mlflow model input. + + Args: + mlflow_model_path (str): The MLflow model path input. + + Returns: + str: The path to the model artifact. + """ + if (is_run_id_type := re.match(MLFLOW_RUN_ID_REGEX, mlflow_model_path)) or re.match( + MLFLOW_REGISTRY_PATH_REGEX, mlflow_model_path + ): + mlflow_tracking_arn = self.model_metadata.get(MLFLOW_TRACKING_ARN) + if not mlflow_tracking_arn: + raise ValueError( + "%s is not provided in ModelMetadata or through set_tracking_arn " + "but MLflow model path was provided." % MLFLOW_TRACKING_ARN, + ) + + if not importlib.util.find_spec("sagemaker_mlflow"): + raise ImportError( + "Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed" + ) + + import mlflow + + mlflow.set_tracking_uri(mlflow_tracking_arn) + if is_run_id_type: + _, run_id, model_path = mlflow_model_path.split("/", 2) + artifact_uri = mlflow.get_run(run_id).info.artifact_uri + if not artifact_uri.endswith("/"): + artifact_uri += "/" + return artifact_uri + model_path + + mlflow_client = mlflow.MlflowClient() + if not mlflow_model_path.endswith("/"): + mlflow_model_path += "/" + + if "@" in mlflow_model_path: + _, model_name_and_alias, artifact_uri = mlflow_model_path.split("/", 2) + model_name, model_alias = model_name_and_alias.split("@") + model_metadata = mlflow_client.get_model_version_by_alias(model_name, model_alias) + else: + _, model_name, model_version, artifact_uri = mlflow_model_path.split("/", 3) + model_metadata = mlflow_client.get_model_version(model_name, model_version) + + source = model_metadata.source + if not source.endswith("/"): + source += "/" + return source + artifact_uri + + if re.match(MODEL_PACKAGE_ARN_REGEX, mlflow_model_path): + model_package = self.sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=mlflow_model_path + ) + return model_package["SourceUri"] + + return mlflow_model_path + + def _mlflow_metadata_exists(self, path: str) -> bool: + """Checks whether an MLmodel file exists in the given directory. + + Returns: + bool: True if the MLmodel file exists, False otherwise. + """ if path.startswith("s3://"): s3_downloader = S3Downloader() if not path.endswith("/"): @@ -679,17 +763,18 @@ def _check_if_input_is_mlflow_model(self) -> bool: file_path = os.path.join(path, MLFLOW_METADATA_FILE) return os.path.isfile(file_path) - def _initialize_for_mlflow(self) -> None: - """Initialize mlflow model artifacts, image uri and model server.""" - mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH) - if not _mlflow_input_is_local_path(mlflow_path): - # TODO: extend to package arn, run id and etc. - logger.info( - "Start downloading model artifacts from %s to %s", mlflow_path, self.model_path - ) - _download_s3_artifacts(mlflow_path, self.model_path, self.sagemaker_session) + def _initialize_for_mlflow(self, artifact_path: str) -> None: + """Initialize mlflow model artifacts, image uri and model server. + + Args: + artifact_path (str): The path to the artifact store. + """ + if artifact_path.startswith("s3://"): + _download_s3_artifacts(artifact_path, self.model_path, self.sagemaker_session) + elif os.path.exists(artifact_path): + _copy_directory_contents(artifact_path, self.model_path) else: - _copy_directory_contents(mlflow_path, self.model_path) + raise ValueError("Invalid path: %s" % artifact_path) mlflow_model_metadata_path = _generate_mlflow_artifact_path( self.model_path, MLFLOW_METADATA_FILE ) @@ -742,6 +827,8 @@ def build( # pylint: disable=R0911 self.role_arn = role_arn self.sagemaker_session = sagemaker_session or Session() + self.sagemaker_session.settings._local_download_dir = self.model_path + # https://github.com/boto/botocore/blob/develop/botocore/useragent.py#L258 # decorate to_string() due to # https://github.com/boto/botocore/blob/develop/botocore/client.py#L1014-L1015 @@ -753,14 +840,12 @@ def build( # pylint: disable=R0911 self.serve_settings = self._get_serve_setting() self._is_custom_image_uri = self.image_uri is not None - self._is_mlflow_model = self._check_if_input_is_mlflow_model() - if self._is_mlflow_model: - logger.warning( - "Support of MLflow format models is experimental and is not intended" - " for production at this moment." - ) - self._initialize_for_mlflow() - _validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR")) + self._handle_mlflow_input() + + self._build_validations() + + if self.model_server: + return self._build_for_model_server() if isinstance(self.model, str): model_task = None @@ -789,15 +874,30 @@ def build( # pylint: disable=R0911 return self._build_for_tei() elif self._can_fit_on_single_gpu(): return self._build_for_transformers() - elif ( - self.model in defaults.DEEPSPEED_RECOMMENDED_ARCHITECTURES - or self.model in defaults.FASTER_TRANSFORMER_RECOMMENDED_ARCHITECTURES - ): - return self._build_for_djl() else: return self._build_for_transformers() - self._build_validations() + # Set TorchServe as default model server + if not self.model_server: + self.model_server = ModelServer.TORCHSERVE + return self._build_for_torchserve() + + raise ValueError("%s model server is not supported" % self.model_server) + + def _build_for_model_server(self): # pylint: disable=R0911, R1710 + """Model server overrides""" + if self.model_server not in supported_model_servers: + raise ValueError( + "%s is not supported yet! Supported model servers: %s" + % (self.model_server, supported_model_servers) + ) + + mlflow_path = None + if self.model_metadata: + mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH) + + if not self.model and not mlflow_path: + raise ValueError("Missing required parameter `model` or 'ml_flow' path") if self.model_server == ModelServer.TORCHSERVE: return self._build_for_torchserve() @@ -808,7 +908,17 @@ def build( # pylint: disable=R0911 if self.model_server == ModelServer.TENSORFLOW_SERVING: return self._build_for_tensorflow_serving() - raise ValueError("%s model server is not supported" % self.model_server) + if self.model_server == ModelServer.DJL_SERVING: + return self._build_for_djl() + + if self.model_server == ModelServer.TEI: + return self._build_for_tei() + + if self.model_server == ModelServer.TGI: + return self._build_for_tgi() + + if self.model_server == ModelServer.MMS: + return self._build_for_transformers() def save( self, @@ -861,6 +971,19 @@ def validate(self, model_dir: str) -> Type[bool]: return get_metadata(model_dir) + def set_tracking_arn(self, arn: str): + """Set tracking server ARN""" + # TODO: support native MLflow URIs + if importlib.util.find_spec("sagemaker_mlflow"): + import mlflow + + mlflow.set_tracking_uri(arn) + self.model_metadata[MLFLOW_TRACKING_ARN] = arn + else: + raise ImportError( + "Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed" + ) + def _hf_schema_builder_init(self, model_task: str): """Initialize the schema builder for the given HF_TASK diff --git a/src/sagemaker/serve/builder/tei_builder.py b/src/sagemaker/serve/builder/tei_builder.py index 79b0d276b7..e251eb4f81 100644 --- a/src/sagemaker/serve/builder/tei_builder.py +++ b/src/sagemaker/serve/builder/tei_builder.py @@ -18,7 +18,7 @@ from sagemaker import image_uris from sagemaker.model import Model -from sagemaker.djl_inference.model import _get_model_config_properties_from_hf +from sagemaker.serve.utils.hf_utils import _get_model_config_properties_from_hf from sagemaker.huggingface import HuggingFaceModel from sagemaker.serve.utils.local_hardware import ( @@ -171,6 +171,7 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa # if has not been built for local container we must use cache # that hosting has write access to. self.pysdk_model.env["TRANSFORMERS_CACHE"] = "/tmp" + self.pysdk_model.env["HF_HOME"] = "/tmp" self.pysdk_model.env["HUGGINGFACE_HUB_CACHE"] = "/tmp" if "endpoint_logging" not in kwargs: diff --git a/src/sagemaker/serve/builder/tgi_builder.py b/src/sagemaker/serve/builder/tgi_builder.py index 13755b1a43..e6cbe41c90 100644 --- a/src/sagemaker/serve/builder/tgi_builder.py +++ b/src/sagemaker/serve/builder/tgi_builder.py @@ -31,7 +31,7 @@ _more_performant, _pretty_print_results_tgi, ) -from sagemaker.djl_inference.model import _get_model_config_properties_from_hf +from sagemaker.serve.utils.hf_utils import _get_model_config_properties_from_hf from sagemaker.serve.model_server.djl_serving.utils import ( _get_admissible_tensor_parallel_degrees, _get_default_tensor_parallel_degree, @@ -210,6 +210,7 @@ def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa # if has not been built for local container we must use cache # that hosting has write access to. self.pysdk_model.env["TRANSFORMERS_CACHE"] = "/tmp" + self.pysdk_model.env["HF_HOME"] = "/tmp" self.pysdk_model.env["HUGGINGFACE_HUB_CACHE"] = "/tmp" if "endpoint_logging" not in kwargs: diff --git a/src/sagemaker/serve/builder/transformers_builder.py b/src/sagemaker/serve/builder/transformers_builder.py index 47ea8189b2..dded7bd0bd 100644 --- a/src/sagemaker/serve/builder/transformers_builder.py +++ b/src/sagemaker/serve/builder/transformers_builder.py @@ -22,7 +22,7 @@ from sagemaker.serve.utils.local_hardware import ( _get_nb_instance, ) -from sagemaker.djl_inference.model import _get_model_config_properties_from_hf +from sagemaker.serve.utils.hf_utils import _get_model_config_properties_from_hf from sagemaker.huggingface import HuggingFaceModel from sagemaker.serve.model_server.multi_model_server.prepare import ( _create_dir_structure, diff --git a/src/sagemaker/serve/model_format/mlflow/constants.py b/src/sagemaker/serve/model_format/mlflow/constants.py index 28a3cbdc8d..d7ddcd9ef0 100644 --- a/src/sagemaker/serve/model_format/mlflow/constants.py +++ b/src/sagemaker/serve/model_format/mlflow/constants.py @@ -22,9 +22,10 @@ MODEL_PACKAGE_ARN_REGEX = ( r"^arn:aws:sagemaker:[a-z0-9\-]+:[0-9]{12}:model-package\/(.*?)(?:/(\d+))?$" ) -MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9]+)*$" -MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+(/[0-9]+)*$" +MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9\-_\.]*)+$" +MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+[@/]?[a-zA-Z0-9\-_\.][/a-zA-Z0-9\-_\.]*$" S3_PATH_REGEX = r"^s3:\/\/[a-zA-Z0-9\-_\.]+(?:\/[a-zA-Z0-9\-_\/\.]*)?$" +MLFLOW_TRACKING_ARN = "MLFLOW_TRACKING_ARN" MLFLOW_MODEL_PATH = "MLFLOW_MODEL_PATH" MLFLOW_METADATA_FILE = "MLmodel" MLFLOW_PIP_DEPENDENCY_FILE = "requirements.txt" diff --git a/src/sagemaker/serve/model_format/mlflow/utils.py b/src/sagemaker/serve/model_format/mlflow/utils.py index c92a6a8a27..0d41cf4e33 100644 --- a/src/sagemaker/serve/model_format/mlflow/utils.py +++ b/src/sagemaker/serve/model_format/mlflow/utils.py @@ -227,28 +227,6 @@ def _get_python_version_from_parsed_mlflow_model_file( raise ValueError(f"{MLFLOW_PYFUNC} cannot be found in MLmodel file.") -def _mlflow_input_is_local_path(model_path: str) -> bool: - """Checks if the given model_path is a local filesystem path. - - Args: - - model_path (str): The model path to check. - - Returns: - - bool: True if model_path is a local path, False otherwise. - """ - if model_path.startswith("s3://"): - return False - - if "/runs/" in model_path or model_path.startswith("runs:"): - return False - - # Check if it's not a local file path - if not os.path.exists(model_path): - return False - - return True - - def _download_s3_artifacts(s3_path: str, dst_path: str, session: Session) -> None: """Downloads all artifacts from a specified S3 path to a local destination path. diff --git a/src/sagemaker/serve/model_server/djl_serving/inference.py b/src/sagemaker/serve/model_server/djl_serving/inference.py deleted file mode 100644 index 2dba9eb877..0000000000 --- a/src/sagemaker/serve/model_server/djl_serving/inference.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""DJL Handler Template - -Getting Started DJL Handle provided via ModelBuilder. -Feel free to re-purpose this script for your DJL usecase -and re-deploy via ModelBuilder().deploy(). -""" -from __future__ import absolute_import - -from djl_python.inputs import Input -from djl_python.outputs import Output - - -class HandleTemplate: - """A DJL Handler class template that uses the default DeepSpeed, FasterTransformer, and HuggingFaceAccelerate Handlers - - Reference the default handlers here: - - https://github.com/deepjavalibrary/djl-serving/blob/master/engines/python/setup/djl_python/deepspeed.py - - https://github.com/deepjavalibrary/djl-serving/blob/master/engines/python/setup/djl_python/fastertransformer.py - - https://github.com/deepjavalibrary/djl-serving/blob/master/engines/python/setup/djl_python/huggingface.py - """ - - def __init__(self): - self.initialized = False - self.handle = None - - def initialize(self, inputs: Input): - """Template method to load you model with specified engine.""" - self.initialized = True - - if "DeepSpeed" == inputs.get_property("engine"): - from djl_python.deepspeed import handle - elif "FasterTransformer" == inputs.get_property("engine"): - from djl_python.fastertransformer import handle - else: - from djl_python.huggingface import handle - - self._handle = handle - - def inference(self, inputs: Input): - """Template method used to invoke the model. Please implement this if you'd like to construct your own script""" - - -_handle_template = HandleTemplate() - - -def handle(inputs: Input) -> Output: - """Driver function required by djl-serving""" - if not _handle_template.initialized: - _handle_template.initialize(inputs) - - return _handle_template._handle(inputs) diff --git a/src/sagemaker/serve/model_server/djl_serving/prepare.py b/src/sagemaker/serve/model_server/djl_serving/prepare.py index 810acc8aff..40cb04152c 100644 --- a/src/sagemaker/serve/model_server/djl_serving/prepare.py +++ b/src/sagemaker/serve/model_server/djl_serving/prepare.py @@ -13,7 +13,6 @@ """Prepare DjlModel for Deployment""" from __future__ import absolute_import -import shutil import json import tarfile import logging @@ -22,139 +21,51 @@ from sagemaker.utils import _tmpdir, custom_extractall_tarfile from sagemaker.s3 import S3Downloader -from sagemaker.djl_inference import DJLModel -from sagemaker.djl_inference.model import _read_existing_serving_properties from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage -_SERVING_PROPERTIES_FILE = "serving.properties" -_ENTRY_POINT_SCRIPT = "inference.py" _SETTING_PROPERTY_STMT = "Setting property: %s to %s" logger = logging.getLogger(__name__) -def _has_serving_properties_file(code_dir: Path) -> bool: - """Check for existing serving properties in the directory""" - return code_dir.joinpath(_SERVING_PROPERTIES_FILE).is_file() - - -def _move_to_code_dir(js_model_dir: str, code_dir: Path): - """Move DJL Jumpstart resources from model to code_dir""" - js_model_resources = Path(js_model_dir).joinpath("model") - for resource in js_model_resources.glob("*"): - try: - shutil.move(resource, code_dir) - except shutil.Error as e: - if "already exists" in str(e): - continue - - -def _extract_js_resource(js_model_dir: str, js_id: str): +def _extract_js_resource(js_model_dir: str, code_dir: Path, js_id: str): """Uncompress the jumpstart resource""" tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz") with tarfile.open(str(tmp_sourcedir)) as resources: - custom_extractall_tarfile(resources, js_model_dir) + custom_extractall_tarfile(resources, code_dir) -def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path): +def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> tuple: """Copy the associated JumpStart Resource into the code directory""" logger.info("Downloading JumpStart artifacts from S3...") s3_downloader = S3Downloader() - invalid_model_data_format = False - with _tmpdir(directory=str(code_dir)) as js_model_dir: - if isinstance(model_data, str): - if model_data.endswith(".tar.gz"): - logger.info("Uncompressing JumpStart artifacts for faster loading...") - s3_downloader.download(model_data, js_model_dir) - _extract_js_resource(js_model_dir, js_id) - else: - logger.info("Copying uncompressed JumpStart artifacts...") + if isinstance(model_data, str): + if model_data.endswith(".tar.gz"): + logger.info("Uncompressing JumpStart artifacts for faster loading...") + with _tmpdir(directory=str(code_dir)) as js_model_dir: s3_downloader.download(model_data, js_model_dir) - elif ( - isinstance(model_data, dict) - and model_data.get("S3DataSource") - and model_data.get("S3DataSource").get("S3Uri") - ): - logger.info("Copying uncompressed JumpStart artifacts...") - s3_downloader.download(model_data.get("S3DataSource").get("S3Uri"), js_model_dir) + _extract_js_resource(js_model_dir, code_dir, js_id) else: - invalid_model_data_format = True - if not invalid_model_data_format: - _move_to_code_dir(js_model_dir, code_dir) - - if invalid_model_data_format: + logger.info("Copying uncompressed JumpStart artifacts...") + s3_downloader.download(model_data, code_dir) + elif ( + isinstance(model_data, dict) + and model_data.get("S3DataSource") + and model_data.get("S3DataSource").get("S3Uri") + ): + logger.info("Copying uncompressed JumpStart artifacts...") + s3_downloader.download(model_data.get("S3DataSource").get("S3Uri"), code_dir) + else: raise ValueError("JumpStart model data compression format is unsupported: %s", model_data) - existing_properties = _read_existing_serving_properties(code_dir) config_json_file = code_dir.joinpath("config.json") - hf_model_config = None if config_json_file.is_file(): with open(str(config_json_file)) as config_json: hf_model_config = json.load(config_json) - return (existing_properties, hf_model_config, True) - - -def _generate_properties_file( - model: DJLModel, code_dir: Path, overwrite_props_from_file: bool, manual_set_props: dict -): - """Construct serving properties file taking into account of overrides or manual specs""" - if _has_serving_properties_file(code_dir): - existing_properties = _read_existing_serving_properties(code_dir) - else: - existing_properties = {} - - serving_properties_dict = model.generate_serving_properties() - serving_properties_file = code_dir.joinpath(_SERVING_PROPERTIES_FILE) - - with open(serving_properties_file, mode="w+") as file: - covered_keys = set() - - if manual_set_props: - for key, value in manual_set_props.items(): - logger.info(_SETTING_PROPERTY_STMT, key, value.strip()) - covered_keys.add(key) - file.write(f"{key}={value}") - - for key, value in serving_properties_dict.items(): - if not overwrite_props_from_file: - logger.info(_SETTING_PROPERTY_STMT, key, value) - file.write(f"{key}={value}\n") - else: - existing_property = existing_properties.get(key) - covered_keys.add(key) - if not existing_property: - logger.info(_SETTING_PROPERTY_STMT, key, value) - file.write(f"{key}={value}\n") - else: - logger.info(_SETTING_PROPERTY_STMT, key, existing_property.strip()) - file.write(f"{key}={existing_property}") - - if overwrite_props_from_file: - # for addition provided properties - for key, value in existing_properties.items(): - if key not in covered_keys: - logger.info(_SETTING_PROPERTY_STMT, key, value.strip()) - file.write(f"{key}={value}") - - -def _store_share_libs(model_path: Path, shared_libs): - """Placeholder Docstring""" - shared_libs_dir = model_path.joinpath("shared_libs") - shared_libs_dir.mkdir(exist_ok=True) - for shared_lib in shared_libs: - shutil.copy2(Path(shared_lib), shared_libs_dir) - - -def _copy_inference_script(code_dir): - """Placeholder Docstring""" - if code_dir.joinpath("inference.py").is_file(): - return - - inference_file = Path(__file__).parent.joinpath(_ENTRY_POINT_SCRIPT) - shutil.copy2(inference_file, code_dir) + return (hf_model_config, True) def _create_dir_structure(model_path: str) -> tuple: @@ -174,36 +85,6 @@ def _create_dir_structure(model_path: str) -> tuple: return (model_path, code_dir) -def prepare_for_djl_serving( - model_path: str, - model: DJLModel, - shared_libs: List[str] = None, - dependencies: str = None, - overwrite_props_from_file: bool = True, - manual_set_props: dict = None, -): - """Prepare serving when a HF model id is given - - Args:to - model_path (str) : Argument - model (DJLModel) : Argument - shared_libs (List[]) : Argument - dependencies (str) : Argument - - Returns: - ( str ) : - - """ - model_path, code_dir = _create_dir_structure(model_path) - - if shared_libs: - _store_share_libs(model_path, shared_libs) - - _copy_inference_script(code_dir) - - _generate_properties_file(model, code_dir, overwrite_props_from_file, manual_set_props) - - def prepare_djl_js_resources( model_path: str, js_id: str, diff --git a/src/sagemaker/serve/model_server/djl_serving/server.py b/src/sagemaker/serve/model_server/djl_serving/server.py index 8b152e5b81..80214332b0 100644 --- a/src/sagemaker/serve/model_server/djl_serving/server.py +++ b/src/sagemaker/serve/model_server/djl_serving/server.py @@ -19,6 +19,7 @@ _DEFAULT_ENV_VARS = { "SERVING_OPTS": "-Dai.djl.logging.level=debug", "TRANSFORMERS_CACHE": "/opt/ml/model/", + "HF_HOME": "/opt/ml/model/", "HUGGINGFACE_HUB_CACHE": "/opt/ml/model/", } diff --git a/src/sagemaker/serve/model_server/djl_serving/utils.py b/src/sagemaker/serve/model_server/djl_serving/utils.py index 03719542d2..93d16001df 100644 --- a/src/sagemaker/serve/model_server/djl_serving/utils.py +++ b/src/sagemaker/serve/model_server/djl_serving/utils.py @@ -1,12 +1,8 @@ """DJL ModelBuilder Utils""" from __future__ import absolute_import -from urllib.error import HTTPError import math import logging -from sagemaker.serve.utils.types import _DjlEngine -from sagemaker.djl_inference import defaults -from sagemaker.djl_inference.model import _get_model_config_properties_from_hf from sagemaker.serve.utils.local_hardware import _get_available_gpus from sagemaker.serve.builder.schema_builder import SchemaBuilder @@ -17,50 +13,6 @@ TOKENS_PER_WORD = 0.75 -def _auto_detect_engine(model_id: str, hf_hub_token: str) -> tuple: - """Placeholder docstring""" - try: - hf_model_config = _get_model_config_properties_from_hf(model_id, hf_hub_token) - model_type = hf_model_config.get("model_type") - - if len(model_type) < 1: - logger.warning( - "Unable to detect the model architecture from provided model_id %s.\ - Defaulting to HuggingFaceAccelerate." - % model_id - ) - engine = _DjlEngine.HUGGINGFACE_ACCELERATE - elif model_type in defaults.DEEPSPEED_RECOMMENDED_ARCHITECTURES: - logger.info("Model architecture %s is recommended to be run on DeepSpeed." % model_type) - engine = _DjlEngine.DEEPSPEED - elif model_type in defaults.FASTER_TRANSFORMER_RECOMMENDED_ARCHITECTURES: - logger.info( - "Model architecture %s is recommended to be run on FasterTransformer." % model_type - ) - engine = _DjlEngine.FASTER_TRANSFORMER - else: - logger.info( - "Model architecture %s does not have a recommended engine. Defaulting to HuggingFaceAccelerate." - % model_type - ) - engine = _DjlEngine.HUGGINGFACE_ACCELERATE - except HTTPError as e: - raise ValueError( - "The provided HuggingFace Model ID could not be accessed from HuggingFace Hub. %s", - str(e), - ) - except ValueError as e: - raise e - except Exception as e: - logger.warning( - "Unable to detect the model's architecture: %s. Defaulting to HuggingFaceAccelerate." - % str(e) - ) - engine = _DjlEngine.HUGGINGFACE_ACCELERATE - - return (engine, hf_model_config) - - def _get_default_tensor_parallel_degree(hf_model_config: dict, gpu_count: int = None) -> int: """Placeholder docstring""" available_gpus = _get_available_gpus() @@ -89,7 +41,7 @@ def _get_default_tensor_parallel_degree(hf_model_config: dict, gpu_count: int = def _get_default_data_type() -> tuple: """Placeholder docstring""" - return "fp16" + return "bf16" def _get_default_batch_size() -> int: @@ -144,22 +96,23 @@ def _get_default_max_tokens(sample_input, sample_output) -> tuple: return (max_total_tokens, max_new_tokens) -def _set_serve_properties(hf_model_config: dict, schema_builder: SchemaBuilder) -> tuple: +def _get_default_djl_configurations( + model_id: str, hf_model_config: dict, schema_builder: SchemaBuilder +) -> tuple: """Placeholder docstring""" default_tensor_parallel_degree = _get_default_tensor_parallel_degree(hf_model_config) + if default_tensor_parallel_degree is None: + default_tensor_parallel_degree = "max" default_data_type = _get_default_data_type() - default_batch_size = _get_default_batch_size() default_max_tokens, default_max_new_tokens = _get_default_max_tokens( schema_builder.sample_input, schema_builder.sample_output ) - return ( - default_tensor_parallel_degree, - default_data_type, - default_batch_size, - default_max_tokens, - default_max_new_tokens, - ) + env = { + "TENSOR_PARALLEL_DEGREE": str(default_tensor_parallel_degree), + "OPTION_DTYPE": default_data_type, + } + return (env, default_max_new_tokens) def _get_admissible_tensor_parallel_degrees(hf_model_config: dict) -> int: diff --git a/src/sagemaker/serve/model_server/tei/server.py b/src/sagemaker/serve/model_server/tei/server.py index 67fca0e847..25c27e6dda 100644 --- a/src/sagemaker/serve/model_server/tei/server.py +++ b/src/sagemaker/serve/model_server/tei/server.py @@ -18,6 +18,7 @@ _SHM_SIZE = "2G" _DEFAULT_ENV_VARS = { "TRANSFORMERS_CACHE": "/opt/ml/model/", + "HF_HOME": "/opt/ml/model/", "HUGGINGFACE_HUB_CACHE": "/opt/ml/model/", } diff --git a/src/sagemaker/serve/model_server/tgi/server.py b/src/sagemaker/serve/model_server/tgi/server.py index ef39e890c8..75cf3bd402 100644 --- a/src/sagemaker/serve/model_server/tgi/server.py +++ b/src/sagemaker/serve/model_server/tgi/server.py @@ -17,6 +17,7 @@ _SHM_SIZE = "2G" _DEFAULT_ENV_VARS = { "TRANSFORMERS_CACHE": "/opt/ml/model/", + "HF_HOME": "/opt/ml/model/", "HUGGINGFACE_HUB_CACHE": "/opt/ml/model/", } diff --git a/src/sagemaker/serve/utils/hf_utils.py b/src/sagemaker/serve/utils/hf_utils.py new file mode 100644 index 0000000000..75f46eeeb9 --- /dev/null +++ b/src/sagemaker/serve/utils/hf_utils.py @@ -0,0 +1,53 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Utility functions for fetching model information from HuggingFace Hub""" +from __future__ import absolute_import +import json +import urllib.request +from json import JSONDecodeError +from urllib.error import HTTPError, URLError +import logging + +logger = logging.getLogger(__name__) + + +def _get_model_config_properties_from_hf(model_id: str, hf_hub_token: str = None): + """Placeholder docstring""" + + config_url = f"https://huggingface.co/{model_id}/raw/main/config.json" + model_config = None + try: + if hf_hub_token: + config_url = urllib.request.Request( + config_url, headers={"Authorization": "Bearer " + hf_hub_token} + ) + with urllib.request.urlopen(config_url) as response: + model_config = json.load(response) + except (HTTPError, URLError, TimeoutError, JSONDecodeError) as e: + if "HTTP Error 401: Unauthorized" in str(e): + raise ValueError( + "Trying to access a gated/private HuggingFace model without valid credentials. " + "Please provide a HUGGING_FACE_HUB_TOKEN in env_vars" + ) + logger.warning( + "Exception encountered while trying to read config file %s. " "Details: %s", + config_url, + e, + ) + if not model_config: + raise ValueError( + f"Did not find a config.json or model_index.json file in huggingface hub for " + f"{model_id}. Please make sure a config.json exists (or model_index.json for Stable " + f"Diffusion Models) for this model in the huggingface hub" + ) + return model_config diff --git a/src/sagemaker/serve/utils/lineage_constants.py b/src/sagemaker/serve/utils/lineage_constants.py index 51be20739f..dce4a41139 100644 --- a/src/sagemaker/serve/utils/lineage_constants.py +++ b/src/sagemaker/serve/utils/lineage_constants.py @@ -16,6 +16,8 @@ LINEAGE_POLLER_INTERVAL_SECS = 15 LINEAGE_POLLER_MAX_TIMEOUT_SECS = 120 +TRACKING_SERVER_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):mlflow-tracking-server/(.*?)$" +TRACKING_SERVER_CREATION_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE = "ModelBuilderInputModelData" MLFLOW_S3_PATH = "S3" MLFLOW_MODEL_PACKAGE_PATH = "ModelPackage" diff --git a/src/sagemaker/serve/utils/lineage_utils.py b/src/sagemaker/serve/utils/lineage_utils.py index 3435e138c9..7278dd8a3c 100644 --- a/src/sagemaker/serve/utils/lineage_utils.py +++ b/src/sagemaker/serve/utils/lineage_utils.py @@ -17,7 +17,7 @@ import time import re import logging -from typing import Optional, Union +from typing import List, Optional, Union from botocore.exceptions import ClientError @@ -35,6 +35,8 @@ from sagemaker.serve.utils.lineage_constants import ( LINEAGE_POLLER_MAX_TIMEOUT_SECS, LINEAGE_POLLER_INTERVAL_SECS, + TRACKING_SERVER_ARN_REGEX, + TRACKING_SERVER_CREATION_TIME_FORMAT, MLFLOW_S3_PATH, MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, MLFLOW_LOCAL_PATH, @@ -51,24 +53,41 @@ def _load_artifact_by_source_uri( - source_uri: str, artifact_type: str, sagemaker_session: Session + source_uri: str, + sagemaker_session: Session, + source_types_to_match: Optional[List[str]] = None, + artifact_type: Optional[str] = None, ) -> Optional[ArtifactSummary]: """Load lineage artifact by source uri Arguments: source_uri (str): The s3 uri used for uploading transfomred model artifacts. - artifact_type (str): The type of the lineage artifact. sagemaker_session (Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the function creates one using the default AWS configuration chain. + source_types_to_match (Optional[List[str]]): A list of source type values to match against + the artifact's source types. If provided, the artifact's source types must match this + list. + artifact_type (Optional[str]): The type of the lineage artifact. Returns: ArtifactSummary: The Artifact Summary for the provided S3 URI. """ artifacts = Artifact.list(source_uri=source_uri, sagemaker_session=sagemaker_session) for artifact_summary in artifacts: - if artifact_summary.artifact_type == artifact_type: - return artifact_summary + if artifact_type is None or artifact_summary.artifact_type == artifact_type: + if source_types_to_match: + if artifact_summary.source.source_types is not None: + artifact_source_types = [ + source_type["Value"] for source_type in artifact_summary.source.source_types + ] + if set(artifact_source_types) == set(source_types_to_match): + return artifact_summary + else: + return None + else: + return artifact_summary + return None @@ -90,7 +109,9 @@ def _poll_lineage_artifact( logger.info("Polling lineage artifact for model data in %s", s3_uri) start_time = time.time() while time.time() - start_time < LINEAGE_POLLER_MAX_TIMEOUT_SECS: - result = _load_artifact_by_source_uri(s3_uri, artifact_type, sagemaker_session) + result = _load_artifact_by_source_uri( + s3_uri, sagemaker_session, artifact_type=artifact_type + ) if result is not None: return result time.sleep(LINEAGE_POLLER_INTERVAL_SECS) @@ -105,12 +126,12 @@ def _get_mlflow_model_path_type(mlflow_model_path: str) -> str: Returns: str: Description of what the input string is identified as. """ - mlflow_rub_id_pattern = MLFLOW_RUN_ID_REGEX + mlflow_run_id_pattern = MLFLOW_RUN_ID_REGEX mlflow_registry_id_pattern = MLFLOW_REGISTRY_PATH_REGEX sagemaker_arn_pattern = MODEL_PACKAGE_ARN_REGEX s3_pattern = S3_PATH_REGEX - if re.match(mlflow_rub_id_pattern, mlflow_model_path): + if re.match(mlflow_run_id_pattern, mlflow_model_path): return MLFLOW_RUN_ID if re.match(mlflow_registry_id_pattern, mlflow_model_path): return MLFLOW_REGISTRY_PATH @@ -127,12 +148,14 @@ def _get_mlflow_model_path_type(mlflow_model_path: str) -> str: def _create_mlflow_model_path_lineage_artifact( mlflow_model_path: str, sagemaker_session: Session, + source_types_to_match: Optional[List[str]] = None, ) -> Optional[Artifact]: """Creates a lineage artifact for the given MLflow model path. Args: mlflow_model_path (str): The path to the MLflow model. sagemaker_session (Session): The SageMaker session object. + source_types_to_match (Optional[List[str]]): Artifact source types. Returns: Optional[Artifact]: The created lineage artifact, or None if an error occurred. @@ -142,8 +165,17 @@ def _create_mlflow_model_path_lineage_artifact( model_builder_input_model_data_type=_artifact_name, ) try: + source_types = [dict(SourceIdType="Custom", Value="ModelBuilderInputModelData")] + if source_types_to_match: + source_types += [ + dict(SourceIdType="Custom", Value=source_type) + for source_type in source_types_to_match + if source_type != "ModelBuilderInputModelData" + ] + return Artifact.create( source_uri=mlflow_model_path, + source_types=source_types, artifact_type=MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, artifact_name=_artifact_name, properties=properties, @@ -160,6 +192,7 @@ def _create_mlflow_model_path_lineage_artifact( def _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact( mlflow_model_path: str, sagemaker_session: Session, + tracking_server_arn: Optional[str] = None, ) -> Optional[Union[Artifact, ArtifactSummary]]: """Retrieves an existing artifact for the given MLflow model path or @@ -170,20 +203,35 @@ def _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact( sagemaker_session (Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the function creates one using the default AWS configuration chain. - + tracking_server_arn (Optional[str]): The MLflow tracking server ARN. Returns: Optional[Union[Artifact, ArtifactSummary]]: The existing or newly created artifact, or None if an error occurred. """ + source_types_to_match = ["ModelBuilderInputModelData"] + input_type = _get_mlflow_model_path_type(mlflow_model_path) + if tracking_server_arn and input_type in [MLFLOW_RUN_ID, MLFLOW_REGISTRY_PATH]: + match = re.match(TRACKING_SERVER_ARN_REGEX, tracking_server_arn) + mlflow_tracking_server_name = match.group(4) + describe_result = sagemaker_session.sagemaker_client.describe_mlflow_tracking_server( + TrackingServerName=mlflow_tracking_server_name + ) + tracking_server_creation_time = describe_result["CreationTime"].strftime( + TRACKING_SERVER_CREATION_TIME_FORMAT + ) + source_types_to_match += [tracking_server_arn, tracking_server_creation_time] _loaded_artifact = _load_artifact_by_source_uri( - mlflow_model_path, MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, sagemaker_session + mlflow_model_path, + sagemaker_session, + source_types_to_match, ) if _loaded_artifact is not None: return _loaded_artifact return _create_mlflow_model_path_lineage_artifact( mlflow_model_path, sagemaker_session, + source_types_to_match, ) @@ -229,6 +277,7 @@ def _maintain_lineage_tracking_for_mlflow_model( mlflow_model_path: str, s3_upload_path: str, sagemaker_session: Session, + tracking_server_arn: Optional[str] = None, ) -> None: """Maintains lineage tracking for an MLflow model by creating or retrieving artifacts. @@ -238,6 +287,7 @@ def _maintain_lineage_tracking_for_mlflow_model( sagemaker_session (Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the function creates one using the default AWS configuration chain. + tracking_server_arn (Optional[str]): The MLflow tracking server ARN. """ artifact_for_transformed_model_data = _poll_lineage_artifact( s3_uri=s3_upload_path, @@ -249,6 +299,7 @@ def _maintain_lineage_tracking_for_mlflow_model( _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact( mlflow_model_path=mlflow_model_path, sagemaker_session=sagemaker_session, + tracking_server_arn=tracking_server_arn, ) ) if mlflow_model_artifact: diff --git a/src/sagemaker/serve/utils/tuning.py b/src/sagemaker/serve/utils/tuning.py index 22f3c06d47..b93c01b522 100644 --- a/src/sagemaker/serve/utils/tuning.py +++ b/src/sagemaker/serve/utils/tuning.py @@ -33,8 +33,8 @@ def _pretty_print_results(results: dict): for key, value in ordered.items(): avg_latencies.append(key) - tensor_parallel_degrees.append(value[0]["option.tensor_parallel_degree"]) - dtypes.append(value[0]["option.dtype"]) + tensor_parallel_degrees.append(value[0]["TENSOR_PARALLEL_DEGREE"]) + dtypes.append(value[0]["OPTION_DTYPE"]) p90s.append(value[1]) avg_tokens_per_seconds.append(value[2]) throughput_per_seconds.append(value[3]) diff --git a/src/sagemaker/serve/utils/types.py b/src/sagemaker/serve/utils/types.py index 2e5e4f40d7..e50be62440 100644 --- a/src/sagemaker/serve/utils/types.py +++ b/src/sagemaker/serve/utils/types.py @@ -21,18 +21,6 @@ def __str__(self): TEI = 7 -class _DjlEngine(Enum): - """An enum for Djl Engines""" - - def __str__(self): - """Placeholder docstring""" - return str(self.name) - - DEEPSPEED = 1 - FASTER_TRANSFORMER = 2 - HUGGINGFACE_ACCELERATE = 3 - - class HardwareType(Enum): """An enum for hardware type""" diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 6593751b58..dfc8fc8266 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -631,43 +631,68 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): bucket = s3.Bucket(name=bucket_name) if bucket.creation_date is None: - try: - # trying head bucket call - s3.meta.client.head_bucket(Bucket=bucket.name) - except ClientError as e: - # bucket does not exist or forbidden to access - error_code = e.response["Error"]["Code"] - message = e.response["Error"]["Message"] + self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True) + + elif self._default_bucket_set_by_sdk: + self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False) + + expected_bucket_owner_id = self.account_id() + self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id) + + def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id): + """Checks if the bucket belongs to a particular owner and throws a Client Error if it is not + + Args: + bucket_name (str): Name of the S3 bucket + s3 (str): S3 object from boto session + expected_bucket_owner_id (str): Owner ID string + + """ + try: + s3.meta.client.head_bucket( + Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id + ) + except ClientError as e: + error_code = e.response["Error"]["Code"] + message = e.response["Error"]["Message"] + if error_code == "403" and message == "Forbidden": + LOGGER.error( + "Since default_bucket param was not set, SageMaker Python SDK tried to use " + "%s bucket. " + "This bucket cannot be configured to use as it is not owned by Account %s. " + "To unblock it's recommended to use custom default_bucket " + "parameter in sagemaker.Session", + bucket_name, + expected_bucket_owner_id, + ) + raise + + def general_bucket_check_if_user_has_permission( + self, bucket_name, s3, bucket, region, bucket_creation_date_none + ): + """Checks if the person running has the permissions to the bucket + + If there is any other error that comes up with calling head bucket, it is raised up here + If there is no bucket , it will create one + + Args: + bucket_name (str): Name of the S3 bucket + s3 (str): S3 object from boto session + region (str): The region in which to create the bucket. + bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not + """ + try: + s3.meta.client.head_bucket(Bucket=bucket_name) + except ClientError as e: + error_code = e.response["Error"]["Code"] + message = e.response["Error"]["Message"] + # bucket does not exist or forbidden to access + if bucket_creation_date_none: if error_code == "404" and message == "Not Found": - # bucket does not exist, create one - try: - if region == "us-east-1": - # 'us-east-1' cannot be specified because it is the default region: - # https://github.com/boto/boto3/issues/125 - s3.create_bucket(Bucket=bucket_name) - else: - s3.create_bucket( - Bucket=bucket_name, - CreateBucketConfiguration={"LocationConstraint": region}, - ) - - logger.info("Created S3 bucket: %s", bucket_name) - except ClientError as e: - error_code = e.response["Error"]["Code"] - message = e.response["Error"]["Message"] - - if ( - error_code == "OperationAborted" - and "conflicting conditional operation" in message - ): - # If this bucket is already being concurrently created, - # we don't need to create it again. - pass - else: - raise + self.create_bucket_for_not_exist_error(bucket_name, region, s3) elif error_code == "403" and message == "Forbidden": - logger.error( + LOGGER.error( "Bucket %s exists, but access is forbidden. Please try again after " "adding appropriate access.", bucket.name, @@ -676,27 +701,37 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): else: raise - if self._default_bucket_set_by_sdk: - # make sure the s3 bucket is configured in users account. - expected_bucket_owner_id = self.account_id() - try: - s3.meta.client.head_bucket( - Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id + def create_bucket_for_not_exist_error(self, bucket_name, region, s3): + """Creates the S3 bucket in the given region + + Args: + bucket_name (str): Name of the S3 bucket + s3 (str): S3 object from boto session + region (str): The region in which to create the bucket. + """ + # bucket does not exist, create one + try: + if region == "us-east-1": + # 'us-east-1' cannot be specified because it is the default region: + # https://github.com/boto/boto3/issues/125 + s3.create_bucket(Bucket=bucket_name) + else: + s3.create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={"LocationConstraint": region}, ) - except ClientError as e: - error_code = e.response["Error"]["Code"] - message = e.response["Error"]["Message"] - if error_code == "403" and message == "Forbidden": - LOGGER.error( - "Since default_bucket param was not set, SageMaker Python SDK tried to use " - "%s bucket. " - "This bucket cannot be configured to use as it is not owned by Account %s. " - "To unblock it's recommended to use custom default_bucket " - "parameter in sagemaker.Session", - bucket_name, - expected_bucket_owner_id, - ) - raise + + logger.info("Created S3 bucket: %s", bucket_name) + except ClientError as e: + error_code = e.response["Error"]["Code"] + message = e.response["Error"]["Message"] + + if error_code == "OperationAborted" and "conflicting conditional operation" in message: + # If this bucket is already being concurrently created, + # we don't need to create it again. + pass + else: + raise def _append_sagemaker_config_tags(self, tags: List[TagsDict], config_path_to_tags: str): """Appends tags specified in the sagemaker_config to the given list of tags. @@ -4073,6 +4108,7 @@ def create_model_package_from_containers( task=None, skip_model_validation="None", source_uri=None, + model_card=None, ): """Get request dictionary for CreateModelPackage API. @@ -4110,6 +4146,8 @@ def create_model_package_from_containers( skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). source_uri (str): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). """ if containers: # Containers are provided. Now we can merge missing entries from config. @@ -4167,6 +4205,7 @@ def create_model_package_from_containers( task=task, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) def submit(request): @@ -4715,7 +4754,7 @@ def create_inference_component( tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, INFERENCE_COMPONENT, TAGS) ) - if len(tags) != 0: + if tags and len(tags) != 0: request["Tags"] = tags self.sagemaker_client.create_inference_component(**request) @@ -6723,6 +6762,323 @@ def wait_for_inference_recommendations_job( _check_job_status(job_name, desc, "Status") return desc + def create_presigned_mlflow_tracking_server_url( + self, + tracking_server_name: str, + expires_in_seconds: int = None, + session_expiration_duration_in_seconds: int = None, + ) -> Dict[str, Any]: + """Creates a Presigned Url to acess the Mlflow UI. + + Args: + tracking_server_name (str): Name of the Mlflow Tracking Server. + expires_in_seconds (int): Expiration duration of the URL. + session_expiration_duration_in_seconds (int): Session duration of the URL. + Returns: + (dict): Return value from the ``CreatePresignedMlflowTrackingServerUrl`` API. + + """ + + create_presigned_url_args = {"TrackingServerName": tracking_server_name} + if expires_in_seconds is not None: + create_presigned_url_args["ExpiresInSeconds"] = expires_in_seconds + + if session_expiration_duration_in_seconds is not None: + create_presigned_url_args["SessionExpirationDurationInSeconds"] = ( + session_expiration_duration_in_seconds + ) + + return self.sagemaker_client.create_presigned_mlflow_tracking_server_url( + **create_presigned_url_args + ) + + def create_hub( + self, + hub_name: str, + hub_description: str, + hub_display_name: str = None, + hub_search_keywords: List[str] = None, + s3_storage_config: Dict[str, Any] = None, + tags: List[Dict[str, Any]] = None, + ) -> Dict[str, str]: + """Creates a SageMaker Hub + + Args: + hub_name (str): The name of the Hub to create. + hub_description (str): A description of the Hub. + hub_display_name (str): The display name of the Hub. + hub_search_keywords (list): The searchable keywords for the Hub. + s3_storage_config (S3StorageConfig): The Amazon S3 storage configuration for the Hub. + tags (list): Any tags to associate with the Hub. + + Returns: + (dict): Return value from the ``CreateHub`` API. + """ + request = {"HubName": hub_name, "HubDescription": hub_description} + + if hub_display_name: + request["HubDisplayName"] = hub_display_name + else: + request["HubDisplayName"] = hub_name + + if hub_search_keywords: + request["HubSearchKeywords"] = hub_search_keywords + if s3_storage_config: + request["S3StorageConfig"] = s3_storage_config + if tags: + request["Tags"] = tags + + return self.sagemaker_client.create_hub(**request) + + def describe_hub(self, hub_name: str) -> Dict[str, Any]: + """Describes a SageMaker Hub + + Args: + hub_name (str): The name of the hub to describe. + + Returns: + (dict): Return value for ``DescribeHub`` API + """ + request = {"HubName": hub_name} + + return self.sagemaker_client.describe_hub(**request) + + def list_hubs( + self, + creation_time_after: str = None, + creation_time_before: str = None, + max_results: int = None, + max_schema_version: str = None, + name_contains: str = None, + next_token: str = None, + sort_by: str = None, + sort_order: str = None, + ) -> Dict[str, Any]: + """Lists all existing SageMaker Hubs + + Args: + creation_time_after (str): Only list HubContent that was created after + the time specified. + creation_time_before (str): Only list HubContent that was created + before the time specified. + max_results (int): The maximum amount of HubContent to list. + max_schema_version (str): The upper bound of the HubContentSchemaVersion. + name_contains (str): Only list HubContent if the name contains the specified string. + next_token (str): If the response to a previous ``ListHubContents`` request was + truncated, the response includes a ``NextToken``. To retrieve the next set of + hub content, use the token in the next request. + sort_by (str): Sort HubContent versions by either name or creation time. + sort_order (str): Sort Hubs by ascending or descending order. + Returns: + (dict): Return value for ``ListHubs`` API + """ + request = {} + if creation_time_after: + request["CreationTimeAfter"] = creation_time_after + if creation_time_before: + request["CreationTimeBefore"] = creation_time_before + if max_results: + request["MaxResults"] = max_results + if max_schema_version: + request["MaxSchemaVersion"] = max_schema_version + if name_contains: + request["NameContains"] = name_contains + if next_token: + request["NextToken"] = next_token + if sort_by: + request["SortBy"] = sort_by + if sort_order: + request["SortOrder"] = sort_order + + return self.sagemaker_client.list_hubs(**request) + + def list_hub_contents( + self, + hub_name: str, + hub_content_type: str, + creation_time_after: str = None, + creation_time_before: str = None, + max_results: int = None, + max_schema_version: str = None, + name_contains: str = None, + next_token: str = None, + sort_by: str = None, + sort_order: str = None, + ) -> Dict[str, Any]: + """Lists the HubContents in a SageMaker Hub + + Args: + hub_name (str): The name of the Hub to list the contents of. + hub_content_type (str): The type of the HubContent to list. + creation_time_after (str): Only list HubContent that was created after the + time specified. + creation_time_before (str): Only list HubContent that was created before the + time specified. + max_results (int): The maximum amount of HubContent to list. + max_schema_version (str): The upper bound of the HubContentSchemaVersion. + name_contains (str): Only list HubContent if the name contains the specified string. + next_token (str): If the response to a previous ``ListHubContents`` request was + truncated, the response includes a ``NextToken``. To retrieve the next set of + hub content, use the token in the next request. + sort_by (str): Sort HubContent versions by either name or creation time. + sort_order (str): Sort Hubs by ascending or descending order. + Returns: + (dict): Return value for ``ListHubContents`` API + """ + request = {"HubName": hub_name, "HubContentType": hub_content_type} + if creation_time_after: + request["CreationTimeAfter"] = creation_time_after + if creation_time_before: + request["CreationTimeBefore"] = creation_time_before + if max_results: + request["MaxResults"] = max_results + if max_schema_version: + request["MaxSchemaVersion"] = max_schema_version + if name_contains: + request["NameContains"] = name_contains + if next_token: + request["NextToken"] = next_token + if sort_by: + request["SortBy"] = sort_by + if sort_order: + request["SortOrder"] = sort_order + + return self.sagemaker_client.list_hub_contents(**request) + + def delete_hub(self, hub_name: str) -> None: + """Deletes a SageMaker Hub + + Args: + hub_name (str): The name of the hub to delete. + """ + request = {"HubName": hub_name} + + return self.sagemaker_client.delete_hub(**request) + + def create_hub_content_reference( + self, + hub_name: str, + source_hub_content_arn: str, + hub_content_name: str = None, + min_version: str = None, + ) -> Dict[str, str]: + """Creates a given HubContent reference in a SageMaker Hub + + Args: + hub_name (str): The name of the Hub that you want to delete content in. + source_hub_content_arn (str): Hub content arn in the public/source Hub. + hub_content_name (str): The name of the reference that you want to add to the Hub. + min_version (str): A minimum version of the hub content to add to the Hub. + + Returns: + (dict): Return value for ``CreateHubContentReference`` API + """ + + request = {"HubName": hub_name, "SageMakerPublicHubContentArn": source_hub_content_arn} + + if hub_content_name: + request["HubContentName"] = hub_content_name + if min_version: + request["MinVersion"] = min_version + + return self.sagemaker_client.create_hub_content_reference(**request) + + def delete_hub_content_reference( + self, hub_name: str, hub_content_type: str, hub_content_name: str + ) -> None: + """Deletes a given HubContent reference in a SageMaker Hub + + Args: + hub_name (str): The name of the Hub that you want to delete content in. + hub_content_type (str): The type of the content that you want to delete from a Hub. + hub_content_name (str): The name of the content that you want to delete from a Hub. + """ + request = { + "HubName": hub_name, + "HubContentType": hub_content_type, + "HubContentName": hub_content_name, + } + + return self.sagemaker_client.delete_hub_content_reference(**request) + + def describe_hub_content( + self, + hub_content_name: str, + hub_content_type: str, + hub_name: str, + hub_content_version: str = None, + ) -> Dict[str, Any]: + """Describes a HubContent in a SageMaker Hub + + Args: + hub_content_name (str): The name of the HubContent to describe. + hub_content_type (str): The type of HubContent in the Hub. + hub_name (str): The name of the Hub that contains the HubContent to describe. + hub_content_version (str): The version of the HubContent to describe + + Returns: + (dict): Return value for ``DescribeHubContent`` API + """ + request = { + "HubContentName": hub_content_name, + "HubContentType": hub_content_type, + "HubName": hub_name, + } + if hub_content_version: + request["HubContentVersion"] = hub_content_version + + return self.sagemaker_client.describe_hub_content(**request) + + def list_hub_content_versions( + self, + hub_name, + hub_content_type: str, + hub_content_name: str, + min_version: str = None, + max_schema_version: str = None, + creation_time_after: str = None, + creation_time_before: str = None, + max_results: int = None, + next_token: str = None, + sort_by: str = None, + sort_order: str = None, + ) -> Dict[str, Any]: + """List all versions of a HubContent in a SageMaker Hub + + Args: + hub_content_name (str): The name of the HubContent to describe. + hub_content_type (str): The type of HubContent in the Hub. + hub_name (str): The name of the Hub that contains the HubContent to describe. + + Returns: + (dict): Return value for ``DescribeHubContent`` API + """ + + request = { + "HubName": hub_name, + "HubContentName": hub_content_name, + "HubContentType": hub_content_type, + } + + if min_version: + request["MinVersion"] = min_version + if creation_time_after: + request["CreationTimeAfter"] = creation_time_after + if creation_time_before: + request["CreationTimeBefore"] = creation_time_before + if max_results: + request["MaxResults"] = max_results + if max_schema_version: + request["MaxSchemaVersion"] = max_schema_version + if next_token: + request["NextToken"] = next_token + if sort_by: + request["SortBy"] = sort_by + if sort_order: + request["SortOrder"] = sort_order + + return self.sagemaker_client.list_hub_content_versions(**request) + def get_model_package_args( content_types=None, @@ -6748,6 +7104,7 @@ def get_model_package_args( task=None, skip_model_validation=None, source_uri=None, + model_card=None, ): """Get arguments for create_model_package method. @@ -6787,6 +7144,8 @@ def get_model_package_args( skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). source_uri (str): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: dict: A dictionary of method argument names and values. @@ -6843,6 +7202,14 @@ def get_model_package_args( model_package_args["skip_model_validation"] = skip_model_validation if source_uri is not None: model_package_args["source_uri"] = source_uri + if model_card is not None: + original_req = model_card._create_request_args() + if original_req.get("ModelCardName") is not None: + del original_req["ModelCardName"] + if original_req.get("Content") is not None: + original_req["ModelCardContent"] = original_req["Content"] + del original_req["Content"] + model_package_args["model_card"] = original_req return model_package_args @@ -6868,6 +7235,7 @@ def get_create_model_package_request( task=None, skip_model_validation="None", source_uri=None, + model_card=None, ): """Get request dictionary for CreateModelPackage API. @@ -6905,6 +7273,8 @@ def get_create_model_package_request( skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). source_uri (str): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). """ if all([model_package_name, model_package_group_name]): @@ -7002,6 +7372,9 @@ def get_create_model_package_request( request_dict["CertifyForMarketplace"] = marketplace_cert request_dict["ModelApprovalStatus"] = approval_status request_dict["SkipModelValidation"] = skip_model_validation + if model_card is not None: + request_dict["ModelCard"] = model_card + return request_dict @@ -7156,6 +7529,7 @@ def container_def( image_config=None, accept_eula=None, additional_model_data_sources=None, + model_reference_arn=None, ): """Create a definition for executing a container as part of a SageMaker model. @@ -7213,6 +7587,11 @@ def container_def( c_def["ModelDataSource"]["S3DataSource"]["ModelAccessConfig"] = { "AcceptEula": accept_eula } + if model_reference_arn: + c_def["ModelDataSource"]["S3DataSource"]["HubAccessConfig"] = { + "HubContentArn": model_reference_arn + } + elif model_data_url is not None: c_def["ModelDataUrl"] = model_data_url diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 27833c1d9c..1ab28eac37 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -23,6 +23,10 @@ from sagemaker.fw_utils import model_code_key_prefix, validate_version_or_image_args from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.predictor import Predictor from sagemaker.serializers import NumpySerializer from sagemaker.sklearn import defaults @@ -172,6 +176,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -223,6 +228,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -263,6 +270,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) def prepare_container_def( @@ -271,6 +279,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """Container definition with framework configuration set in model environment variables. @@ -320,6 +329,7 @@ def prepare_container_def( model_data_uri, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def serving_image_uri(self, region_name, instance_type, serverless_inference_config=None): diff --git a/src/sagemaker/telemetry/__init__.py b/src/sagemaker/telemetry/__init__.py new file mode 100644 index 0000000000..ada3f1f09f --- /dev/null +++ b/src/sagemaker/telemetry/__init__.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Placeholder docstring""" +from __future__ import absolute_import + +from .telemetry_logging import _telemetry_emitter # noqa: F401 diff --git a/src/sagemaker/telemetry/constants.py b/src/sagemaker/telemetry/constants.py new file mode 100644 index 0000000000..332d706351 --- /dev/null +++ b/src/sagemaker/telemetry/constants.py @@ -0,0 +1,42 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Constants used in SageMaker Python SDK telemetry.""" + +from __future__ import absolute_import +from enum import Enum + +# Default AWS region used by SageMaker +DEFAULT_AWS_REGION = "us-west-2" + + +class Feature(Enum): + """Enumeration of feature names used in telemetry.""" + + SDK_DEFAULTS = 1 + LOCAL_MODE = 2 + REMOTE_FUNCTION = 3 + + def __str__(self): # pylint: disable=E0307 + """Return the feature name.""" + return self.name + + +class Status(Enum): + """Enumeration of status values used in telemetry.""" + + SUCCESS = 1 + FAILURE = 0 + + def __str__(self): # pylint: disable=E0307 + """Return the status name.""" + return self.name diff --git a/src/sagemaker/telemetry/telemetry_logging.py b/src/sagemaker/telemetry/telemetry_logging.py new file mode 100644 index 0000000000..d2b91a321c --- /dev/null +++ b/src/sagemaker/telemetry/telemetry_logging.py @@ -0,0 +1,256 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Telemetry module for SageMaker Python SDK to collect usage data and metrics.""" +from __future__ import absolute_import +import logging +import platform +import sys +from time import perf_counter +from typing import List +import functools +import requests + +import boto3 +from sagemaker.session import Session +from sagemaker.utils import resolve_value_from_config +from sagemaker.config.config_schema import TELEMETRY_OPT_OUT_PATH +from sagemaker.telemetry.constants import ( + Feature, + Status, + DEFAULT_AWS_REGION, +) +from sagemaker.user_agent import SDK_VERSION, process_studio_metadata_file + +logger = logging.getLogger(__name__) + +OS_NAME = platform.system() or "UnresolvedOS" +OS_VERSION = platform.release() or "UnresolvedOSVersion" +OS_NAME_VERSION = "{}/{}".format(OS_NAME, OS_VERSION) +PYTHON_VERSION = "{}.{}.{}".format( + sys.version_info.major, sys.version_info.minor, sys.version_info.micro +) + +TELEMETRY_OPT_OUT_MESSAGING = ( + "SageMaker Python SDK will collect telemetry to help us better understand our user's needs, " + "diagnose issues, and deliver additional features.\n" + "To opt out of telemetry, please disable via TelemetryOptOut parameter in SDK defaults config. " + "For more information, refer to https://sagemaker.readthedocs.io/en/stable/overview.html" + "#configuring-and-using-defaults-with-the-sagemaker-python-sdk." +) + +FEATURE_TO_CODE = { + str(Feature.SDK_DEFAULTS): 1, + str(Feature.LOCAL_MODE): 2, + str(Feature.REMOTE_FUNCTION): 3, +} + +STATUS_TO_CODE = { + str(Status.SUCCESS): 1, + str(Status.FAILURE): 0, +} + + +def _telemetry_emitter(feature: str, func_name: str): + """Decorator to emit telemetry logs for SageMaker Python SDK functions""" + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + sagemaker_session = None + if len(args) > 0 and hasattr(args[0], "sagemaker_session"): + # Get the sagemaker_session from the instance method args + sagemaker_session = args[0].sagemaker_session + elif feature == Feature.REMOTE_FUNCTION: + # Get the sagemaker_session from the function keyword arguments for remote function + sagemaker_session = kwargs.get( + "sagemaker_session", _get_default_sagemaker_session() + ) + + if sagemaker_session: + logger.debug("sagemaker_session found, preparing to emit telemetry...") + logger.info(TELEMETRY_OPT_OUT_MESSAGING) + response = None + caught_ex = None + studio_app_type = process_studio_metadata_file() + + # Check if telemetry is opted out + telemetry_opt_out_flag = resolve_value_from_config( + direct_input=None, + config_path=TELEMETRY_OPT_OUT_PATH, + default_value=False, + sagemaker_session=sagemaker_session, + ) + logger.debug("TelemetryOptOut flag is set to: %s", telemetry_opt_out_flag) + + # Construct the feature list to track feature combinations + feature_list: List[int] = [FEATURE_TO_CODE[str(feature)]] + + if sagemaker_session.sagemaker_config and feature != Feature.SDK_DEFAULTS: + feature_list.append(FEATURE_TO_CODE[str(Feature.SDK_DEFAULTS)]) + + if sagemaker_session.local_mode and feature != Feature.LOCAL_MODE: + feature_list.append(FEATURE_TO_CODE[str(Feature.LOCAL_MODE)]) + + # Construct the extra info to track platform and environment usage metadata + extra = ( + f"{func_name}" + f"&x-sdkVersion={SDK_VERSION}" + f"&x-env={PYTHON_VERSION}" + f"&x-sys={OS_NAME_VERSION}" + f"&x-platform={studio_app_type}" + ) + + # Add endpoint ARN to the extra info if available + if sagemaker_session.endpoint_arn: + extra += f"&x-endpointArn={sagemaker_session.endpoint_arn}" + + start_timer = perf_counter() + try: + # Call the original function + response = func(*args, **kwargs) + stop_timer = perf_counter() + elapsed = stop_timer - start_timer + extra += f"&x-latency={round(elapsed, 2)}" + if not telemetry_opt_out_flag: + _send_telemetry_request( + STATUS_TO_CODE[str(Status.SUCCESS)], + feature_list, + sagemaker_session, + None, + None, + extra, + ) + except Exception as e: # pylint: disable=W0703 + stop_timer = perf_counter() + elapsed = stop_timer - start_timer + extra += f"&x-latency={round(elapsed, 2)}" + if not telemetry_opt_out_flag: + _send_telemetry_request( + STATUS_TO_CODE[str(Status.FAILURE)], + feature_list, + sagemaker_session, + str(e), + e.__class__.__name__, + extra, + ) + caught_ex = e + finally: + if caught_ex: + raise caught_ex + return response # pylint: disable=W0150 + else: + logger.debug( + "Unable to send telemetry for function %s. " + "sagemaker_session is not provided or not valid.", + func_name, + ) + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def _send_telemetry_request( + status: int, + feature_list: List[int], + session: Session, + failure_reason: str = None, + failure_type: str = None, + extra_info: str = None, +) -> None: + """Make GET request to an empty object in S3 bucket""" + try: + accountId = _get_accountId(session) + region = _get_region_or_default(session) + url = _construct_url( + accountId, + region, + str(status), + str( + ",".join(map(str, feature_list)) + ), # Remove brackets and quotes to cut down on length + failure_reason, + failure_type, + extra_info, + ) + # Send the telemetry request + logger.debug("Sending telemetry request to [%s]", url) + _requests_helper(url, 2) + logger.debug("SageMaker Python SDK telemetry successfully emitted.") + except Exception: # pylint: disable=W0703 + logger.debug("SageMaker Python SDK telemetry not emitted!") + + +def _construct_url( + accountId: str, + region: str, + status: str, + feature: str, + failure_reason: str, + failure_type: str, + extra_info: str, +) -> str: + """Construct the URL for the telemetry request""" + + base_url = ( + f"https://sm-pysdk-t-{region}.s3.{region}.amazonaws.com/telemetry?" + f"x-accountId={accountId}" + f"&x-status={status}" + f"&x-feature={feature}" + ) + logger.debug("Failure reason: %s", failure_reason) + if failure_reason: + base_url += f"&x-failureReason={failure_reason}" + base_url += f"&x-failureType={failure_type}" + if extra_info: + base_url += f"&x-extra={extra_info}" + return base_url + + +def _requests_helper(url, timeout): + """Make a GET request to the given URL""" + + response = None + try: + response = requests.get(url, timeout) + except requests.exceptions.RequestException as e: + logger.exception("Request exception: %s", str(e)) + return response + + +def _get_accountId(session): + """Return the account ID from the boto session""" + + try: + sts = session.boto_session.client("sts") + return sts.get_caller_identity()["Account"] + except Exception: # pylint: disable=W0703 + return None + + +def _get_region_or_default(session): + """Return the region name from the boto session or default to us-west-2""" + + try: + return session.boto_session.region_name + except Exception: # pylint: disable=W0703 + return DEFAULT_AWS_REGION + + +def _get_default_sagemaker_session(): + """Return the default sagemaker session""" + boto_session = boto3.Session(region_name=DEFAULT_AWS_REGION) + sagemaker_session = Session(boto_session=boto_session) + + return sagemaker_session diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 77f162207c..c06fe74887 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -22,6 +22,10 @@ from sagemaker.deprecations import removed_kwargs from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.metadata_properties import MetadataProperties +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer from sagemaker.workflow import is_pipeline_variable @@ -234,6 +238,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -285,6 +290,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -325,6 +332,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) def deploy( @@ -389,6 +397,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """Prepare the container definition. @@ -465,6 +474,7 @@ def prepare_container_def( model_data, env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def _get_container_env(self): diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 841cd68083..e405d1034a 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -329,6 +329,7 @@ def __init__( task=None, skip_model_validation=None, source_uri=None, + model_card=None, **kwargs, ): """Constructor of a register model step. @@ -381,6 +382,8 @@ def __init__( skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). source_uri (str): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). **kwargs: additional arguments to `create_model`. """ super(_RegisterModelStep, self).__init__( @@ -418,6 +421,7 @@ def __init__( self.container_def_list = container_def_list self.skip_model_validation = skip_model_validation self.source_uri = source_uri + self.model_card = model_card self._properties = Properties( step_name=name, step=self, shape_name="DescribeModelPackageOutput" @@ -493,6 +497,7 @@ def arguments(self) -> RequestType: task=self.task, skip_model_validation=self.skip_model_validation, source_uri=self.source_uri, + model_card=self.model_card, ) request_dict = get_create_model_package_request(**model_package_args) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index 0eedf4aa96..c88c82efa9 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -97,6 +97,7 @@ def __init__( data_input_configuration=None, skip_model_validation=None, source_uri=None, + model_card=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. @@ -155,7 +156,8 @@ def __init__( skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). source_uri (str): The URI of the source for the model package (default: None). - + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). **kwargs: additional arguments to `create_model`. """ super().__init__(name=name, depends_on=depends_on) @@ -294,6 +296,7 @@ def __init__( task=task, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, **kwargs, ) if not repack_model: diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index 8101f32721..6d69801847 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -23,6 +23,10 @@ from sagemaker.fw_utils import model_code_key_prefix from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.predictor import Predictor from sagemaker.serializers import LibSVMSerializer from sagemaker.utils import to_string @@ -160,6 +164,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -211,6 +216,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: str: A string of SageMaker Model Package ARN. @@ -251,6 +258,7 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, ) def prepare_container_def( @@ -259,6 +267,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """Return a container definition with framework configuration. @@ -306,6 +315,7 @@ def prepare_container_def( model_data, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def serving_image_uri(self, region_name, instance_type, serverless_inference_config=None): diff --git a/tests/conftest.py b/tests/conftest.py index 7bab05dfb3..ceb2a03f51 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -269,7 +269,9 @@ def pytorch_training_py_version(pytorch_training_version, request): @pytest.fixture(scope="module", params=["py2", "py3"]) def pytorch_inference_py_version(pytorch_inference_version, request): - if Version(pytorch_inference_version) >= Version("2.0"): + if Version(pytorch_inference_version) >= Version("2.3"): + return "py311" + elif Version(pytorch_inference_version) >= Version("2.0"): return "py310" elif Version(pytorch_inference_version) >= Version("1.13"): return "py39" diff --git a/tests/data/serve_resources/mlflow/pytorch/requirements.txt b/tests/data/serve_resources/mlflow/pytorch/requirements.txt index 895e2173bf..de976327a4 100644 --- a/tests/data/serve_resources/mlflow/pytorch/requirements.txt +++ b/tests/data/serve_resources/mlflow/pytorch/requirements.txt @@ -1,4 +1,4 @@ -mlflow==2.12.1 +mlflow==2.13.2 astunparse==1.6.3 cffi==1.16.0 cloudpickle==2.2.1 diff --git a/tests/data/serve_resources/mlflow/tensorflow/requirements.txt b/tests/data/serve_resources/mlflow/tensorflow/requirements.txt index d4ff5b4782..ff99d3b92e 100644 --- a/tests/data/serve_resources/mlflow/tensorflow/requirements.txt +++ b/tests/data/serve_resources/mlflow/tensorflow/requirements.txt @@ -1,4 +1,4 @@ -mlflow==2.12.1 +mlflow==2.13.2 cloudpickle==2.2.1 numpy==1.26.4 tensorflow==2.16.1 diff --git a/tests/data/serve_resources/mlflow/xgboost/requirements.txt b/tests/data/serve_resources/mlflow/xgboost/requirements.txt index 18d687aec6..1130dcaec5 100644 --- a/tests/data/serve_resources/mlflow/xgboost/requirements.txt +++ b/tests/data/serve_resources/mlflow/xgboost/requirements.txt @@ -1,4 +1,4 @@ -mlflow==2.12.1 +mlflow==2.13.2 lz4==4.3.2 numpy==1.24.4 pandas==2.0.3 diff --git a/tests/integ/sagemaker/conftest.py b/tests/integ/sagemaker/conftest.py index 043b0c703e..46539e6de3 100644 --- a/tests/integ/sagemaker/conftest.py +++ b/tests/integ/sagemaker/conftest.py @@ -102,7 +102,9 @@ "channels:\n" " - defaults\n" "dependencies:\n" - " - scipy=1.10.1\n" + " - requests=2.32.3\n" + " - charset-normalizer=3.3.2\n" + " - scipy=1.13.1\n" " - pip:\n" " - /sagemaker-{sagemaker_version}.tar.gz\n" "prefix: /opt/conda/bin/conda\n" diff --git a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py index 0da64ecf05..00c87fac1b 100644 --- a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py +++ b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py @@ -174,6 +174,7 @@ def test_gated_model_training_v2(setup): tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], role=get_sm_session().get_caller_identity_arn(), sagemaker_session=get_sm_session(), + instance_type="ml.g5.2xlarge", ) payload = { diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 5205765e2f..6bc0a5c996 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -219,6 +219,11 @@ def test_jumpstart_gated_model_inference_component_enabled(setup): assert response is not None + model = JumpStartModel.attach(predictor.endpoint_name, sagemaker_session=get_sm_session()) + assert model.model_id == model_id + assert model.endpoint_name == predictor.endpoint_name + assert model.inference_component_name == predictor.component_name + @mock.patch("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning") def test_instatiating_model(mock_warning_logger, setup): @@ -260,6 +265,8 @@ def test_jumpstart_model_register(setup): response = predictor.predict("hello world!") + predictor.delete_predictor() + assert response is not None @@ -286,3 +293,59 @@ def test_proprietary_jumpstart_model(setup): response = predictor.predict(payload) assert response is not None + + +@pytest.mark.skipif( + True, + reason="Only enable if test account is subscribed to the proprietary model", +) +def test_register_proprietary_jumpstart_model(setup): + + model_id = "ai21-jurassic-2-light" + + model = JumpStartModel( + model_id=model_id, + model_version="2.0.004", + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + model_package = model.register() + + predictor = model_package.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}] + ) + payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1} + + response = predictor.predict(payload) + + predictor.delete_predictor() + + assert response is not None + + +@pytest.mark.skipif( + True, + reason="Only enable if test account is subscribed to the proprietary model", +) +def test_register_gated_jumpstart_model(setup): + + model_id = "meta-textgenerationneuron-llama-2-7b" + model = JumpStartModel( + model_id=model_id, + model_version="1.1.0", + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + model_package = model.register(accept_eula=True) + + predictor = model_package.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + accept_eula=True, + ) + payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1} + + response = predictor.predict(payload) + + predictor.delete_predictor() + + assert response is not None diff --git a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py index 0733649cb2..0bdbc18c99 100644 --- a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py +++ b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py @@ -18,12 +18,15 @@ # and the RegisterModel and CreateModelStep have been replaced with the new interface - ModelStep from __future__ import absolute_import +import json import logging import os import re import pytest +from sagemaker.model_card.model_card import ModelCard, ModelOverview, ModelPackageModelCard +from sagemaker.model_card.schema_constraints import ModelCardStatusEnum import tests from tests.integ.sagemaker.workflow.helpers import wait_pipeline_execution from sagemaker.tensorflow import TensorFlow, TensorFlowModel @@ -56,6 +59,15 @@ ) from tests.integ.kms_utils import get_or_create_kms_key from tests.integ import DATA_DIR +from sagemaker.model_card import ( + IntendedUses, + BusinessDetails, + EvaluationJob, + AdditionalInformation, + Metric, + MetricGroup, + MetricTypeEnum, +) @pytest.fixture @@ -703,6 +715,425 @@ def test_model_registration_with_drift_check_baselines( pass +def test_model_registration_with_model_card_object( + sagemaker_session_for_pipeline, + role, + pipeline_name, +): + instance_count = ParameterInteger(name="InstanceCount", default_value=1) + instance_type = "ml.m5.xlarge" + + # upload model data to s3 + model_local_path = os.path.join(DATA_DIR, "mxnet_mnist/model.tar.gz") + model_base_uri = "s3://{}/{}/input/model/{}".format( + sagemaker_session_for_pipeline.default_bucket(), + "register_model_test_with_drift_baseline", + utils.unique_name_from_base("model"), + ) + model_uri = S3Uploader.upload( + model_local_path, model_base_uri, sagemaker_session=sagemaker_session_for_pipeline + ) + model_uri_param = ParameterString(name="model_uri", default_value=model_uri) + + # upload metrics to s3 + metrics_data = ( + '{"regression_metrics": {"mse": {"value": 4.925353410353891, ' + '"standard_deviation": 2.219186917819692}}}' + ) + metrics_base_uri = "s3://{}/{}/input/metrics/{}".format( + sagemaker_session_for_pipeline.default_bucket(), + "register_model_test_with_drift_baseline", + utils.unique_name_from_base("metrics"), + ) + metrics_uri = S3Uploader.upload_string_as_file_body( + body=metrics_data, + desired_s3_uri=metrics_base_uri, + sagemaker_session=sagemaker_session_for_pipeline, + ) + metrics_uri_param = ParameterString(name="metrics_uri", default_value=metrics_uri) + + model_metrics = ModelMetrics( + bias=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + explainability=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + bias_pre_training=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + bias_post_training=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + ) + customer_metadata_properties = {"key1": "value1"} + domain = "COMPUTER_VISION" + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' + skip_model_validation = "All" + + # If image_uri is not provided, the instance_type should not be a pipeline variable + # since instance_type is used to retrieve image_uri in compile time (PySDK) + estimator = XGBoost( + entry_point="training.py", + source_dir=os.path.join(DATA_DIR, "sip"), + instance_type=instance_type, + instance_count=instance_count, + framework_version="0.90-2", + sagemaker_session=sagemaker_session_for_pipeline, + py_version="py3", + role=role, + ) + intended_uses = IntendedUses( + purpose_of_model="Test model card.", + intended_uses="Not used except this test.", + factors_affecting_model_efficiency="No.", + risk_rating="Low", + explanations_for_risk_rating="Just an example.", + ) + business_details = BusinessDetails( + business_problem="The business problem that your model is used to solve.", + business_stakeholders="The stakeholders who have the interest in the business that your model is used for.", + line_of_business="Services that the business is offering.", + ) + additional_information = AdditionalInformation( + ethical_considerations="Your model ethical consideration.", + caveats_and_recommendations="Your model's caveats and recommendations.", + custom_details={"custom details1": "details value"}, + ) + manual_metric_group = MetricGroup( + name="binary classification metrics", + metric_data=[Metric(name="accuracy", type=MetricTypeEnum.NUMBER, value=0.5)], + ) + example_evaluation_job = EvaluationJob( + name="Example evaluation job", + evaluation_observation="Evaluation observations.", + datasets=["s3://path/to/evaluation/data"], + metric_groups=[manual_metric_group], + ) + evaluation_details = [example_evaluation_job] + + model_overview = ModelOverview(model_creator="TestCreator") + + my_card = ModelCard( + name="TestName", + sagemaker_session=sagemaker_session_for_pipeline, + status=ModelCardStatusEnum.DRAFT, + model_overview=model_overview, + intended_uses=intended_uses, + business_details=business_details, + evaluation_details=evaluation_details, + additional_information=additional_information, + ) + + step_register = RegisterModel( + name="MyRegisterModelStep", + estimator=estimator, + model_data=model_uri_param, + content_types=["application/json"], + response_types=["application/json"], + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + transform_instances=["ml.m5.xlarge"], + model_package_group_name="testModelPackageGroup", + model_metrics=model_metrics, + customer_metadata_properties=customer_metadata_properties, + domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + skip_model_validation=skip_model_validation, + model_card=my_card, + ) + + pipeline = Pipeline( + name=pipeline_name, + parameters=[ + model_uri_param, + metrics_uri_param, + instance_count, + ], + steps=[step_register], + sagemaker_session=sagemaker_session_for_pipeline, + ) + + try: + response = pipeline.create(role) + create_arn = response["PipelineArn"] + + for _ in retries( + max_retry_count=5, + exception_message_prefix="Waiting for a successful execution of pipeline", + seconds_to_sleep=10, + ): + execution = pipeline.start( + parameters={"model_uri": model_uri, "metrics_uri": metrics_uri} + ) + response = execution.describe() + + assert response["PipelineArn"] == create_arn + + wait_pipeline_execution(execution=execution) + execution_steps = execution.list_steps() + + assert len(execution_steps) == 1 + failure_reason = execution_steps[0].get("FailureReason", "") + if failure_reason != "": + logging.error( + f"Pipeline execution failed with error: {failure_reason}." " Retrying.." + ) + continue + assert execution_steps[0]["StepStatus"] == "Succeeded" + assert execution_steps[0]["StepName"] == "MyRegisterModelStep-RegisterModel" + + response = sagemaker_session_for_pipeline.sagemaker_client.describe_model_package( + ModelPackageName=execution_steps[0]["Metadata"]["RegisterModel"]["Arn"] + ) + + assert ( + response["ModelMetrics"]["Explainability"]["Report"]["ContentType"] + == "application/json" + ) + assert response["CustomerMetadataProperties"] == customer_metadata_properties + assert response["Domain"] == domain + assert response["Task"] == task + assert response["SamplePayloadUrl"] == sample_payload_url + assert response["SkipModelValidation"] == skip_model_validation + assert (response["ModelCard"]["ModelCardStatus"]) == ModelCardStatusEnum.DRAFT + model_card_content = json.loads(response["ModelCard"]["ModelCardContent"]) + assert (model_card_content["model_overview"]["model_creator"]) == "TestCreator" + assert (model_card_content["intended_uses"]["purpose_of_model"]) == "Test model card." + assert ( + model_card_content["business_details"]["line_of_business"] + ) == "Services that the business is offering." + assert (model_card_content["evaluation_details"][0]["name"]) == "Example evaluation job" + + break + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_model_registration_with_model_card_json( + sagemaker_session_for_pipeline, + role, + pipeline_name, +): + instance_count = ParameterInteger(name="InstanceCount", default_value=1) + instance_type = "ml.m5.xlarge" + + # upload model data to s3 + model_local_path = os.path.join(DATA_DIR, "mxnet_mnist/model.tar.gz") + model_base_uri = "s3://{}/{}/input/model/{}".format( + sagemaker_session_for_pipeline.default_bucket(), + "register_model_test_with_drift_baseline", + utils.unique_name_from_base("model"), + ) + model_uri = S3Uploader.upload( + model_local_path, model_base_uri, sagemaker_session=sagemaker_session_for_pipeline + ) + model_uri_param = ParameterString(name="model_uri", default_value=model_uri) + + # upload metrics to s3 + metrics_data = ( + '{"regression_metrics": {"mse": {"value": 4.925353410353891, ' + '"standard_deviation": 2.219186917819692}}}' + ) + metrics_base_uri = "s3://{}/{}/input/metrics/{}".format( + sagemaker_session_for_pipeline.default_bucket(), + "register_model_test_with_drift_baseline", + utils.unique_name_from_base("metrics"), + ) + metrics_uri = S3Uploader.upload_string_as_file_body( + body=metrics_data, + desired_s3_uri=metrics_base_uri, + sagemaker_session=sagemaker_session_for_pipeline, + ) + metrics_uri_param = ParameterString(name="metrics_uri", default_value=metrics_uri) + + model_metrics = ModelMetrics( + bias=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + explainability=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + bias_pre_training=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + bias_post_training=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + ) + customer_metadata_properties = {"key1": "value1"} + domain = "COMPUTER_VISION" + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' + skip_model_validation = "All" + + # If image_uri is not provided, the instance_type should not be a pipeline variable + # since instance_type is used to retrieve image_uri in compile time (PySDK) + estimator = XGBoost( + entry_point="training.py", + source_dir=os.path.join(DATA_DIR, "sip"), + instance_type=instance_type, + instance_count=instance_count, + framework_version="0.90-2", + sagemaker_session=sagemaker_session_for_pipeline, + py_version="py3", + role=role, + ) + + model_card_content = { + "model_overview": { + "model_creator": "TestCreator", + }, + "intended_uses": { + "purpose_of_model": "Test model card.", + "intended_uses": "Not used except this test.", + "factors_affecting_model_efficiency": "No.", + "risk_rating": "Low", + "explanations_for_risk_rating": "Just an example.", + }, + "business_details": { + "business_problem": "The business problem that your model is used to solve.", + "business_stakeholders": "The stakeholders who have the interest in the business.", + "line_of_business": "Services that the business is offering.", + }, + "evaluation_details": [ + { + "name": "Example evaluation job", + "evaluation_observation": "Evaluation observations.", + "metric_groups": [ + { + "name": "binary classification metrics", + "metric_data": [{"name": "accuracy", "type": "number", "value": 0.5}], + } + ], + } + ], + "additional_information": { + "ethical_considerations": "Your model ethical consideration.", + "caveats_and_recommendations": 'Your model"s caveats and recommendations.', + "custom_details": {"custom details1": "details value"}, + }, + } + my_card = ModelPackageModelCard( + model_card_status=ModelCardStatusEnum.DRAFT, model_card_content=model_card_content + ) + + step_register = RegisterModel( + name="MyRegisterModelStep", + estimator=estimator, + model_data=model_uri_param, + content_types=["application/json"], + response_types=["application/json"], + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + transform_instances=["ml.m5.xlarge"], + model_package_group_name="testModelPackageGroup", + model_metrics=model_metrics, + customer_metadata_properties=customer_metadata_properties, + domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + skip_model_validation=skip_model_validation, + model_card=my_card, + ) + + pipeline = Pipeline( + name=pipeline_name, + parameters=[ + model_uri_param, + metrics_uri_param, + instance_count, + ], + steps=[step_register], + sagemaker_session=sagemaker_session_for_pipeline, + ) + + try: + response = pipeline.create(role) + create_arn = response["PipelineArn"] + + for _ in retries( + max_retry_count=5, + exception_message_prefix="Waiting for a successful execution of pipeline", + seconds_to_sleep=10, + ): + execution = pipeline.start( + parameters={"model_uri": model_uri, "metrics_uri": metrics_uri} + ) + response = execution.describe() + + assert response["PipelineArn"] == create_arn + + wait_pipeline_execution(execution=execution) + execution_steps = execution.list_steps() + + assert len(execution_steps) == 1 + failure_reason = execution_steps[0].get("FailureReason", "") + if failure_reason != "": + logging.error( + f"Pipeline execution failed with error: {failure_reason}." " Retrying.." + ) + continue + assert execution_steps[0]["StepStatus"] == "Succeeded" + assert execution_steps[0]["StepName"] == "MyRegisterModelStep-RegisterModel" + + response = sagemaker_session_for_pipeline.sagemaker_client.describe_model_package( + ModelPackageName=execution_steps[0]["Metadata"]["RegisterModel"]["Arn"] + ) + + assert ( + response["ModelMetrics"]["Explainability"]["Report"]["ContentType"] + == "application/json" + ) + assert response["CustomerMetadataProperties"] == customer_metadata_properties + assert response["Domain"] == domain + assert response["Task"] == task + assert response["SamplePayloadUrl"] == sample_payload_url + assert response["SkipModelValidation"] == skip_model_validation + assert (response["ModelCard"]["ModelCardStatus"]) == ModelCardStatusEnum.DRAFT + model_card_content = json.loads(response["ModelCard"]["ModelCardContent"]) + assert (model_card_content["model_overview"]["model_creator"]) == "TestCreator" + assert (model_card_content["intended_uses"]["purpose_of_model"]) == "Test model card." + assert ( + model_card_content["business_details"]["line_of_business"] + ) == "Services that the business is offering." + assert (model_card_content["evaluation_details"][0]["name"]) == "Example evaluation job" + + break + finally: + try: + pipeline.delete() + except Exception: + pass + + def test_model_registration_with_model_repack( sagemaker_session_for_pipeline, role, diff --git a/tests/integ/test_model_package.py b/tests/integ/test_model_package.py index 914c5db7ed..f59901ee61 100644 --- a/tests/integ/test_model_package.py +++ b/tests/integ/test_model_package.py @@ -12,8 +12,17 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import json import os -from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum +from sagemaker.model_card.model_card import ( + AdditionalInformation, + BusinessDetails, + IntendedUses, + ModelCard, + ModelOverview, + ModelPackageModelCard, +) +from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum, ModelCardStatusEnum from sagemaker.utils import unique_name_from_base from tests.integ import DATA_DIR from sagemaker.xgboost import XGBoostModel @@ -183,6 +192,216 @@ def test_update_source_uri(sagemaker_session): assert desc_model_package["SourceUri"] == source_uri +def test_update_model_card_with_model_card_object(sagemaker_session): + model_group_name = unique_name_from_base("test-model-group") + intended_uses = IntendedUses( + purpose_of_model="Test model card.", + intended_uses="Not used except this test.", + factors_affecting_model_efficiency="No.", + risk_rating="Low", + explanations_for_risk_rating="Just an example.", + ) + business_details = BusinessDetails( + business_problem="The business problem that your model is used to solve.", + business_stakeholders="The stakeholders who have the interest in the business that your model is used for.", + line_of_business="Services that the business is offering.", + ) + additional_information = AdditionalInformation( + ethical_considerations="Your model ethical consideration.", + caveats_and_recommendations="Your model's caveats and recommendations.", + custom_details={"custom details1": "details value"}, + ) + + model_overview = ModelOverview(model_creator="TestCreator") + + my_card = ModelCard( + name="TestName", + sagemaker_session=sagemaker_session, + status=ModelCardStatusEnum.DRAFT, + model_overview=model_overview, + intended_uses=intended_uses, + business_details=business_details, + additional_information=additional_information, + ) + + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + model = XGBoostModel( + model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session + ) + + model_package = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.m5.large"], + transform_instances=["ml.m5.large"], + model_package_group_name=model_group_name, + model_card=my_card, + ) + + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + updated_model_overview = ModelOverview(model_creator="updatedCreator") + updated_intended_uses = IntendedUses( + purpose_of_model="Updated Test model card.", + ) + updated_my_card = ModelCard( + name="TestName", + sagemaker_session=sagemaker_session, + model_overview=updated_model_overview, + intended_uses=updated_intended_uses, + ) + model_package.update_model_card(updated_my_card) + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + model_card_content = json.loads(desc_model_package["ModelCard"]["ModelCardContent"]) + assert model_card_content["intended_uses"]["purpose_of_model"] == "Updated Test model card." + assert model_card_content["model_overview"]["model_creator"] == "updatedCreator" + updated_my_card_status = ModelCard( + name="TestName", + sagemaker_session=sagemaker_session, + status=ModelCardStatusEnum.PENDING_REVIEW, + ) + model_package.update_model_card(updated_my_card_status) + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + model_card_content = json.loads(desc_model_package["ModelCard"]["ModelCardContent"]) + assert desc_model_package["ModelCard"]["ModelCardStatus"] == ModelCardStatusEnum.PENDING_REVIEW + + +def test_update_model_card_with_model_card_json(sagemaker_session): + model_group_name = unique_name_from_base("test-model-group") + model_card_content = { + "model_overview": { + "model_creator": "TestCreator", + }, + "intended_uses": { + "purpose_of_model": "Test model card.", + "intended_uses": "Not used except this test.", + "factors_affecting_model_efficiency": "No.", + "risk_rating": "Low", + "explanations_for_risk_rating": "Just an example.", + }, + "business_details": { + "business_problem": "The business problem that your model is used to solve.", + "business_stakeholders": "The stakeholders who have the interest in the business.", + "line_of_business": "Services that the business is offering.", + }, + "evaluation_details": [ + { + "name": "Example evaluation job", + "evaluation_observation": "Evaluation observations.", + "metric_groups": [ + { + "name": "binary classification metrics", + "metric_data": [{"name": "accuracy", "type": "number", "value": 0.5}], + } + ], + } + ], + "additional_information": { + "ethical_considerations": "Your model ethical consideration.", + "caveats_and_recommendations": 'Your model"s caveats and recommendations.', + "custom_details": {"custom details1": "details value"}, + }, + } + my_card = ModelPackageModelCard( + model_card_status=ModelCardStatusEnum.DRAFT, model_card_content=model_card_content + ) + + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + model = XGBoostModel( + model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session + ) + + model_package = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.m5.large"], + transform_instances=["ml.m5.large"], + model_package_group_name=model_group_name, + model_card=my_card, + ) + + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + updated_model_card_content = { + "model_overview": { + "model_creator": "updatedCreator", + }, + "intended_uses": { + "purpose_of_model": "Updated Test model card.", + "intended_uses": "Not used except this test.", + "factors_affecting_model_efficiency": "No.", + "risk_rating": "Low", + "explanations_for_risk_rating": "Just an example.", + }, + "business_details": { + "business_problem": "The business problem that your model is used to solve.", + "business_stakeholders": "The stakeholders who have the interest in the business.", + "line_of_business": "Services that the business is offering.", + }, + "evaluation_details": [ + { + "name": "Example evaluation job", + "evaluation_observation": "Evaluation observations.", + "metric_groups": [ + { + "name": "binary classification metrics", + "metric_data": [{"name": "accuracy", "type": "number", "value": 0.5}], + } + ], + } + ], + "additional_information": { + "ethical_considerations": "Your model ethical consideration.", + "caveats_and_recommendations": 'Your model"s caveats and recommendations.', + "custom_details": {"custom details1": "details value"}, + }, + } + updated_my_card = ModelPackageModelCard( + model_card_status=ModelCardStatusEnum.DRAFT, model_card_content=updated_model_card_content + ) + model_package.update_model_card(updated_my_card) + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + model_card_content = json.loads(desc_model_package["ModelCard"]["ModelCardContent"]) + assert model_card_content["intended_uses"]["purpose_of_model"] == "Updated Test model card." + assert model_card_content["model_overview"]["model_creator"] == "updatedCreator" + updated_my_card_status = ModelPackageModelCard( + model_card_status=ModelCardStatusEnum.PENDING_REVIEW, + ) + model_package.update_model_card(updated_my_card_status) + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + assert desc_model_package["ModelCard"]["ModelCardStatus"] == ModelCardStatusEnum.PENDING_REVIEW + + def test_clone_model_package_using_source_uri(sagemaker_session): model_group_name = unique_name_from_base("test-model-group") diff --git a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py index 11165a0625..91c132f053 100644 --- a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py +++ b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py @@ -13,7 +13,7 @@ from __future__ import absolute_import import boto3 -from mock.mock import patch, Mock +from mock.mock import patch, Mock, ANY from sagemaker import accept_types from sagemaker.jumpstart.utils import verify_model_region_and_return_specs @@ -54,9 +54,11 @@ def test_jumpstart_default_accept_types( patched_get_model_specs.assert_called_once_with( region=region, model_id=model_id, + hub_arn=None, version=model_version, - s3_client=mock_client, + s3_client=ANY, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) @@ -91,6 +93,8 @@ def test_jumpstart_supported_accept_types( region=region, model_id=model_id, version=model_version, - s3_client=mock_client, + s3_client=ANY, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py index d116c8121b..dda1e30db2 100644 --- a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py +++ b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py @@ -56,6 +56,8 @@ def test_jumpstart_default_content_types( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) @@ -75,15 +77,12 @@ def test_jumpstart_supported_content_types( model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" - supported_content_types = content_types.retrieve_options( + content_types.retrieve_options( region=region, model_id=model_id, model_version=model_version, sagemaker_session=mock_session, ) - assert supported_content_types == [ - "application/x-text", - ] patched_get_model_specs.assert_called_once_with( region=region, @@ -91,4 +90,6 @@ def test_jumpstart_supported_content_types( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py index f0102068e7..9bbca51654 100644 --- a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py +++ b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py @@ -58,6 +58,8 @@ def test_jumpstart_default_deserializers( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) @@ -98,4 +100,6 @@ def test_jumpstart_deserializer_options( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index 5f00f93abf..13f720870c 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -61,6 +61,8 @@ def test_jumpstart_default_environment_variables( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -85,6 +87,8 @@ def test_jumpstart_default_environment_variables( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -147,6 +151,8 @@ def test_jumpstart_sdk_environment_variables( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -172,6 +178,8 @@ def test_jumpstart_sdk_environment_variables( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py index 40ee4978cf..565ebbce87 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py @@ -54,6 +54,8 @@ def test_jumpstart_default_hyperparameters( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -72,6 +74,8 @@ def test_jumpstart_default_hyperparameters( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -98,6 +102,8 @@ def test_jumpstart_default_hyperparameters( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index fdc29b4d90..edf2cfca59 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -146,6 +146,8 @@ def add_options_to_hyperparameter(*largs, **kwargs): version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -452,6 +454,8 @@ def add_options_to_hyperparameter(*largs, **kwargs): version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -514,6 +518,8 @@ def test_jumpstart_validate_all_hyperparameters( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index 88b95b9403..cc3723c3c5 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -53,9 +53,11 @@ def test_jumpstart_common_image_uri( patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", + hub_arn=None, version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -74,9 +76,11 @@ def test_jumpstart_common_image_uri( patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", + hub_arn=None, version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -95,9 +99,11 @@ def test_jumpstart_common_image_uri( patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", + hub_arn=None, version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -116,9 +122,11 @@ def test_jumpstart_common_image_uri( patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", + hub_arn=None, version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/image_uris/test_smp_v2.py b/tests/unit/sagemaker/image_uris/test_smp_v2.py index b53a45133e..e9c8cec292 100644 --- a/tests/unit/sagemaker/image_uris/test_smp_v2.py +++ b/tests/unit/sagemaker/image_uris/test_smp_v2.py @@ -35,7 +35,7 @@ def test_smp_v2(load_config): for region in ACCOUNTS.keys(): for instance_type in CONTAINER_VERSIONS.keys(): cuda_vers = CONTAINER_VERSIONS[instance_type] - if "2.1" in version or "2.2" in version: + if "2.1" in version or "2.2" in version or "2.3" in version: cuda_vers = "cu121" uri = image_uris.get_training_image_uri( diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index 2e51afd3f7..5db149c4c3 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -51,6 +51,8 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, + hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -70,6 +72,8 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -95,6 +99,8 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -122,6 +128,8 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index eb2598b357..1ae489acf8 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -1250,288 +1250,557 @@ "dynamic_container_deployment_supported": True, }, }, - "env-var-variant-model": { - "model_id": "huggingface-llm-falcon-180b-bf16", - "url": "https://huggingface.co/tiiuae/falcon-180B", - "version": "1.0.0", - "min_sdk_version": "2.175.0", - "training_supported": False, + # noqa: E501 + "gemma-model-2b-v1_1_0": { + "model_id": "huggingface-llm-gemma-2b-instruct", + "url": "https://huggingface.co/google/gemma-2b-it", + "version": "1.1.0", + "min_sdk_version": "2.189.0", + "training_supported": True, "incremental_training_supported": False, "hosting_ecr_specs": { "framework": "huggingface-llm", - "framework_version": "0.9.3", - "py_version": "py39", - "huggingface_transformers_version": "4.29.2", + "framework_version": "1.4.2", + "py_version": "py310", + "huggingface_transformers_version": "4.33.2", }, - "hosting_artifact_key": "huggingface-infer/infer-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_artifact_key": "huggingface-llm/huggingface-llm-gemma-2b-instruct/artifacts/inference/v1.0.0/", "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/infer-prepack" - "-huggingface-llm-falcon-180b-bf16.tar.gz", - "hosting_prepacked_artifact_version": "1.0.1", + "hosting_prepacked_artifact_key": ( + "huggingface-llm/huggingface-llm-gemma-2b-instruct/artifacts/inference-prepack/v1.0.0/" + ), + "hosting_prepacked_artifact_version": "1.0.0", "hosting_use_script_uri": False, + "hosting_eula_key": "fmhMetadata/terms/gemmaTerms.txt", "inference_vulnerable": False, "inference_dependencies": [], "inference_vulnerabilities": [], "training_vulnerable": False, - "training_dependencies": [], + "training_dependencies": [ + "accelerate==0.26.1", + "bitsandbytes==0.42.0", + "deepspeed==0.10.3", + "docstring-parser==0.15", + "flash_attn==2.5.5", + "ninja==1.11.1", + "packaging==23.2", + "peft==0.8.2", + "py_cpuinfo==9.0.0", + "rich==13.7.0", + "safetensors==0.4.2", + "sagemaker_jumpstart_huggingface_script_utilities==1.2.1", + "sagemaker_jumpstart_script_utilities==1.1.9", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + "shtab==1.6.5", + "tokenizers==0.15.1", + "transformers==4.38.1", + "trl==0.7.10", + "tyro==0.7.2", + ], "training_vulnerabilities": [], "deprecated": False, - "inference_environment_variables": [ - { - "name": "SAGEMAKER_PROGRAM", - "type": "text", - "default": "inference.py", - "scope": "container", - "required_for_model_class": True, - }, + "hyperparameters": [ { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "name": "peft_type", "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - "required_for_model_class": False, + "default": "lora", + "options": ["lora", "None"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "name": "instruction_tuned", "type": "text", - "default": "20", - "scope": "container", - "required_for_model_class": False, + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "name": "chat_dataset", "type": "text", - "default": "3600", - "scope": "container", - "required_for_model_class": False, + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "ENDPOINT_SERVER_TIMEOUT", + "name": "epoch", "type": "int", - "default": 3600, - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", }, { - "name": "SAGEMAKER_ENV", - "type": "text", - "default": "1", - "scope": "container", - "required_for_model_class": True, + "name": "learning_rate", + "type": "float", + "default": 0.0001, + "min": 1e-08, + "max": 1, + "scope": "algorithm", }, { - "name": "HF_MODEL_ID", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + "name": "lora_r", + "type": "int", + "default": 64, + "min": 1, + "max": 1000, + "scope": "algorithm", }, + {"name": "lora_alpha", "type": "int", "default": 16, "min": 0, "scope": "algorithm"}, { - "name": "SM_NUM_GPUS", - "type": "text", - "default": "8", - "scope": "container", - "required_for_model_class": True, + "name": "lora_dropout", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", }, + {"name": "bits", "type": "int", "default": 4, "scope": "algorithm"}, { - "name": "MAX_INPUT_LENGTH", + "name": "double_quant", "type": "text", - "default": "1024", - "scope": "container", - "required_for_model_class": True, + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "MAX_TOTAL_TOKENS", + "name": "quant_type", "type": "text", - "default": "2048", - "scope": "container", - "required_for_model_class": True, + "default": "nf4", + "options": ["fp4", "nf4"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "name": "per_device_train_batch_size", "type": "int", "default": 1, - "scope": "container", - "required_for_model_class": True, + "min": 1, + "max": 1000, + "scope": "algorithm", }, - ], - "metrics": [], - "default_inference_instance_type": "ml.p4de.24xlarge", - "supported_inference_instance_types": ["ml.p4de.24xlarge"], - "model_kwargs": {}, - "deploy_kwargs": { - "model_data_download_timeout": 3600, - "container_startup_health_check_timeout": 3600, - }, - "predictor_specs": { - "supported_content_types": ["application/json"], - "supported_accept_types": ["application/json"], - "default_content_type": "application/json", - "default_accept_type": "application/json", - }, - "inference_volume_size": 512, - "inference_enable_network_isolation": True, - "validation_supported": False, - "fine_tuning_supported": False, - "resource_name_base": "hf-llm-falcon-180b-bf16", - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", - "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", - } + { + "name": "per_device_eval_batch_size", + "type": "int", + "default": 2, + "min": 1, + "max": 1000, + "scope": "algorithm", }, - "variants": { - "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "80"}}}, - "ml.p4d.24xlarge": { - "properties": { - "environment_variables": { - "YODEL": "NACEREMA", - } - } - }, + { + "name": "warmup_ratio", + "type": "float", + "default": 0.1, + "min": 0, + "max": 1, + "scope": "algorithm", }, - }, - }, - "inference-instance-types-variant-model": { - "model_id": "huggingface-llm-falcon-180b-bf16", - "url": "https://huggingface.co/tiiuae/falcon-180B", - "version": "1.0.0", - "min_sdk_version": "2.175.0", - "training_supported": True, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "huggingface-llm", - "framework_version": "0.9.3", - "py_version": "py39", - "huggingface_transformers_version": "4.29.2", - }, - "hosting_artifact_key": "huggingface-infer/infer-huggingface-llm-falcon-180b-bf16.tar.gz", - "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/infer-prepack" - "-huggingface-llm-falcon-180b-bf16.tar.gz", - "hosting_prepacked_artifact_version": "1.0.1", - "hosting_use_script_uri": False, - "inference_vulnerable": False, - "inference_dependencies": [], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], - "deprecated": False, - "inference_environment_variables": [ { - "name": "SAGEMAKER_PROGRAM", + "name": "train_from_scratch", "type": "text", - "default": "inference.py", - "scope": "container", - "required_for_model_class": True, + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "name": "fp16", "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - "required_for_model_class": False, + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "name": "bf16", "type": "text", - "default": "20", - "scope": "container", - "required_for_model_class": False, + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "name": "evaluation_strategy", "type": "text", - "default": "3600", - "scope": "container", - "required_for_model_class": False, + "default": "steps", + "options": ["steps", "epoch", "no"], + "scope": "algorithm", }, { - "name": "ENDPOINT_SERVER_TIMEOUT", + "name": "eval_steps", "type": "int", - "default": 3600, - "scope": "container", - "required_for_model_class": True, + "default": 20, + "min": 1, + "max": 1000, + "scope": "algorithm", }, { - "name": "MODEL_CACHE_ROOT", + "name": "gradient_accumulation_steps", + "type": "int", + "default": 4, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "logging_steps", + "type": "int", + "default": 8, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "weight_decay", + "type": "float", + "default": 0.2, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "load_best_model_at_end", "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_ENV", + "name": "max_train_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_val_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "seed", + "type": "int", + "default": 10, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "max_input_length", + "type": "int", + "default": 1024, + "min": -1, + "scope": "algorithm", + }, + { + "name": "validation_split_ratio", + "type": "float", + "default": 0.2, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "train_data_split_seed", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "preprocessing_num_workers", "type": "text", - "default": "1", - "scope": "container", - "required_for_model_class": True, + "default": "None", + "scope": "algorithm", }, + {"name": "max_steps", "type": "int", "default": -1, "scope": "algorithm"}, { - "name": "HF_MODEL_ID", + "name": "gradient_checkpointing", "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SM_NUM_GPUS", + "name": "early_stopping_patience", + "type": "int", + "default": 3, + "min": 1, + "scope": "algorithm", + }, + { + "name": "early_stopping_threshold", + "type": "float", + "default": 0.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "adam_beta1", + "type": "float", + "default": 0.9, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_beta2", + "type": "float", + "default": 0.999, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_epsilon", + "type": "float", + "default": 1e-08, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "max_grad_norm", + "type": "float", + "default": 1.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "label_smoothing_factor", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "logging_first_step", "type": "text", - "default": "8", - "scope": "container", - "required_for_model_class": True, + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "MAX_INPUT_LENGTH", + "name": "logging_nan_inf_filter", "type": "text", - "default": "1024", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "save_strategy", + "type": "text", + "default": "steps", + "options": ["no", "epoch", "steps"], + "scope": "algorithm", + }, + {"name": "save_steps", "type": "int", "default": 500, "min": 1, "scope": "algorithm"}, + {"name": "save_total_limit", "type": "int", "default": 1, "scope": "algorithm"}, + { + "name": "dataloader_drop_last", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "dataloader_num_workers", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "eval_accumulation_steps", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "auto_find_batch_size", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "lr_scheduler_type", + "type": "text", + "default": "constant_with_warmup", + "options": ["constant_with_warmup", "linear"], + "scope": "algorithm", + }, + {"name": "warmup_steps", "type": "int", "default": 0, "min": 0, "scope": "algorithm"}, + { + "name": "deepspeed", + "type": "text", + "default": "False", + "options": ["False"], + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", "scope": "container", - "required_for_model_class": True, }, { - "name": "MAX_TOTAL_TOKENS", + "name": "sagemaker_program", "type": "text", - "default": "2048", + "default": "transfer_learning.py", "scope": "container", - "required_for_model_class": True, }, { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "int", - "default": 1, + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", "scope": "container", - "required_for_model_class": True, }, ], - "metrics": [], - "default_inference_instance_type": "ml.p4de.24xlarge", - "supported_inference_instance_types": ["ml.p4de.24xlarge"], - "default_training_instance_type": "ml.p4de.24xlarge", - "supported_training_instance_types": ["ml.p4de.24xlarge"], - "model_kwargs": {}, + "training_script_key": "source-directory-tarballs/huggingface/transfer_learning/llm/v1.1.1/sourcedir.tar.gz", + "training_prepacked_script_key": ( + "source-directory-tarballs/huggingface/transfer_learning/llm/prepack/v1.1.1/sourcedir.tar.gz" + ), + "training_prepacked_script_version": "1.1.1", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, + "training_artifact_key": "huggingface-training/train-huggingface-llm-gemma-2b-instruct.tar.gz", + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "8191", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "8192", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_BATCH_PREFILL_TOKENS", + "type": "text", + "default": "8191", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SM_NUM_GPUS", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + {"Name": "huggingface-textgeneration:train-loss", "Regex": "'loss': ([0-9]+\\.[0-9]+)"}, + ], + "default_inference_instance_type": "ml.g5.xlarge", + "supported_inference_instance_types": [ + "ml.g5.xlarge", + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "default_training_instance_type": "ml.g5.2xlarge", + "supported_training_instance_types": [ + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "model_kwargs": {}, "deploy_kwargs": { - "model_data_download_timeout": 3600, - "container_startup_health_check_timeout": 3600, + "model_data_download_timeout": 1200, + "container_startup_health_check_timeout": 1200, + }, + "estimator_kwargs": { + "encrypt_inter_container_traffic": True, + "disable_output_compression": True, + "max_run": 360000, }, + "fit_kwargs": {}, "predictor_specs": { "supported_content_types": ["application/json"], "supported_accept_types": ["application/json"], @@ -1539,406 +1808,698 @@ "default_accept_type": "application/json", }, "inference_volume_size": 512, + "training_volume_size": 512, "inference_enable_network_isolation": True, - "validation_supported": False, - "fine_tuning_supported": False, - "resource_name_base": "hf-llm-falcon-180b-bf16", - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", - "gpu_image_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/stud-gpu", - "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", - } - }, - "variants": { - "ml.p2.12xlarge": { - "properties": { - "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, - "supported_inference_instance_types": ["ml.p5.xlarge"], - "default_inference_instance_type": "ml.p5.xlarge", - "metrics": [ - { - "Name": "huggingface-textgeneration:eval-loss", - "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:instance-typemetric-loss", - "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:train-loss", - "Regex": "'instance type specific': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:noneyourbusiness-loss", - "Regex": "'loss-noyb instance specific': ([0-9]+\\.[0-9]+)", - }, - ], - } + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/oasst_top/train/", + "validation_supported": True, + "fine_tuning_supported": True, + "resource_name_base": "hf-llm-gemma-2b-instruct", + "default_payloads": { + "HelloWorld": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", }, - "p2": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": { - "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.xlarge"], - "default_inference_instance_type": "ml.p2.xlarge", - "metrics": [ - { - "Name": "huggingface-textgeneration:wtafigo", - "Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:eval-loss", - "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:train-loss", - "Regex": "'instance family specific': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:noneyourbusiness-loss", - "Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)", - }, - ], + "body": { + "inputs": ( + "user\nWrite a hello world program\nmodel" + ), + "parameters": { + "max_new_tokens": 256, + "decoder_input_details": True, + "details": True, }, }, - "p3": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, - "ml.p3.200xlarge": {"regional_properties": {"image_uri": "$gpu_image_uri_2"}}, - "p4": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": { - "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/number2/" - }, + }, + "MachineLearningPoem": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", }, - "g4": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": { - "artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" + "body": { + "inputs": "Write me a poem about Machine Learning.", + "parameters": { + "max_new_tokens": 256, + "decoder_input_details": True, + "details": True, }, }, - "g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, - "g9": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": { - "prepacked_artifact_key": "asfs/adsf/sda/f", - "hyperparameters": [ - { - "name": "num_bag_sets", - "type": "int", - "default": 5, - "min": 5, - "scope": "algorithm", - }, - { - "name": "num_stack_levels", - "type": "int", - "default": 6, - "min": 7, - "max": 3, - "scope": "algorithm", - }, - { - "name": "refit_full", - "type": "text", - "default": "False", - "options": ["True", "False"], - "scope": "algorithm", - }, - { - "name": "set_best_to_refit_full", - "type": "text", - "default": "False", - "options": ["True", "False"], - "scope": "algorithm", - }, - { - "name": "save_space", - "type": "text", - "default": "False", - "options": ["True", "False"], - "scope": "algorithm", - }, - { - "name": "verbosity", - "type": "int", - "default": 2, - "min": 0, - "max": 4, - "scope": "algorithm", - }, - { - "name": "sagemaker_submit_directory", - "type": "text", - "default": "/opt/ml/input/data/code/sourcedir.tar.gz", - "scope": "container", - }, - { - "name": "sagemaker_program", - "type": "text", - "default": "transfer_learning.py", - "scope": "container", - }, - { - "name": "sagemaker_container_log_level", - "type": "text", - "default": "20", - "scope": "container", - }, - ], - }, + }, + }, + "gated_bucket": True, + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": ( + "626614931356.dkr.ecr.af-south-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, - "p9": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": {"artifact_key": "do/re/mi"}, + "ap-east-1": { + "gpu_ecr_uri_1": ( + "871362719292.dkr.ecr.ap-east-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, - "m2": { - "regional_properties": {"image_uri": "$cpu_image_uri"}, - "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "400"}}, + "ap-northeast-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, - "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, - "local": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, - "ml.g5.48xlarge": { - "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}} + "ap-northeast-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, - "ml.g5.12xlarge": { - "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}} + "ap-south-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-south-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, - "g5": { - "properties": { - "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4", "JOHN": "DOE"} - } + "ap-southeast-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) }, - "ml.g9.12xlarge": { + "ap-southeast-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "ca-central-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ca-central-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "cn-north-1": { + "gpu_ecr_uri_1": ( + "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "eu-central-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-central-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "eu-north-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-north-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "eu-south-1": { + "gpu_ecr_uri_1": ( + "692866216735.dkr.ecr.eu-south-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "eu-west-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "eu-west-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "eu-west-3": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-3.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "il-central-1": { + "gpu_ecr_uri_1": ( + "780543022126.dkr.ecr.il-central-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "me-south-1": { + "gpu_ecr_uri_1": ( + "217643126080.dkr.ecr.me-south-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "sa-east-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.sa-east-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "us-east-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "us-east-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-east-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "us-west-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-west-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "us-west-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g4dn.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": ( + "626614931356.dkr.ecr.af-south-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "ap-east-1": { + "gpu_ecr_uri_1": ( + "871362719292.dkr.ecr.ap-east-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "ap-northeast-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "ap-northeast-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "ap-south-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-south-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "ap-southeast-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "ap-southeast-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "ca-central-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ca-central-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "cn-north-1": { + "gpu_ecr_uri_1": ( + "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "eu-central-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-central-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "eu-north-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-north-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "eu-south-1": { + "gpu_ecr_uri_1": ( + "692866216735.dkr.ecr.eu-south-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "eu-west-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "eu-west-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "eu-west-3": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-3.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "il-central-1": { + "gpu_ecr_uri_1": ( + "780543022126.dkr.ecr.il-central-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "me-south-1": { + "gpu_ecr_uri_1": ( + "217643126080.dkr.ecr.me-south-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "sa-east-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.sa-east-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "us-east-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "us-east-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-east-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "us-west-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-west-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "us-west-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + }, + "variants": { + "g4dn": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, "properties": { - "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, - "prepacked_artifact_key": "nlahdasf/asdf/asd/f", - "hyperparameters": [ - { - "name": "eval_metric", - "type": "text", - "default": "auto", - "scope": "algorithm", - }, - { - "name": "presets", - "type": "text", - "default": "medium_quality", - "options": [ - "best_quality", - "high_quality", - "good_quality", - "medium_quality", - "optimize_for_deployment", - "interpretable", - ], - "scope": "algorithm", - }, - { - "name": "auto_stack", - "type": "text", - "default": "False", - "options": ["True", "False"], - "scope": "algorithm", - }, - { - "name": "num_bag_folds", - "type": "text", - "default": "0", - "options": ["0", "2", "3", "4", "5", "6", "7", "8", "9", "10"], - "scope": "algorithm", - }, - { - "name": "num_bag_sets", - "type": "int", - "default": 1, - "min": 1, - "scope": "algorithm", - }, - { - "name": "num_stack_levels", - "type": "int", - "default": 0, - "min": 0, - "max": 3, - "scope": "algorithm", - }, - ], - } + "gated_model_key_env_var_value": ( + "huggingface-training/g4dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + ) + }, }, - "ml.p9.12xlarge": { + "g5": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, "properties": { - "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, - "artifact_key": "you/not/entertained", - } + "gated_model_key_env_var_value": ( + "huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + ) + }, }, - "g6": { + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, "properties": { - "environment_variables": {"BLAH": "4"}, - "artifact_key": "path/to/training/artifact.tar.gz", - "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/", - } + "gated_model_key_env_var_value": ( + "huggingface-training/p3dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + ) + }, }, - "trn1": { + "p4d": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, "properties": { - "supported_inference_instance_types": ["ml.inf1.xlarge", "ml.inf1.2xlarge"], - "default_inference_instance_type": "ml.inf1.xlarge", - } + "gated_model_key_env_var_value": ( + "huggingface-training/p4d/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + ) + }, }, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, }, }, - "training_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "hosting_resource_requirements": {"min_memory_mb": 8192, "num_accelerators": 1}, + "dynamic_container_deployment_supported": True, + }, + # noqa: E501 + "env-var-variant-model": { + "model_id": "huggingface-llm-falcon-180b-bf16", + "url": "https://huggingface.co/tiiuae/falcon-180B", + "version": "1.0.0", + "min_sdk_version": "2.175.0", + "training_supported": False, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface-llm", + "framework_version": "0.9.3", + "py_version": "py39", + "huggingface_transformers_version": "4.29.2", }, - "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", - "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", - "training_prepacked_script_key": None, - "training_model_package_artifact_uris": None, - "deprecate_warn_message": None, - "deprecated_message": None, - "hosting_eula_key": None, - "hyperparameters": [ + "hosting_artifact_key": "huggingface-infer/infer-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/infer-prepack" + "-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_prepacked_artifact_version": "1.0.1", + "hosting_use_script_uri": False, + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "inference_environment_variables": [ { - "name": "epochs", - "type": "int", - "default": 3, - "min": 1, - "max": 1000, - "scope": "algorithm", + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, }, { - "name": "adam-learning-rate", - "type": "float", - "default": 0.05, - "min": 1e-08, - "max": 1, - "scope": "algorithm", + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, }, { - "name": "batch-size", + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", "type": "int", - "default": 4, - "min": 1, - "max": 1024, - "scope": "algorithm", + "default": 3600, + "scope": "container", + "required_for_model_class": True, }, { - "name": "sagemaker_submit_directory", + "name": "MODEL_CACHE_ROOT", "type": "text", - "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "default": "/opt/ml/model", "scope": "container", + "required_for_model_class": True, }, { - "name": "sagemaker_program", + "name": "SAGEMAKER_ENV", "type": "text", - "default": "transfer_learning.py", + "default": "1", "scope": "container", + "required_for_model_class": True, }, { - "name": "sagemaker_container_log_level", + "name": "HF_MODEL_ID", "type": "text", - "default": "20", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SM_NUM_GPUS", + "type": "text", + "default": "8", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "1024", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "2048", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, "scope": "container", + "required_for_model_class": True, }, ], - "training_vulnerable": False, - "deprecated": False, - "estimator_kwargs": { - "encrypt_inter_container_traffic": True, + "metrics": [], + "default_inference_instance_type": "ml.p4de.24xlarge", + "supported_inference_instance_types": ["ml.p4de.24xlarge"], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 3600, + "container_startup_health_check_timeout": 3600, }, - "training_volume_size": 456, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 512, "inference_enable_network_isolation": True, - "training_enable_network_isolation": False, + "validation_supported": False, + "fine_tuning_supported": False, + "resource_name_base": "hf-llm-falcon-180b-bf16", + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", + } + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "80"}}}, + "ml.p4d.24xlarge": { + "properties": { + "environment_variables": { + "YODEL": "NACEREMA", + } + } + }, + }, + }, }, - "variant-model": { - "model_id": "pytorch-ic-mobilenet-v2", - "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", + "inference-instance-types-variant-model": { + "model_id": "huggingface-llm-falcon-180b-bf16", + "url": "https://huggingface.co/tiiuae/falcon-180B", "version": "1.0.0", - "min_sdk_version": "2.49.0", + "min_sdk_version": "2.175.0", "training_supported": True, - "incremental_training_supported": True, - "hosting_model_package_arns": { - "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/ll" - "ama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" - }, + "incremental_training_supported": False, "hosting_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", + "framework": "huggingface-llm", + "framework_version": "0.9.3", + "py_version": "py39", + "huggingface_transformers_version": "4.29.2", }, - "training_instance_type_variants": { - "regional_aliases": {}, - "variants": { - "ml.p2.12xlarge": { - "properties": { - "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, - "hyperparameters": [ - { - "name": "eval_metric", - "type": "text", - "default": "auto", - "scope": "algorithm", - }, - { - "name": "presets", - "type": "text", - "default": "medium_quality", - "options": [ - "best_quality", - "high_quality", - "good_quality", - "medium_quality", - "optimize_for_deployment", - "interpretable", - ], - "scope": "algorithm", - }, - { - "name": "auto_stack", - "type": "text", - "default": "False", - "options": ["True", "False"], - "scope": "algorithm", - }, - { - "name": "num_bag_folds", - "type": "text", - "default": "0", - "options": ["0", "2", "3", "4", "5", "6", "7", "8", "9", "10"], - "scope": "algorithm", - }, - { - "name": "num_bag_sets", - "type": "int", - "default": 1, - "min": 1, - "scope": "algorithm", - }, - { - "name": "batch-size", - "type": "int", - "default": 1, - "min": 1, - "scope": "algorithm", - }, - { - "name": "num_stack_levels", - "type": "int", - "default": 0, - "min": 0, - "max": 3, - "scope": "algorithm", - }, - ], + "hosting_artifact_key": "huggingface-infer/infer-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/infer-prepack" + "-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_prepacked_artifact_version": "1.0.1", + "hosting_use_script_uri": False, + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SM_NUM_GPUS", + "type": "text", + "default": "8", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "1024", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "2048", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [], + "default_inference_instance_type": "ml.p4de.24xlarge", + "supported_inference_instance_types": ["ml.p4de.24xlarge"], + "default_training_instance_type": "ml.p4de.24xlarge", + "supported_training_instance_types": ["ml.p4de.24xlarge"], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 3600, + "container_startup_health_check_timeout": 3600, + }, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 512, + "inference_enable_network_isolation": True, + "validation_supported": False, + "fine_tuning_supported": False, + "resource_name_base": "hf-llm-falcon-180b-bf16", + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + "gpu_image_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/stud-gpu", + "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", + } + }, + "variants": { + "ml.p2.12xlarge": { + "properties": { + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, + "supported_inference_instance_types": ["ml.p5.xlarge"], + "default_inference_instance_type": "ml.p5.xlarge", "metrics": [ { - "Name": "huggingface-textgeneration:instance-typemetric-loss", + "Name": "huggingface-textgeneration:eval-loss", "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", }, { - "Name": "huggingface-textgeneration:eval-loss", + "Name": "huggingface-textgeneration:instance-typemetric-loss", "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", }, { @@ -1953,8 +2514,49 @@ } }, "p2": { - "regional_properties": {"image_uri": "$gpu_ecr_uri_2"}, + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.xlarge"], + "default_inference_instance_type": "ml.p2.xlarge", + "metrics": [ + { + "Name": "huggingface-textgeneration:wtafigo", + "Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance family specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)", + }, + ], + }, + }, + "p3": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, + "ml.p3.200xlarge": {"regional_properties": {"image_uri": "$gpu_image_uri_2"}}, + "p4": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/number2/" + }, + }, + "g4": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" + }, + }, + "g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, + "g9": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, "properties": { + "prepacked_artifact_key": "asfs/adsf/sda/f", "hyperparameters": [ { "name": "num_bag_sets", @@ -2019,87 +2621,105 @@ "scope": "container", }, ], - "metrics": [ + }, + }, + "p9": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": {"artifact_key": "do/re/mi"}, + }, + "m2": { + "regional_properties": {"image_uri": "$cpu_image_uri"}, + "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "400"}}, + }, + "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, + "local": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, + "ml.g5.48xlarge": { + "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}} + }, + "ml.g5.12xlarge": { + "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}} + }, + "g5": { + "properties": { + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4", "JOHN": "DOE"} + } + }, + "ml.g9.12xlarge": { + "properties": { + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, + "prepacked_artifact_key": "nlahdasf/asdf/asd/f", + "hyperparameters": [ { - "Name": "huggingface-textgeneration:wtafigo", - "Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)", + "name": "eval_metric", + "type": "text", + "default": "auto", + "scope": "algorithm", }, { - "Name": "huggingface-textgeneration:eval-loss", - "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + "name": "presets", + "type": "text", + "default": "medium_quality", + "options": [ + "best_quality", + "high_quality", + "good_quality", + "medium_quality", + "optimize_for_deployment", + "interpretable", + ], + "scope": "algorithm", }, { - "Name": "huggingface-textgeneration:train-loss", - "Regex": "'instance family specific': ([0-9]+\\.[0-9]+)", + "name": "auto_stack", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "Name": "huggingface-textgeneration:noneyourbusiness-loss", - "Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)", - }, - ], - }, - }, - }, - }, - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", - "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", - "inf_model_package_arn": "us-west-2/blah/blah/blah/inf", - "gpu_model_package_arn": "us-west-2/blah/blah/blah/gpu", - } - }, - "variants": { - "p2": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - "model_package_arn": "$gpu_model_package_arn", - } - }, - "p3": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - "model_package_arn": "$gpu_model_package_arn", - } - }, - "p4": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - "model_package_arn": "$gpu_model_package_arn", + "name": "num_bag_folds", + "type": "text", + "default": "0", + "options": ["0", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + "scope": "algorithm", + }, + { + "name": "num_bag_sets", + "type": "int", + "default": 1, + "min": 1, + "scope": "algorithm", + }, + { + "name": "num_stack_levels", + "type": "int", + "default": 0, + "min": 0, + "max": 3, + "scope": "algorithm", + }, + ], } }, - "g4dn": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - "model_package_arn": "$gpu_model_package_arn", + "ml.p9.12xlarge": { + "properties": { + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, + "artifact_key": "you/not/entertained", } }, - "g5": { + "g6": { "properties": { - "resource_requirements": { - "num_accelerators": 888810, - "randon-field-2": 2222, - } + "environment_variables": {"BLAH": "4"}, + "artifact_key": "path/to/training/artifact.tar.gz", + "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/", } }, - "m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, - "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, - "ml.g5.xlarge": { + "trn1": { "properties": { - "environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}, - "resource_requirements": {"num_accelerators": 10}, + "supported_inference_instance_types": ["ml.inf1.xlarge", "ml.inf1.2xlarge"], + "default_inference_instance_type": "ml.inf1.xlarge", } }, - "ml.g5.48xlarge": { - "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}} - }, - "ml.g5.12xlarge": { - "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}} - }, - "inf1": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}}, - "inf2": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}}, }, }, "training_ecr_specs": { @@ -2107,18 +2727,9 @@ "framework_version": "1.5.0", "py_version": "py3", }, - "dynamic_container_deployment_supported": True, - "hosting_resource_requirements": { - "min_memory_mb": 81999, - "num_accelerators": 1, - "random_field_1": 1, - }, - "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", - "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", "training_prepacked_script_key": None, - "hosting_prepacked_artifact_key": None, "training_model_package_artifact_uris": None, "deprecate_warn_message": None, "deprecated_message": None, @@ -2167,169 +2778,494 @@ "scope": "container", }, ], - "inference_environment_variables": [ - { - "name": "SAGEMAKER_PROGRAM", - "type": "text", - "default": "inference.py", - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", - "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - "required_for_model_class": False, - }, - { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", - "type": "text", - "default": "20", - "scope": "container", - "required_for_model_class": False, - }, - { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", - "scope": "container", - "required_for_model_class": False, - }, - { - "name": "ENDPOINT_SERVER_TIMEOUT", - "type": "int", - "default": 3600, - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "SAGEMAKER_ENV", - "type": "text", - "default": "1", - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "int", - "default": 1, - "scope": "container", - "required_for_model_class": True, - }, - ], - "inference_vulnerable": False, - "inference_dependencies": [], - "inference_vulnerabilities": [], "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], "deprecated": False, - "default_inference_instance_type": "ml.p2.xlarge", - "supported_inference_instance_types": [ - "ml.p2.xlarge", - "ml.p3.2xlarge", - "ml.g4dn.xlarge", - "ml.m5.large", - "ml.m5.xlarge", - "ml.c5.xlarge", - "ml.c5.2xlarge", - ], - "default_training_instance_type": "ml.p3.2xlarge", - "supported_training_instance_types": [ - "ml.p3.2xlarge", - "ml.p2.xlarge", - "ml.g4dn.2xlarge", - "ml.m5.xlarge", - "ml.c5.2xlarge", - ], - "hosting_use_script_uri": True, - "metrics": [ - { - "Name": "huggingface-textgeneration:train-loss", - "Regex": "'loss default': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeyyyuyuyuyneration:train-loss", - "Regex": "'loss default': ([0-9]+\\.[0-9]+)", - }, - ], - "model_kwargs": {"some-model-kwarg-key": "some-model-kwarg-value"}, - "deploy_kwargs": {"some-model-deploy-kwarg-key": "some-model-deploy-kwarg-value"}, "estimator_kwargs": { "encrypt_inter_container_traffic": True, }, - "fit_kwargs": {"some-estimator-fit-key": "some-estimator-fit-value"}, - "predictor_specs": { - "supported_content_types": ["application/x-image"], - "supported_accept_types": ["application/json;verbose", "application/json"], - "default_content_type": "application/x-image", - "default_accept_type": "application/json", - }, - "inference_volume_size": 123, "training_volume_size": 456, "inference_enable_network_isolation": True, "training_enable_network_isolation": False, - "resource_name_base": "dfsdfsds", }, - "gated_llama_neuron_model": { - "model_id": "meta-textgenerationneuron-llama-2-7b", - "url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", + # noqa: E501 + "variant-model": { + "model_id": "pytorch-ic-mobilenet-v2", + "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", "version": "1.0.0", - "min_sdk_version": "2.198.0", + "min_sdk_version": "2.49.0", "training_supported": True, - "incremental_training_supported": False, + "incremental_training_supported": True, + "hosting_model_package_arns": { + "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/ll" + "ama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" + }, "hosting_ecr_specs": { - "framework": "djl-neuronx", - "framework_version": "0.24.0", - "py_version": "py39", + "framework": "pytorch", + "framework_version": "1.5.0", + "py_version": "py3", }, - "hosting_artifact_key": "meta-textgenerationneuron/meta-textgenerationneuron-llama-2-7b/artifac" - "ts/inference/v1.0.0/", - "hosting_script_key": "source-directory-tarballs/meta/inference/textgenerationneuron/v1.0.0/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "meta-textgenerationneuron/meta-textgenerationneuro" - "n-llama-2-7b/artifacts/inference-prepack/v1.0.0/", - "hosting_prepacked_artifact_version": "1.0.0", - "hosting_use_script_uri": False, - "hosting_eula_key": "fmhMetadata/eula/llamaEula.txt", - "inference_vulnerable": False, - "inference_dependencies": [ - "sagemaker_jumpstart_huggingface_script_utilities==1.0.8", - "sagemaker_jumpstart_script_utilities==1.1.8", - ], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [ - "sagemaker_jumpstart_huggingface_script_utilities==1.1.3", - "sagemaker_jumpstart_script_utilities==1.1.9", - "sagemaker_jumpstart_tabular_script_utilities==1.0.0", - ], - "training_vulnerabilities": [], - "deprecated": False, - "hyperparameters": [ - { - "name": "max_input_length", - "type": "int", - "default": 2048, - "min": 128, - "scope": "algorithm", - }, - { - "name": "preprocessing_num_workers", - "type": "text", - "default": "None", - "scope": "algorithm", - }, - { - "name": "learning_rate", - "type": "float", + "training_instance_type_variants": { + "regional_aliases": {}, + "variants": { + "ml.p2.12xlarge": { + "properties": { + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, + "hyperparameters": [ + { + "name": "eval_metric", + "type": "text", + "default": "auto", + "scope": "algorithm", + }, + { + "name": "presets", + "type": "text", + "default": "medium_quality", + "options": [ + "best_quality", + "high_quality", + "good_quality", + "medium_quality", + "optimize_for_deployment", + "interpretable", + ], + "scope": "algorithm", + }, + { + "name": "auto_stack", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "num_bag_folds", + "type": "text", + "default": "0", + "options": ["0", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + "scope": "algorithm", + }, + { + "name": "num_bag_sets", + "type": "int", + "default": 1, + "min": 1, + "scope": "algorithm", + }, + { + "name": "batch-size", + "type": "int", + "default": 1, + "min": 1, + "scope": "algorithm", + }, + { + "name": "num_stack_levels", + "type": "int", + "default": 0, + "min": 0, + "max": 3, + "scope": "algorithm", + }, + ], + "metrics": [ + { + "Name": "huggingface-textgeneration:instance-typemetric-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance type specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb instance specific': ([0-9]+\\.[0-9]+)", + }, + ], + } + }, + "p2": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_2"}, + "properties": { + "hyperparameters": [ + { + "name": "num_bag_sets", + "type": "int", + "default": 5, + "min": 5, + "scope": "algorithm", + }, + { + "name": "num_stack_levels", + "type": "int", + "default": 6, + "min": 7, + "max": 3, + "scope": "algorithm", + }, + { + "name": "refit_full", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "set_best_to_refit_full", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "save_space", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "verbosity", + "type": "int", + "default": 2, + "min": 0, + "max": 4, + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "metrics": [ + { + "Name": "huggingface-textgeneration:wtafigo", + "Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance family specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)", + }, + ], + }, + }, + }, + }, + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", + "inf_model_package_arn": "us-west-2/blah/blah/blah/inf", + "gpu_model_package_arn": "us-west-2/blah/blah/blah/gpu", + } + }, + "variants": { + "p2": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", + } + }, + "p3": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", + } + }, + "p4": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", + } + }, + "g4dn": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", + } + }, + "g5": { + "properties": { + "resource_requirements": { + "num_accelerators": 888810, + "randon-field-2": 2222, + } + } + }, + "m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, + "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, + "ml.g5.xlarge": { + "properties": { + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}, + "resource_requirements": {"num_accelerators": 10}, + } + }, + "ml.g5.48xlarge": { + "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}} + }, + "ml.g5.12xlarge": { + "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}} + }, + "inf1": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}}, + "inf2": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}}, + }, + }, + "training_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.5.0", + "py_version": "py3", + }, + "dynamic_container_deployment_supported": True, + "hosting_resource_requirements": { + "min_memory_mb": 81999, + "num_accelerators": 1, + "random_field_1": 1, + }, + "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", + "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", + "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + "training_prepacked_script_key": None, + "hosting_prepacked_artifact_key": None, + "training_model_package_artifact_uris": None, + "deprecate_warn_message": None, + "deprecated_message": None, + "hosting_eula_key": None, + "hyperparameters": [ + { + "name": "epochs", + "type": "int", + "default": 3, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "adam-learning-rate", + "type": "float", + "default": 0.05, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "batch-size", + "type": "int", + "default": 4, + "min": 1, + "max": 1024, + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "default_inference_instance_type": "ml.p2.xlarge", + "supported_inference_instance_types": [ + "ml.p2.xlarge", + "ml.p3.2xlarge", + "ml.g4dn.xlarge", + "ml.m5.large", + "ml.m5.xlarge", + "ml.c5.xlarge", + "ml.c5.2xlarge", + ], + "default_training_instance_type": "ml.p3.2xlarge", + "supported_training_instance_types": [ + "ml.p3.2xlarge", + "ml.p2.xlarge", + "ml.g4dn.2xlarge", + "ml.m5.xlarge", + "ml.c5.2xlarge", + ], + "hosting_use_script_uri": True, + "metrics": [ + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'loss default': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeyyyuyuyuyneration:train-loss", + "Regex": "'loss default': ([0-9]+\\.[0-9]+)", + }, + ], + "model_kwargs": {"some-model-kwarg-key": "some-model-kwarg-value"}, + "deploy_kwargs": {"some-model-deploy-kwarg-key": "some-model-deploy-kwarg-value"}, + "estimator_kwargs": { + "encrypt_inter_container_traffic": True, + }, + "fit_kwargs": {"some-estimator-fit-key": "some-estimator-fit-value"}, + "predictor_specs": { + "supported_content_types": ["application/x-image"], + "supported_accept_types": ["application/json;verbose", "application/json"], + "default_content_type": "application/x-image", + "default_accept_type": "application/json", + }, + "inference_volume_size": 123, + "training_volume_size": 456, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": False, + "resource_name_base": "dfsdfsds", + }, + "gated_llama_neuron_model": { + "model_id": "meta-textgenerationneuron-llama-2-7b", + "url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", + "version": "1.0.0", + "min_sdk_version": "2.198.0", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "djl-neuronx", + "framework_version": "0.24.0", + "py_version": "py39", + }, + "hosting_artifact_key": "meta-textgenerationneuron/meta-textgenerationneuron-llama-2-7b/artifac" + "ts/inference/v1.0.0/", + "hosting_script_key": "source-directory-tarballs/meta/inference/textgenerationneuron/v1.0.0/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "meta-textgenerationneuron/meta-textgenerationneuro" + "n-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "hosting_prepacked_artifact_version": "1.0.0", + "hosting_use_script_uri": False, + "hosting_eula_key": "fmhMetadata/eula/llamaEula.txt", + "inference_vulnerable": False, + "inference_dependencies": [ + "sagemaker_jumpstart_huggingface_script_utilities==1.0.8", + "sagemaker_jumpstart_script_utilities==1.1.8", + ], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [ + "sagemaker_jumpstart_huggingface_script_utilities==1.1.3", + "sagemaker_jumpstart_script_utilities==1.1.9", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + ], + "training_vulnerabilities": [], + "deprecated": False, + "hyperparameters": [ + { + "name": "max_input_length", + "type": "int", + "default": 2048, + "min": 128, + "scope": "algorithm", + }, + { + "name": "preprocessing_num_workers", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "learning_rate", + "type": "float", "default": 6e-06, "min": 1e-08, "max": 1, @@ -7107,35 +8043,155 @@ }, {"name": "max_depth", "type": "int", "default": 6, "min": 1, "scope": "algorithm"}, { - "name": "subsample", + "name": "subsample", + "type": "float", + "default": 1, + "min": 1e-20, + "max": 1, + "scope": "algorithm", + }, + { + "name": "colsample_bytree", + "type": "float", + "default": 1, + "min": 1e-20, + "max": 1, + "scope": "algorithm", + }, + { + "name": "reg_lambda", + "type": "float", + "default": 1, + "min": 0, + "max": 200, + "scope": "algorithm", + }, + { + "name": "reg_alpha", + "type": "float", + "default": 0, + "min": 0, + "max": 200, + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "training_script_key": "source-directory-tarballs/xgboost/transfer_learning/classification/" + "v1.0.0/sourcedir.tar.gz", + "training_ecr_specs": { + "framework_version": "1.3-1", + "framework": "xgboost", + "py_version": "py3", + }, + "training_artifact_key": "xgboost-training/train-xgboost-classification-model.tar.gz", + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + }, + {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "text", + "default": "1", + "scope": "container", + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + }, + ], + }, + "sklearn-classification-linear": { + "model_id": "sklearn-classification-linear", + "url": "https://scikit-learn.org/stable/", + "version": "1.0.0", + "min_sdk_version": "2.68.1", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "sklearn", + "framework_version": "0.23-1", + "py_version": "py3", + }, + "hosting_artifact_key": "sklearn-infer/infer-sklearn-classification-linear.tar.gz", + "hosting_script_key": "source-directory-tarballs/sklearn/inference/classification/v1.0.0/sourcedir.tar.gz", + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "hyperparameters": [ + { + "name": "tol", "type": "float", - "default": 1, + "default": 0.0001, "min": 1e-20, - "max": 1, + "max": 50, "scope": "algorithm", }, { - "name": "colsample_bytree", - "type": "float", - "default": 1, - "min": 1e-20, - "max": 1, + "name": "penalty", + "type": "text", + "default": "l2", + "options": ["l1", "l2", "elasticnet", "none"], "scope": "algorithm", }, { - "name": "reg_lambda", + "name": "alpha", "type": "float", - "default": 1, - "min": 0, - "max": 200, + "default": 0.0001, + "min": 1e-20, + "max": 999, "scope": "algorithm", }, { - "name": "reg_alpha", + "name": "l1_ratio", "type": "float", - "default": 0, + "default": 0.15, "min": 0, - "max": 200, + "max": 1, "scope": "algorithm", }, { @@ -7157,14 +8213,14 @@ "scope": "container", }, ], - "training_script_key": "source-directory-tarballs/xgboost/transfer_learning/classification/" + "training_script_key": "source-directory-tarballs/sklearn/transfer_learning/classification/" "v1.0.0/sourcedir.tar.gz", "training_ecr_specs": { - "framework_version": "1.3-1", - "framework": "xgboost", + "framework_version": "0.23-1", + "framework": "sklearn", "py_version": "py3", }, - "training_artifact_key": "xgboost-training/train-xgboost-classification-model.tar.gz", + "training_artifact_key": "sklearn-training/train-sklearn-classification-linear.tar.gz", "inference_environment_variables": [ { "name": "SAGEMAKER_PROGRAM", @@ -7198,1013 +8254,1889 @@ "scope": "container", }, { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", - "scope": "container", - }, + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + }, + ], + }, +} + +BASE_SPEC = { + "model_id": "pytorch-ic-mobilenet-v2", + "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", + "version": "1.0.0", + "min_sdk_version": "2.49.0", + "training_supported": True, + "incremental_training_supported": True, + "gated_bucket": False, + "default_payloads": None, + "hosting_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.5.0", + "py_version": "py3", + }, + "hosting_instance_type_variants": None, + "training_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.5.0", + "py_version": "py3", + }, + "training_instance_type_variants": None, + "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", + "hosting_artifact_uri": None, + "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", + "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + "training_prepacked_script_key": None, + "hosting_prepacked_artifact_key": None, + "training_model_package_artifact_uris": None, + "deprecate_warn_message": None, + "deprecated_message": None, + "hosting_model_package_arns": {}, + "hosting_eula_key": None, + "model_subscription_link": None, + "hyperparameters": [ + { + "name": "epochs", + "type": "int", + "default": 3, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "adam-learning-rate", + "type": "float", + "default": 0.05, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "batch-size", + "type": "int", + "default": 4, + "min": 1, + "max": 1024, + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "default_inference_instance_type": "ml.p2.xlarge", + "supported_inference_instance_types": [ + "ml.p2.xlarge", + "ml.p3.2xlarge", + "ml.g4dn.xlarge", + "ml.m5.large", + "ml.m5.xlarge", + "ml.c5.xlarge", + "ml.c5.2xlarge", + ], + "default_training_instance_type": "ml.p3.2xlarge", + "supported_training_instance_types": [ + "ml.p3.2xlarge", + "ml.p2.xlarge", + "ml.g4dn.2xlarge", + "ml.m5.xlarge", + "ml.c5.2xlarge", + ], + "hosting_use_script_uri": True, + "usage_info_message": None, + "metrics": [{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}], + "model_kwargs": {"some-model-kwarg-key": "some-model-kwarg-value"}, + "deploy_kwargs": {"some-model-deploy-kwarg-key": "some-model-deploy-kwarg-value"}, + "estimator_kwargs": { + "encrypt_inter_container_traffic": True, + }, + "fit_kwargs": {"some-estimator-fit-key": "some-estimator-fit-value"}, + "predictor_specs": { + "supported_content_types": ["application/x-image"], + "supported_accept_types": ["application/json;verbose", "application/json"], + "default_content_type": "application/x-image", + "default_accept_type": "application/json", + }, + "inference_volume_size": 123, + "training_volume_size": 456, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": False, + "resource_name_base": "dfsdfsds", + "hosting_resource_requirements": {"num_accelerators": 1, "min_memory_mb": 34360}, + "dynamic_container_deployment_supported": True, + "inference_configs": None, + "inference_config_components": None, + "training_configs": None, + "training_config_components": None, + "inference_config_rankings": None, + "training_config_rankings": None, + "hosting_additional_data_sources": None, + "hosting_neuron_model_id": None, + "hosting_neuron_model_version": None, +} + +BASE_HOSTING_ADDITIONAL_DATA_SOURCES = { + "hosting_additional_data_sources": { + "speculative_decoding": [ + { + "channel_name": "speculative_decoding_channel", + "artifact_version": "version", + "s3_data_source": { + "compression_type": "None", + "s3_data_type": "S3Prefix", + "s3_uri": "s3://bucket/path1", + "hub_access_config": None, + "model_access_config": None, + }, + } + ], + "scripts": [ + { + "channel_name": "scripts_channel", + "artifact_version": "version", + "s3_data_source": { + "compression_type": "None", + "s3_data_type": "S3Prefix", + "s3_uri": "s3://bucket/path1", + "hub_access_config": None, + "model_access_config": None, + }, + } ], }, - "sklearn-classification-linear": { - "model_id": "sklearn-classification-linear", - "url": "https://scikit-learn.org/stable/", +} + +BASE_HEADER = { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet" + "-inception-v3-classification-4/specs_v1.0.0.json", +} + +BASE_MANIFEST = [ + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", "version": "1.0.0", - "min_sdk_version": "2.68.1", - "training_supported": True, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "sklearn", - "framework_version": "0.23-1", - "py_version": "py3", - }, - "hosting_artifact_key": "sklearn-infer/infer-sklearn-classification-linear.tar.gz", - "hosting_script_key": "source-directory-tarballs/sklearn/inference/classification/v1.0.0/sourcedir.tar.gz", - "inference_vulnerable": False, - "inference_dependencies": [], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], - "deprecated": False, - "hyperparameters": [ - { - "name": "tol", - "type": "float", - "default": 0.0001, - "min": 1e-20, - "max": 50, - "scope": "algorithm", - }, - { - "name": "penalty", - "type": "text", - "default": "l2", - "options": ["l1", "l2", "elasticnet", "none"], - "scope": "algorithm", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet" + "-inception-v3-classification-4/specs_v1.0.0.json", + }, + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet" + "-inception-v3-classification-4/specs_v2.0.0.json", + }, + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/pytorch-ic-" + "imagenet-inception-v3-classification-4/specs_v1.0.0.json", + }, + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/pytorch-ic-imagenet-" + "inception-v3-classification-4/specs_v2.0.0.json", + }, + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "3.0.0", + "min_version": "4.49.0", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v3.0.0.json", + }, +] + +BASE_PROPRIETARY_HEADER = { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], +} + +BASE_PROPRIETARY_MANIFEST = [ + { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], + }, + { + "model_id": "lighton-mini-instruct40b", + "version": "v1.0", + "min_version": "2.0.0", + "spec_key": "proprietary-models/lighton-mini-instruct40b/proprietary_specs_v1.0.json", + "search_keywords": ["Text2Text", "Generation"], + }, + { + "model_id": "ai21-paraphrase", + "version": "1.0.005", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-paraphrase/proprietary_specs_1.0.005.json", + "search_keywords": ["Text2Text", "Generation"], + }, + { + "model_id": "ai21-paraphrase", + "version": "v1.00-rc2-not-valid-version", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-paraphrase/proprietary_specs_1.0.005.json", + "search_keywords": ["Text2Text", "Generation"], + }, + { + "model_id": "nc-soft-model-1", + "version": "v3.0-not-valid-version!", + "min_version": "2.0.0", + "spec_key": "proprietary-models/nc-soft-model-1/proprietary_specs_1.0.005.json", + "search_keywords": ["Text2Text", "Generation"], + }, +] + +BASE_PROPRIETARY_SPEC = { + "model_id": "ai21-jurassic-2-light", + "version": "2.0.004", + "min_sdk_version": "2.999.0", + "listing_id": "prodview-roz6zicyvi666", + "product_id": "1bd680a0-f29b-479d-91c3-9899743021cf", + "model_subscription_link": "https://aws.amazon.com/marketplace/ai/procurement?productId=1bd680a0", + "hosting_notebook_key": "pmm-notebooks/pmm-notebook-ai21-jurassic-2-light.ipynb", + "deploy_kwargs": { + "model_data_download_timeout": 3600, + "container_startup_health_check_timeout": 600, + }, + "default_payloads": { + "Shakespeare": { + "content_type": "application/json", + "prompt_key": "prompt", + "output_keys": {"generated_text": "[0].completions[0].data.text"}, + "body": {"prompt": "To be, or", "maxTokens": 1, "temperature": 0}, + } + }, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "default_inference_instance_type": "ml.p4de.24xlarge", + "supported_inference_instance_types": ["ml.p4de.24xlarge"], + "hosting_model_package_arns": { + "us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/j2-light-v2-0-004", + "us-east-2": "arn:aws:sagemaker:us-east-2:057799348421:model-package/j2-light-v2-0-004", + "us-west-1": "arn:aws:sagemaker:us-west-1:382657785993:model-package/j2-light-v2-0-004", + "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/j2-light-v2-0-004", + "ca-central-1": "arn:aws:sagemaker:ca-central-1:470592106596:model-package/j2-light-v2-0-004", + "eu-central-1": "arn:aws:sagemaker:eu-central-1:446921602837:model-package/j2-light-v2-0-004", + "eu-west-1": "arn:aws:sagemaker:eu-west-1:985815980388:model-package/j2-light-v2-0-004", + "eu-west-2": "arn:aws:sagemaker:eu-west-2:856760150666:model-package/j2-light-v2-0-004", + "eu-west-3": "arn:aws:sagemaker:eu-west-3:843114510376:model-package/j2-light-v2-0-004", + "eu-north-1": "arn:aws:sagemaker:eu-north-1:136758871317:model-package/j2-light-v2-0-004", + "ap-southeast-1": "arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/j2-light-v2-0-004", + "ap-southeast-2": "arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/j2-light-v2-0-004", + "ap-northeast-2": "arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/j2-light-v2-0-004", + "ap-northeast-1": "arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/j2-light-v2-0-004", + "ap-south-1": "arn:aws:sagemaker:ap-south-1:077584701553:model-package/j2-light-v2-0-004", + "sa-east-1": "arn:aws:sagemaker:sa-east-1:270155090741:model-package/j2-light-v2-0-004", + }, +} + + +INFERENCE_CONFIGS = { + "inference_configs": { + "neuron-inference": { + "benchmark_metrics": { + "ml.inf2.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] }, - { - "name": "alpha", - "type": "float", - "default": 0.0001, - "min": 1e-20, - "max": 999, - "scope": "algorithm", + "component_names": ["neuron-inference"], + }, + "neuron-inference-budget": { + "benchmark_metrics": { + "ml.inf2.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] }, - { - "name": "l1_ratio", - "type": "float", - "default": 0.15, - "min": 0, - "max": 1, - "scope": "algorithm", + "component_names": ["neuron-base"], + }, + "gpu-inference-budget": { + "benchmark_metrics": { + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] }, - { - "name": "sagemaker_submit_directory", - "type": "text", - "default": "/opt/ml/input/data/code/sourcedir.tar.gz", - "scope": "container", + "component_names": ["gpu-inference-budget"], + }, + "gpu-inference": { + "benchmark_metrics": { + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] }, - { - "name": "sagemaker_program", - "type": "text", - "default": "transfer_learning.py", - "scope": "container", + "component_names": ["gpu-inference"], + }, + "gpu-inference-model-package": { + "benchmark_metrics": { + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] }, - { - "name": "sagemaker_container_log_level", - "type": "text", - "default": "20", - "scope": "container", + "component_names": ["gpu-inference-model-package"], + }, + "gpu-accelerated": { + "benchmark_metrics": { + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] }, - ], - "training_script_key": "source-directory-tarballs/sklearn/transfer_learning/classification/" - "v1.0.0/sourcedir.tar.gz", - "training_ecr_specs": { - "framework_version": "0.23-1", - "framework": "sklearn", - "py_version": "py3", + "component_names": ["gpu-accelerated"], }, - "training_artifact_key": "sklearn-training/train-sklearn-classification-linear.tar.gz", - "inference_environment_variables": [ - { - "name": "SAGEMAKER_PROGRAM", - "type": "text", - "default": "inference.py", - "scope": "container", + }, + "inference_config_components": { + "neuron-base": { + "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"] + }, + "neuron-inference": { + "default_inference_instance_type": "ml.inf2.xlarge", + "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"], + "hosting_ecr_specs": { + "framework": "huggingface-llm-neuronx", + "framework_version": "0.0.17", + "py_version": "py310", }, - { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", - "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", + "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-inference/model/", + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": {"inf2": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, }, - { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", - "type": "text", - "default": "20", - "scope": "container", + }, + "neuron-budget": {"inference_environment_variables": {"BUDGET": "1234"}}, + "gpu-inference": { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference/model/", + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-hosting:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + } + }, + "variants": { + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, }, - { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", + }, + "gpu-inference-model-package": { + "default_inference_instance_type": "ml.p2.xlarge", + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "hosting_model_package_arns": { + "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/ll" + "ama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" }, - {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, - { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "text", - "default": "1", - "scope": "container", + }, + "gpu-inference-budget": { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference-budget/model/", + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting:1.13.1-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, }, - { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", - "scope": "container", + }, + "gpu-accelerated": { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, }, - ], + "hosting_additional_data_sources": { + "speculative_decoding": [ + { + "channel_name": "draft_model_name", + "artifact_version": "1.2.1", + "s3_data_source": { + "compression_type": "None", + "model_access_config": {"accept_eula": False}, + "s3_data_type": "S3Prefix", + "s3_uri": "key/to/draft/model/artifact/", + }, + } + ], + }, + }, }, } -BASE_SPEC = { - "model_id": "pytorch-ic-mobilenet-v2", - "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", - "version": "1.0.0", - "min_sdk_version": "2.49.0", - "training_supported": True, - "incremental_training_supported": True, - "gated_bucket": False, - "default_payloads": None, - "hosting_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", - }, - "hosting_instance_type_variants": None, - "training_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", - }, - "training_instance_type_variants": None, - "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", - "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", - "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", - "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", - "training_prepacked_script_key": None, - "hosting_prepacked_artifact_key": None, - "training_model_package_artifact_uris": None, - "deprecate_warn_message": None, - "deprecated_message": None, - "hosting_model_package_arns": {}, - "hosting_eula_key": None, - "model_subscription_link": None, - "hyperparameters": [ - { - "name": "epochs", - "type": "int", - "default": 3, - "min": 1, - "max": 1000, - "scope": "algorithm", - }, - { - "name": "adam-learning-rate", - "type": "float", - "default": 0.05, - "min": 1e-08, - "max": 1, - "scope": "algorithm", - }, - { - "name": "batch-size", - "type": "int", - "default": 4, - "min": 1, - "max": 1024, - "scope": "algorithm", - }, - { - "name": "sagemaker_submit_directory", - "type": "text", - "default": "/opt/ml/input/data/code/sourcedir.tar.gz", - "scope": "container", +TRAINING_CONFIGS = { + "training_configs": { + "neuron-training": { + "benchmark_metrics": { + "ml.tr1n1.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ], + "ml.tr1n1.4xlarge": [ + {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1} + ], + }, + "component_names": ["neuron-training"], + "default_inference_config": "neuron-inference", + "default_incremental_training_config": "neuron-training", + "supported_inference_configs": ["neuron-inference", "neuron-inference-budget"], + "supported_incremental_training_configs": ["neuron-training", "neuron-training-budget"], }, - { - "name": "sagemaker_program", - "type": "text", - "default": "transfer_learning.py", - "scope": "container", + "neuron-training-budget": { + "benchmark_metrics": { + "ml.tr1n1.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ], + "ml.tr1n1.4xlarge": [ + {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1} + ], + }, + "component_names": ["neuron-training-budget"], + "default_inference_config": "neuron-inference-budget", + "default_incremental_training_config": "neuron-training-budget", + "supported_inference_configs": ["neuron-inference", "neuron-inference-budget"], + "supported_incremental_training_configs": ["neuron-training", "neuron-training-budget"], }, - { - "name": "sagemaker_container_log_level", - "type": "text", - "default": "20", - "scope": "container", + "gpu-training": { + "benchmark_metrics": { + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "200", "unit": "Tokens/S", "concurrency": "1"} + ], + }, + "component_names": ["gpu-training"], + "default_inference_config": "gpu-inference", + "default_incremental_training_config": "gpu-training", + "supported_inference_configs": ["gpu-inference", "gpu-inference-budget"], + "supported_incremental_training_configs": ["gpu-training", "gpu-training-budget"], }, - ], - "inference_environment_variables": [ - { - "name": "SAGEMAKER_PROGRAM", - "type": "text", - "default": "inference.py", - "scope": "container", - "required_for_model_class": True, + "gpu-training-budget": { + "benchmark_metrics": { + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": "1"} + ] + }, + "component_names": ["gpu-training-budget"], + "default_inference_config": "gpu-inference-budget", + "default_incremental_training_config": "gpu-training-budget", + "supported_inference_configs": ["gpu-inference", "gpu-inference-budget"], + "supported_incremental_training_configs": ["gpu-training", "gpu-training-budget"], }, - { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", - "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - "required_for_model_class": False, + }, + "training_config_components": { + "neuron-training": { + "default_training_instance_type": "ml.trn1.2xlarge", + "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], + "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training/model/", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": {"trn1": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, + }, }, - { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", - "type": "text", - "default": "20", - "scope": "container", - "required_for_model_class": False, + "gpu-training": { + "default_training_instance_type": "ml.p2.xlarge", + "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training/model/", + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-training:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, }, - { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", - "scope": "container", - "required_for_model_class": False, + "neuron-training-budget": { + "default_training_instance_type": "ml.trn1.2xlarge", + "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], + "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training-budget/model/", + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": {"trn1": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, + }, }, - { - "name": "ENDPOINT_SERVER_TIMEOUT", - "type": "int", - "default": 3600, - "scope": "container", - "required_for_model_class": True, + "gpu-training-budget": { + "default_training_instance_type": "ml.p2.xlarge", + "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training-budget/model/", + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-training:1.13.1-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, }, - { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + }, +} + + +INFERENCE_CONFIG_RANKINGS = { + "inference_config_rankings": { + "overall": { + "description": "Overall rankings of configs", + "rankings": [ + "neuron-inference", + "neuron-inference-budget", + "gpu-inference", + "gpu-inference-budget", + "gpu-accelerated", + ], }, - { - "name": "SAGEMAKER_ENV", - "type": "text", - "default": "1", - "scope": "container", - "required_for_model_class": True, + "performance": { + "description": "Configs ranked based on performance", + "rankings": [ + "neuron-inference", + "gpu-inference", + "neuron-inference-budget", + "gpu-inference-budget", + ], }, - { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "int", - "default": 1, - "scope": "container", - "required_for_model_class": True, + "cost": { + "description": "Configs ranked based on cost", + "rankings": [ + "neuron-inference-budget", + "gpu-inference-budget", + "neuron-inference", + "gpu-inference", + ], }, - ], - "inference_vulnerable": False, - "inference_dependencies": [], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], - "deprecated": False, - "default_inference_instance_type": "ml.p2.xlarge", - "supported_inference_instance_types": [ - "ml.p2.xlarge", - "ml.p3.2xlarge", - "ml.g4dn.xlarge", - "ml.m5.large", - "ml.m5.xlarge", - "ml.c5.xlarge", - "ml.c5.2xlarge", - ], - "default_training_instance_type": "ml.p3.2xlarge", - "supported_training_instance_types": [ - "ml.p3.2xlarge", - "ml.p2.xlarge", - "ml.g4dn.2xlarge", - "ml.m5.xlarge", - "ml.c5.2xlarge", - ], - "hosting_use_script_uri": True, - "usage_info_message": None, - "metrics": [{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}], - "model_kwargs": {"some-model-kwarg-key": "some-model-kwarg-value"}, - "deploy_kwargs": {"some-model-deploy-kwarg-key": "some-model-deploy-kwarg-value"}, - "estimator_kwargs": { - "encrypt_inter_container_traffic": True, - }, - "fit_kwargs": {"some-estimator-fit-key": "some-estimator-fit-value"}, - "predictor_specs": { - "supported_content_types": ["application/x-image"], - "supported_accept_types": ["application/json;verbose", "application/json"], - "default_content_type": "application/x-image", - "default_accept_type": "application/json", - }, - "inference_volume_size": 123, - "training_volume_size": 456, - "inference_enable_network_isolation": True, - "training_enable_network_isolation": False, - "resource_name_base": "dfsdfsds", - "hosting_resource_requirements": {"num_accelerators": 1, "min_memory_mb": 34360}, - "dynamic_container_deployment_supported": True, - "inference_configs": None, - "inference_config_components": None, - "training_configs": None, - "training_config_components": None, - "inference_config_rankings": None, - "training_config_rankings": None, - "hosting_additional_data_sources": None, - "hosting_neuron_model_id": None, - "hosting_neuron_model_version": None, + } } -BASE_HOSTING_ADDITIONAL_DATA_SOURCES = { - "hosting_additional_data_sources": { - "speculative_decoding": [ - { - "channel_name": "speculative_decoding_channel", - "artifact_version": "version", - "s3_data_source": { - "compression_type": "None", - "s3_data_type": "S3Prefix", - "s3_uri": "s3://bucket/path1", - "hub_access_config": None, - "model_access_config": None, - }, - } - ], - "scripts": [ - { - "channel_name": "scripts_channel", - "artifact_version": "version", - "s3_data_source": { - "compression_type": "None", - "s3_data_type": "S3Prefix", - "s3_uri": "s3://bucket/path1", - "hub_access_config": None, - "model_access_config": None, - }, - } - ], - }, +TRAINING_CONFIG_RANKINGS = { + "training_config_rankings": { + "overall": { + "description": "Overall rankings of configs", + "rankings": [ + "neuron-training", + "neuron-training-budget", + "gpu-training", + "gpu-training-budget", + ], + }, + "performance_training": { + "description": "Configs ranked based on performance", + "rankings": [ + "neuron-training", + "gpu-training", + "neuron-training-budget", + "gpu-training-budget", + ], + "instance_type_overrides": { + "ml.p2.xlarge": [ + "neuron-training", + "neuron-training-budget", + "gpu-training", + "gpu-training-budget", + ] + }, + }, + "cost_training": { + "description": "Configs ranked based on cost", + "rankings": [ + "neuron-training-budget", + "gpu-training-budget", + "neuron-training", + "gpu-training", + ], + }, + } } -BASE_HEADER = { - "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", - "version": "1.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet" - "-inception-v3-classification-4/specs_v1.0.0.json", -} -BASE_MANIFEST = [ - { - "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", - "version": "1.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet" - "-inception-v3-classification-4/specs_v1.0.0.json", - }, +DEPLOYMENT_CONFIGS = [ { - "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", - "version": "2.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet" - "-inception-v3-classification-4/specs_v2.0.0.json", + "DeploymentConfigName": "neuron-inference", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "InstanceType": "ml.p2.xlarge", + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + }, + "AccelerationConfigs": None, + "BenchmarkMetrics": [ + {"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs", "concurrency": 1} + ], }, { - "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", - "version": "1.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/pytorch-ic-" - "imagenet-inception-v3-classification-4/specs_v1.0.0.json", + "DeploymentConfigName": "neuron-inference-budget", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "InstanceType": "ml.p2.xlarge", + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + }, + "AccelerationConfigs": None, + "BenchmarkMetrics": [ + {"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs", "concurrency": 1} + ], }, { - "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", - "version": "2.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/pytorch-ic-imagenet-" - "inception-v3-classification-4/specs_v2.0.0.json", + "DeploymentConfigName": "gpu-inference-budget", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "InstanceType": "ml.p2.xlarge", + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + }, + "AccelerationConfigs": None, + "BenchmarkMetrics": [ + {"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs", "concurrency": 1} + ], }, { - "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", - "version": "3.0.0", - "min_version": "4.49.0", - "spec_key": "community_models_specs/tensorflow-ic-" - "imagenet-inception-v3-classification-4/specs_v3.0.0.json", + "DeploymentConfigName": "gpu-inference", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "InstanceType": "ml.p2.xlarge", + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + }, + "AccelerationConfigs": None, + "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], }, ] -BASE_PROPRIETARY_HEADER = { - "model_id": "ai21-summarization", - "version": "1.1.003", - "min_version": "2.0.0", - "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", - "search_keywords": ["Text2Text", "Generation"], -} - -BASE_PROPRIETARY_MANIFEST = [ - { - "model_id": "ai21-summarization", - "version": "1.1.003", - "min_version": "2.0.0", - "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", - "search_keywords": ["Text2Text", "Generation"], - }, - { - "model_id": "lighton-mini-instruct40b", - "version": "v1.0", - "min_version": "2.0.0", - "spec_key": "proprietary-models/lighton-mini-instruct40b/proprietary_specs_v1.0.json", - "search_keywords": ["Text2Text", "Generation"], - }, - { - "model_id": "ai21-paraphrase", - "version": "1.0.005", - "min_version": "2.0.0", - "spec_key": "proprietary-models/ai21-paraphrase/proprietary_specs_1.0.005.json", - "search_keywords": ["Text2Text", "Generation"], - }, - { - "model_id": "ai21-paraphrase", - "version": "v1.00-rc2-not-valid-version", - "min_version": "2.0.0", - "spec_key": "proprietary-models/ai21-paraphrase/proprietary_specs_1.0.005.json", - "search_keywords": ["Text2Text", "Generation"], - }, - { - "model_id": "nc-soft-model-1", - "version": "v3.0-not-valid-version!", - "min_version": "2.0.0", - "spec_key": "proprietary-models/nc-soft-model-1/proprietary_specs_1.0.005.json", - "search_keywords": ["Text2Text", "Generation"], - }, -] -BASE_PROPRIETARY_SPEC = { - "model_id": "ai21-jurassic-2-light", - "version": "2.0.004", - "min_sdk_version": "2.999.0", - "listing_id": "prodview-roz6zicyvi666", - "product_id": "1bd680a0-f29b-479d-91c3-9899743021cf", - "model_subscription_link": "https://aws.amazon.com/marketplace/ai/procurement?productId=1bd680a0", - "hosting_notebook_key": "pmm-notebooks/pmm-notebook-ai21-jurassic-2-light.ipynb", - "deploy_kwargs": { - "model_data_download_timeout": 3600, - "container_startup_health_check_timeout": 600, - }, - "default_payloads": { - "Shakespeare": { - "content_type": "application/json", - "prompt_key": "prompt", - "output_keys": {"generated_text": "[0].completions[0].data.text"}, - "body": {"prompt": "To be, or", "maxTokens": 1, "temperature": 0}, +INIT_KWARGS = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu" + "-py310-cu121-ubuntu20.04", + "model_data": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface-textgeneration" + "-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", } }, - "predictor_specs": { - "supported_content_types": ["application/json"], - "supported_accept_types": ["application/json"], - "default_content_type": "application/json", - "default_accept_type": "application/json", - }, - "default_inference_instance_type": "ml.p4de.24xlarge", - "supported_inference_instance_types": ["ml.p4de.24xlarge"], - "hosting_model_package_arns": { - "us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/j2-light-v2-0-004", - "us-east-2": "arn:aws:sagemaker:us-east-2:057799348421:model-package/j2-light-v2-0-004", - "us-west-1": "arn:aws:sagemaker:us-west-1:382657785993:model-package/j2-light-v2-0-004", - "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/j2-light-v2-0-004", - "ca-central-1": "arn:aws:sagemaker:ca-central-1:470592106596:model-package/j2-light-v2-0-004", - "eu-central-1": "arn:aws:sagemaker:eu-central-1:446921602837:model-package/j2-light-v2-0-004", - "eu-west-1": "arn:aws:sagemaker:eu-west-1:985815980388:model-package/j2-light-v2-0-004", - "eu-west-2": "arn:aws:sagemaker:eu-west-2:856760150666:model-package/j2-light-v2-0-004", - "eu-west-3": "arn:aws:sagemaker:eu-west-3:843114510376:model-package/j2-light-v2-0-004", - "eu-north-1": "arn:aws:sagemaker:eu-north-1:136758871317:model-package/j2-light-v2-0-004", - "ap-southeast-1": "arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/j2-light-v2-0-004", - "ap-southeast-2": "arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/j2-light-v2-0-004", - "ap-northeast-2": "arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/j2-light-v2-0-004", - "ap-northeast-1": "arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/j2-light-v2-0-004", - "ap-south-1": "arn:aws:sagemaker:ap-south-1:077584701553:model-package/j2-light-v2-0-004", - "sa-east-1": "arn:aws:sagemaker:sa-east-1:270155090741:model-package/j2-light-v2-0-004", + "instance_type": "ml.p2.xlarge", + "env": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", }, + "role": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20230707T131628", + "name": "hf-textgeneration-bloom-1b1-2024-04-22-20-23-48-799", + "enable_network_isolation": True, } - -INFERENCE_CONFIGS = { - "inference_configs": { - "neuron-inference": { - "benchmark_metrics": { - "ml.inf2.2xlarge": [ - {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} - ] +HUB_MODEL_DOCUMENT_DICTS = { + "huggingface-llm-gemma-2b-instruct": { + "Url": "https://huggingface.co/google/gemma-2b-it", + "MinSdkVersion": "2.189.0", + "TrainingSupported": True, + "IncrementalTrainingSupported": False, + "HostingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04", # noqa: E501 + "HostingArtifactS3DataType": "S3Prefix", + "HostingArtifactCompressionType": "None", + "HostingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/huggingface-llm-gemma-2b-instruct/artifacts/inference/v1.0.0/", # noqa: E501 + "HostingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", # noqa: E501 + "HostingPrepackedArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/huggingface-llm-gemma-2b-instruct/artifacts/inference-prepack/v1.0.0/", # noqa: E501 + "HostingPrepackedArtifactVersion": "1.0.0", + "HostingUseScriptUri": False, + "HostingEulaUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/fmhMetadata/terms/gemmaTerms.txt", + "TrainingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/transfer_learning/llm/v1.1.1/sourcedir.tar.gz", # noqa: E501 + "TrainingPrepackedScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/transfer_learning/llm/prepack/v1.1.1/sourcedir.tar.gz", # noqa: E501 + "TrainingPrepackedScriptVersion": "1.1.1", + "TrainingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", # noqa: E501 + "TrainingArtifactS3DataType": "S3Prefix", + "TrainingArtifactCompressionType": "None", + "TrainingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + "Hyperparameters": [ + { + "Name": "peft_type", + "Type": "text", + "Default": "lora", + "Options": ["lora", "None"], + "Scope": "algorithm", + }, + { + "Name": "instruction_tuned", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "chat_dataset", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "epoch", + "Type": "int", + "Default": 1, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "learning_rate", + "Type": "float", + "Default": 0.0001, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "lora_r", + "Type": "int", + "Default": 64, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + {"Name": "lora_alpha", "Type": "int", "Default": 16, "Min": 0, "Scope": "algorithm"}, + { + "Name": "lora_dropout", + "Type": "float", + "Default": 0, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + {"Name": "bits", "Type": "int", "Default": 4, "Scope": "algorithm"}, + { + "Name": "double_quant", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "quant_Type", + "Type": "text", + "Default": "nf4", + "Options": ["fp4", "nf4"], + "Scope": "algorithm", + }, + { + "Name": "per_device_train_batch_size", + "Type": "int", + "Default": 1, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "per_device_eval_batch_size", + "Type": "int", + "Default": 2, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "warmup_ratio", + "Type": "float", + "Default": 0.1, + "Min": 0, + "Max": 1, + "Scope": "algorithm", }, - "component_names": ["neuron-inference"], - }, - "neuron-inference-budget": { - "benchmark_metrics": { - "ml.inf2.2xlarge": [ - {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} - ] + { + "Name": "train_from_scratch", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", }, - "component_names": ["neuron-base"], - }, - "gpu-inference-budget": { - "benchmark_metrics": { - "ml.p3.2xlarge": [ - {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} - ] + { + "Name": "fp16", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", }, - "component_names": ["gpu-inference-budget"], - }, - "gpu-inference": { - "benchmark_metrics": { - "ml.p3.2xlarge": [ - {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} - ] + { + "Name": "bf16", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", }, - "component_names": ["gpu-inference"], - }, - "gpu-inference-model-package": { - "benchmark_metrics": { - "ml.p3.2xlarge": [ - {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} - ] + { + "Name": "evaluation_strategy", + "Type": "text", + "Default": "steps", + "Options": ["steps", "epoch", "no"], + "Scope": "algorithm", }, - "component_names": ["gpu-inference-model-package"], - }, - "gpu-accelerated": { - "benchmark_metrics": { - "ml.p3.2xlarge": [ - {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} - ] + { + "Name": "eval_steps", + "Type": "int", + "Default": 20, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", }, - "component_names": ["gpu-accelerated"], - }, - }, - "inference_config_components": { - "neuron-base": { - "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"] - }, - "neuron-inference": { - "default_inference_instance_type": "ml.inf2.xlarge", - "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"], - "hosting_ecr_specs": { - "framework": "huggingface-llm-neuronx", - "framework_version": "0.0.17", - "py_version": "py310", + { + "Name": "gradient_accumulation_steps", + "Type": "int", + "Default": 4, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", }, - "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-inference/model/", - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } - }, - "variants": {"inf2": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, + { + "Name": "logging_steps", + "Type": "int", + "Default": 8, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", }, - }, - "neuron-budget": {"inference_environment_variables": {"BUDGET": "1234"}}, - "gpu-inference": { - "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], - "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference/model/", - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-hosting:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - } - }, - "variants": { - "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - }, + { + "Name": "weight_decay", + "Type": "float", + "Default": 0.2, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", }, - }, - "gpu-inference-model-package": { - "default_inference_instance_type": "ml.p2.xlarge", - "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], - "hosting_model_package_arns": { - "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/ll" - "ama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" + { + "Name": "load_best_model_at_end", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", }, - }, - "gpu-inference-budget": { - "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], - "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference-budget/model/", - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting:1.13.1-py310-sdk2.14.1-ubuntu20.04" - } - }, - "variants": { - "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - }, + { + "Name": "max_train_samples", + "Type": "int", + "Default": -1, + "Min": -1, + "Scope": "algorithm", }, - }, - "gpu-accelerated": { - "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } - }, - "variants": { - "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - }, + { + "Name": "max_val_samples", + "Type": "int", + "Default": -1, + "Min": -1, + "Scope": "algorithm", }, - "hosting_additional_data_sources": { - "speculative_decoding": [ - { - "channel_name": "draft_model_name", - "artifact_version": "1.2.1", - "s3_data_source": { - "compression_type": "None", - "model_access_config": {"accept_eula": False}, - "s3_data_type": "S3Prefix", - "s3_uri": "key/to/draft/model/artifact/", - }, - } - ], + { + "Name": "seed", + "Type": "int", + "Default": 10, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "max_input_length", + "Type": "int", + "Default": 1024, + "Min": -1, + "Scope": "algorithm", + }, + { + "Name": "validation_split_ratio", + "Type": "float", + "Default": 0.2, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "train_data_split_seed", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + { + "Name": "preprocessing_num_workers", + "Type": "text", + "Default": "None", + "Scope": "algorithm", + }, + {"Name": "max_steps", "Type": "int", "Default": -1, "Scope": "algorithm"}, + { + "Name": "gradient_checkpointing", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "early_stopping_patience", + "Type": "int", + "Default": 3, + "Min": 1, + "Scope": "algorithm", + }, + { + "Name": "early_stopping_threshold", + "Type": "float", + "Default": 0.0, + "Min": 0, + "Scope": "algorithm", + }, + { + "Name": "adam_beta1", + "Type": "float", + "Default": 0.9, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "adam_beta2", + "Type": "float", + "Default": 0.999, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "adam_epsilon", + "Type": "float", + "Default": 1e-08, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "max_grad_norm", + "Type": "float", + "Default": 1.0, + "Min": 0, + "Scope": "algorithm", + }, + { + "Name": "label_smoothing_factor", + "Type": "float", + "Default": 0, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "logging_first_step", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "logging_nan_inf_filter", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "save_strategy", + "Type": "text", + "Default": "steps", + "Options": ["no", "epoch", "steps"], + "Scope": "algorithm", + }, + { + "Name": "save_steps", + "Type": "int", + "Default": 500, + "Min": 1, + "Scope": "algorithm", + }, # noqa: E501 + { + "Name": "save_total_limit", + "Type": "int", + "Default": 1, + "Scope": "algorithm", + }, # noqa: E501 + { + "Name": "dataloader_drop_last", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "dataloader_num_workers", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + { + "Name": "eval_accumulation_steps", + "Type": "text", + "Default": "None", + "Scope": "algorithm", + }, + { + "Name": "auto_find_batch_size", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "lr_scheduler_type", + "Type": "text", + "Default": "constant_with_warmup", + "Options": ["constant_with_warmup", "linear"], + "Scope": "algorithm", + }, + { + "Name": "warmup_steps", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, # noqa: E501 + { + "Name": "deepspeed", + "Type": "text", + "Default": "False", + "Options": ["False"], + "Scope": "algorithm", + }, + { + "Name": "sagemaker_submit_directory", + "Type": "text", + "Default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "Scope": "container", + }, + { + "Name": "sagemaker_program", + "Type": "text", + "Default": "transfer_learning.py", + "Scope": "container", + }, + { + "Name": "sagemaker_container_log_level", + "Type": "text", + "Default": "20", + "Scope": "container", + }, + ], + "InferenceEnvironmentVariables": [ + { + "Name": "SAGEMAKER_PROGRAM", + "Type": "text", + "Default": "inference.py", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "SAGEMAKER_SUBMIT_DIRECTORY", + "Type": "text", + "Default": "/opt/ml/model/code", + "Scope": "container", + "RequiredForModelClass": False, + }, + { + "Name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "Type": "text", + "Default": "20", + "Scope": "container", + "RequiredForModelClass": False, + }, + { + "Name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "Type": "text", + "Default": "3600", + "Scope": "container", + "RequiredForModelClass": False, + }, + { + "Name": "ENDPOINT_SERVER_TIMEOUT", + "Type": "int", + "Default": 3600, + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "MODEL_CACHE_ROOT", + "Type": "text", + "Default": "/opt/ml/model", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "SAGEMAKER_ENV", + "Type": "text", + "Default": "1", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "HF_MODEL_ID", + "Type": "text", + "Default": "/opt/ml/model", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "MAX_INPUT_LENGTH", + "Type": "text", + "Default": "8191", + "Scope": "container", + "RequiredForModelClass": True, }, - }, - }, -} - -TRAINING_CONFIGS = { - "training_configs": { - "neuron-training": { - "benchmark_metrics": { - "ml.tr1n1.2xlarge": [ - {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} - ], - "ml.tr1n1.4xlarge": [ - {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1} - ], + { + "Name": "MAX_TOTAL_TOKENS", + "Type": "text", + "Default": "8192", + "Scope": "container", + "RequiredForModelClass": True, }, - "component_names": ["neuron-training"], - "default_inference_config": "neuron-inference", - "default_incremental_training_config": "neuron-training", - "supported_inference_configs": ["neuron-inference", "neuron-inference-budget"], - "supported_incremental_training_configs": ["neuron-training", "neuron-training-budget"], - }, - "neuron-training-budget": { - "benchmark_metrics": { - "ml.tr1n1.2xlarge": [ - {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} - ], - "ml.tr1n1.4xlarge": [ - {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1} - ], + { + "Name": "MAX_BATCH_PREFILL_TOKENS", + "Type": "text", + "Default": "8191", + "Scope": "container", + "RequiredForModelClass": True, }, - "component_names": ["neuron-training-budget"], - "default_inference_config": "neuron-inference-budget", - "default_incremental_training_config": "neuron-training-budget", - "supported_inference_configs": ["neuron-inference", "neuron-inference-budget"], - "supported_incremental_training_configs": ["neuron-training", "neuron-training-budget"], - }, - "gpu-training": { - "benchmark_metrics": { - "ml.p3.2xlarge": [ - {"name": "Latency", "value": "200", "unit": "Tokens/S", "concurrency": "1"} - ], + { + "Name": "SM_NUM_GPUS", + "Type": "text", + "Default": "1", + "Scope": "container", + "RequiredForModelClass": True, }, - "component_names": ["gpu-training"], - "default_inference_config": "gpu-inference", - "default_incremental_training_config": "gpu-training", - "supported_inference_configs": ["gpu-inference", "gpu-inference-budget"], - "supported_incremental_training_configs": ["gpu-training", "gpu-training-budget"], - }, - "gpu-training-budget": { - "benchmark_metrics": { - "ml.p3.2xlarge": [ - {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": "1"} - ] + { + "Name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "Type": "int", + "Default": 1, + "Scope": "container", + "RequiredForModelClass": True, }, - "component_names": ["gpu-training-budget"], - "default_inference_config": "gpu-inference-budget", - "default_incremental_training_config": "gpu-training-budget", - "supported_inference_configs": ["gpu-inference", "gpu-inference-budget"], - "supported_incremental_training_configs": ["gpu-training", "gpu-training-budget"], - }, - }, - "training_config_components": { - "neuron-training": { - "default_training_instance_type": "ml.trn1.2xlarge", - "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], - "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training/model/", - "training_ecr_specs": { - "framework": "huggingface", - "framework_version": "2.0.0", - "py_version": "py310", - "huggingface_transformers_version": "4.28.1", + ], + "TrainingMetrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", }, - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'loss': ([0-9]+\\.[0-9]+)", + }, # noqa: E501 + ], + "InferenceDependencies": [], + "TrainingDependencies": [ + "accelerate==0.26.1", + "bitsandbytes==0.42.0", + "deepspeed==0.10.3", + "docstring-parser==0.15", + "flash_attn==2.5.5", + "ninja==1.11.1", + "packaging==23.2", + "peft==0.8.2", + "py_cpuinfo==9.0.0", + "rich==13.7.0", + "safetensors==0.4.2", + "sagemaker_jumpstart_huggingface_script_utilities==1.2.1", + "sagemaker_jumpstart_script_utilities==1.1.9", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + "shtab==1.6.5", + "tokenizers==0.15.1", + "transformers==4.38.1", + "trl==0.7.10", + "tyro==0.7.2", + ], + "DefaultInferenceInstanceType": "ml.g5.xlarge", + "SupportedInferenceInstanceTypes": [ + "ml.g5.xlarge", + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "DefaultTrainingInstanceType": "ml.g5.2xlarge", + "SupportedTrainingInstanceTypes": [ + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "SageMakerSdkPredictorSpecifications": { + "SupportedContentTypes": ["application/json"], + "SupportedAcceptTypes": ["application/json"], + "DefaultContentType": "application/json", + "DefaultAcceptType": "application/json", + }, + "InferenceVolumeSize": 512, + "TrainingVolumeSize": 512, + "InferenceEnableNetworkIsolation": True, + "TrainingEnableNetworkIsolation": True, + "FineTuningSupported": True, + "ValidationSupported": True, + "DefaultTrainingDatasetUri": "s3://jumpstart-cache-prod-us-west-2/training-datasets/oasst_top/train/", # noqa: E501 + "ResourceNameBase": "hf-llm-gemma-2b-instruct", + "DefaultPayloads": { + "HelloWorld": { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "GeneratedText": "[0].generated_text", + "InputLogprobs": "[0].details.prefill[*].logprob", + }, + "Body": { + "Inputs": "user\nWrite a hello world program\nmodel", # noqa: E501 + "Parameters": { + "MaxNewTokens": 256, + "DecoderInputDetails": True, + "Details": True, + }, }, - "variants": {"trn1": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, }, - }, - "gpu-training": { - "default_training_instance_type": "ml.p2.xlarge", - "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], - "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training/model/", - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-training:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } + "MachineLearningPoem": { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "GeneratedText": "[0].generated_text", + "InputLogprobs": "[0].details.prefill[*].logprob", }, - "variants": { - "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "Body": { + "Inputs": "Write me a poem about Machine Learning.", + "Parameters": { + "MaxNewTokens": 256, + "DecoderInputDetails": True, + "Details": True, + }, }, }, }, - "neuron-training-budget": { - "default_training_instance_type": "ml.trn1.2xlarge", - "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], - "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training-budget/model/", - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } - }, - "variants": {"trn1": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, + "GatedBucket": True, + "HostingResourceRequirements": {"MinMemoryMb": 8192, "NumAccelerators": 1}, + "HostingInstanceTypeVariants": { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" # noqa: E501 + }, + "Variants": { + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, }, }, - "gpu-training-budget": { - "default_training_instance_type": "ml.p2.xlarge", - "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], - "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training-budget/model/", - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-training:1.13.1-py310-sdk2.14.1-ubuntu20.04" - } + "TrainingInstanceTypeVariants": { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" # noqa: E501 + }, + "Variants": { + "g4dn": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/g4dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + }, }, - "variants": { - "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "g5": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + }, + }, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/p3dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + }, + }, + "p4d": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/p4d/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + }, }, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, }, }, - }, -} - - -INFERENCE_CONFIG_RANKINGS = { - "inference_config_rankings": { - "overall": { - "description": "Overall rankings of configs", - "rankings": [ - "neuron-inference", - "neuron-inference-budget", - "gpu-inference", - "gpu-inference-budget", - "gpu-accelerated", - ], - }, - "performance": { - "description": "Configs ranked based on performance", - "rankings": [ - "neuron-inference", - "gpu-inference", - "neuron-inference-budget", - "gpu-inference-budget", - ], - }, - "cost": { - "description": "Configs ranked based on cost", - "rankings": [ - "neuron-inference-budget", - "gpu-inference-budget", - "neuron-inference", - "gpu-inference", + "ContextualHelp": { + "HubFormatTrainData": [ + "A train and an optional validation directories. Each directory contains a CSV/JSON/TXT. ", + "- For CSV/JSON files, the text data is used from the column called 'text' or the first column if no column called 'text' is found", # noqa: E501 + "- The number of files under train and validation (if provided) should equal to one, respectively.", + " [Learn how to setup an AWS S3 bucket.](https://docs.aws.amazon.com/AmazonS3/latest/dev/UsingBucket.html)", # noqa: E501 ], - }, - } -} - -TRAINING_CONFIG_RANKINGS = { - "training_config_rankings": { - "overall": { - "description": "Overall rankings of configs", - "rankings": [ - "neuron-training", - "neuron-training-budget", - "gpu-training", - "gpu-training-budget", + "HubDefaultTrainData": [ + "Dataset: [SEC](https://www.sec.gov/edgar/searchedgar/companysearch)", + "SEC filing contains regulatory documents that companies and issuers of securities must submit to the Securities and Exchange Commission (SEC) on a regular basis.", # noqa: E501 + "License: [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/legalcode)", ], }, - "performance_training": { - "description": "Configs ranked based on performance", - "rankings": [ - "neuron-training", - "gpu-training", - "neuron-training-budget", - "gpu-training-budget", - ], - "instance_type_overrides": { - "ml.p2.xlarge": [ - "neuron-training", - "neuron-training-budget", - "gpu-training", - "gpu-training-budget", - ] + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "EncryptInterContainerTraffic": True, + "DisableOutputCompression": True, + "MaxRuntimeInSeconds": 360000, + "DynamicContainerDeploymentSupported": True, + "TrainingModelPackageArtifactUri": None, + "Dependencies": [], + }, + "meta-textgeneration-llama-2-70b": { + "Url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", + "MinSdkVersion": "2.198.0", + "TrainingSupported": True, + "IncrementalTrainingSupported": False, + "HostingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04", # noqa: E501 + "HostingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration-llama-2-70b/artifacts/inference/v1.0.0/", # noqa: E501 + "HostingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/meta/inference/textgeneration/v1.2.3/sourcedir.tar.gz", # noqa: E501 + "HostingPrepackedArtifactUri": "s3://jumpstart-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration-llama-2-70b/artifacts/inference-prepack/v1.0.0/", # noqa: E501 + "HostingPrepackedArtifactVersion": "1.0.0", + "HostingUseScriptUri": False, + "HostingEulaUri": "s3://jumpstart-cache-prod-us-west-2/fmhMetadata/eula/llamaEula.txt", + "InferenceDependencies": [], + "TrainingDependencies": [ + "accelerate==0.21.0", + "bitsandbytes==0.39.1", + "black==23.7.0", + "brotli==1.0.9", + "datasets==2.14.1", + "fire==0.5.0", + "huggingface-hub==0.20.3", + "inflate64==0.3.1", + "loralib==0.1.1", + "multivolumefile==0.2.3", + "mypy-extensions==1.0.0", + "nvidia-cublas-cu12==12.1.3.1", + "nvidia-cuda-cupti-cu12==12.1.105", + "nvidia-cuda-nvrtc-cu12==12.1.105", + "nvidia-cuda-runtime-cu12==12.1.105", + "nvidia-cudnn-cu12==8.9.2.26", + "nvidia-cufft-cu12==11.0.2.54", + "nvidia-curand-cu12==10.3.2.106", + "nvidia-cusolver-cu12==11.4.5.107", + "nvidia-cusolver-cu12==11.4.5.107", + "nvidia-cusparse-cu12==12.1.0.106", + "nvidia-nccl-cu12==2.19.3", + "nvidia-nvjitlink-cu12==12.3.101", + "nvidia-nvtx-cu12==12.1.105", + "pathspec==0.11.1", + "peft==0.4.0", + "py7zr==0.20.5", + "pybcj==1.0.1", + "pycryptodomex==3.18.0", + "pyppmd==1.0.0", + "pyzstd==0.15.9", + "safetensors==0.3.1", + "sagemaker_jumpstart_huggingface_script_utilities==1.1.4", + "sagemaker_jumpstart_script_utilities==1.1.9", + "scipy==1.11.1", + "termcolor==2.3.0", + "texttable==1.6.7", + "tokenize-rt==5.1.0", + "tokenizers==0.13.3", + "torch==2.2.0", + "transformers==4.33.3", + "triton==2.2.0", + "typing-extensions==4.8.0", + ], + "Hyperparameters": [ + { + "Name": "epoch", + "Type": "int", + "Default": 5, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", }, - }, - "cost_training": { - "description": "Configs ranked based on cost", - "rankings": [ - "neuron-training-budget", - "gpu-training-budget", - "neuron-training", - "gpu-training", - ], - }, - } -} - - -DEPLOYMENT_CONFIGS = [ - { - "DeploymentConfigName": "neuron-inference", - "DeploymentArgs": { - "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" - ".0-gpu-py310-cu121-ubuntu20.04", - "ModelData": { - "S3DataSource": { - "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" - "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", - "S3DataType": "S3Prefix", - "CompressionType": "None", - } + { + "Name": "learning_rate", + "Type": "float", + "Default": 0.0001, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", }, - "Environment": { - "SAGEMAKER_PROGRAM": "inference.py", - "ENDPOINT_SERVER_TIMEOUT": "3600", - "MODEL_CACHE_ROOT": "/opt/ml/model", - "SAGEMAKER_ENV": "1", - "HF_MODEL_ID": "/opt/ml/model", - "SM_NUM_GPUS": "1", - "MAX_INPUT_LENGTH": "2047", - "MAX_TOTAL_TOKENS": "2048", - "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + { + "Name": "instruction_tuned", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", }, - "InstanceType": "ml.p2.xlarge", - "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, - "ModelDataDownloadTimeout": None, - "ContainerStartupHealthCheckTimeout": None, - }, - "AccelerationConfigs": None, - "BenchmarkMetrics": [ - {"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs", "concurrency": 1} ], - }, - { - "DeploymentConfigName": "neuron-inference-budget", - "DeploymentArgs": { - "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" - ".0-gpu-py310-cu121-ubuntu20.04", - "ModelData": { - "S3DataSource": { - "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" - "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", - "S3DataType": "S3Prefix", - "CompressionType": "None", - } + "TrainingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/meta/transfer_learning/textgeneration/v1.0.11/sourcedir.tar.gz", # noqa: E501 + "TrainingPrepackedScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/meta/transfer_learning/textgeneration/prepack/v1.0.5/sourcedir.tar.gz", # noqa: E501 + "TrainingPrepackedScriptVersion": "1.0.5", + "TrainingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04", # TODO: not a training image # noqa: E501 + "TrainingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/meta-training/train-meta-textgeneration-llama-2-70b.tar.gz", # noqa: E501 + "InferenceEnvironmentVariables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, }, - "Environment": { - "SAGEMAKER_PROGRAM": "inference.py", - "ENDPOINT_SERVER_TIMEOUT": "3600", - "MODEL_CACHE_ROOT": "/opt/ml/model", - "SAGEMAKER_ENV": "1", - "HF_MODEL_ID": "/opt/ml/model", - "SM_NUM_GPUS": "1", - "MAX_INPUT_LENGTH": "2047", - "MAX_TOTAL_TOKENS": "2048", - "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, }, - "InstanceType": "ml.p2.xlarge", - "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, - "ModelDataDownloadTimeout": None, - "ContainerStartupHealthCheckTimeout": None, - }, - "AccelerationConfigs": None, - "BenchmarkMetrics": [ - {"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs", "concurrency": 1} ], - }, - { - "DeploymentConfigName": "gpu-inference-budget", - "DeploymentArgs": { - "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" - ".0-gpu-py310-cu121-ubuntu20.04", - "ModelData": { - "S3DataSource": { - "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" - "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", - "S3DataType": "S3Prefix", - "CompressionType": "None", - } + "TrainingMetrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "eval_epoch_loss=tensor\\(([0-9\\.]+)", }, - "Environment": { - "SAGEMAKER_PROGRAM": "inference.py", - "ENDPOINT_SERVER_TIMEOUT": "3600", - "MODEL_CACHE_ROOT": "/opt/ml/model", - "SAGEMAKER_ENV": "1", - "HF_MODEL_ID": "/opt/ml/model", - "SM_NUM_GPUS": "1", - "MAX_INPUT_LENGTH": "2047", - "MAX_TOTAL_TOKENS": "2048", - "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + { + "Name": "huggingface-textgeneration:eval-ppl", + "Regex": "eval_ppl=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "train_epoch_loss=([0-9\\.]+)", }, - "InstanceType": "ml.p2.xlarge", - "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, - "ModelDataDownloadTimeout": None, - "ContainerStartupHealthCheckTimeout": None, - }, - "AccelerationConfigs": None, - "BenchmarkMetrics": [ - {"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs", "concurrency": 1} ], - }, - { - "DeploymentConfigName": "gpu-inference", - "DeploymentArgs": { - "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" - ".0-gpu-py310-cu121-ubuntu20.04", - "ModelData": { - "S3DataSource": { - "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" - "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", - "S3DataType": "S3Prefix", - "CompressionType": "None", - } + "DefaultInferenceInstanceType": "ml.g5.48xlarge", + "supported_inference_instance_types": ["ml.g5.48xlarge", "ml.p4d.24xlarge"], + "default_training_instance_type": "ml.g5.48xlarge", + "SupportedInferenceInstanceTypes": ["ml.g5.48xlarge", "ml.p4d.24xlarge"], + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "EncryptInterContainerTraffic": True, + "DisableOutputCompression": True, + "MaxRuntimeInSeconds": 360000, + "SageMakerSdkPredictorSpecifications": { + "SupportedContentTypes": ["application/json"], + "SupportedAcceptTypes": ["application/json"], + "DefaultContentType": "application/json", + "DefaultAcceptType": "application/json", + }, + "InferenceVolumeSize": 256, + "TrainingVolumeSize": 256, + "InferenceEnableNetworkIsolation": True, + "TrainingEnableNetworkIsolation": True, + "DefaultTrainingDatasetUri": "s3://jumpstart-cache-prod-us-west-2/training-datasets/sec_amazon/", # noqa: E501 + "ValidationSupported": True, + "FineTuningSupported": True, + "ResourceNameBase": "meta-textgeneration-llama-2-70b", + "DefaultPayloads": { + "meaningOfLife": { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", + }, + "Body": { + "inputs": "I believe the meaning of life is", + "parameters": { + "max_new_tokens": 64, + "top_p": 0.9, + "temperature": 0.6, + "decoder_input_details": True, + "details": True, + }, + }, }, - "Environment": { - "SAGEMAKER_PROGRAM": "inference.py", - "ENDPOINT_SERVER_TIMEOUT": "3600", - "MODEL_CACHE_ROOT": "/opt/ml/model", - "SAGEMAKER_ENV": "1", - "HF_MODEL_ID": "/opt/ml/model", - "SM_NUM_GPUS": "1", - "MAX_INPUT_LENGTH": "2047", - "MAX_TOTAL_TOKENS": "2048", - "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "theoryOfRelativity": { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": {"generated_text": "[0].generated_text"}, + "Body": { + "inputs": "Simply put, the theory of relativity states that ", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, }, - "InstanceType": "ml.p2.xlarge", - "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, - "ModelDataDownloadTimeout": None, - "ContainerStartupHealthCheckTimeout": None, }, - "AccelerationConfigs": None, - "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], - }, -] - - -INIT_KWARGS = { - "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu" - "-py310-cu121-ubuntu20.04", - "model_data": { - "S3DataSource": { - "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface-textgeneration" - "-bloom-1b1/artifacts/inference-prepack/v4.0.0/", - "S3DataType": "S3Prefix", - "CompressionType": "None", - } + "GatedBucket": True, + "HostingInstanceTypeVariants": { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" # noqa: E501 + }, + "Variants": { + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "TrainingInstanceTypeVariants": { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" # noqa: E501 + }, + "Variants": { + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "meta-training/g5/v1.0.0/train-meta-textgeneration-llama-2-70b.tar.gz", # noqa: E501 + }, + }, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "meta-training/p4d/v1.0.0/train-meta-textgeneration-llama-2-70b.tar.gz", # noqa: E501 + }, + }, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + }, + "HostingArtifactS3DataType": "S3Prefix", + "HostingArtifactCompressionType": "None", + "HostingResourceRequirements": {"MinMemoryMb": 393216, "NumAccelerators": 8}, + "DynamicContainerDeploymentSupported": True, + "TrainingModelPackageArtifactUri": None, + "Task": "text generation", + "DataType": "text", + "Framework": "meta", + "Dependencies": [], }, - "instance_type": "ml.p2.xlarge", - "env": { - "SAGEMAKER_PROGRAM": "inference.py", - "ENDPOINT_SERVER_TIMEOUT": "3600", - "MODEL_CACHE_ROOT": "/opt/ml/model", - "SAGEMAKER_ENV": "1", - "HF_MODEL_ID": "/opt/ml/model", - "SM_NUM_GPUS": "1", - "MAX_INPUT_LENGTH": "2047", - "MAX_TOTAL_TOKENS": "2048", - "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "huggingface-textembedding-bloom-7b1": { + "Url": "https://huggingface.co/bigscience/bloom-7b1", + "MinSdkVersion": "2.144.0", + "TrainingSupported": False, + "IncrementalTrainingSupported": False, + "HostingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04", # noqa: E501 + "HostingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/infer-huggingface-textembedding-bloom-7b1.tar.gz", # noqa: E501 + "HostingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/inference/textembedding/v1.0.1/sourcedir.tar.gz", # noqa: E501 + "HostingPrepackedArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/prepack/v1.0.1/infer-prepack-huggingface-textembedding-bloom-7b1.tar.gz", # noqa: E501 + "HostingPrepackedArtifactVersion": "1.0.1", + "InferenceDependencies": [ + "accelerate==0.16.0", + "bitsandbytes==0.37.0", + "filelock==3.9.0", + "huggingface_hub==0.12.0", + "regex==2022.7.9", + "tokenizers==0.13.2", + "transformers==4.26.0", + ], + "TrainingDependencies": [], + "InferenceEnvironmentVariables": [ + { + "Name": "SAGEMAKER_PROGRAM", + "Type": "text", + "Default": "inference.py", + "Scope": "container", + "RequiredForModelClass": True, + } + ], + "TrainingMetrics": [], + "DefaultInferenceInstanceType": "ml.g5.12xlarge", + "SupportedInferenceInstanceTypes": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.p3.8xlarge", + "ml.p3.16xlarge", + "ml.g4dn.12xlarge", + ], + "deploy_kwargs": { + "ModelDataDownloadTimeout": 3600, + "ContainerStartupHealthCheckTimeout": 3600, + }, + "SageMakerSdkPredictorSpecifications": { + "SupportedContentTypes": ["application/json", "application/x-text"], + "SupportedAcceptTypes": ["application/json;verbose", "application/json"], + "DefaultContentType": "application/json", + "DefaultAcceptType": "application/json", + }, + "InferenceVolumeSize": 256, + "InferenceEnableNetworkIsolation": True, + "ValidationSupported": False, + "FineTuningSupported": False, + "ResourceNameBase": "hf-textembedding-bloom-7b1", + "HostingInstanceTypeVariants": { + "Aliases": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", # noqa: E501 + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.12.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.12.0-gpu-py38", + }, + "Variants": { + "c4": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "trn1": {"properties": {"image_uri": "$alias_ecr_uri_3"}}, + "trn1n": {"properties": {"image_uri": "$alias_ecr_uri_3"}}, + }, + }, + "TrainingModelPackageArtifactUri": None, + "DynamicContainerDeploymentSupported": False, + "License": "BigScience RAIL", + "Dependencies": [], }, - "role": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20230707T131628", - "name": "hf-textgeneration-bloom-1b1-2024-04-22-20-23-48-799", - "enable_network_isolation": True, } diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 6f5f3dba05..3678685db5 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1143,6 +1143,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): "tolerate_vulnerable_model", "tolerate_deprecated_model", "config_name", + "hub_name", } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -1172,6 +1173,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): "self", "name", "resources", + "model_reference_arn", } assert parent_class_deploy_args - js_class_deploy_args == deploy_args_to_skip @@ -1251,6 +1253,7 @@ def test_no_predictor_returns_default_predictor( tolerate_vulnerable_model=False, sagemaker_session=estimator.sagemaker_session, config_name=None, + hub_arn=None, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @@ -1421,6 +1424,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, config_name=None, + hub_arn=None, ) @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @@ -1477,6 +1481,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, config_name=None, + hub_arn=None, ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @@ -1544,6 +1549,9 @@ def test_estimator_sets_different_inference_instance_depending_on_training_insta estimator.deploy(image_uri="blah") assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.p4de.24xlarge" + estimator.deploy(image_uri="blah", instance_type="ml.quantum.large") + assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.quantum.large" + @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch( @@ -1752,6 +1760,7 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, + hub_arn=None, ), mock.call( model_id="js-trainable-model", @@ -1759,6 +1768,7 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, + hub_arn=None, ), ] ) @@ -1780,6 +1790,7 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, + hub_arn=None, ), mock.call( model_id="js-trainable-model", @@ -1787,6 +1798,7 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, + hub_arn=None, ), ] ) diff --git a/tests/unit/sagemaker/jumpstart/hub/__init__.py b/tests/unit/sagemaker/jumpstart/hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/jumpstart/hub/test_hub.py b/tests/unit/sagemaker/jumpstart/hub/test_hub.py new file mode 100644 index 0000000000..e2085e5ab9 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/hub/test_hub.py @@ -0,0 +1,235 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from datetime import datetime +from unittest.mock import patch, MagicMock +import pytest +from mock import Mock +from sagemaker.jumpstart.hub.hub import Hub +from sagemaker.jumpstart.hub.types import S3ObjectLocation + + +REGION = "us-east-1" +ACCOUNT_ID = "123456789123" +HUB_NAME = "mock-hub-name" + +MODULE_PATH = "sagemaker.jumpstart.hub.hub.Hub" + +FAKE_TIME = datetime(1997, 8, 14, 00, 00, 00) + + +@pytest.fixture() +def sagemaker_session(): + boto_mock = Mock(name="boto_session") + sagemaker_session_mock = Mock( + name="sagemaker_session", boto_session=boto_mock, boto_region_name=REGION + ) + sagemaker_session_mock._client_config.user_agent = ( + "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource" + ) + sagemaker_session_mock.describe_hub.return_value = { + "S3StorageConfig": {"S3OutputPath": "s3://mock-bucket-123"} + } + sagemaker_session_mock.account_id.return_value = ACCOUNT_ID + return sagemaker_session_mock + + +@pytest.fixture +def mock_instance(sagemaker_session): + mock_instance = MagicMock() + mock_instance.hub_name = "test-hub" + mock_instance._sagemaker_session = sagemaker_session + return mock_instance + + +def test_instantiates(sagemaker_session): + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + assert hub.hub_name == HUB_NAME + assert hub.region == "us-east-1" + assert hub._sagemaker_session == sagemaker_session + + +@pytest.mark.parametrize( + ("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"), + [ + pytest.param("MockHub1", "this is my sagemaker hub", None, None, None, None), + pytest.param( + "MockHub2", + "this is my sagemaker hub two", + None, + "DisplayMockHub2", + ["mock", "hub", "123"], + [{"Key": "tag-key-1", "Value": "tag-value-1"}], + ), + ], +) +@patch("sagemaker.jumpstart.hub.hub.Hub._generate_hub_storage_location") +def test_create_with_no_bucket_name( + mock_generate_hub_storage_location, + sagemaker_session, + hub_name, + hub_description, + hub_bucket_name, + hub_display_name, + hub_search_keywords, + tags, +): + storage_location = S3ObjectLocation( + "sagemaker-hubs-us-east-1-123456789123", f"{hub_name}-{FAKE_TIME.timestamp()}" + ) + mock_generate_hub_storage_location.return_value = storage_location + create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + sagemaker_session.create_hub = Mock(return_value=create_hub) + sagemaker_session.describe_hub.return_value = { + "S3StorageConfig": {"S3OutputPath": f"s3://{hub_bucket_name}/{storage_location.key}"} + } + hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session) + request = { + "hub_name": hub_name, + "hub_description": hub_description, + "hub_display_name": hub_display_name, + "hub_search_keywords": hub_search_keywords, + "s3_storage_config": { + "S3OutputPath": f"s3://sagemaker-hubs-us-east-1-123456789123/{storage_location.key}" + }, + "tags": tags, + } + response = hub.create( + description=hub_description, + display_name=hub_display_name, + search_keywords=hub_search_keywords, + tags=tags, + ) + sagemaker_session.create_hub.assert_called_with(**request) + assert response == {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + + +@pytest.mark.parametrize( + ("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"), + [ + pytest.param("MockHub1", "this is my sagemaker hub", "mock-bucket-123", None, None, None), + pytest.param( + "MockHub2", + "this is my sagemaker hub two", + "mock-bucket-123", + "DisplayMockHub2", + ["mock", "hub", "123"], + [{"Key": "tag-key-1", "Value": "tag-value-1"}], + ), + ], +) +@patch("sagemaker.jumpstart.hub.hub.Hub._generate_hub_storage_location") +def test_create_with_bucket_name( + mock_generate_hub_storage_location, + sagemaker_session, + hub_name, + hub_description, + hub_bucket_name, + hub_display_name, + hub_search_keywords, + tags, +): + storage_location = S3ObjectLocation(hub_bucket_name, f"{hub_name}-{FAKE_TIME.timestamp()}") + mock_generate_hub_storage_location.return_value = storage_location + create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + sagemaker_session.create_hub = Mock(return_value=create_hub) + hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session, bucket_name=hub_bucket_name) + request = { + "hub_name": hub_name, + "hub_description": hub_description, + "hub_display_name": hub_display_name, + "hub_search_keywords": hub_search_keywords, + "s3_storage_config": {"S3OutputPath": f"s3://mock-bucket-123/{storage_location.key}"}, + "tags": tags, + } + response = hub.create( + description=hub_description, + display_name=hub_display_name, + search_keywords=hub_search_keywords, + tags=tags, + ) + sagemaker_session.create_hub.assert_called_with(**request) + assert response == {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + + +@patch("sagemaker.jumpstart.hub.interfaces.DescribeHubContentResponse.from_json") +def test_describe_model_success(mock_describe_hub_content_response, sagemaker_session): + mock_describe_hub_content_response.return_value = Mock() + mock_list_hub_content_versions = sagemaker_session.list_hub_content_versions + mock_list_hub_content_versions.return_value = { + "HubContentSummaries": [ + {"HubContentVersion": "1.0"}, + {"HubContentVersion": "2.0"}, + {"HubContentVersion": "3.0"}, + ] + } + + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + + with patch("sagemaker.jumpstart.hub.utils.get_hub_model_version") as mock_get_hub_model_version: + mock_get_hub_model_version.return_value = "3.0" + + hub.describe_model("test-model") + + mock_list_hub_content_versions.assert_called_with( + hub_name=HUB_NAME, hub_content_name="test-model", hub_content_type="Model" + ) + sagemaker_session.describe_hub_content.assert_called_with( + hub_name=HUB_NAME, + hub_content_name="test-model", + hub_content_version="3.0", + hub_content_type="Model", + ) + + +def test_create_hub_content_reference(sagemaker_session): + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + model_name = "mock-model-one-huggingface" + min_version = "1.1.1" + public_model_arn = ( + f"arn:aws:sagemaker:us-east-1:123456789123:hub-content/JumpStartHub/model/{model_name}" + ) + create_hub_content_reference = { + "HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{HUB_NAME}", + "HubContentReferenceArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub-content/{HUB_NAME}/ModelRef/{model_name}", # noqa: E501 + } + sagemaker_session.create_hub_content_reference = Mock(return_value=create_hub_content_reference) + + request = { + "hub_name": HUB_NAME, + "source_hub_content_arn": public_model_arn, + "hub_content_name": model_name, + "min_version": min_version, + } + + response = hub.create_model_reference( + model_arn=public_model_arn, model_name=model_name, min_version=min_version + ) + sagemaker_session.create_hub_content_reference.assert_called_with(**request) + + assert response == { + "HubArn": "arn:aws:sagemaker:us-east-1:123456789123:hub/mock-hub-name", + "HubContentReferenceArn": "arn:aws:sagemaker:us-east-1:123456789123:hub-content/mock-hub-name/ModelRef/mock-model-one-huggingface", # noqa: E501 + } + + +def test_delete_hub_content_reference(sagemaker_session): + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + model_name = "mock-model-one-huggingface" + + hub.delete_model_reference(model_name) + sagemaker_session.delete_hub_content_reference.assert_called_with( + hub_name=HUB_NAME, + hub_content_type="ModelReference", + hub_content_name="mock-model-one-huggingface", + ) diff --git a/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py new file mode 100644 index 0000000000..c4b95443ec --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py @@ -0,0 +1,981 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import numpy as np +from sagemaker.jumpstart.types import ( + JumpStartHyperparameter, + JumpStartInstanceTypeVariants, + JumpStartEnvironmentVariable, + JumpStartPredictorSpecs, + JumpStartSerializablePayload, +) +from sagemaker.jumpstart.hub.interfaces import HubModelDocument +from tests.unit.sagemaker.jumpstart.constants import ( + SPECIAL_MODEL_SPECS_DICT, + HUB_MODEL_DOCUMENT_DICTS, +) + +gemma_model_spec = SPECIAL_MODEL_SPECS_DICT["gemma-model-2b-v1_1_0"] + + +def test_hub_content_document_from_json_obj(): + region = "us-west-2" + gemma_model_document = HubModelDocument( + json_obj=HUB_MODEL_DOCUMENT_DICTS["huggingface-llm-gemma-2b-instruct"], region=region + ) + assert gemma_model_document.url == "https://huggingface.co/google/gemma-2b-it" + assert gemma_model_document.min_sdk_version == "2.189.0" + assert gemma_model_document.training_supported is True + assert gemma_model_document.incremental_training_supported is False + assert ( + gemma_model_document.hosting_ecr_uri + == "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:" + "2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + with pytest.raises(AttributeError) as excinfo: + gemma_model_document.hosting_ecr_specs + assert str(excinfo.value) == "'HubModelDocument' object has no attribute 'hosting_ecr_specs'" + assert gemma_model_document.hosting_artifact_s3_data_type == "S3Prefix" + assert gemma_model_document.hosting_artifact_compression_type == "None" + assert ( + gemma_model_document.hosting_artifact_uri + == "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/huggingface-llm-gemma-2b-instruct" + "/artifacts/inference/v1.0.0/" + ) + assert ( + gemma_model_document.hosting_script_uri + == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/inference/" + "llm/v1.0.1/sourcedir.tar.gz" + ) + assert gemma_model_document.inference_dependencies == [] + assert gemma_model_document.training_dependencies == [ + "accelerate==0.26.1", + "bitsandbytes==0.42.0", + "deepspeed==0.10.3", + "docstring-parser==0.15", + "flash_attn==2.5.5", + "ninja==1.11.1", + "packaging==23.2", + "peft==0.8.2", + "py_cpuinfo==9.0.0", + "rich==13.7.0", + "safetensors==0.4.2", + "sagemaker_jumpstart_huggingface_script_utilities==1.2.1", + "sagemaker_jumpstart_script_utilities==1.1.9", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + "shtab==1.6.5", + "tokenizers==0.15.1", + "transformers==4.38.1", + "trl==0.7.10", + "tyro==0.7.2", + ] + assert ( + gemma_model_document.hosting_prepacked_artifact_uri + == "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/huggingface-llm-gemma-2b-instruct/" + "artifacts/inference-prepack/v1.0.0/" + ) + assert gemma_model_document.hosting_prepacked_artifact_version == "1.0.0" + assert gemma_model_document.hosting_use_script_uri is False + assert ( + gemma_model_document.hosting_eula_uri + == "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/fmhMetadata/terms/gemmaTerms.txt" + ) + assert ( + gemma_model_document.training_ecr_uri + == "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers" + "4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + with pytest.raises(AttributeError) as excinfo: + gemma_model_document.training_ecr_specs + assert str(excinfo.value) == "'HubModelDocument' object has no attribute 'training_ecr_specs'" + assert ( + gemma_model_document.training_prepacked_script_uri + == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/transfer_learning/" + "llm/prepack/v1.1.1/sourcedir.tar.gz" + ) + assert gemma_model_document.training_prepacked_script_version == "1.1.1" + assert ( + gemma_model_document.training_script_uri + == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/transfer_learning/" + "llm/v1.1.1/sourcedir.tar.gz" + ) + assert gemma_model_document.training_artifact_s3_data_type == "S3Prefix" + assert gemma_model_document.training_artifact_compression_type == "None" + assert ( + gemma_model_document.training_artifact_uri + == "s3://jumpstart-cache-prod-us-west-2/huggingface-training/train-huggingface-llm-gemma-2b-instruct" + ".tar.gz" + ) + assert gemma_model_document.hyperparameters == [ + JumpStartHyperparameter( + { + "Name": "peft_type", + "Type": "text", + "Default": "lora", + "Options": ["lora", "None"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "instruction_tuned", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "chat_dataset", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "epoch", + "Type": "int", + "Default": 1, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "learning_rate", + "Type": "float", + "Default": 0.0001, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "lora_r", + "Type": "int", + "Default": 64, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + {"Name": "lora_alpha", "Type": "int", "Default": 16, "Min": 0, "Scope": "algorithm"}, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "lora_dropout", + "Type": "float", + "Default": 0, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + {"Name": "bits", "Type": "int", "Default": 4, "Scope": "algorithm"}, is_hub_content=True + ), + JumpStartHyperparameter( + { + "Name": "double_quant", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "quant_Type", + "Type": "text", + "Default": "nf4", + "Options": ["fp4", "nf4"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "per_device_train_batch_size", + "Type": "int", + "Default": 1, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "per_device_eval_batch_size", + "Type": "int", + "Default": 2, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "warmup_ratio", + "Type": "float", + "Default": 0.1, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "train_from_scratch", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "fp16", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "bf16", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "evaluation_strategy", + "Type": "text", + "Default": "steps", + "Options": ["steps", "epoch", "no"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "eval_steps", + "Type": "int", + "Default": 20, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "gradient_accumulation_steps", + "Type": "int", + "Default": 4, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "logging_steps", + "Type": "int", + "Default": 8, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "weight_decay", + "Type": "float", + "Default": 0.2, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "load_best_model_at_end", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "max_train_samples", + "Type": "int", + "Default": -1, + "Min": -1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "max_val_samples", + "Type": "int", + "Default": -1, + "Min": -1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "seed", + "Type": "int", + "Default": 10, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "max_input_length", + "Type": "int", + "Default": 1024, + "Min": -1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "validation_split_ratio", + "Type": "float", + "Default": 0.2, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "train_data_split_seed", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "preprocessing_num_workers", + "Type": "text", + "Default": "None", + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + {"Name": "max_steps", "Type": "int", "Default": -1, "Scope": "algorithm"}, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "gradient_checkpointing", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "early_stopping_patience", + "Type": "int", + "Default": 3, + "Min": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "early_stopping_threshold", + "Type": "float", + "Default": 0.0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "adam_beta1", + "Type": "float", + "Default": 0.9, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "adam_beta2", + "Type": "float", + "Default": 0.999, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "adam_epsilon", + "Type": "float", + "Default": 1e-08, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "max_grad_norm", + "Type": "float", + "Default": 1.0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "label_smoothing_factor", + "Type": "float", + "Default": 0, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "logging_first_step", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "logging_nan_inf_filter", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "save_strategy", + "Type": "text", + "Default": "steps", + "Options": ["no", "epoch", "steps"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "save_steps", + "Type": "int", + "Default": 500, + "Min": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "save_total_limit", + "Type": "int", + "Default": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "dataloader_drop_last", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "dataloader_num_workers", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "eval_accumulation_steps", + "Type": "text", + "Default": "None", + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "auto_find_batch_size", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "lr_scheduler_type", + "Type": "text", + "Default": "constant_with_warmup", + "Options": ["constant_with_warmup", "linear"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "warmup_steps", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "deepspeed", + "Type": "text", + "Default": "False", + "Options": ["False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "sagemaker_submit_directory", + "Type": "text", + "Default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "Scope": "container", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "sagemaker_program", + "Type": "text", + "Default": "transfer_learning.py", + "Scope": "container", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "sagemaker_container_log_level", + "Type": "text", + "Default": "20", + "Scope": "container", + }, + is_hub_content=True, + ), + ] + assert gemma_model_document.inference_environment_variables == [ + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_PROGRAM", + "Type": "text", + "Default": "inference.py", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_SUBMIT_DIRECTORY", + "Type": "text", + "Default": "/opt/ml/model/code", + "Scope": "container", + "RequiredForModelClass": False, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "Type": "text", + "Default": "20", + "Scope": "container", + "RequiredForModelClass": False, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "Type": "text", + "Default": "3600", + "Scope": "container", + "RequiredForModelClass": False, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "ENDPOINT_SERVER_TIMEOUT", + "Type": "int", + "Default": 3600, + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "MODEL_CACHE_ROOT", + "Type": "text", + "Default": "/opt/ml/model", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_ENV", + "Type": "text", + "Default": "1", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "HF_MODEL_ID", + "Type": "text", + "Default": "/opt/ml/model", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "MAX_INPUT_LENGTH", + "Type": "text", + "Default": "8191", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "MAX_TOTAL_TOKENS", + "Type": "text", + "Default": "8192", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "MAX_BATCH_PREFILL_TOKENS", + "Type": "text", + "Default": "8191", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SM_NUM_GPUS", + "Type": "text", + "Default": "1", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "Type": "int", + "Default": 1, + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + ] + assert gemma_model_document.training_metrics == [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'loss': ([0-9]+\\.[0-9]+)", + }, + ] + assert gemma_model_document.default_inference_instance_type == "ml.g5.xlarge" + assert gemma_model_document.supported_inference_instance_types == [ + "ml.g5.xlarge", + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ] + assert gemma_model_document.default_training_instance_type == "ml.g5.2xlarge" + assert np.array_equal( + gemma_model_document.supported_training_instance_types, + [ + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + ) + assert gemma_model_document.sage_maker_sdk_predictor_specifications == JumpStartPredictorSpecs( + { + "SupportedContentTypes": ["application/json"], + "SupportedAcceptTypes": ["application/json"], + "DefaultContentType": "application/json", + "DefaultAcceptType": "application/json", + }, + is_hub_content=True, + ) + assert gemma_model_document.inference_volume_size == 512 + assert gemma_model_document.training_volume_size == 512 + assert gemma_model_document.inference_enable_network_isolation is True + assert gemma_model_document.training_enable_network_isolation is True + assert gemma_model_document.fine_tuning_supported is True + assert gemma_model_document.validation_supported is True + assert ( + gemma_model_document.default_training_dataset_uri + == "s3://jumpstart-cache-prod-us-west-2/training-datasets/oasst_top/train/" + ) + assert gemma_model_document.resource_name_base == "hf-llm-gemma-2b-instruct" + assert gemma_model_document.default_payloads == { + "HelloWorld": JumpStartSerializablePayload( + { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "GeneratedText": "[0].generated_text", + "InputLogprobs": "[0].details.prefill[*].logprob", + }, + "Body": { + "Inputs": "user\nWrite a hello world program" + "\nmodel", + "Parameters": { + "MaxNewTokens": 256, + "DecoderInputDetails": True, + "Details": True, + }, + }, + }, + is_hub_content=True, + ), + "MachineLearningPoem": JumpStartSerializablePayload( + { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "GeneratedText": "[0].generated_text", + "InputLogprobs": "[0].details.prefill[*].logprob", + }, + "Body": { + "Inputs": "Write me a poem about Machine Learning.", + "Parameters": { + "MaxNewTokens": 256, + "DecoderInputDetails": True, + "Details": True, + }, + }, + }, + is_hub_content=True, + ), + } + assert gemma_model_document.gated_bucket is True + assert gemma_model_document.hosting_resource_requirements == { + "MinMemoryMb": 8192, + "NumAccelerators": 1, + } + assert gemma_model_document.hosting_instance_type_variants == JumpStartInstanceTypeVariants( + { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch" + "-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "Variants": { + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + is_hub_content=True, + ) + assert gemma_model_document.training_instance_type_variants == JumpStartInstanceTypeVariants( + { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-" + "training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "Variants": { + "g4dn": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/g4dn/v1.0.0/train-" + "huggingface-llm-gemma-2b-instruct.tar.gz", + }, + }, + "g5": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/g5/v1.0.0/train-" + "huggingface-llm-gemma-2b-instruct.tar.gz", + }, + }, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/p3dn/v1.0.0/train-" + "huggingface-llm-gemma-2b-instruct.tar.gz", + }, + }, + "p4d": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/p4d/v1.0.0/train-" + "huggingface-llm-gemma-2b-instruct.tar.gz", + }, + }, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + }, + is_hub_content=True, + ) + assert gemma_model_document.contextual_help == { + "HubFormatTrainData": [ + "A train and an optional validation directories. Each directory contains a CSV/JSON/TXT. ", + "- For CSV/JSON files, the text data is used from the column called 'text' or the " + "first column if no column called 'text' is found", + "- The number of files under train and validation (if provided) should equal to one," + " respectively.", + " [Learn how to setup an AWS S3 bucket.]" + "(https://docs.aws.amazon.com/AmazonS3/latest/dev/UsingBucket.html)", + ], + "HubDefaultTrainData": [ + "Dataset: [SEC](https://www.sec.gov/edgar/searchedgar/companysearch)", + "SEC filing contains regulatory documents that companies and issuers of securities must " + "submit to the Securities and Exchange Commission (SEC) on a regular basis.", + "License: [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/legalcode)", + ], + } + assert gemma_model_document.model_data_download_timeout == 1200 + assert gemma_model_document.container_startup_health_check_timeout == 1200 + assert gemma_model_document.encrypt_inter_container_traffic is True + assert gemma_model_document.disable_output_compression is True + assert gemma_model_document.max_runtime_in_seconds == 360000 + assert gemma_model_document.dynamic_container_deployment_supported is True + assert gemma_model_document.training_model_package_artifact_uri is None + assert gemma_model_document.dependencies == [] diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py new file mode 100644 index 0000000000..ee50805792 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -0,0 +1,256 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import patch, Mock +from sagemaker.jumpstart.types import HubArnExtractedInfo +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.jumpstart.hub import utils + + +def test_get_info_from_hub_resource_arn(): + model_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MockHub/Model/my-mock-model/1.0.2" + ) + assert utils.get_info_from_hub_resource_arn(model_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + hub_content_type="Model", + hub_content_name="my-mock-model", + hub_content_version="1.0.2", + ) + + notebook_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MockHub/Notebook/my-mock-notebook/1.0.2" + assert utils.get_info_from_hub_resource_arn(notebook_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + hub_content_type="Notebook", + hub_content_name="my-mock-notebook", + hub_content_version="1.0.2", + ) + + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub" + assert utils.get_info_from_hub_resource_arn(hub_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + ) + + invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123" + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) + + invalid_arn = "nonsense-string" + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) + + invalid_arn = "" + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) + + +def test_construct_hub_arn_from_name(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.boto_region_name = "us-west-2" + hub_name = "my-cool-hub" + + assert ( + utils.construct_hub_arn_from_name(hub_name=hub_name, session=mock_sagemaker_session) + == "arn:aws:sagemaker:us-west-2:123456789123:hub/my-cool-hub" + ) + + assert ( + utils.construct_hub_arn_from_name( + hub_name=hub_name, region="us-east-1", session=mock_sagemaker_session + ) + == "arn:aws:sagemaker:us-east-1:123456789123:hub/my-cool-hub" + ) + + +def test_construct_hub_model_arn_from_inputs(): + model_name, version = "pytorch-ic-imagenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" + + assert ( + utils.construct_hub_model_arn_from_inputs(hub_arn, model_name, version) + == "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub/Model/pytorch-ic-imagenet-v2/1.0.2" + ) + + version = "*" + assert ( + utils.construct_hub_model_arn_from_inputs(hub_arn, model_name, version) + == "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub/Model/pytorch-ic-imagenet-v2/*" + ) + + +def test_generate_hub_arn_for_init_kwargs(): + hub_name = "my-hub-name" + hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub" + # Mock default session with default values + mock_default_session = Mock() + mock_default_session.account_id.return_value = "123456789123" + mock_default_session.boto_region_name = JUMPSTART_DEFAULT_REGION_NAME + # Mock custom session with custom values + mock_custom_session = Mock() + mock_custom_session.account_id.return_value = "000000000000" + mock_custom_session.boto_region_name = "us-east-2" + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_name, session=mock_default_session) + == "arn:aws:sagemaker:us-west-2:123456789123:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_name, "us-east-1", session=mock_default_session) + == "arn:aws:sagemaker:us-east-1:123456789123:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_name, "eu-west-1", mock_custom_session) + == "arn:aws:sagemaker:eu-west-1:000000000000:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_name, None, mock_custom_session) + == "arn:aws:sagemaker:us-east-2:000000000000:hub/my-hub-name" + ) + + assert utils.generate_hub_arn_for_init_kwargs(hub_arn, session=mock_default_session) == hub_arn + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", session=mock_default_session) + == hub_arn + ) + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn + ) + + assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn + + +def test_create_hub_bucket_if_it_does_not_exist_hub_arn(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.client("sts").get_caller_identity.return_value = { + "Account": "123456789123" + } + hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub" + # Mock custom session with custom values + mock_custom_session = Mock() + mock_custom_session.account_id.return_value = "000000000000" + mock_custom_session.boto_region_name = "us-east-2" + mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None + mock_sagemaker_session.boto_region_name = "us-east-1" + + bucket_name = "sagemaker-hubs-us-east-1-123456789123" + created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( + sagemaker_session=mock_sagemaker_session + ) + + mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() + assert created_hub_bucket_name == bucket_name + assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn + + +def test_is_gated_bucket(): + assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-west-2") is True + + assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-east-1") is True + + assert utils.is_gated_bucket("jumpstart-cache-prod-us-west-2") is False + + assert utils.is_gated_bucket("") is False + + +def test_create_hub_bucket_if_it_does_not_exist(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.client("sts").get_caller_identity.return_value = { + "Account": "123456789123" + } + mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None + mock_sagemaker_session.boto_region_name = "us-east-1" + bucket_name = "sagemaker-hubs-us-east-1-123456789123" + created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( + sagemaker_session=mock_sagemaker_session + ) + + mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() + assert created_hub_bucket_name == bucket_name + + +@patch("sagemaker.session.Session") +def test_get_hub_model_version_success(mock_session): + hub_name = "test_hub" + hub_model_name = "test_model" + hub_model_type = "test_type" + hub_model_version = "1.0.0" + mock_session.list_hub_content_versions.return_value = { + "HubContentSummaries": [ + {"HubContentVersion": "1.0.0"}, + {"HubContentVersion": "1.2.3"}, + {"HubContentVersion": "2.0.0"}, + ] + } + + result = utils.get_hub_model_version( + hub_name, hub_model_name, hub_model_type, hub_model_version, mock_session + ) + + assert result == "1.0.0" + + +@patch("sagemaker.session.Session") +def test_get_hub_model_version_None(mock_session): + hub_name = "test_hub" + hub_model_name = "test_model" + hub_model_type = "test_type" + hub_model_version = None + mock_session.list_hub_content_versions.return_value = { + "HubContentSummaries": [ + {"HubContentVersion": "1.0.0"}, + {"HubContentVersion": "1.2.3"}, + {"HubContentVersion": "2.0.0"}, + ] + } + + result = utils.get_hub_model_version( + hub_name, hub_model_name, hub_model_type, hub_model_version, mock_session + ) + + assert result == "2.0.0" + + +@patch("sagemaker.session.Session") +def test_get_hub_model_version_wildcard_char(mock_session): + hub_name = "test_hub" + hub_model_name = "test_model" + hub_model_type = "test_type" + hub_model_version = "*" + mock_session.list_hub_content_versions.return_value = { + "HubContentSummaries": [ + {"HubContentVersion": "1.0.0"}, + {"HubContentVersion": "1.2.3"}, + {"HubContentVersion": "2.0.0"}, + ] + } + + result = utils.get_hub_model_version( + hub_name, hub_model_name, hub_model_type, hub_model_version, mock_session + ) + + assert result == "2.0.0" diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index a3d70933eb..bddb3bc9bc 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -505,9 +505,11 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.model.Model.register") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_proprietary_model_endpoint( self, + mock_model_register: mock.Mock, mock_model_deploy: mock.Mock, mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, @@ -541,8 +543,17 @@ def test_proprietary_model_endpoint( enable_network_isolation=False, ) + model.register() model.deploy() + mock_model_register.assert_called_once_with( + model_type=JumpStartModelType.PROPRIETARY, + content_types=["application/json"], + response_types=["application/json"], + model_package_group_name=model_id, + source_uri=model.model_package_arn, + ) + mock_model_deploy.assert_called_once_with( initial_instance_count=1, instance_type="ml.p4de.24xlarge", @@ -759,7 +770,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): Please add the new argument to the skip set below, and reach out to JumpStart team.""" - init_args_to_skip: Set[str] = set([]) + init_args_to_skip: Set[str] = set(["model_reference_arn"]) deploy_args_to_skip: Set[str] = set(["kwargs"]) parent_class_init = Model.__init__ @@ -777,6 +788,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): "instance_type", "model_package_arn", "config_name", + "hub_name", } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -854,6 +866,7 @@ def test_no_predictor_returns_default_predictor( model_id=model_id, model_version="*", region=region, + hub_arn=None, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=model.sagemaker_session, @@ -994,6 +1007,7 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, + hub_arn=None, ), mock.call( model_id="js-trainable-model", @@ -1001,6 +1015,7 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, + hub_arn=None, ), ] ) @@ -1025,6 +1040,7 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, + hub_arn=None, ), mock.call( model_id="js-trainable-model", @@ -1032,6 +1048,7 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, + hub_arn=None, ), ] ) @@ -1423,6 +1440,60 @@ def test_model_artifact_variant_model( enable_network_isolation=True, ) + @mock.patch("sagemaker.jumpstart.model.get_model_info_from_endpoint") + @mock.patch("sagemaker.jumpstart.model.JumpStartModel.__init__") + def test_attach( + self, + mock_js_model_init, + mock_get_model_info_from_endpoint, + ): + mock_js_model_init.return_value = None + mock_get_model_info_from_endpoint.return_value = ( + "model-id", + "model-version", + None, + None, + None, + ) + val = JumpStartModel.attach("some-endpoint") + mock_get_model_info_from_endpoint.assert_called_once_with( + endpoint_name="some-endpoint", + inference_component_name=None, + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ) + mock_js_model_init.assert_called_once_with( + model_id="model-id", + model_version="model-version", + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ) + assert isinstance(val, JumpStartModel) + + mock_get_model_info_from_endpoint.reset_mock() + JumpStartModel.attach("some-endpoint", model_id="some-id") + mock_get_model_info_from_endpoint.assert_called_once_with( + endpoint_name="some-endpoint", + inference_component_name=None, + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ) + + mock_get_model_info_from_endpoint.reset_mock() + JumpStartModel.attach("some-endpoint", model_id="some-id", model_version="some-version") + mock_get_model_info_from_endpoint.assert_called_once_with( + endpoint_name="some-endpoint", + inference_component_name=None, + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ) + + # providing model id, version, and ic name should bypass check with endpoint tags + mock_get_model_info_from_endpoint.reset_mock() + JumpStartModel.attach( + "some-endpoint", + model_id="some-id", + model_version="some-version", + inference_component_name="some-ic-name", + ) + mock_get_model_info_from_endpoint.assert_not_called() + @mock.patch( "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} ) @@ -1457,8 +1528,10 @@ def test_model_registry_accept_and_response_types( model.register() mock_model_register.assert_called_once_with( + model_type=JumpStartModelType.OPEN_WEIGHTS, content_types=["application/x-text"], response_types=["application/json;verbose", "application/json"], + model_package_group_name=model.model_id, ) @mock.patch( diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 50fe6da0a6..c97e6ba895 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -495,8 +495,8 @@ def test_jumpstart_cache_accepts_input_parameters(): assert cache.get_manifest_file_s3_key() == manifest_file_key assert cache.get_region() == region assert cache.get_bucket() == bucket - assert cache._s3_cache._max_cache_items == max_s3_cache_items - assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon + assert cache._content_cache._max_cache_items == max_s3_cache_items + assert cache._content_cache._expiration_horizon == s3_cache_expiration_horizon assert ( cache._open_weight_model_id_manifest_key_cache._max_cache_items == max_semantic_version_cache_items @@ -535,8 +535,8 @@ def test_jumpstart_proprietary_cache_accepts_input_parameters(): ) assert cache.get_region() == region assert cache.get_bucket() == bucket - assert cache._s3_cache._max_cache_items == max_s3_cache_items - assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon + assert cache._content_cache._max_cache_items == max_s3_cache_items + assert cache._content_cache._expiration_horizon == s3_cache_expiration_horizon assert ( cache._proprietary_model_id_manifest_key_cache._max_cache_items == max_semantic_version_cache_items diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 301afe4d53..a06b48deb7 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -4,7 +4,9 @@ import json from unittest import TestCase -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, ANY + +import boto3 import pytest from sagemaker.jumpstart.constants import ( @@ -755,6 +757,9 @@ def test_get_model_url( patched_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) + mock_client = boto3.client("s3") + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) model_id, version = "xgboost-classification-model", "1.0.0" assert "https://xgboost.readthedocs.io/en/latest/" == get_model_url(model_id, version) @@ -773,12 +778,14 @@ def test_get_model_url( **{key: value for key, value in kwargs.items() if key != "region"}, ) - get_model_url(model_id, version, region="us-west-2") + get_model_url(model_id, version, region="us-west-2", sagemaker_session=mock_session) patched_get_model_specs.assert_called_once_with( model_id=model_id, version=version, region="us-west-2", - s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, + s3_client=ANY, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index a3425a7b90..8368f72d58 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -129,6 +129,7 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( sagemaker_session=mock_session, model_type=JumpStartModelType.OPEN_WEIGHTS, config_name=None, + hub_arn=None, ) diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 06099ee066..884639b5d6 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -377,6 +377,7 @@ def test_jumpstart_model_specs(): { "name": "epochs", "type": "int", + # "_is_hub_content": False, "default": 3, "min": 1, "max": 1000, @@ -387,6 +388,7 @@ def test_jumpstart_model_specs(): { "name": "adam-learning-rate", "type": "float", + # "_is_hub_content": False, "default": 0.05, "min": 1e-08, "max": 1, @@ -397,6 +399,7 @@ def test_jumpstart_model_specs(): { "name": "batch-size", "type": "int", + # "_is_hub_content": False, "default": 4, "min": 1, "max": 1024, @@ -407,6 +410,7 @@ def test_jumpstart_model_specs(): { "name": "sagemaker_submit_directory", "type": "text", + # "_is_hub_content": False, "default": "/opt/ml/input/data/code/sourcedir.tar.gz", "scope": "container", } @@ -415,6 +419,7 @@ def test_jumpstart_model_specs(): { "name": "sagemaker_program", "type": "text", + # "_is_hub_content": False, "default": "transfer_learning.py", "scope": "container", } @@ -423,6 +428,7 @@ def test_jumpstart_model_specs(): { "name": "sagemaker_container_log_level", "type": "text", + # "_is_hub_content": False, "default": "20", "scope": "container", } diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index cc4ef71cee..de274f0374 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -23,8 +23,8 @@ JUMPSTART_REGION_NAME_SET, ) from sagemaker.jumpstart.types import ( - JumpStartCachedS3ContentKey, - JumpStartCachedS3ContentValue, + JumpStartCachedContentKey, + JumpStartCachedContentValue, JumpStartModelSpecs, JumpStartS3FileType, JumpStartModelHeader, @@ -115,7 +115,9 @@ def get_prototype_model_spec( model_id: str = None, version: str = None, s3_client: boto3.client = None, + hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. @@ -131,7 +133,9 @@ def get_special_model_spec( model_id: str = None, version: str = None, s3_client: boto3.client = None, + hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. This is reserved @@ -147,7 +151,9 @@ def get_special_model_spec_for_inference_component_based_endpoint( model_id: str = None, version: str = None, s3_client: boto3.client = None, + hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID and adding @@ -170,8 +176,10 @@ def get_spec_from_base_spec( model_id: str = None, version_str: str = None, version: str = None, + hub_arn: Optional[str] = None, s3_client: boto3.client = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: if version and version_str: @@ -194,6 +202,7 @@ def get_spec_from_base_spec( "catboost" not in model_id, "lightgbm" not in model_id, "sklearn" not in model_id, + "ai21" not in model_id, ] ): raise KeyError("Bad model ID") @@ -216,8 +225,10 @@ def get_base_spec_with_prototype_configs( region: str = None, model_id: str = None, version: str = None, + hub_arn: Optional[str] = None, s3_client: boto3.client = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: spec = copy.deepcopy(BASE_SPEC) inference_configs = {**INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} @@ -255,6 +266,8 @@ def get_prototype_spec_with_configs( version: str = None, s3_client: boto3.client = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + hub_arn: str = None, + sagemaker_session: boto3.Session = None, ) -> JumpStartModelSpecs: spec = copy.deepcopy(PROTOTYPICAL_MODEL_SPECS_DICT[model_id]) inference_configs = {**INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} @@ -268,33 +281,31 @@ def get_prototype_spec_with_configs( def patched_retrieval_function( _modelCacheObj: JumpStartModelsCache, - key: JumpStartCachedS3ContentKey, - value: JumpStartCachedS3ContentValue, -) -> JumpStartCachedS3ContentValue: + key: JumpStartCachedContentKey, + value: JumpStartCachedContentValue, +) -> JumpStartCachedContentValue: - filetype, s3_key = key.file_type, key.s3_key - if filetype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: + data_type, id_info = key.data_type, key.id_info + if data_type == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: - return JumpStartCachedS3ContentValue( - formatted_content=get_formatted_manifest(BASE_MANIFEST) - ) + return JumpStartCachedContentValue(formatted_content=get_formatted_manifest(BASE_MANIFEST)) - if filetype == JumpStartS3FileType.OPEN_WEIGHT_SPECS: - _, model_id, specs_version = s3_key.split("/") + if data_type == JumpStartS3FileType.OPEN_WEIGHT_SPECS: + _, model_id, specs_version = id_info.split("/") version = specs_version.replace("specs_v", "").replace(".json", "") - return JumpStartCachedS3ContentValue( + return JumpStartCachedContentValue( formatted_content=get_spec_from_base_spec(model_id=model_id, version=version) ) - if filetype == JumpStartS3FileType.PROPRIETARY_MANIFEST: - return JumpStartCachedS3ContentValue( + if data_type == JumpStartS3FileType.PROPRIETARY_MANIFEST: + return JumpStartCachedContentValue( formatted_content=get_formatted_manifest(BASE_PROPRIETARY_MANIFEST) ) - if filetype == JumpStartS3FileType.PROPRIETARY_SPECS: - _, model_id, specs_version = s3_key.split("/") + if data_type == JumpStartS3FileType.PROPRIETARY_SPECS: + _, model_id, specs_version = id_info.split("/") version = specs_version.replace("proprietary_specs_", "").replace(".json", "") - return JumpStartCachedS3ContentValue( + return JumpStartCachedContentValue( formatted_content=get_spec_from_base_spec( model_id=model_id, version=version, @@ -302,7 +313,7 @@ def patched_retrieval_function( ) ) - raise ValueError(f"Bad value for filetype: {filetype}") + raise ValueError(f"Bad value for filetype: {data_type}") def overwrite_dictionary( @@ -337,7 +348,7 @@ def get_base_deployment_configs_with_acceleration_configs() -> List[Dict[str, An def get_mock_init_kwargs( model_id: str, config_name: Optional[str] = None ) -> JumpStartModelInitKwargs: - return JumpStartModelInitKwargs( + kwargs = JumpStartModelInitKwargs( model_id=model_id, model_type=JumpStartModelType.OPEN_WEIGHTS, model_data=INIT_KWARGS.get("model_data"), @@ -347,6 +358,9 @@ def get_mock_init_kwargs( resources=ResourceRequirements(), config_name=config_name, ) + setattr(kwargs, "model_reference_arn", None) + setattr(kwargs, "hub_content_type", None) + return kwargs def get_base_deployment_configs_metadata( diff --git a/tests/unit/sagemaker/local/test_local_session.py b/tests/unit/sagemaker/local/test_local_session.py index ceae674704..ce8fd19b5c 100644 --- a/tests/unit/sagemaker/local/test_local_session.py +++ b/tests/unit/sagemaker/local/test_local_session.py @@ -47,7 +47,8 @@ @patch("sagemaker.local.image._SageMakerContainer.process") @patch("sagemaker.local.local_session.LocalSession") -def test_create_processing_job(process, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_processing_job(process, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -142,7 +143,8 @@ def test_create_processing_job(process, LocalSession): @patch("sagemaker.local.image._SageMakerContainer.process") @patch("sagemaker.local.local_session.LocalSession") -def test_create_processing_job_not_fully_replicated(process, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_processing_job_not_fully_replicated(process, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -197,7 +199,8 @@ def test_create_processing_job_not_fully_replicated(process, LocalSession): @patch("sagemaker.local.image._SageMakerContainer.process") @patch("sagemaker.local.local_session.LocalSession") -def test_create_processing_job_invalid_upload_mode(process, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_processing_job_invalid_upload_mode(process, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -252,7 +255,8 @@ def test_create_processing_job_invalid_upload_mode(process, LocalSession): @patch("sagemaker.local.image._SageMakerContainer.process") @patch("sagemaker.local.local_session.LocalSession") -def test_create_processing_job_invalid_processing_input(process, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_processing_job_invalid_processing_input(process, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -302,7 +306,8 @@ def test_create_processing_job_invalid_processing_input(process, LocalSession): @patch("sagemaker.local.image._SageMakerContainer.process") @patch("sagemaker.local.local_session.LocalSession") -def test_create_processing_job_invalid_processing_output(process, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_processing_job_invalid_processing_output(process, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -360,7 +365,8 @@ def test_describe_invalid_processing_job(*args): @patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model") @patch("sagemaker.local.local_session.LocalSession") -def test_create_training_job(train, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_training_job(train, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -427,7 +433,8 @@ def test_describe_invalid_training_job(*args): @patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model") @patch("sagemaker.local.local_session.LocalSession") -def test_create_training_job_invalid_data_source(train, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_training_job_invalid_data_source(train, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -466,7 +473,8 @@ def test_create_training_job_invalid_data_source(train, LocalSession): @patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model") @patch("sagemaker.local.local_session.LocalSession") -def test_create_training_job_not_fully_replicated(train, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_training_job_not_fully_replicated(train, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -503,7 +511,8 @@ def test_create_training_job_not_fully_replicated(train, LocalSession): @patch("sagemaker.local.local_session.LocalSession") -def test_create_model(LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_model(LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER) @@ -512,7 +521,8 @@ def test_create_model(LocalSession): @patch("sagemaker.local.local_session.LocalSession") -def test_delete_model(LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_delete_model(LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER) @@ -523,7 +533,8 @@ def test_delete_model(LocalSession): @patch("sagemaker.local.local_session.LocalSession") -def test_describe_model(LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_describe_model(LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() with pytest.raises(ClientError): @@ -536,9 +547,10 @@ def test_describe_model(LocalSession): assert response["PrimaryContainer"]["ModelDataUrl"] == "/some/model/path" +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) @patch("sagemaker.local.local_session._LocalTransformJob") @patch("sagemaker.local.local_session.LocalSession") -def test_create_transform_job(LocalSession, _LocalTransformJob): +def test_create_transform_job(LocalSession, _LocalTransformJob, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() local_sagemaker_client.create_transform_job("transform-job", "some-model", None, None, None) @@ -572,7 +584,8 @@ def test_logs_for_processing_job(process, LocalSession): @patch("sagemaker.local.local_session.LocalSession") -def test_describe_endpoint_config(LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_describe_endpoint_config(LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() # No Endpoint Config Created @@ -588,7 +601,8 @@ def test_describe_endpoint_config(LocalSession): @patch("sagemaker.local.local_session.LocalSession") -def test_create_endpoint_config(LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_endpoint_config(LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS) @@ -598,7 +612,8 @@ def test_create_endpoint_config(LocalSession): @patch("sagemaker.local.local_session.LocalSession") -def test_delete_endpoint_config(LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_delete_endpoint_config(LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS) @@ -613,12 +628,15 @@ def test_delete_endpoint_config(LocalSession): ) +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) @patch("sagemaker.local.image._SageMakerContainer.serve") @patch("sagemaker.local.local_session.LocalSession") @patch("urllib3.PoolManager.request") @patch("sagemaker.local.local_session.LocalSagemakerClient.describe_endpoint_config") @patch("sagemaker.local.local_session.LocalSagemakerClient.describe_model") -def test_describe_endpoint(describe_model, describe_endpoint_config, request, *args): +def test_describe_endpoint( + describe_model, describe_endpoint_config, request, mock_telemetry, *args +): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() request.return_value = OK_RESPONSE @@ -658,12 +676,13 @@ def test_describe_endpoint(describe_model, describe_endpoint_config, request, *a assert response["EndpointName"] == "test-endpoint" +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) @patch("sagemaker.local.image._SageMakerContainer.serve") @patch("sagemaker.local.local_session.LocalSession") @patch("urllib3.PoolManager.request") @patch("sagemaker.local.local_session.LocalSagemakerClient.describe_endpoint_config") @patch("sagemaker.local.local_session.LocalSagemakerClient.describe_model") -def test_create_endpoint(describe_model, describe_endpoint_config, request, *args): +def test_create_endpoint(describe_model, describe_endpoint_config, request, mock_telemetry, *args): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() request.return_value = OK_RESPONSE diff --git a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py index 835a09a58c..12d3a2169d 100644 --- a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py +++ b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py @@ -59,6 +59,8 @@ def test_jumpstart_default_metric_definitions( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, + hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -79,6 +81,8 @@ def test_jumpstart_default_metric_definitions( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, + hub_arn=None, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/mlflow/__init__.py b/tests/unit/sagemaker/mlflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/mlflow/test_tracking_server.py b/tests/unit/sagemaker/mlflow/test_tracking_server.py new file mode 100644 index 0000000000..1fc4943f16 --- /dev/null +++ b/tests/unit/sagemaker/mlflow/test_tracking_server.py @@ -0,0 +1,42 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from __future__ import absolute_import +from sagemaker.mlflow.tracking_server import generate_mlflow_presigned_url + + +def test_generate_presigned_url(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_presigned_mlflow_tracking_server_url.return_value = { + "AuthorizedUrl": "https://t-wo.example.com", + } + url = generate_mlflow_presigned_url( + "w", + expires_in_seconds=10, + session_expiration_duration_in_seconds=5, + sagemaker_session=sagemaker_session, + ) + client.create_presigned_mlflow_tracking_server_url.assert_called_with( + TrackingServerName="w", ExpiresInSeconds=10, SessionExpirationDurationInSeconds=5 + ) + assert url == "https://t-wo.example.com" + + +def test_generate_presigned_url_minimal(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_presigned_mlflow_tracking_server_url.return_value = { + "AuthorizedUrl": "https://t-wo.example.com", + } + url = generate_mlflow_presigned_url("w", sagemaker_session=sagemaker_session) + client.create_presigned_mlflow_tracking_server_url.assert_called_with(TrackingServerName="w") + assert url == "https://t-wo.example.com" diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 69ea2c1f56..50f6c370d5 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -114,7 +114,11 @@ def test_deploy(name_from_base, prepare_container_def, production_variant, sagem assert 2 == name_from_base.call_count prepare_container_def.assert_called_with( - INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None, accept_eula=None + INSTANCE_TYPE, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None, ) production_variant.assert_called_with( MODEL_NAME, @@ -930,7 +934,11 @@ def test_deploy_customized_volume_size_and_timeout( assert 2 == name_from_base.call_count prepare_container_def.assert_called_with( - INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None, accept_eula=None + INSTANCE_TYPE, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None, ) production_variant.assert_called_with( MODEL_NAME, diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index c0b18a3eb3..e43ad0ed0a 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -287,7 +287,11 @@ def test_create_sagemaker_model(prepare_container_def, sagemaker_session): model._create_sagemaker_model() prepare_container_def.assert_called_with( - None, accelerator_type=None, serverless_inference_config=None, accept_eula=None + None, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None, ) sagemaker_session.create_model.assert_called_with( name=MODEL_NAME, @@ -305,7 +309,11 @@ def test_create_sagemaker_model_instance_type(prepare_container_def, sagemaker_s model._create_sagemaker_model(INSTANCE_TYPE) prepare_container_def.assert_called_with( - INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None, accept_eula=None + INSTANCE_TYPE, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None, ) @@ -321,6 +329,7 @@ def test_create_sagemaker_model_accelerator_type(prepare_container_def, sagemake accelerator_type=accelerator_type, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ) @@ -336,6 +345,7 @@ def test_create_sagemaker_model_with_eula(prepare_container_def, sagemaker_sessi accelerator_type=accelerator_type, serverless_inference_config=None, accept_eula=True, + model_reference_arn=None, ) @@ -351,6 +361,7 @@ def test_create_sagemaker_model_with_eula_false(prepare_container_def, sagemaker accelerator_type=accelerator_type, serverless_inference_config=None, accept_eula=False, + model_reference_arn=None, ) diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index 9bfc830a75..062ffaf2ed 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -19,7 +19,8 @@ import sagemaker from sagemaker.model import ModelPackage -from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum +from sagemaker.model_card.model_card import ModelCard, ModelOverview +from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum, ModelCardStatusEnum MODEL_PACKAGE_VERSIONED_ARN = ( "arn:aws:sagemaker:us-west-2:001234567890:model-package/testmodelgroup/1" @@ -56,6 +57,10 @@ "ModelPackageStatus": "Completed", "ModelPackageName": "mp-scikit-decision-trees-1542410022-2018-11-20-22-13-56-502", "CertifyForMarketplace": False, + "ModelCard": { + "ModelCardStatus": "Draft", + "ModelCardContent": '{"model_overview": {"model_creator": "updatedCreator", "model_artifact": []}}', + }, } MODEL_DATA = { @@ -442,3 +447,48 @@ def test_update_source_uri(sagemaker_session): sagemaker_session.sagemaker_client.update_model_package.assert_called_with( ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, SourceUri=source_uri ) + + +def test_update_model_card(sagemaker_session): + model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE) + + sagemaker_session.sagemaker_client.describe_model_package = Mock( + return_value=model_package_response + ) + model_package = ModelPackage( + role="role", + model_package_arn=MODEL_PACKAGE_VERSIONED_ARN, + sagemaker_session=sagemaker_session, + ) + + update_my_card = ModelCard( + name="UpdateTestName", + sagemaker_session=sagemaker_session, + status=ModelCardStatusEnum.PENDING_REVIEW, + ) + model_package.update_model_card(update_my_card) + update_my_card_req = update_my_card._create_request_args() + del update_my_card_req["ModelCardName"] + del update_my_card_req["Content"] + sagemaker_session.sagemaker_client.update_model_package.assert_called_with( + ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, ModelCard=update_my_card_req + ) + + model_overview = ModelOverview( + model_creator="UpdatedNewCreator", + ) + update_my_card_1 = ModelCard( + name="UpdateTestName", + sagemaker_session=sagemaker_session, + status=ModelCardStatusEnum.DRAFT, + model_overview=model_overview, + ) + model_package.update_model_card(update_my_card_1) + update_my_card_req_1 = update_my_card_1._create_request_args() + del update_my_card_req_1["ModelCardName"] + del update_my_card_req_1["ModelCardStatus"] + update_my_card_req_1["ModelCardContent"] = update_my_card_req_1["Content"] + del update_my_card_req_1["Content"] + sagemaker_session.sagemaker_client.update_model_package.assert_called_with( + ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, ModelCard=update_my_card_req_1 + ) diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 8ec9478d8a..e71207d439 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -54,6 +54,8 @@ def test_jumpstart_common_model_uri( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -72,6 +74,8 @@ def test_jumpstart_common_model_uri( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -91,6 +95,8 @@ def test_jumpstart_common_model_uri( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -110,6 +116,8 @@ def test_jumpstart_common_model_uri( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/remote_function/test_client.py b/tests/unit/sagemaker/remote_function/test_client.py index 20d05a933e..1d752f89ed 100644 --- a/tests/unit/sagemaker/remote_function/test_client.py +++ b/tests/unit/sagemaker/remote_function/test_client.py @@ -15,6 +15,7 @@ import os import threading import time +import inspect import pytest from mock import MagicMock, patch, Mock, ANY, call @@ -1498,7 +1499,6 @@ def test_consistency_between_remote_and_step_decorator(): from sagemaker.workflow.function_step import step remote_args_to_ignore = [ - "_remote", "include_local_workdir", "custom_file_filter", "s3_kms_key", @@ -1508,7 +1508,7 @@ def test_consistency_between_remote_and_step_decorator(): step_args_to_ignore = ["_step", "name", "display_name", "description", "retry_policies"] - remote_decorator_args = remote.__code__.co_varnames + remote_decorator_args = inspect.signature(remote).parameters.keys() common_remote_decorator_args = set(remote_args_to_ignore) ^ set(remote_decorator_args) step_decorator_args = step.__code__.co_varnames @@ -1522,8 +1522,7 @@ def test_consistency_between_remote_and_executor(): executor_arg_list.remove("self") executor_arg_list.remove("max_parallel_jobs") - remote_args_list = list(remote.__code__.co_varnames) - remote_args_list.remove("_remote") + remote_args_list = list(inspect.signature(remote).parameters.keys()) remote_args_list.remove("_func") assert executor_arg_list == remote_args_list diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index 1c0cfa35b3..d149e08cab 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -56,6 +56,8 @@ def test_jumpstart_resource_requirements( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -76,6 +78,7 @@ def test_jumpstart_resource_requirements_instance_type_variants(patched_get_mode scope="inference", sagemaker_session=mock_session, instance_type="ml.g5.xlarge", + hub_arn=None, ) assert default_inference_resource_requirements.requests == { "memory": 81999, @@ -89,6 +92,7 @@ def test_jumpstart_resource_requirements_instance_type_variants(patched_get_mode scope="inference", sagemaker_session=mock_session, instance_type="ml.g5.555xlarge", + hub_arn=None, ) assert default_inference_resource_requirements.requests == { "memory": 81999, @@ -102,6 +106,7 @@ def test_jumpstart_resource_requirements_instance_type_variants(patched_get_mode scope="inference", sagemaker_session=mock_session, instance_type="ml.f9.555xlarge", + hub_arn=None, ) assert default_inference_resource_requirements.requests == { "memory": 81999, @@ -138,6 +143,8 @@ def test_jumpstart_no_supported_resource_requirements( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 16b7256ed2..b67f238cac 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -53,7 +53,9 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -71,7 +73,9 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -90,7 +94,9 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -109,7 +115,9 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index 90ec5df6b5..dde308dcfb 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -53,9 +53,11 @@ def test_jumpstart_default_serializers( patched_get_model_specs.assert_called_once_with( region=region, model_id=model_id, + hub_arn=None, version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -99,6 +101,8 @@ def test_jumpstart_serializer_options( region=region, model_id=model_id, version=model_version, + hub_arn=None, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/serve/builder/test_djl_builder.py b/tests/unit/sagemaker/serve/builder/test_djl_builder.py index ccabdb86b3..7b0c67f326 100644 --- a/tests/unit/sagemaker/serve/builder/test_djl_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_djl_builder.py @@ -15,14 +15,9 @@ import unittest from sagemaker.serve.builder.model_builder import ModelBuilder -from sagemaker.serve.utils.types import _DjlEngine from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve import ModelServer -from sagemaker.djl_inference.model import ( - DeepSpeedModel, - FasterTransformerModel, - HuggingFaceAccelerateModel, -) +from sagemaker.djl_inference.model import DJLModel from sagemaker.serve.utils.exceptions import ( LocalDeepPingException, LocalModelLoadException, @@ -33,45 +28,23 @@ from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG mock_model_id = "TheBloke/Llama-2-7b-chat-fp16" -mock_t5_model_id = "google/flan-t5-xxl" mock_prompt = "Hello, I'm a language model," mock_response = "Hello, I'm a language model, and I'm here to help you with your English." mock_sample_input = {"inputs": mock_prompt, "parameters": {}} mock_sample_output = [{"generated_text": mock_response}] -mock_expected_huggingfaceaccelerate_serving_properties = { - "engine": "Python", - "option.entryPoint": "inference.py", - "option.model_id": "TheBloke/Llama-2-7b-chat-fp16", - "option.tensor_parallel_degree": 4, - "option.dtype": "fp16", -} -mock_expected_deepspeed_serving_properties = { - "engine": "DeepSpeed", - "option.entryPoint": "inference.py", - "option.model_id": "TheBloke/Llama-2-7b-chat-fp16", - "option.tensor_parallel_degree": 4, - "option.dtype": "fp16", - "option.max_tokens": 256, - "option.triangular_masking": True, - "option.return_tuple": True, -} -mock_expected_fastertransformer_serving_properties = { - "engine": "FasterTransformer", - "option.entryPoint": "inference.py", - "option.model_id": "google/flan-t5-xxl", - "option.tensor_parallel_degree": 4, - "option.dtype": "fp16", +mock_default_configs = { + "HF_MODEL_ID": mock_model_id, + "OPTION_ENGINE": "Python", + "TENSOR_PARALLEL_DEGREE": "max", + "OPTION_DTYPE": "bf16", + "MODEL_LOADING_TIMEOUT": "1800", } mock_most_performant_serving_properties = { - "engine": "Python", - "option.entryPoint": "inference.py", - "option.model_id": "TheBloke/Llama-2-7b-chat-fp16", - "option.tensor_parallel_degree": 1, - "option.dtype": "bf16", + "OPTION_ENGINE": "Python", + "HF_MODEL_ID": "TheBloke/Llama-2-7b-chat-fp16", + "TENSOR_PARALLEL_DEGREE": "1", + "OPTION_DTYPE": "bf16", } -mock_model_config_properties = {"model_type": "llama", "num_attention_heads": 32} -mock_model_config_properties_faster_transformer = {"model_type": "t5", "num_attention_heads": 32} -mock_set_serving_properties = (4, "fp16", 1, 256, 256) mock_schema_builder = MagicMock() mock_schema_builder.sample_input = mock_sample_input @@ -88,24 +61,12 @@ class TestDjlBuilder(unittest.TestCase): "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", return_value=False, ) - @patch( - "sagemaker.serve.builder.djl_builder._auto_detect_engine", - return_value=(_DjlEngine.HUGGINGFACE_ACCELERATE, mock_model_config_properties), - ) - @patch( - "sagemaker.serve.builder.djl_builder._set_serve_properties", - return_value=mock_set_serving_properties, - ) - @patch("sagemaker.serve.builder.djl_builder.prepare_for_djl_serving", side_effect=None) @patch("sagemaker.serve.builder.djl_builder._get_ram_usage_mb", return_value=1024) @patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge") def test_build_deploy_for_djl_local_container( self, mock_get_nb_instance, mock_get_ram_usage_mb, - mock_prepare_for_djl_serving, - mock_set_serving_properties, - mock_auto_detect_engine, mock_is_jumpstart_model, mock_telemetry, ): @@ -124,20 +85,12 @@ def test_build_deploy_for_djl_local_container( model = builder.build() builder.serve_settings.telemetry_opt_out = True - assert isinstance(model, HuggingFaceAccelerateModel) - assert ( - model.generate_serving_properties() - == mock_expected_huggingfaceaccelerate_serving_properties - ) - assert builder._default_tensor_parallel_degree == 4 - assert builder._default_data_type == "fp16" - assert builder._default_max_tokens == 256 - assert builder._default_max_new_tokens == 256 - assert builder.schema_builder.sample_input["parameters"]["max_new_tokens"] == 256 + assert isinstance(model, DJLModel) + assert builder.schema_builder.sample_input["parameters"]["max_new_tokens"] == 128 assert builder.nb_instance_type == "ml.g5.24xlarge" assert model.image_config == MOCK_IMAGE_CONFIG assert model.vpc_config == MOCK_VPC_CONFIG - assert "deepspeed" in builder.image_uri + assert "lmi" in builder.image_uri builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() predictor = model.deploy(model_data_download_timeout=1800) @@ -153,100 +106,11 @@ def test_build_deploy_for_djl_local_container( with self.assertRaises(ValueError) as _: model.deploy(mode=Mode.IN_PROCESS) - @patch( - "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", - return_value=False, - ) - @patch( - "sagemaker.serve.builder.djl_builder._auto_detect_engine", - return_value=( - _DjlEngine.FASTER_TRANSFORMER, - mock_model_config_properties_faster_transformer, - ), - ) - @patch( - "sagemaker.serve.builder.djl_builder._set_serve_properties", - return_value=mock_set_serving_properties, - ) - @patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge") - def test_build_for_djl_local_container_faster_transformer( - self, - mock_get_nb_instance, - mock_set_serving_properties, - mock_auto_detect_engine, - mock_is_jumpstart_model, - ): - builder = ModelBuilder( - model=mock_t5_model_id, - schema_builder=mock_schema_builder, - mode=Mode.LOCAL_CONTAINER, - model_server=ModelServer.DJL_SERVING, - image_config=MOCK_IMAGE_CONFIG, - vpc_config=MOCK_VPC_CONFIG, - ) - model = builder.build() - builder.serve_settings.telemetry_opt_out = True - - assert isinstance(model, FasterTransformerModel) - assert ( - model.generate_serving_properties() - == mock_expected_fastertransformer_serving_properties - ) - assert model.image_config == MOCK_IMAGE_CONFIG - assert model.vpc_config == MOCK_VPC_CONFIG - assert "fastertransformer" in builder.image_uri - - @patch( - "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", - return_value=False, - ) - @patch( - "sagemaker.serve.builder.djl_builder._auto_detect_engine", - return_value=(_DjlEngine.DEEPSPEED, mock_model_config_properties), - ) - @patch( - "sagemaker.serve.builder.djl_builder._set_serve_properties", - return_value=mock_set_serving_properties, - ) - @patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge") - def test_build_for_djl_local_container_deepspeed( - self, - mock_get_nb_instance, - mock_set_serving_properties, - mock_auto_detect_engine, - mock_is_jumpstart_model, - ): - builder = ModelBuilder( - model=mock_model_id, - schema_builder=mock_schema_builder, - mode=Mode.LOCAL_CONTAINER, - model_server=ModelServer.DJL_SERVING, - image_config=MOCK_IMAGE_CONFIG, - vpc_config=MOCK_VPC_CONFIG, - ) - model = builder.build() - builder.serve_settings.telemetry_opt_out = True - - assert isinstance(model, DeepSpeedModel) - assert model.image_config == MOCK_IMAGE_CONFIG - assert model.vpc_config == MOCK_VPC_CONFIG - assert model.generate_serving_properties() == mock_expected_deepspeed_serving_properties - assert "deepspeed" in builder.image_uri - @patch("sagemaker.serve.builder.djl_builder._capture_telemetry", side_effect=None) @patch( "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", return_value=False, ) - @patch( - "sagemaker.serve.builder.djl_builder._auto_detect_engine", - return_value=(_DjlEngine.HUGGINGFACE_ACCELERATE, mock_model_config_properties), - ) - @patch( - "sagemaker.serve.builder.djl_builder._set_serve_properties", - return_value=mock_set_serving_properties, - ) - @patch("sagemaker.serve.builder.djl_builder.prepare_for_djl_serving", side_effect=None) @patch("sagemaker.serve.builder.djl_builder._get_ram_usage_mb", return_value=1024) @patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge") @patch( @@ -268,9 +132,6 @@ def test_tune_for_djl_local_container( mock_admissible_tensor_parallel_degrees, mock_get_nb_instance, mock_get_ram_usage_mb, - mock_prepare_for_djl_serving, - mock_set_serving_properties, - mock_auto_detect_engine, mock_is_jumpstart_model, mock_telemetry, ): @@ -287,22 +148,13 @@ def test_tune_for_djl_local_container( model = builder.build() builder.serve_settings.telemetry_opt_out = True tuned_model = model.tune() - assert tuned_model.generate_serving_properties() == mock_most_performant_serving_properties + assert tuned_model.env == mock_most_performant_serving_properties @patch("sagemaker.serve.builder.djl_builder._capture_telemetry", side_effect=None) @patch( "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", return_value=False, ) - @patch( - "sagemaker.serve.builder.djl_builder._auto_detect_engine", - return_value=(_DjlEngine.HUGGINGFACE_ACCELERATE, mock_model_config_properties), - ) - @patch( - "sagemaker.serve.builder.djl_builder._set_serve_properties", - return_value=mock_set_serving_properties, - ) - @patch("sagemaker.serve.builder.djl_builder.prepare_for_djl_serving", side_effect=None) @patch("sagemaker.serve.builder.djl_builder._get_ram_usage_mb", return_value=1024) @patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge") @patch( @@ -319,9 +171,6 @@ def test_tune_for_djl_local_container_deep_ping_ex( mock_serial_benchmarks, mock_get_nb_instance, mock_get_ram_usage_mb, - mock_prepare_for_djl_serving, - mock_set_serving_properties, - mock_auto_detect_engine, mock_is_jumpstart_model, mock_telemetry, ): @@ -337,25 +186,13 @@ def test_tune_for_djl_local_container_deep_ping_ex( model = builder.build() builder.serve_settings.telemetry_opt_out = True tuned_model = model.tune() - assert ( - tuned_model.generate_serving_properties() - == mock_expected_huggingfaceaccelerate_serving_properties - ) + assert tuned_model.env == mock_default_configs @patch("sagemaker.serve.builder.djl_builder._capture_telemetry", side_effect=None) @patch( "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", return_value=False, ) - @patch( - "sagemaker.serve.builder.djl_builder._auto_detect_engine", - return_value=(_DjlEngine.HUGGINGFACE_ACCELERATE, mock_model_config_properties), - ) - @patch( - "sagemaker.serve.builder.djl_builder._set_serve_properties", - return_value=mock_set_serving_properties, - ) - @patch("sagemaker.serve.builder.djl_builder.prepare_for_djl_serving", side_effect=None) @patch("sagemaker.serve.builder.djl_builder._get_ram_usage_mb", return_value=1024) @patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge") @patch( @@ -372,9 +209,6 @@ def test_tune_for_djl_local_container_load_ex( mock_serial_benchmarks, mock_get_nb_instance, mock_get_ram_usage_mb, - mock_prepare_for_djl_serving, - mock_set_serving_properties, - mock_auto_detect_engine, mock_is_jumpstart_model, mock_telemetry, ): @@ -390,25 +224,13 @@ def test_tune_for_djl_local_container_load_ex( model = builder.build() builder.serve_settings.telemetry_opt_out = True tuned_model = model.tune() - assert ( - tuned_model.generate_serving_properties() - == mock_expected_huggingfaceaccelerate_serving_properties - ) + assert tuned_model.env == mock_default_configs @patch("sagemaker.serve.builder.djl_builder._capture_telemetry", side_effect=None) @patch( "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", return_value=False, ) - @patch( - "sagemaker.serve.builder.djl_builder._auto_detect_engine", - return_value=(_DjlEngine.HUGGINGFACE_ACCELERATE, mock_model_config_properties), - ) - @patch( - "sagemaker.serve.builder.djl_builder._set_serve_properties", - return_value=mock_set_serving_properties, - ) - @patch("sagemaker.serve.builder.djl_builder.prepare_for_djl_serving", side_effect=None) @patch("sagemaker.serve.builder.djl_builder._get_ram_usage_mb", return_value=1024) @patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge") @patch( @@ -425,9 +247,6 @@ def test_tune_for_djl_local_container_oom_ex( mock_serial_benchmarks, mock_get_nb_instance, mock_get_ram_usage_mb, - mock_prepare_for_djl_serving, - mock_set_serving_properties, - mock_auto_detect_engine, mock_is_jumpstart_model, mock_telemetry, ): @@ -443,25 +262,13 @@ def test_tune_for_djl_local_container_oom_ex( model = builder.build() builder.serve_settings.telemetry_opt_out = True tuned_model = model.tune() - assert ( - tuned_model.generate_serving_properties() - == mock_expected_huggingfaceaccelerate_serving_properties - ) + assert tuned_model.env == mock_default_configs @patch("sagemaker.serve.builder.djl_builder._capture_telemetry", side_effect=None) @patch( "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", return_value=False, ) - @patch( - "sagemaker.serve.builder.djl_builder._auto_detect_engine", - return_value=(_DjlEngine.HUGGINGFACE_ACCELERATE, mock_model_config_properties), - ) - @patch( - "sagemaker.serve.builder.djl_builder._set_serve_properties", - return_value=mock_set_serving_properties, - ) - @patch("sagemaker.serve.builder.djl_builder.prepare_for_djl_serving", side_effect=None) @patch("sagemaker.serve.builder.djl_builder._get_ram_usage_mb", return_value=1024) @patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge") @patch( @@ -478,9 +285,6 @@ def test_tune_for_djl_local_container_invoke_ex( mock_serial_benchmarks, mock_get_nb_instance, mock_get_ram_usage_mb, - mock_prepare_for_djl_serving, - mock_set_serving_properties, - mock_auto_detect_engine, mock_is_jumpstart_model, mock_telemetry, ): @@ -496,10 +300,7 @@ def test_tune_for_djl_local_container_invoke_ex( model = builder.build() builder.serve_settings.telemetry_opt_out = True tuned_model = model.tune() - assert ( - tuned_model.generate_serving_properties() - == mock_expected_huggingfaceaccelerate_serving_properties - ) + assert tuned_model.env == mock_default_configs @patch( "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 3671d2382e..19a06dd5bb 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -21,6 +21,7 @@ from sagemaker.serve.builder.model_builder import ModelBuilder from sagemaker.serve.mode.function_pointers import Mode +from sagemaker.serve.model_format.mlflow.constants import MLFLOW_TRACKING_ARN from sagemaker.serve.utils import task from sagemaker.serve.utils.exceptions import TaskNotFoundException from sagemaker.serve.utils.predictors import TensorflowServingLocalPredictor @@ -51,11 +52,14 @@ mock_secret_key = "mock_secret_key" mock_instance_type = "mock instance type" -supported_model_server = { +supported_model_servers = { ModelServer.TORCHSERVE, ModelServer.TRITON, ModelServer.DJL_SERVING, ModelServer.TENSORFLOW_SERVING, + ModelServer.MMS, + ModelServer.TGI, + ModelServer.TEI, } mock_session = MagicMock() @@ -79,7 +83,7 @@ def test_validation_cannot_set_both_model_and_inference_spec(self, mock_serveSet builder = ModelBuilder(inference_spec="some value", model=Mock(spec=object)) self.assertRaisesRegex( Exception, - "Cannot have both the Model and Inference spec in the builder", + "Can only set one of the following: model, inference_spec.", builder.build, Mode.SAGEMAKER_ENDPOINT, mock_role_arn, @@ -92,7 +96,7 @@ def test_validation_unsupported_model_server_type(self, mock_serveSettings): self.assertRaisesRegex( Exception, "%s is not supported yet! Supported model servers: %s" - % (builder.model_server, supported_model_server), + % (builder.model_server, supported_model_servers), builder.build, Mode.SAGEMAKER_ENDPOINT, mock_role_arn, @@ -105,7 +109,7 @@ def test_validation_model_server_not_set_with_image_uri(self, mock_serveSettings self.assertRaisesRegex( Exception, "Model_server must be set when non-first-party image_uri is set. " - + "Supported model servers: %s" % supported_model_server, + + "Supported model servers: %s" % supported_model_servers, builder.build, Mode.SAGEMAKER_ENDPOINT, mock_role_arn, @@ -126,6 +130,120 @@ def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set mock_session, ) + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl") + def test_model_server_override_djl_with_model(self, mock_build_for_djl, mock_serve_settings): + mock_setting_object = mock_serve_settings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + builder = ModelBuilder(model_server=ModelServer.DJL_SERVING, model="gpt_llm_burt") + builder.build(sagemaker_session=mock_session) + + mock_build_for_djl.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_model_server_override_djl_without_model_or_mlflow(self, mock_serve_settings): + builder = ModelBuilder( + model_server=ModelServer.DJL_SERVING, model=None, inference_spec=None + ) + self.assertRaisesRegex( + Exception, + "Missing required parameter `model` or 'ml_flow' path", + builder.build, + Mode.SAGEMAKER_ENDPOINT, + mock_role_arn, + mock_session, + ) + + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_torchserve") + def test_model_server_override_torchserve_with_model( + self, mock_build_for_ts, mock_serve_settings + ): + mock_setting_object = mock_serve_settings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + builder = ModelBuilder(model_server=ModelServer.TORCHSERVE, model="gpt_llm_burt") + builder.build(sagemaker_session=mock_session) + + mock_build_for_ts.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_model_server_override_torchserve_without_model_or_mlflow(self, mock_serve_settings): + builder = ModelBuilder(model_server=ModelServer.TORCHSERVE) + self.assertRaisesRegex( + Exception, + "Missing required parameter `model` or 'ml_flow' path", + builder.build, + Mode.SAGEMAKER_ENDPOINT, + mock_role_arn, + mock_session, + ) + + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_triton") + def test_model_server_override_triton_with_model(self, mock_build_for_ts, mock_serve_settings): + mock_setting_object = mock_serve_settings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + builder = ModelBuilder(model_server=ModelServer.TRITON, model="gpt_llm_burt") + builder.build(sagemaker_session=mock_session) + + mock_build_for_ts.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tensorflow_serving") + def test_model_server_override_tensor_with_model(self, mock_build_for_ts, mock_serve_settings): + mock_setting_object = mock_serve_settings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + builder = ModelBuilder(model_server=ModelServer.TENSORFLOW_SERVING, model="gpt_llm_burt") + builder.build(sagemaker_session=mock_session) + + mock_build_for_ts.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tei") + def test_model_server_override_tei_with_model(self, mock_build_for_ts, mock_serve_settings): + mock_setting_object = mock_serve_settings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + builder = ModelBuilder(model_server=ModelServer.TEI, model="gpt_llm_burt") + builder.build(sagemaker_session=mock_session) + + mock_build_for_ts.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tgi") + def test_model_server_override_tgi_with_model(self, mock_build_for_ts, mock_serve_settings): + mock_setting_object = mock_serve_settings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + builder = ModelBuilder(model_server=ModelServer.TGI, model="gpt_llm_burt") + builder.build(sagemaker_session=mock_session) + + mock_build_for_ts.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers") + def test_model_server_override_transformers_with_model( + self, mock_build_for_ts, mock_serve_settings + ): + mock_setting_object = mock_serve_settings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + builder = ModelBuilder(model_server=ModelServer.MMS, model="gpt_llm_burt") + builder.build(sagemaker_session=mock_session) + + mock_build_for_ts.assert_called_once() + @patch("os.makedirs", Mock()) @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") @@ -1010,8 +1128,8 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co @patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1056,8 +1174,8 @@ def test_build_happy_path_when_schema_builder_not_present( @patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1102,8 +1220,8 @@ def test_build_negative_path_when_schema_builder_not_present( @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock()) @patch("sagemaker.serve.builder.model_builder.ModelBuilder._can_fit_on_single_gpu") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1140,51 +1258,11 @@ def test_build_can_fit_on_single_gpu( mock_can_fit_on_single_gpu.assert_called_once() - @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl") - @patch("sagemaker.serve.builder.model_builder.ModelBuilder._can_fit_on_single_gpu") - @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") - @patch("sagemaker.huggingface.llm_utils.urllib") - @patch("sagemaker.huggingface.llm_utils.json") - @patch("sagemaker.model_uris.retrieve") - @patch("sagemaker.serve.builder.model_builder._ServeSettings") - def test_build_is_deepspeed_model( - self, - mock_serveSettings, - mock_model_uris_retrieve, - mock_llm_utils_json, - mock_llm_utils_urllib, - mock_model_json, - mock_model_urllib, - mock_image_uris_retrieve, - mock_can_fit_on_single_gpu, - mock_build_for_djl, - ): - mock_setting_object = mock_serveSettings.return_value - mock_setting_object.role_arn = mock_role_arn - mock_setting_object.s3_model_data_url = mock_s3_model_data_url - - mock_model_uris_retrieve.side_effect = KeyError - mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-classification"} - mock_llm_utils_urllib.request.Request.side_effect = Mock() - - mock_model_json.load.return_value = {"some": "config"} - mock_model_urllib.request.Request.side_effect = Mock() - - mock_image_uris_retrieve.return_value = "https://some-image-uri" - mock_can_fit_on_single_gpu.return_value = False - - model_builder = ModelBuilder(model="stable-diffusion") - model_builder.build(sagemaker_session=mock_session) - - mock_build_for_djl.assert_called_once() - @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers") @patch("sagemaker.serve.builder.model_builder.ModelBuilder._can_fit_on_single_gpu") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1224,8 +1302,8 @@ def test_build_for_transformers_happy_case( @patch("sagemaker.serve.builder.model_builder.ModelBuilder._try_fetch_gpu_info") @patch("sagemaker.serve.builder.model_builder._total_inference_model_size_mib") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1263,12 +1341,12 @@ def test_build_for_transformers_happy_case_with_values( mock_build_for_transformers.assert_called_once() - @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl", Mock()) + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock()) @patch("sagemaker.serve.builder.model_builder._get_gpu_info") @patch("sagemaker.serve.builder.model_builder._total_inference_model_size_mib") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1312,8 +1390,8 @@ def test_build_for_transformers_happy_case_with_valid_gpu_info( @patch("sagemaker.serve.builder.model_builder._get_gpu_info_fallback") @patch("sagemaker.serve.builder.model_builder._total_inference_model_size_mib") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1359,51 +1437,11 @@ def test_build_for_transformers_happy_case_with_valid_gpu_fallback( ) self.assertEqual(model_builder._can_fit_on_single_gpu(), True) - @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl") - @patch("sagemaker.serve.builder.model_builder.ModelBuilder._can_fit_on_single_gpu") - @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") - @patch("sagemaker.huggingface.llm_utils.urllib") - @patch("sagemaker.huggingface.llm_utils.json") - @patch("sagemaker.model_uris.retrieve") - @patch("sagemaker.serve.builder.model_builder._ServeSettings") - def test_build_is_fast_transformers_model( - self, - mock_serveSettings, - mock_model_uris_retrieve, - mock_llm_utils_json, - mock_llm_utils_urllib, - mock_model_json, - mock_model_urllib, - mock_image_uris_retrieve, - mock_can_fit_on_single_gpu, - mock_build_for_djl, - ): - mock_setting_object = mock_serveSettings.return_value - mock_setting_object.role_arn = mock_role_arn - mock_setting_object.s3_model_data_url = mock_s3_model_data_url - - mock_model_uris_retrieve.side_effect = KeyError - mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-classification"} - mock_llm_utils_urllib.request.Request.side_effect = Mock() - - mock_model_json.load.return_value = {"some": "config"} - mock_model_urllib.request.Request.side_effect = Mock() - - mock_image_uris_retrieve.return_value = "https://some-image-uri" - mock_can_fit_on_single_gpu.return_value = False - - model_builder = ModelBuilder(model="gpt_neo") - model_builder.build(sagemaker_session=mock_session) - - mock_build_for_djl.assert_called_once() - @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers") @patch("sagemaker.serve.builder.model_builder.ModelBuilder._can_fit_on_single_gpu") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1442,8 +1480,8 @@ def test_build_fallback_to_transformers( @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tgi") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1480,8 +1518,8 @@ def test_text_generation( @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tei") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1519,8 +1557,8 @@ def test_sentence_similarity( @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock()) @patch("sagemaker.serve.builder.model_builder.ModelBuilder._try_fetch_gpu_info") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1558,8 +1596,8 @@ def test_try_fetch_gpu_info_throws( @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock()) @patch("sagemaker.serve.builder.model_builder._total_inference_model_size_mib") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1596,8 +1634,8 @@ def test_total_inference_model_size_mib_throws( @patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1643,8 +1681,8 @@ def test_build_happy_path_override_with_task_provided( self.assertEqual(sample_outputs, model_builder.schema_builder.sample_output) @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -2341,3 +2379,187 @@ def test_optimize(self, mock_send_telemetry, mock_get_serve_setting): mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( **expected_create_optimization_job_args ) + + def test_handle_mlflow_input_without_mlflow_model_path(self): + builder = ModelBuilder(model_metadata={}) + assert not builder._has_mlflow_arguments() + + @patch("importlib.util.find_spec") + @patch("mlflow.set_tracking_uri") + @patch("mlflow.get_run") + @patch.object(ModelBuilder, "_mlflow_metadata_exists", autospec=True) + @patch.object(ModelBuilder, "_initialize_for_mlflow", autospec=True) + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._validate_input_for_mlflow") + def test_handle_mlflow_input_run_id( + self, + mock_validate, + mock_s3_downloader, + mock_initialize, + mock_check_mlflow_model, + mock_get_run, + mock_set_tracking_uri, + mock_find_spec, + ): + mock_find_spec.return_value = True + mock_run_info = Mock() + mock_run_info.info.artifact_uri = "s3://bucket/path" + mock_get_run.return_value = mock_run_info + mock_check_mlflow_model.return_value = True + mock_s3_downloader.return_value = ["s3://some_path/MLmodel"] + + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "runs:/runid/mlflow-path", + "MLFLOW_TRACKING_ARN": "arn:aws:sagemaker:us-west-2:000000000000:mlflow-tracking-server/test", + } + ) + builder._handle_mlflow_input() + mock_initialize.assert_called_once_with(builder, "s3://bucket/path/mlflow-path") + + @patch("importlib.util.find_spec") + @patch("mlflow.set_tracking_uri") + @patch("mlflow.MlflowClient.get_model_version") + @patch.object(ModelBuilder, "_mlflow_metadata_exists", autospec=True) + @patch.object(ModelBuilder, "_initialize_for_mlflow", autospec=True) + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._validate_input_for_mlflow") + def test_handle_mlflow_input_registry_path_with_model_version( + self, + mock_validate, + mock_s3_downloader, + mock_initialize, + mock_check_mlflow_model, + mock_get_model_version, + mock_set_tracking_uri, + mock_find_spec, + ): + mock_find_spec.return_value = True + mock_registry_path = Mock() + mock_registry_path.source = "s3://bucket/path/" + mock_get_model_version.return_value = mock_registry_path + mock_check_mlflow_model.return_value = True + mock_s3_downloader.return_value = ["s3://some_path/MLmodel"] + + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "models:/model-name/1", + "MLFLOW_TRACKING_ARN": "arn:aws:sagemaker:us-west-2:000000000000:mlflow-tracking-server/test", + } + ) + builder._handle_mlflow_input() + mock_initialize.assert_called_once_with(builder, "s3://bucket/path/") + + @patch("importlib.util.find_spec") + @patch("mlflow.set_tracking_uri") + @patch("mlflow.MlflowClient.get_model_version_by_alias") + @patch.object(ModelBuilder, "_mlflow_metadata_exists", autospec=True) + @patch.object(ModelBuilder, "_initialize_for_mlflow", autospec=True) + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._validate_input_for_mlflow") + def test_handle_mlflow_input_registry_path_with_model_alias( + self, + mock_validate, + mock_s3_downloader, + mock_initialize, + mock_check_mlflow_model, + mock_get_model_version_by_alias, + mock_set_tracking_uri, + mock_find_spec, + ): + mock_find_spec.return_value = True + mock_registry_path = Mock() + mock_registry_path.source = "s3://bucket/path" + mock_get_model_version_by_alias.return_value = mock_registry_path + mock_check_mlflow_model.return_value = True + mock_s3_downloader.return_value = ["s3://some_path/MLmodel"] + + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "models:/model-name@production", + "MLFLOW_TRACKING_ARN": "arn:aws:sagemaker:us-west-2:000000000000:mlflow-tracking-server/test", + } + ) + builder._handle_mlflow_input() + mock_initialize.assert_called_once_with(builder, "s3://bucket/path/") + + @patch("mlflow.MlflowClient.get_model_version") + @patch.object(ModelBuilder, "_mlflow_metadata_exists", autospec=True) + @patch.object(ModelBuilder, "_initialize_for_mlflow", autospec=True) + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._validate_input_for_mlflow") + def test_handle_mlflow_input_registry_path_missing_tracking_server_arn( + self, + mock_validate, + mock_s3_downloader, + mock_initialize, + mock_check_mlflow_model, + mock_get_model_version, + ): + mock_registry_path = Mock() + mock_registry_path.source = "s3://bucket/path" + mock_get_model_version.return_value = mock_registry_path + mock_check_mlflow_model.return_value = True + mock_s3_downloader.return_value = ["s3://some_path/MLmodel"] + + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "models:/model-name/1", + } + ) + self.assertRaisesRegex( + Exception, + "%s is not provided in ModelMetadata or through set_tracking_arn " + "but MLflow model path was provided." % MLFLOW_TRACKING_ARN, + builder._handle_mlflow_input, + ) + + @patch.object(ModelBuilder, "_mlflow_metadata_exists", autospec=True) + @patch.object(ModelBuilder, "_initialize_for_mlflow", autospec=True) + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._validate_input_for_mlflow") + def test_handle_mlflow_input_model_package_arn( + self, mock_validate, mock_s3_downloader, mock_initialize, mock_check_mlflow_model + ): + mock_check_mlflow_model.return_value = True + mock_s3_downloader.return_value = ["s3://some_path/MLmodel"] + mock_model_package = {"SourceUri": "s3://bucket/path"} + mock_session.sagemaker_client.describe_model_package.return_value = mock_model_package + + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "arn:aws:sagemaker:us-west-2:000000000000:model-package/test", + "MLFLOW_TRACKING_ARN": "arn:aws:sagemaker:us-west-2:000000000000:mlflow-tracking-server/test", + }, + sagemaker_session=mock_session, + ) + builder._handle_mlflow_input() + mock_initialize.assert_called_once_with(builder, "s3://bucket/path") + + @patch("importlib.util.find_spec", Mock(return_value=True)) + @patch("mlflow.set_tracking_uri") + def test_set_tracking_arn_success(self, mock_set_tracking_uri): + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "arn:aws:sagemaker:us-west-2:000000000000:model-package/test", + } + ) + tracking_arn = "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test" + builder.set_tracking_arn(tracking_arn) + mock_set_tracking_uri.assert_called_once_with(tracking_arn) + assert builder.model_metadata[MLFLOW_TRACKING_ARN] == tracking_arn + + @patch("importlib.util.find_spec", Mock(return_value=False)) + def test_set_tracking_arn_mlflow_not_installed(self): + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "arn:aws:sagemaker:us-west-2:000000000000:model-package/test", + } + ) + tracking_arn = "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test" + self.assertRaisesRegex( + ImportError, + "Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed", + builder.set_tracking_arn, + tracking_arn, + ) diff --git a/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py b/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py index 23d1315647..819800ba46 100644 --- a/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py +++ b/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py @@ -32,7 +32,6 @@ _get_framework_version_from_requirements, _get_deployment_flavor, _get_python_version_from_parsed_mlflow_model_file, - _mlflow_input_is_local_path, _download_s3_artifacts, _select_container_for_mlflow_model, _validate_input_for_mlflow, @@ -197,17 +196,6 @@ def test_get_python_version_from_parsed_mlflow_model_file(): _get_python_version_from_parsed_mlflow_model_file({}) -@patch("os.path.exists") -def test_mlflow_input_is_local_path(mock_path_exists): - valid_path = "/path/to/mlflow_model" - mock_path_exists.side_effect = lambda path: path == valid_path - - assert not _mlflow_input_is_local_path("s3://my_bucket/path/to/model") - assert not _mlflow_input_is_local_path("runs:/run-id/run/relative/path/to/model") - assert not _mlflow_input_is_local_path("/invalid/path") - assert _mlflow_input_is_local_path(valid_path) - - def test_download_s3_artifacts(): pass diff --git a/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py b/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py index caa8884186..183d15d13e 100644 --- a/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py +++ b/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py @@ -13,12 +13,11 @@ from __future__ import absolute_import from unittest import TestCase -from unittest.mock import Mock, PropertyMock, patch, mock_open, call +from unittest.mock import Mock, PropertyMock, patch, mock_open from sagemaker.serve.model_server.djl_serving.prepare import ( _copy_jumpstart_artifacts, _create_dir_structure, - _move_to_code_dir, _extract_js_resource, ) from tests.unit.sagemaker.serve.model_server.constants import ( @@ -31,7 +30,7 @@ MOCK_INVALID_MODEL_DATA_DICT, ) -MOCK_DJL_JUMPSTART_GLOBED_RESOURCES = ["./inference.py", "./serving.properties", "./config.json"] +MOCK_DJL_JUMPSTART_GLOBED_RESOURCES = ["./config.json"] class DjlPrepareTests(TestCase): @@ -69,114 +68,65 @@ def test_create_dir_structure_invalid_path(self, mock_path): self.assertEquals("model_dir is not a valid directory", str(context.exception)) @patch("sagemaker.serve.model_server.djl_serving.prepare.S3Downloader") - @patch("sagemaker.serve.model_server.djl_serving.prepare._tmpdir") - @patch( - "sagemaker.serve.model_server.djl_serving.prepare._read_existing_serving_properties", - return_value={}, - ) - @patch("sagemaker.serve.model_server.djl_serving.prepare._move_to_code_dir") @patch("builtins.open", new_callable=mock_open, read_data="data") @patch("json.load", return_value={}) def test_prepare_djl_js_resources_for_jumpstart_uncompressed_str( self, mock_load, mock_open, - mock_move_to_code_dir, - mock_existing_props, - mock_tmpdir, mock_s3_downloader, ): mock_code_dir = Mock() - mock_config_json_file = Mock() - mock_config_json_file.is_file.return_value = True - mock_code_dir.joinpath.return_value = mock_config_json_file - mock_s3_downloader_obj = Mock() mock_s3_downloader.return_value = mock_s3_downloader_obj - mock_tmpdir_obj = Mock() - mock_js_dir = Mock() - mock_js_dir.return_value = MOCK_TMP_DIR - type(mock_tmpdir_obj).__enter__ = PropertyMock(return_value=mock_js_dir) - type(mock_tmpdir_obj).__exit__ = PropertyMock(return_value=Mock()) - mock_tmpdir.return_value = mock_tmpdir_obj - - existing_properties, hf_model_config, success = _copy_jumpstart_artifacts( + _copy_jumpstart_artifacts( MOCK_UNCOMPRESSED_MODEL_DATA_STR, MOCK_JUMPSTART_ID, mock_code_dir ) mock_s3_downloader_obj.download.assert_called_once_with( - MOCK_UNCOMPRESSED_MODEL_DATA_STR, MOCK_TMP_DIR + MOCK_UNCOMPRESSED_MODEL_DATA_STR, mock_code_dir ) - mock_move_to_code_dir.assert_called_once_with(MOCK_TMP_DIR, mock_code_dir) - mock_code_dir.joinpath.assert_called_once_with("config.json") - self.assertEqual(existing_properties, {}) - self.assertEqual(hf_model_config, {}) - self.assertEqual(success, True) @patch("sagemaker.serve.model_server.djl_serving.prepare.S3Downloader") - @patch("sagemaker.serve.model_server.djl_serving.prepare._tmpdir") - @patch( - "sagemaker.serve.model_server.djl_serving.prepare._read_existing_serving_properties", - return_value={}, - ) - @patch("sagemaker.serve.model_server.djl_serving.prepare._move_to_code_dir") @patch("builtins.open", new_callable=mock_open, read_data="data") @patch("json.load", return_value={}) def test_prepare_djl_js_resources_for_jumpstart_uncompressed_dict( self, mock_load, mock_open, - mock_move_to_code_dir, - mock_existing_props, - mock_tmpdir, mock_s3_downloader, ): mock_code_dir = Mock() - mock_config_json_file = Mock() - mock_config_json_file.is_file.return_value = True - mock_code_dir.joinpath.return_value = mock_config_json_file - mock_s3_downloader_obj = Mock() mock_s3_downloader.return_value = mock_s3_downloader_obj - mock_tmpdir_obj = Mock() - mock_js_dir = Mock() - mock_js_dir.return_value = MOCK_TMP_DIR - type(mock_tmpdir_obj).__enter__ = PropertyMock(return_value=mock_js_dir) - type(mock_tmpdir_obj).__exit__ = PropertyMock(return_value=Mock()) - mock_tmpdir.return_value = mock_tmpdir_obj - - existing_properties, hf_model_config, success = _copy_jumpstart_artifacts( + _copy_jumpstart_artifacts( MOCK_UNCOMPRESSED_MODEL_DATA_DICT, MOCK_JUMPSTART_ID, mock_code_dir ) mock_s3_downloader_obj.download.assert_called_once_with( - MOCK_UNCOMPRESSED_MODEL_DATA_STR_FOR_DICT, MOCK_TMP_DIR + MOCK_UNCOMPRESSED_MODEL_DATA_STR_FOR_DICT, mock_code_dir ) - mock_move_to_code_dir.assert_called_once_with(MOCK_TMP_DIR, mock_code_dir) - mock_code_dir.joinpath.assert_called_once_with("config.json") - self.assertEqual(existing_properties, {}) - self.assertEqual(hf_model_config, {}) - self.assertEqual(success, True) - @patch("sagemaker.serve.model_server.djl_serving.prepare._tmpdir") - @patch("sagemaker.serve.model_server.djl_serving.prepare._move_to_code_dir") + @patch("sagemaker.serve.model_server.djl_serving.prepare.S3Downloader") + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch("json.load", return_value={}) def test_prepare_djl_js_resources_for_jumpstart_invalid_model_data( - self, mock_move_to_code_dir, mock_tmpdir + self, + mock_load, + mock_open, + mock_s3_downloader, ): mock_code_dir = Mock() - mock_tmpdir_obj = Mock() - type(mock_tmpdir_obj).__enter__ = PropertyMock(return_value=Mock()) - type(mock_tmpdir_obj).__exit__ = PropertyMock(return_value=Mock()) - mock_tmpdir.return_value = mock_tmpdir_obj + mock_s3_downloader_obj = Mock() + mock_s3_downloader.return_value = mock_s3_downloader_obj with self.assertRaises(ValueError) as context: _copy_jumpstart_artifacts( MOCK_INVALID_MODEL_DATA_DICT, MOCK_JUMPSTART_ID, mock_code_dir ) - assert not mock_move_to_code_dir.called self.assertTrue( "JumpStart model data compression format is unsupported" in str(context.exception) ) @@ -184,27 +134,17 @@ def test_prepare_djl_js_resources_for_jumpstart_invalid_model_data( @patch("sagemaker.serve.model_server.djl_serving.prepare.S3Downloader") @patch("sagemaker.serve.model_server.djl_serving.prepare._extract_js_resource") @patch("sagemaker.serve.model_server.djl_serving.prepare._tmpdir") - @patch( - "sagemaker.serve.model_server.djl_serving.prepare._read_existing_serving_properties", - return_value={}, - ) - @patch("sagemaker.serve.model_server.djl_serving.prepare._move_to_code_dir") @patch("builtins.open", new_callable=mock_open, read_data="data") @patch("json.load", return_value={}) def test_prepare_djl_js_resources_for_jumpstart_compressed_str( self, mock_load, mock_open, - mock_move_to_code_dir, - mock_existing_props, mock_tmpdir, mock_extract_js_resource, mock_s3_downloader, ): mock_code_dir = Mock() - mock_config_json_file = Mock() - mock_config_json_file.is_file.return_value = True - mock_code_dir.joinpath.return_value = mock_config_json_file mock_s3_downloader_obj = Mock() mock_s3_downloader.return_value = mock_s3_downloader_obj @@ -216,41 +156,14 @@ def test_prepare_djl_js_resources_for_jumpstart_compressed_str( type(mock_tmpdir_obj).__exit__ = PropertyMock(return_value=Mock()) mock_tmpdir.return_value = mock_tmpdir_obj - existing_properties, hf_model_config, success = _copy_jumpstart_artifacts( - MOCK_COMPRESSED_MODEL_DATA_STR, MOCK_JUMPSTART_ID, mock_code_dir - ) + _copy_jumpstart_artifacts(MOCK_COMPRESSED_MODEL_DATA_STR, MOCK_JUMPSTART_ID, mock_code_dir) mock_s3_downloader_obj.download.assert_called_once_with( MOCK_COMPRESSED_MODEL_DATA_STR, MOCK_TMP_DIR ) - mock_extract_js_resource.assert_called_with(MOCK_TMP_DIR, MOCK_JUMPSTART_ID) - mock_move_to_code_dir.assert_called_once_with(MOCK_TMP_DIR, mock_code_dir) - mock_code_dir.joinpath.assert_called_once_with("config.json") - self.assertEqual(existing_properties, {}) - self.assertEqual(hf_model_config, {}) - self.assertEqual(success, True) - - @patch("sagemaker.serve.model_server.djl_serving.prepare.Path") - @patch("sagemaker.serve.model_server.djl_serving.prepare.shutil") - def test_move_to_code_dir_success(self, mock_shutil, mock_path): - mock_path_obj = Mock() - mock_js_model_resources = Mock() - mock_js_model_resources.glob.return_value = MOCK_DJL_JUMPSTART_GLOBED_RESOURCES - mock_path_obj.joinpath.return_value = mock_js_model_resources - mock_path.return_value = mock_path_obj - - mock_js_model_dir = "" - mock_code_dir = Mock() - _move_to_code_dir(mock_js_model_dir, mock_code_dir) - - mock_path_obj.joinpath.assert_called_once_with("model") - - expected_moves = [ - call("./inference.py", mock_code_dir), - call("./serving.properties", mock_code_dir), - call("./config.json", mock_code_dir), - ] - mock_shutil.move.assert_has_calls(expected_moves) + mock_extract_js_resource.assert_called_once_with( + MOCK_TMP_DIR, mock_code_dir, MOCK_JUMPSTART_ID + ) @patch("sagemaker.serve.model_server.djl_serving.prepare.Path") @patch("sagemaker.serve.model_server.djl_serving.prepare.tarfile") @@ -268,8 +181,9 @@ def test_extract_js_resources_success(self, mock_tarfile, mock_path): mock_tarfile.open.return_value = mock_tar_obj js_model_dir = "" - _extract_js_resource(js_model_dir, MOCK_JUMPSTART_ID) + code_dir = Mock() + _extract_js_resource(js_model_dir, code_dir, MOCK_JUMPSTART_ID) mock_path.assert_called_once_with(js_model_dir) mock_path_obj.joinpath.assert_called_once_with(f"infer-prepack-{MOCK_JUMPSTART_ID}.tar.gz") - mock_resource_obj.extractall.assert_called_once_with(path=js_model_dir, filter="data") + mock_resource_obj.extractall.assert_called_once_with(path=code_dir, filter="data") diff --git a/tests/unit/sagemaker/serve/model_server/tei/test_server.py b/tests/unit/sagemaker/serve/model_server/tei/test_server.py index 16dcf12b5a..2344a61fbc 100644 --- a/tests/unit/sagemaker/serve/model_server/tei/test_server.py +++ b/tests/unit/sagemaker/serve/model_server/tei/test_server.py @@ -66,6 +66,7 @@ def test_start_invoke_destroy_local_tei_server(self, mock_requests): volumes={PosixPath("model_path/code"): {"bind": "/opt/ml/model/", "mode": "rw"}}, environment={ "TRANSFORMERS_CACHE": "/opt/ml/model/", + "HF_HOME": "/opt/ml/model/", "HUGGINGFACE_HUB_CACHE": "/opt/ml/model/", "KEY": "VALUE", "SAGEMAKER_SERVE_SECRET_KEY": "secret_key", diff --git a/tests/unit/sagemaker/serve/utils/test_lineage_utils.py b/tests/unit/sagemaker/serve/utils/test_lineage_utils.py index 25e4fe246e..99da766031 100644 --- a/tests/unit/sagemaker/serve/utils/test_lineage_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_lineage_utils.py @@ -14,6 +14,7 @@ from unittest.mock import call +import datetime import pytest from botocore.exceptions import ClientError from mock import Mock, patch @@ -22,6 +23,7 @@ from sagemaker.lineage.query import LineageSourceEnum from sagemaker.serve.utils.lineage_constants import ( + TRACKING_SERVER_CREATION_TIME_FORMAT, MLFLOW_RUN_ID, MLFLOW_MODEL_PACKAGE_PATH, MLFLOW_S3_PATH, @@ -55,7 +57,7 @@ def test_load_artifact_by_source_uri(mock_artifact_list): mock_artifact_list.return_value = mock_artifacts result = _load_artifact_by_source_uri( - source_uri, LineageSourceEnum.MODEL_DATA.value, sagemaker_session + source_uri, sagemaker_session, artifact_type=LineageSourceEnum.MODEL_DATA.value ) mock_artifact_list.assert_called_once_with( @@ -77,7 +79,7 @@ def test_load_artifact_by_source_uri_no_match(mock_artifact_list): mock_artifact_list.return_value = mock_artifacts result = _load_artifact_by_source_uri( - source_uri, LineageSourceEnum.MODEL_DATA.value, sagemaker_session + source_uri, sagemaker_session, artifact_type=LineageSourceEnum.MODEL_DATA.value ) mock_artifact_list.assert_called_once_with( @@ -104,7 +106,7 @@ def test_poll_lineage_artifact_found(mock_load_artifact): assert result == mock_artifact mock_load_artifact.assert_has_calls( [ - call(s3_uri, LineageSourceEnum.MODEL_DATA.value, sagemaker_session), + call(s3_uri, sagemaker_session, artifact_type=LineageSourceEnum.MODEL_DATA.value), ] ) @@ -130,7 +132,7 @@ def test_poll_lineage_artifact_not_found(mock_load_artifact): @pytest.mark.parametrize( "mlflow_model_path, expected_output", [ - ("runs:/abc123", MLFLOW_RUN_ID), + ("runs:/abc123/my-model", MLFLOW_RUN_ID), ("models:/my-model/1", MLFLOW_REGISTRY_PATH), ( "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-model-package", @@ -163,7 +165,8 @@ def test_get_mlflow_model_path_type_invalid(): def test_create_mlflow_model_path_lineage_artifact_success( mock_artifact_create, mock_get_mlflow_path_type ): - mlflow_model_path = "runs:/Ab12Cd34" + mlflow_model_path = "runs:/Ab12Cd34/my-model" + mock_source_types = [dict(SourceIdType="Custom", Value="ModelBuilderInputModelData")] sagemaker_session = Mock(spec=Session) mock_artifact = Mock(spec=Artifact) mock_get_mlflow_path_type.return_value = "mlflow_run_id" @@ -175,6 +178,7 @@ def test_create_mlflow_model_path_lineage_artifact_success( mock_get_mlflow_path_type.assert_called_once_with(mlflow_model_path) mock_artifact_create.assert_called_once_with( source_uri=mlflow_model_path, + source_types=mock_source_types, artifact_type=MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, artifact_name="mlflow_run_id", properties={"model_builder_input_model_data_type": "mlflow_run_id"}, @@ -187,7 +191,7 @@ def test_create_mlflow_model_path_lineage_artifact_success( def test_create_mlflow_model_path_lineage_artifact_validation_exception( mock_artifact_create, mock_get_mlflow_path_type ): - mlflow_model_path = "runs:/Ab12Cd34" + mlflow_model_path = "runs:/Ab12Cd34/my-model" sagemaker_session = Mock(spec=Session) mock_get_mlflow_path_type.return_value = "mlflow_run_id" mock_artifact_create.side_effect = ClientError( @@ -204,7 +208,7 @@ def test_create_mlflow_model_path_lineage_artifact_validation_exception( def test_create_mlflow_model_path_lineage_artifact_other_exception( mock_artifact_create, mock_get_mlflow_path_type ): - mlflow_model_path = "runs:/Ab12Cd34" + mlflow_model_path = "runs:/Ab12Cd34/my-model" sagemaker_session = Mock(spec=Session) mock_get_mlflow_path_type.return_value = "mlflow_run_id" mock_artifact_create.side_effect = ClientError( @@ -220,18 +224,33 @@ def test_create_mlflow_model_path_lineage_artifact_other_exception( def test_retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact_existing( mock_load_artifact, mock_create_artifact ): - mlflow_model_path = "runs:/Ab12Cd34" + mlflow_model_path = "runs:/Ab12Cd34/my-model" + mock_tracking_server_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test" + ) + mock_creation_time = datetime.datetime(2024, 5, 15, 0, 0, 0) sagemaker_session = Mock(spec=Session) + mock_sagemaker_client = Mock() + mock_describe_response = {"CreationTime": mock_creation_time} + mock_sagemaker_client.describe_mlflow_tracking_server.return_value = mock_describe_response + sagemaker_session.sagemaker_client = mock_sagemaker_client + mock_source_types_to_match = [ + "ModelBuilderInputModelData", + mock_tracking_server_arn, + mock_creation_time.strftime(TRACKING_SERVER_CREATION_TIME_FORMAT), + ] mock_artifact_summary = Mock(spec=ArtifactSummary) mock_load_artifact.return_value = mock_artifact_summary result = _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact( - mlflow_model_path, sagemaker_session + mlflow_model_path, sagemaker_session, mock_tracking_server_arn ) assert result == mock_artifact_summary mock_load_artifact.assert_called_once_with( - mlflow_model_path, MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, sagemaker_session + mlflow_model_path, + sagemaker_session, + mock_source_types_to_match, ) mock_create_artifact.assert_not_called() @@ -241,21 +260,38 @@ def test_retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact_exi def test_retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact_create( mock_load_artifact, mock_create_artifact ): - mlflow_model_path = "runs:/Ab12Cd34" + mlflow_model_path = "runs:/Ab12Cd34/my-model" + mock_tracking_server_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test" + ) + mock_creation_time = datetime.datetime(2024, 5, 15, 0, 0, 0) sagemaker_session = Mock(spec=Session) + mock_sagemaker_client = Mock() + mock_describe_response = {"CreationTime": mock_creation_time} + mock_sagemaker_client.describe_mlflow_tracking_server.return_value = mock_describe_response + sagemaker_session.sagemaker_client = mock_sagemaker_client + mock_source_types_to_match = [ + "ModelBuilderInputModelData", + mock_tracking_server_arn, + mock_creation_time.strftime(TRACKING_SERVER_CREATION_TIME_FORMAT), + ] mock_artifact = Mock(spec=Artifact) mock_load_artifact.return_value = None mock_create_artifact.return_value = mock_artifact result = _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact( - mlflow_model_path, sagemaker_session + mlflow_model_path, sagemaker_session, mock_tracking_server_arn ) assert result == mock_artifact mock_load_artifact.assert_called_once_with( - mlflow_model_path, MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, sagemaker_session + mlflow_model_path, + sagemaker_session, + mock_source_types_to_match, + ) + mock_create_artifact.assert_called_once_with( + mlflow_model_path, sagemaker_session, mock_source_types_to_match ) - mock_create_artifact.assert_called_once_with(mlflow_model_path, sagemaker_session) @patch("sagemaker.lineage.association.Association.create") @@ -320,7 +356,10 @@ def test_add_association_between_artifacts_other_exception(mock_association_crea def test_maintain_lineage_tracking_for_mlflow_model_success( mock_add_association, mock_retrieve_create_artifact, mock_poll_artifact ): - mlflow_model_path = "runs:/Ab12Cd34" + mlflow_model_path = "runs:/Ab12Cd34/my-model" + mock_tracking_server_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test" + ) s3_upload_path = "s3://mybucket/path/to/model" sagemaker_session = Mock(spec=Session) mock_model_data_artifact = Mock(spec=ArtifactSummary) @@ -329,7 +368,7 @@ def test_maintain_lineage_tracking_for_mlflow_model_success( mock_retrieve_create_artifact.return_value = mock_mlflow_model_artifact _maintain_lineage_tracking_for_mlflow_model( - mlflow_model_path, s3_upload_path, sagemaker_session + mlflow_model_path, s3_upload_path, sagemaker_session, mock_tracking_server_arn ) mock_poll_artifact.assert_called_once_with( @@ -338,7 +377,9 @@ def test_maintain_lineage_tracking_for_mlflow_model_success( sagemaker_session=sagemaker_session, ) mock_retrieve_create_artifact.assert_called_once_with( - mlflow_model_path=mlflow_model_path, sagemaker_session=sagemaker_session + mlflow_model_path=mlflow_model_path, + tracking_server_arn=mock_tracking_server_arn, + sagemaker_session=sagemaker_session, ) mock_add_association.assert_called_once_with( mlflow_model_path_artifact_arn=mock_mlflow_model_artifact.artifact_arn, @@ -355,14 +396,17 @@ def test_maintain_lineage_tracking_for_mlflow_model_success( def test_maintain_lineage_tracking_for_mlflow_model_no_model_data_artifact( mock_add_association, mock_retrieve_create_artifact, mock_poll_artifact ): - mlflow_model_path = "runs:/Ab12Cd34" + mlflow_model_path = "runs:/Ab12Cd34/my-model" + mock_tracking_server_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test" + ) s3_upload_path = "s3://mybucket/path/to/model" sagemaker_session = Mock(spec=Session) mock_poll_artifact.return_value = None mock_retrieve_create_artifact.return_value = None _maintain_lineage_tracking_for_mlflow_model( - mlflow_model_path, s3_upload_path, sagemaker_session + mlflow_model_path, s3_upload_path, sagemaker_session, mock_tracking_server_arn ) mock_poll_artifact.assert_called_once_with( diff --git a/tests/unit/sagemaker/telemetry/test_telemetry_logging.py b/tests/unit/sagemaker/telemetry/test_telemetry_logging.py new file mode 100644 index 0000000000..9107256b5b --- /dev/null +++ b/tests/unit/sagemaker/telemetry/test_telemetry_logging.py @@ -0,0 +1,302 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import unittest +import pytest +import requests +from unittest.mock import Mock, patch, MagicMock +import boto3 +import sagemaker +from sagemaker.telemetry.constants import Feature +from sagemaker.telemetry.telemetry_logging import ( + _send_telemetry_request, + _telemetry_emitter, + _construct_url, + _get_accountId, + _requests_helper, + _get_region_or_default, + _get_default_sagemaker_session, + OS_NAME_VERSION, + PYTHON_VERSION, +) +from sagemaker.user_agent import SDK_VERSION, process_studio_metadata_file +from sagemaker.serve.utils.exceptions import ModelBuilderException, LocalModelOutOfMemoryException + +MOCK_SESSION = Mock() +MOCK_EXCEPTION = LocalModelOutOfMemoryException("mock raise ex") +MOCK_FEATURE = Feature.SDK_DEFAULTS +MOCK_FUNC_NAME = "Mock.local_session.create_model" +MOCK_ENDPOINT_ARN = "arn:aws:sagemaker:us-west-2:123456789012:endpoint/test" + + +class LocalSagemakerClientMock: + def __init__(self): + self.sagemaker_session = MOCK_SESSION + + @_telemetry_emitter(MOCK_FEATURE, MOCK_FUNC_NAME) + def mock_create_model(self, mock_exception_func=None): + if mock_exception_func: + mock_exception_func() + + +class TestTelemetryLogging(unittest.TestCase): + @patch("sagemaker.telemetry.telemetry_logging._requests_helper") + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + def test_log_sucessfully(self, mock_get_accountId, mock_request_helper): + """Test to check if the telemetry logging is successful""" + MOCK_SESSION.boto_session.region_name = "us-west-2" + mock_get_accountId.return_value = "testAccountId" + _send_telemetry_request("someStatus", "1", MOCK_SESSION) + mock_request_helper.assert_called_with( + "https://sm-pysdk-t-us-west-2.s3.us-west-2.amazonaws.com/" + "telemetry?x-accountId=testAccountId&x-status=someStatus&x-feature=1", + 2, + ) + + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + def test_log_handle_exception(self, mock_get_accountId): + """Test to check if the exception is handled while logging telemetry""" + mock_get_accountId.side_effect = Exception("Internal error") + _send_telemetry_request("someStatus", "1", MOCK_SESSION) + self.assertRaises(Exception) + + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + @patch("sagemaker.telemetry.telemetry_logging._get_region_or_default") + def test_send_telemetry_request_success(self, mock_get_region, mock_get_accountId): + """Test to check the _send_telemetry_request function with success status""" + mock_get_accountId.return_value = "testAccountId" + mock_get_region.return_value = "us-west-2" + + with patch( + "sagemaker.telemetry.telemetry_logging._requests_helper" + ) as mock_requests_helper: + mock_requests_helper.return_value = None + _send_telemetry_request(1, [1, 2], MagicMock(), None, None, "extra_info") + mock_requests_helper.assert_called_with( + "https://sm-pysdk-t-us-west-2.s3.us-west-2.amazonaws.com/" + "telemetry?x-accountId=testAccountId&x-status=1&x-feature=1,2&x-extra=extra_info", + 2, + ) + + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + @patch("sagemaker.telemetry.telemetry_logging._get_region_or_default") + def test_send_telemetry_request_failure(self, mock_get_region, mock_get_accountId): + """Test to check the _send_telemetry_request function with failure status""" + mock_get_accountId.return_value = "testAccountId" + mock_get_region.return_value = "us-west-2" + + with patch( + "sagemaker.telemetry.telemetry_logging._requests_helper" + ) as mock_requests_helper: + mock_requests_helper.return_value = None + _send_telemetry_request( + 0, [1, 2], MagicMock(), "failure_reason", "failure_type", "extra_info" + ) + mock_requests_helper.assert_called_with( + "https://sm-pysdk-t-us-west-2.s3.us-west-2.amazonaws.com/" + "telemetry?x-accountId=testAccountId&x-status=0&x-feature=1,2" + "&x-failureReason=failure_reason&x-failureType=failure_type&x-extra=extra_info", + 2, + ) + + @patch("sagemaker.telemetry.telemetry_logging._send_telemetry_request") + @patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config") + def test_telemetry_emitter_decorator_no_call_when_disabled( + self, mock_resolve_config, mock_send_telemetry_request + ): + """Test to check if the _telemetry_emitter decorator is not called when telemetry is disabled""" + mock_resolve_config.return_value = True + + assert not mock_send_telemetry_request.called + + @patch("sagemaker.telemetry.telemetry_logging._send_telemetry_request") + @patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config") + def test_telemetry_emitter_decorator_success( + self, mock_resolve_config, mock_send_telemetry_request + ): + """Test to verify the _telemetry_emitter decorator with success status""" + mock_resolve_config.return_value = False + mock_local_client = LocalSagemakerClientMock() + mock_local_client.sagemaker_session.endpoint_arn = MOCK_ENDPOINT_ARN + mock_local_client.mock_create_model() + app_type = process_studio_metadata_file() + + args = mock_send_telemetry_request.call_args.args + latency = str(args[5]).split("latency=")[1] + expected_extra_str = ( + f"{MOCK_FUNC_NAME}" + f"&x-sdkVersion={SDK_VERSION}" + f"&x-env={PYTHON_VERSION}" + f"&x-sys={OS_NAME_VERSION}" + f"&x-platform={app_type}" + f"&x-endpointArn={MOCK_ENDPOINT_ARN}" + f"&x-latency={latency}" + ) + + mock_send_telemetry_request.assert_called_once_with( + 1, [1, 2], MOCK_SESSION, None, None, expected_extra_str + ) + + @patch("sagemaker.telemetry.telemetry_logging._send_telemetry_request") + @patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config") + def test_telemetry_emitter_decorator_handle_exception_success( + self, mock_resolve_config, mock_send_telemetry_request + ): + """Test to verify the _telemetry_emitter decorator when function emits exception""" + mock_resolve_config.return_value = False + mock_local_client = LocalSagemakerClientMock() + mock_local_client.sagemaker_session.endpoint_arn = MOCK_ENDPOINT_ARN + app_type = process_studio_metadata_file() + + mock_exception = Mock() + mock_exception_obj = MOCK_EXCEPTION + mock_exception.side_effect = mock_exception_obj + + with self.assertRaises(ModelBuilderException) as _: + mock_local_client.mock_create_model(mock_exception) + + args = mock_send_telemetry_request.call_args.args + latency = str(args[5]).split("latency=")[1] + expected_extra_str = ( + f"{MOCK_FUNC_NAME}" + f"&x-sdkVersion={SDK_VERSION}" + f"&x-env={PYTHON_VERSION}" + f"&x-sys={OS_NAME_VERSION}" + f"&x-platform={app_type}" + f"&x-endpointArn={MOCK_ENDPOINT_ARN}" + f"&x-latency={latency}" + ) + + mock_send_telemetry_request.assert_called_once_with( + 0, + [1, 2], + MOCK_SESSION, + str(mock_exception_obj), + mock_exception_obj.__class__.__name__, + expected_extra_str, + ) + + def test_construct_url_with_failure_reason_and_extra_info(self): + """Test to verify the _construct_url function with failure reason and extra info""" + mock_accountId = "testAccountId" + mock_status = 0 + mock_feature = "1,2" + mock_failure_reason = str(MOCK_EXCEPTION) + mock_failure_type = MOCK_EXCEPTION.__class__.__name__ + mock_extra_info = "mock_extra_info" + mock_region = "us-west-2" + + resulted_url = _construct_url( + accountId=mock_accountId, + region=mock_region, + status=mock_status, + feature=mock_feature, + failure_reason=mock_failure_reason, + failure_type=mock_failure_type, + extra_info=mock_extra_info, + ) + + expected_base_url = ( + f"https://sm-pysdk-t-{mock_region}.s3.{mock_region}.amazonaws.com/telemetry?" + f"x-accountId={mock_accountId}" + f"&x-status={mock_status}" + f"&x-feature={mock_feature}" + f"&x-failureReason={mock_failure_reason}" + f"&x-failureType={mock_failure_type}" + f"&x-extra={mock_extra_info}" + ) + self.assertEqual(resulted_url, expected_base_url) + + @patch("sagemaker.telemetry.telemetry_logging.requests.get") + def test_requests_helper_success(self, mock_requests_get): + """Test to verify the _requests_helper function with success status""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_requests_get.return_value = mock_response + url = "https://example.com" + timeout = 10 + + response = _requests_helper(url, timeout) + + mock_requests_get.assert_called_once_with(url, timeout) + self.assertEqual(response, mock_response) + + @patch("sagemaker.telemetry.telemetry_logging.requests.get") + def test_requests_helper_exception(self, mock_requests_get): + """Test to verify the _requests_helper function with exception""" + mock_requests_get.side_effect = requests.exceptions.RequestException("Error making request") + url = "https://example.com" + timeout = 10 + + response = _requests_helper(url, timeout) + + mock_requests_get.assert_called_once_with(url, timeout) + self.assertIsNone(response) + + def test_get_accountId_success(self): + """Test to verify the _get_accountId function with success status""" + boto_mock = MagicMock(name="boto_session") + boto_mock.client("sts").get_caller_identity.return_value = {"Account": "testAccountId"} + session = sagemaker.Session(boto_session=boto_mock) + account_id = _get_accountId(session) + + self.assertEqual(account_id, "testAccountId") + + def test_get_accountId_exception(self): + """Test to verify the _get_accountId function with exception""" + sts_client_mock = MagicMock() + sts_client_mock.side_effect = Exception("Error creating STS client") + boto_mock = MagicMock(name="boto_session") + boto_mock.client("sts").get_caller_identity.return_value = sts_client_mock + session = sagemaker.Session(boto_session=boto_mock) + + with pytest.raises(Exception) as exception: + account_id = _get_accountId(session) + assert account_id is None + assert "Error creating STS client" in str(exception) + + def test_get_region_or_default_success(self): + """Test to verify the _get_region_or_default function with success status""" + mock_session = MagicMock() + mock_session.boto_session = MagicMock(region_name="us-east-1") + + region = _get_region_or_default(mock_session) + + assert region == "us-east-1" + + def test_get_region_or_default_exception(self): + """Test to verify the _get_region_or_default function with exception""" + mock_session = MagicMock() + mock_session.boto_session = MagicMock() + mock_session.boto_session.region_name.side_effect = Exception("Error creating boto session") + + with pytest.raises(Exception) as exception: + region = _get_region_or_default(mock_session) + assert region == "us-west-2" + assert "Error creating boto session" in str(exception) + + @patch.object(boto3.Session, "region_name", "us-west-2") + def test_get_default_sagemaker_session(self): + sagemaker_session = _get_default_sagemaker_session() + + assert isinstance(sagemaker_session, sagemaker.Session) is True + assert sagemaker_session.boto_session.region_name == "us-west-2" + + @patch.object(boto3.Session, "region_name", None) + def test_get_default_sagemaker_session_with_no_region(self): + with self.assertRaises(ValueError) as context: + _get_default_sagemaker_session() + + assert "Must setup local AWS configuration with a region supported by SageMaker." in str( + context.exception + ) diff --git a/tests/unit/test_default_bucket.py b/tests/unit/test_default_bucket.py index 9f0d68f01d..6ce4b50c75 100644 --- a/tests/unit/test_default_bucket.py +++ b/tests/unit/test_default_bucket.py @@ -13,6 +13,8 @@ from __future__ import absolute_import import datetime +from unittest.mock import Mock + import pytest from botocore.exceptions import ClientError from mock import MagicMock @@ -42,8 +44,14 @@ def test_default_bucket_s3_create_call(sagemaker_session): error_response={"Error": {"Code": "404", "Message": "Not Found"}}, operation_name="foo", ) - sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error - bucket_name = sagemaker_session.default_bucket() + sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = Mock( + side_effect=error + ) + + try: + bucket_name = sagemaker_session.default_bucket() + except ClientError: + pass create_calls = sagemaker_session.boto_session.resource().create_bucket.mock_calls _1, _2, create_kwargs = create_calls[0] @@ -53,7 +61,6 @@ def test_default_bucket_s3_create_call(sagemaker_session): "CreateBucketConfiguration": {"LocationConstraint": "us-west-2"}, "Bucket": bucket_name, } - assert sagemaker_session._default_bucket == bucket_name def test_default_bucket_s3_needs_access(sagemaker_session, caplog): diff --git a/tests/unit/test_djl_inference.py b/tests/unit/test_djl_inference.py index cc8a99cf1c..6b0f5a6f92 100644 --- a/tests/unit/test_djl_inference.py +++ b/tests/unit/test_djl_inference.py @@ -12,42 +12,25 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -import logging - -import json -from json import JSONDecodeError - import pytest -from mock import Mock, MagicMock -from mock import patch, mock_open +from mock import Mock from sagemaker.djl_inference import ( - defaults, DJLModel, - DJLPredictor, - HuggingFaceAccelerateModel, - DeepSpeedModel, ) -from sagemaker.djl_inference.model import DJLServingEngineEntryPointDefaults -from sagemaker.s3_utils import s3_path_join from sagemaker.session_settings import SessionSettings -from tests.unit import ( - _test_default_bucket_and_prefix_combinations, - DEFAULT_S3_BUCKET_NAME, - DEFAULT_S3_OBJECT_KEY_PREFIX_NAME, -) +from sagemaker import image_uris VALID_UNCOMPRESSED_MODEL_DATA = "s3://mybucket/model" -INVALID_UNCOMPRESSED_MODEL_DATA = "s3://mybucket/model.tar.gz" +VALID_COMPRESSED_MODEL_DATA = "s3://mybucket/model.tar.gz" HF_MODEL_ID = "hf_hub_model_id" -ENTRY_POINT = "entrypoint.py" -SOURCE_DIR = "source_dir/" -ENV = {"ENV_VAR": "env_value"} ROLE = "dummy_role" REGION = "us-west-2" -BUCKET = "mybucket" -IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazon.com/djl-inference:0.24.0-deepspeed0.10.0-cu118" -GPU_INSTANCE = "ml.g5.12xlarge" +VERSION = "0.28.0" + +LMI_IMAGE_URI = image_uris.retrieve(framework="djl-lmi", version=VERSION, region=REGION) +TRT_IMAGE_URI = image_uris.retrieve(framework="djl-tensorrtllm", version=VERSION, region=REGION) +TNX_IMAGE_URI = image_uris.retrieve(framework="djl-neuronx", version=VERSION, region=REGION) @pytest.fixture() @@ -66,756 +49,134 @@ def sagemaker_session(): endpoint_from_production_variants=Mock(name="endpoint_from_production_variants"), default_bucket_prefix=None, ) - session.default_bucket = Mock(name="default_bucket", return_value=BUCKET) + session.default_bucket = Mock(name="default_bucket", return_value="bucket") # For tests which doesn't verify config file injection, operate with empty config session.sagemaker_config = {} return session -def test_create_model_invalid_s3_uri(): - with pytest.raises(ValueError) as invalid_s3_data: - _ = DJLModel( - INVALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - ) - assert str(invalid_s3_data.value).startswith( - "DJLModel does not support model artifacts in tar.gz" - ) - - -@patch("urllib.request.urlopen") -def test_create_model_valid_hf_hub_model_id( - mock_urlopen, - sagemaker_session, -): - model_config = { - "model_type": "opt", - "num_attention_heads": 4, - } - - cm = MagicMock() - cm.getcode.return_value = 200 - cm.read.return_value = json.dumps(model_config).encode("utf-8") - cm.__enter__.return_value = cm - mock_urlopen.return_value = cm +def test_create_djl_model_only_model_id(sagemaker_session): model = DJLModel( - HF_MODEL_ID, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=4, - ) - assert model.engine == DJLServingEngineEntryPointDefaults.DEEPSPEED - expected_url = f"https://huggingface.co/{HF_MODEL_ID}/raw/main/config.json" - mock_urlopen.assert_any_call(expected_url) - - serving_properties = model.generate_serving_properties() - assert serving_properties["option.model_id"] == HF_MODEL_ID - - -@patch("json.load") -@patch("urllib.request.urlopen") -def test_create_model_invalid_hf_hub_model_id( - mock_urlopen, - json_load, - sagemaker_session, -): - expected_url = f"https://huggingface.co/{HF_MODEL_ID}/raw/main/config.json" - with pytest.raises(ValueError) as invalid_model_id: - cm = MagicMock() - cm.__enter__.return_value = cm - mock_urlopen.return_value = cm - json_load.side_effect = JSONDecodeError("", "", 0) - _ = DJLModel( - HF_MODEL_ID, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=4, - ) - mock_urlopen.assert_any_call(expected_url) - assert str(invalid_model_id.value).startswith( - "Did not find a config.json or model_index.json file in huggingface hub" - ) - - -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -def test_create_model_automatic_engine_selection(mock_s3_list, mock_read_file, sagemaker_session): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - hf_model_config = { - "model_type": "t5", - "num_attention_heads": 4, - } - mock_read_file.return_value = json.dumps(hf_model_config) - hf_model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, + model_id=VALID_UNCOMPRESSED_MODEL_DATA, sagemaker_session=sagemaker_session, - number_of_partitions=4, + role=ROLE, ) - assert hf_model.engine == DJLServingEngineEntryPointDefaults.FASTER_TRANSFORMER - - hf_model_config = { - "model_type": "gpt2", - "num_attention_heads": 25, - } - mock_read_file.return_value = json.dumps(hf_model_config) - hf_model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=4, - ) - assert hf_model.engine == DJLServingEngineEntryPointDefaults.HUGGINGFACE_ACCELERATE - - for model_type in defaults.DEEPSPEED_RECOMMENDED_ARCHITECTURES: - ds_model_config = { - "model_type": model_type, - "num_attention_heads": 12, - } - mock_read_file.return_value = json.dumps(ds_model_config) - ds_model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=2, - ) - mock_s3_list.assert_any_call( - VALID_UNCOMPRESSED_MODEL_DATA, sagemaker_session=sagemaker_session - ) - if model_type == defaults.STABLE_DIFFUSION_MODEL_TYPE: - assert ds_model.engine == DJLServingEngineEntryPointDefaults.STABLE_DIFFUSION - else: - assert ds_model.engine == DJLServingEngineEntryPointDefaults.DEEPSPEED + assert model.engine == "Python" + assert model.image_uri == LMI_IMAGE_URI + assert model.env == {"HF_MODEL_ID": VALID_UNCOMPRESSED_MODEL_DATA, "OPTION_ENGINE": "Python"} -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -def test_create_deepspeed_model(mock_s3_list, mock_read_file, sagemaker_session): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - ds_model_config = { - "model_type": "opt", - "n_head": 12, - } - mock_read_file.return_value = json.dumps(ds_model_config) - ds_model = DeepSpeedModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - tensor_parallel_degree=4, - ) - assert ds_model.engine == DJLServingEngineEntryPointDefaults.DEEPSPEED - - ds_model_config = { - "model_type": "opt", - "n_head": 25, - } - mock_read_file.return_value = json.dumps(ds_model_config) - with pytest.raises(ValueError) as invalid_partitions: - _ = DeepSpeedModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - tensor_parallel_degree=4, - ) - assert str(invalid_partitions.value).startswith("The number of attention heads is not evenly") - - -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -def test_create_huggingface_model(mock_s3_list, mock_read_file, sagemaker_session): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - hf_model_config = { - "model_type": "opt", - "n_head": 12, - } - mock_read_file.return_value = json.dumps(hf_model_config) - hf_model = HuggingFaceAccelerateModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=4, - ) - assert hf_model.engine == DJLServingEngineEntryPointDefaults.HUGGINGFACE_ACCELERATE - - hf_model_config = { - "model_type": "t5", - "n_head": 13, - } - mock_read_file.return_value = json.dumps(hf_model_config) - hf_model = HuggingFaceAccelerateModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=4, - ) - assert hf_model.engine == DJLServingEngineEntryPointDefaults.HUGGINGFACE_ACCELERATE - - -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -def test_model_unsupported_methods(mock_s3_list, mock_read_file, sagemaker_session): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - model_config = { - "model_type": "opt", - "n_head": 12, - } - mock_read_file.return_value = json.dumps(model_config) +def test_create_djl_model_only_model_data(sagemaker_session): model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, + model_data={ + "S3DataSource": { + "S3Uri": VALID_COMPRESSED_MODEL_DATA, + "S3DataType": "S3Object", + "CompressionType": "Gzip", + } + }, sagemaker_session=sagemaker_session, + role=ROLE, ) + assert model.engine == "Python" + assert model.image_uri == LMI_IMAGE_URI + assert model.env == {"OPTION_ENGINE": "Python"} - with pytest.raises(NotImplementedError) as invalid_method: - model.package_for_edge() - assert str(invalid_method.value).startswith("DJLModels do not support Sagemaker Edge") - - with pytest.raises(NotImplementedError) as invalid_method: - model.compile() - assert str(invalid_method.value).startswith( - "DJLModels do not currently support compilation with SageMaker Neo" - ) - - with pytest.raises(NotImplementedError) as invalid_method: - model.transformer() - assert str(invalid_method.value).startswith( - "DJLModels do not currently support Batch Transform inference jobs" - ) - -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -def test_deploy_base_model_invalid_instance(mock_s3_list, mock_read_file, sagemaker_session): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - model_config = { - "model_type": "gpt-neox", - "n_head": 25, - } - mock_read_file.return_value = json.dumps(model_config) +def test_create_djl_model_with_task(sagemaker_session): model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, + model_id=VALID_UNCOMPRESSED_MODEL_DATA, sagemaker_session=sagemaker_session, - number_of_partitions=4, - ) - - with pytest.raises(ValueError) as invalid_instance: - _ = model.deploy("ml.m5.12xlarge") - assert str(invalid_instance.value).startswith("Invalid instance type. DJLModels only support") - - -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -def test_generate_deepspeed_serving_properties_invalid_configurations( - mock_s3_list, mock_read_file, sagemaker_session -): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - model_config = { - "model_type": "bert", - "n_head": 4, - } - mock_read_file.return_value = json.dumps(model_config) - model = DeepSpeedModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - tensor_parallel_degree=4, - enable_cuda_graph=True, - ) - with pytest.raises(ValueError) as invalid_config: - _ = model.generate_serving_properties() - assert str(invalid_config.value).startswith("enable_cuda_graph is not supported") - - -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -def test_generate_huggingface_serving_properties_invalid_configurations( - mock_s3_list, mock_read_file, sagemaker_session -): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - model_config = { - "model_type": "t5", - "n_head": 4, - } - mock_read_file.return_value = json.dumps(model_config) - model = HuggingFaceAccelerateModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - dtype="fp16", - load_in_8bit=True, - ) - with pytest.raises(ValueError) as invalid_config: - _ = model.generate_serving_properties() - assert str(invalid_config.value).startswith("Set dtype='int8' to use load_in_8bit") - - model = HuggingFaceAccelerateModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=2, - device_id=1, - ) - with pytest.raises(ValueError) as invalid_config: - _ = model.generate_serving_properties() - assert str(invalid_config.value).startswith( - "device_id cannot be set when number_of_partitions is > 1" - ) - - -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -def test_generate_serving_properties_with_valid_configurations( - mock_s3_list, mock_read_file, sagemaker_session -): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - model_config = { - "model_type": "gpt-neox", - "n_head": 25, - } - mock_read_file.return_value = json.dumps(model_config) - model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=4, - min_workers=1, - max_workers=3, - job_queue_size=4, - dtype="fp16", - parallel_loading=True, - model_loading_timeout=120, - prediction_timeout=4, - source_dir=SOURCE_DIR, - entry_point=ENTRY_POINT, - task="text-classification", - ) - serving_properties = model.generate_serving_properties() - expected_dict = { - "engine": "Python", - "option.entryPoint": ENTRY_POINT, - "option.model_id": VALID_UNCOMPRESSED_MODEL_DATA, - "option.tensor_parallel_degree": 4, - "option.task": "text-classification", - "option.dtype": "fp16", - "minWorkers": 1, - "maxWorkers": 3, - "job_queue_size": 4, - "option.parallel_loading": True, - "option.model_loading_timeout": 120, - "option.prediction_timeout": 4, - } - assert serving_properties == expected_dict - serving_properties.clear() - expected_dict.clear() - - model_config = { - "model_type": "opt", - "n_head": 4, - } - mock_read_file.return_value = json.dumps(model_config) - model = DeepSpeedModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - tensor_parallel_degree=1, + role=ROLE, task="text-generation", - dtype="bf16", - max_tokens=2048, - low_cpu_mem_usage=True, - enable_cuda_graph=True, ) - serving_properties = model.generate_serving_properties() - expected_dict = { - "engine": "DeepSpeed", - "option.entryPoint": "djl_python.deepspeed", - "option.model_id": VALID_UNCOMPRESSED_MODEL_DATA, - "option.tensor_parallel_degree": 1, - "option.task": "text-generation", - "option.dtype": "bf16", - "option.max_tokens": 2048, - "option.enable_cuda_graph": True, - "option.low_cpu_mem_usage": True, - "option.triangular_masking": True, - "option.return_tuple": True, + assert model.engine == "Python" + assert model.image_uri == LMI_IMAGE_URI + assert model.env == { + "HF_MODEL_ID": VALID_UNCOMPRESSED_MODEL_DATA, + "OPTION_ENGINE": "Python", + "HF_TASK": "text-generation", } - assert serving_properties == expected_dict - serving_properties.clear() - expected_dict.clear() - model = HuggingFaceAccelerateModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=1, - device_id=4, - device_map="balanced", - dtype="fp32", - low_cpu_mem_usage=False, - ) - serving_properties = model.generate_serving_properties() - expected_dict = { - "engine": "Python", - "option.entryPoint": "djl_python.huggingface", - "option.model_id": VALID_UNCOMPRESSED_MODEL_DATA, - "option.tensor_parallel_degree": 1, - "option.dtype": "fp32", - "option.device_id": 4, - "option.device_map": "balanced", - } - assert serving_properties == expected_dict - - -@patch("sagemaker.image_uris.retrieve", return_value=IMAGE_URI) -@patch("shutil.rmtree") -@patch("sagemaker.utils.base_name_from_image") -@patch("tempfile.mkdtemp") -@patch("sagemaker.container_def") -@patch("sagemaker.utils._tmpdir") -@patch("sagemaker.utils._create_or_update_code_dir") -@patch("sagemaker.fw_utils.tar_and_upload_dir") -@patch("os.mkdir") -@patch("os.path.exists") -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -def test_deploy_model_no_local_code( - mock_s3_list, - mock_read_file, - mock_path_exists, - mock_mkdir, - mock_tar_upload, - mock_create_code_dir, - mock_tmpdir, - mock_container_def, - mock_mktmp, - mock_name_from_base, - mock_shutil_rmtree, - mock_imguri_retrieve, - sagemaker_session, -): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - model_config = { - "model_type": "bloom", - "n_heads": 120, - } - mock_read_file.return_value = json.dumps(model_config) model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, + model_id=HF_MODEL_ID, sagemaker_session=sagemaker_session, - number_of_partitions=4, - dtype="fp16", - container_log_level=logging.DEBUG, - env=ENV, + role=ROLE, + task="text-embedding", ) - - assert model.image_uri is None - - mock_path_exists.side_effect = [True, False, True] - mock_mktmp.return_value = "/tmp/dir" - mock_tar_upload.return_value = Mock(s3_prefix="s3prefix") - expected_env = {"ENV_VAR": "env_value", "SERVING_OPTS": '"-Dai.djl.logging.level=debug"'} - with patch("builtins.open", mock_open()) as fake_serving_properties: - predictor = model.deploy(GPU_INSTANCE) - - assert isinstance(predictor, DJLPredictor) - mock_mktmp.assert_called_once_with(prefix="tmp", suffix="", dir=None) - mock_mkdir.assert_called() - assert fake_serving_properties.call_count == 2 - fake_serving_properties.assert_any_call("/tmp/dir/code/serving.properties", "w+") - fake_serving_properties.assert_any_call("/tmp/dir/code/serving.properties", "r") - model.sagemaker_session.create_model.assert_called_once() - mock_container_def.assert_called_once_with( - IMAGE_URI, model_data_url="s3prefix", env=expected_env - ) - - -@patch("sagemaker.image_uris.retrieve", return_value=IMAGE_URI) -@patch("shutil.rmtree") -@patch("sagemaker.utils.base_name_from_image") -@patch("tempfile.mkdtemp") -@patch("sagemaker.container_def") -@patch("sagemaker.utils._tmpdir") -@patch("sagemaker.utils._create_or_update_code_dir") -@patch("os.mkdir") -@patch("os.path.exists") -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -@patch("sagemaker.s3.S3Uploader.upload") -@patch("sagemaker.estimator.Estimator.fit") -@patch("sagemaker.fw_utils.model_code_key_prefix") -@patch("os.path.isfile") -@patch("boto3.client") -def test_partition( - mock_client, - mock_is_file, - mock_model_key_prefix, - mock_estimator_fit, - mock_upload, - mock_s3_list, - mock_read_file, - mock_path_exists, - mock_mkdir, - mock_create_code_dir, - mock_tmpdir, - mock_container_def, - mock_mktmp, - mock_name_from_base, - mock_shutil_rmtree, - mock_imguri_retrieve, - sagemaker_session, -): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - model_config = { - "model_type": "bloom", - "n_heads": 120, + assert model.engine == "OnnxRuntime" + assert model.image_uri == LMI_IMAGE_URI + assert model.env == { + "HF_MODEL_ID": HF_MODEL_ID, + "OPTION_ENGINE": "OnnxRuntime", + "HF_TASK": "text-embedding", } - mock_read_file.return_value = json.dumps(model_config) - model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=4, - data_type="fp16", - container_log_level=logging.DEBUG, - env=ENV, - ) - - assert model.image_uri is None - mock_is_file.return_value = False - mock_path_exists.side_effect = [True, False, True] - mock_mktmp.return_value = "/tmp/dir" - expected_env = {"ENV_VAR": "env_value", "SERVING_OPTS": '"-Dai.djl.logging.level=debug"'} - mock_upload.return_value = "s3prefix" - - s3_output_uri = f"s3://{BUCKET}/partitions/" - mock_model_key_prefix.return_value = "s3prefix" - with patch("builtins.open", mock_open()) as fake_serving_properties: - model.partition(GPU_INSTANCE, s3_output_uri) - - mock_mktmp.assert_called_once_with(prefix="tmp", suffix="", dir=None) - mock_mkdir.assert_called() - assert fake_serving_properties.call_count == 2 - fake_serving_properties.assert_any_call("/tmp/dir/code/serving.properties", "w+") - fake_serving_properties.assert_any_call("/tmp/dir/code/serving.properties", "r") - mock_container_def.assert_called_once_with( - IMAGE_URI, model_data_url="s3prefix", env=expected_env - ) - - assert model.model_id == f"{s3_output_uri}aot-partitioned-checkpoints" - -@patch("sagemaker.djl_inference.model.fw_utils.model_code_key_prefix") -@patch("sagemaker.djl_inference.model._get_model_config_properties_from_s3") -@patch("sagemaker.djl_inference.model.fw_utils.tar_and_upload_dir") -def test__upload_model_to_s3__with_upload_as_tar__default_bucket_and_prefix_combinations( - tar_and_upload_dir, - _get_model_config_properties_from_s3, - model_code_key_prefix, -): - # Skip appending of timestamps that this normally does - model_code_key_prefix.side_effect = lambda a, b, c: s3_path_join(a, b, c) - def with_user_input(sess): +def test_create_djl_model_with_provided_image(sagemaker_session): + for img_uri in [LMI_IMAGE_URI, TRT_IMAGE_URI, TNX_IMAGE_URI]: model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sess, - number_of_partitions=4, - data_type="fp16", - container_log_level=logging.DEBUG, - env=ENV, - code_location="s3://test-bucket/test-prefix/test-prefix-2", - image_uri="image_uri", - ) - model._upload_model_to_s3(upload_as_tar=True) - args = tar_and_upload_dir.call_args.args - return "s3://%s/%s" % (args[1], args[2]) - - def without_user_input(sess): - model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sess, - number_of_partitions=4, - data_type="fp16", - container_log_level=logging.DEBUG, - env=ENV, - image_uri="image_uri", + model_id=VALID_UNCOMPRESSED_MODEL_DATA, + sagemaker_session=sagemaker_session, + role=ROLE, + image_uri=img_uri, ) - model._upload_model_to_s3(upload_as_tar=True) - args = tar_and_upload_dir.call_args.args - return "s3://%s/%s" % (args[1], args[2]) - - actual, expected = _test_default_bucket_and_prefix_combinations( - function_with_user_input=with_user_input, - function_without_user_input=without_user_input, - expected__without_user_input__with_default_bucket_and_default_prefix=( - f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/image_uri" - ), - expected__without_user_input__with_default_bucket_only=( - f"s3://{DEFAULT_S3_BUCKET_NAME}/image_uri" - ), - expected__with_user_input__with_default_bucket_and_prefix=( - "s3://test-bucket/test-prefix/test-prefix-2/image_uri" - ), - expected__with_user_input__with_default_bucket_only=( - "s3://test-bucket/test-prefix/test-prefix-2/image_uri" - ), - ) - assert actual == expected - - -@patch("sagemaker.djl_inference.model.fw_utils.model_code_key_prefix") -@patch("sagemaker.djl_inference.model._get_model_config_properties_from_s3") -@patch("sagemaker.djl_inference.model.S3Uploader.upload") -def test__upload_model_to_s3__without_upload_as_tar__default_bucket_and_prefix_combinations( - upload, - _get_model_config_properties_from_s3, - model_code_key_prefix, -): - """This test is similar to test__upload_model_to_s3__with_upload_as_tar__default_bucket_and_prefix_combinations - - except upload_as_tar is False and S3Uploader.upload is checked - """ - - # Skip appending of timestamps that this normally does - model_code_key_prefix.side_effect = lambda a, b, c: s3_path_join(a, b, c) + assert model.engine == "Python" + assert model.image_uri == img_uri + assert model.env == { + "HF_MODEL_ID": VALID_UNCOMPRESSED_MODEL_DATA, + "OPTION_ENGINE": "Python", + } - def with_user_input(sess): + for framework in ["djl-lmi", "djl-tensorrtllm", "djl-neuronx"]: model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sess, - number_of_partitions=4, - data_type="fp16", - container_log_level=logging.DEBUG, - env=ENV, - code_location="s3://test-bucket/test-prefix/test-prefix-2", - image_uri="image_uri", + model_id=VALID_UNCOMPRESSED_MODEL_DATA, + sagemaker_session=sagemaker_session, + role=ROLE, + djl_framework=framework, ) - model._upload_model_to_s3(upload_as_tar=False) - args = upload.call_args.args - return args[1] - - def without_user_input(sess): - model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sess, - number_of_partitions=4, - data_type="fp16", - container_log_level=logging.DEBUG, - env=ENV, - image_uri="image_uri", + assert model.engine == "Python" + assert model.image_uri == image_uris.retrieve( + framework=framework, version=VERSION, region=REGION ) - model._upload_model_to_s3(upload_as_tar=False) - args = upload.call_args.args - return args[1] - - actual, expected = _test_default_bucket_and_prefix_combinations( - function_with_user_input=with_user_input, - function_without_user_input=without_user_input, - expected__without_user_input__with_default_bucket_and_default_prefix=( - f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/image_uri/aot-model" - ), - expected__without_user_input__with_default_bucket_only=( - f"s3://{DEFAULT_S3_BUCKET_NAME}/image_uri/aot-model" - ), - expected__with_user_input__with_default_bucket_and_prefix=( - "s3://test-bucket/test-prefix/test-prefix-2/image_uri/aot-model" - ), - expected__with_user_input__with_default_bucket_only=( - "s3://test-bucket/test-prefix/test-prefix-2/image_uri/aot-model" - ), - ) - assert actual == expected - - -@pytest.mark.parametrize( - ( - "code_location," - "expected__without_user_input__with_default_bucket_and_default_prefix, " - "expected__without_user_input__with_default_bucket_only, " - "expected__with_user_input__with_default_bucket_and_prefix, " - "expected__with_user_input__with_default_bucket_only" - ), - [ - ( - "s3://code-test-bucket/code-test-prefix/code-test-prefix-2", - "s3://code-test-bucket/code-test-prefix/code-test-prefix-2/image_uri", - "s3://code-test-bucket/code-test-prefix/code-test-prefix-2/image_uri", - "s3://test-bucket/test-prefix/test-prefix-2", - "s3://test-bucket/test-prefix/test-prefix-2", - ), - ( - None, - f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/image_uri", - f"s3://{DEFAULT_S3_BUCKET_NAME}/image_uri", - "s3://test-bucket/test-prefix/test-prefix-2", - "s3://test-bucket/test-prefix/test-prefix-2", - ), - ], -) -@patch("sagemaker.djl_inference.model.fw_utils.model_code_key_prefix") -@patch("sagemaker.djl_inference.model._get_model_config_properties_from_s3") -@patch("sagemaker.djl_inference.model.fw_utils.tar_and_upload_dir") -@patch("sagemaker.djl_inference.model._create_estimator") -def test_partition_default_bucket_and_prefix_combinations( - _create_estimator, - tar_and_upload_dir, - _get_model_config_properties_from_s3, - model_code_key_prefix, - code_location, - expected__without_user_input__with_default_bucket_and_default_prefix, - expected__without_user_input__with_default_bucket_only, - expected__with_user_input__with_default_bucket_and_prefix, - expected__with_user_input__with_default_bucket_only, -): - # Skip appending of timestamps that this normally does - model_code_key_prefix.side_effect = lambda a, b, c: s3_path_join(a, b, c) - - def with_user_input(sess): - model = DeepSpeedModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sess, - data_type="fp16", - container_log_level=logging.DEBUG, - env=ENV, - code_location=code_location, - image_uri="image_uri", - ) - model.partition(GPU_INSTANCE, s3_output_uri="s3://test-bucket/test-prefix/test-prefix-2") - kwargs = _create_estimator.call_args.kwargs - return kwargs["s3_output_uri"] + assert model.env == { + "HF_MODEL_ID": VALID_UNCOMPRESSED_MODEL_DATA, + "OPTION_ENGINE": "Python", + } - def without_user_input(sess): - model = DeepSpeedModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sess, - data_type="fp16", - container_log_level=logging.DEBUG, - env=ENV, - code_location=code_location, - image_uri="image_uri", - ) - model.partition(GPU_INSTANCE) - kwargs = _create_estimator.call_args.kwargs - return kwargs["s3_output_uri"] - actual, expected = _test_default_bucket_and_prefix_combinations( - function_with_user_input=with_user_input, - function_without_user_input=without_user_input, - expected__without_user_input__with_default_bucket_and_default_prefix=( - expected__without_user_input__with_default_bucket_and_default_prefix - ), - expected__without_user_input__with_default_bucket_only=expected__without_user_input__with_default_bucket_only, - expected__with_user_input__with_default_bucket_and_prefix=( - expected__with_user_input__with_default_bucket_and_prefix - ), - expected__with_user_input__with_default_bucket_only=expected__with_user_input__with_default_bucket_only, - ) - assert actual == expected +def test_create_djl_model_all_provided_args(sagemaker_session): + model = DJLModel( + model_id=HF_MODEL_ID, + sagemaker_session=sagemaker_session, + role=ROLE, + task="text-generation", + djl_framework="djl-tensorrtllm", + dtype="fp16", + tensor_parallel_degree=4, + min_workers=1, + max_workers=4, + job_queue_size=12, + parallel_loading=True, + model_loading_timeout=10, + prediction_timeout=3, + huggingface_hub_token="token", + ) + + assert model.engine == "Python" + assert model.image_uri == TRT_IMAGE_URI + assert model.env == { + "HF_MODEL_ID": HF_MODEL_ID, + "OPTION_ENGINE": "Python", + "HF_TASK": "text-generation", + "TENSOR_PARALLEL_DEGREE": "4", + "SERVING_MIN_WORKERS": "1", + "SERVING_MAX_WORKERS": "4", + "SERVING_JOB_QUEUE_SIZE": "12", + "OPTION_PARALLEL_LOADING": "True", + "OPTION_MODEL_LOADING_TIMEOUT": "10", + "OPTION_PREDICT_TIMEOUT": "3", + "HF_TOKEN": "token", + "OPTION_DTYPE": "fp16", + } diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index fd45601801..b557a9c9f0 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -51,6 +51,8 @@ from sagemaker.instance_group import InstanceGroup from sagemaker.interactive_apps import SupportedInteractiveAppTypes from sagemaker.model import FrameworkModel +from sagemaker.model_card.model_card import ModelCard, ModelOverview +from sagemaker.model_card.schema_constraints import ModelCardStatusEnum from sagemaker.mxnet.estimator import MXNet from sagemaker.predictor import Predictor from sagemaker.pytorch.estimator import PyTorch @@ -264,6 +266,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): return MODEL_CONTAINER_DEF @@ -4336,6 +4339,12 @@ def test_register_default_image(sagemaker_session): framework_version = "2.9" nearest_model_name = "resnet50" data_input_config = '{"input_1":[1,224,224,3]}' + model_overview = ModelOverview(model_creator="TestCreator") + model_card = ModelCard( + name="TestCard", + status=ModelCardStatusEnum.DRAFT, + model_overview=model_overview, + ) estimator.register( content_types=content_types, @@ -4349,9 +4358,13 @@ def test_register_default_image(sagemaker_session): framework_version=framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_config, + model_card=model_card, ) sagemaker_session.create_model.assert_not_called() - + exp_model_card = { + "ModelCardStatus": "Draft", + "ModelCardContent": '{"model_overview": {"model_creator": "TestCreator", "model_artifact": []}}', + } expected_create_model_package_request = { "containers": [{"Image": estimator.image_uri, "ModelDataUrl": estimator.model_data}], "content_types": content_types, @@ -4362,6 +4375,7 @@ def test_register_default_image(sagemaker_session): "marketplace_cert": False, "sample_payload_url": sample_payload_url, "task": task, + "model_card": exp_model_card, } sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request @@ -5243,6 +5257,7 @@ def test_all_framework_estimators_add_jumpstart_uri_tags( entry_point="inference.py", role=ROLE, tags=[{"Key": "blah", "Value": "yoyoma"}], + model_reference_arn=None, ) assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == [ diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index f7dede1ce9..c776dfe479 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -24,6 +24,8 @@ from botocore.exceptions import ClientError from mock import ANY, MagicMock, Mock, patch, call, mock_open +from sagemaker.model_card.schema_constraints import ModelCardStatusEnum + from .common import _raise_unexpected_client_error import sagemaker from sagemaker import TrainingInput, Session, get_execution_role, exceptions @@ -5343,6 +5345,21 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): domain = "COMPUTER_VISION" task = "IMAGE_CLASSIFICATION" sample_payload_url = "s3://test-bucket/model" + model_card = { + "ModelCardStatus": ModelCardStatusEnum.DRAFT, + "Content": { + "model_overview": { + "model_creator": "TestCreator", + }, + "intended_uses": { + "purpose_of_model": "Test model card.", + "intended_uses": "Not used except this test.", + "factors_affecting_model_efficiency": "No.", + "risk_rating": "Low", + "explanations_for_risk_rating": "Just an example.", + }, + }, + } sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -5361,6 +5378,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): sample_payload_url=sample_payload_url, task=task, skip_model_validation=skip_model_validation, + model_card=model_card, ) expected_args = { "ModelPackageName": model_package_name, @@ -5382,6 +5400,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): "SamplePayloadUrl": sample_payload_url, "Task": task, "SkipModelValidation": skip_model_validation, + "ModelCard": model_card, } sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) @@ -6263,6 +6282,24 @@ def test_create_inference_recommendations_job_propogate_other_exception( assert "AccessDeniedException" in str(error) +def test_create_presigned_mlflow_tracking_server_url(sagemaker_session): + sagemaker_session.create_presigned_mlflow_tracking_server_url("ts", 1, 2) + assert ( + sagemaker_session.sagemaker_client.create_presigned_mlflow_tracking_server_url.called_with( + TrackingServerName="ts", ExpiresInSeconds=1, SessionExpirationDurationInSeconds=2 + ) + ) + + +def test_create_presigned_mlflow_tracking_server_url_minimal(sagemaker_session): + sagemaker_session.create_presigned_mlflow_tracking_server_url("ts") + assert ( + sagemaker_session.sagemaker_client.create_presigned_mlflow_tracking_server_url.called_with( + TrackingServerName="ts" + ) + ) + + DEFAULT_LOG_EVENTS_INFERENCE_RECOMMENDER = [ MockBotoException("ResourceNotFoundException"), {"nextForwardToken": None, "events": [{"timestamp": 1, "message": "hi there #1"}]}, @@ -6972,3 +7009,170 @@ def test_download_data_with_file_and_directory(makedirs, sagemaker_session): Filename="./foo/bar/mode.tar.gz", ExtraArgs=None, ) + + +def test_create_hub(sagemaker_session): + sagemaker_session.create_hub( + hub_name="mock-hub-name", + hub_description="this is my sagemaker hub", + hub_display_name="Mock Hub", + hub_search_keywords=["mock", "hub", "123"], + s3_storage_config={"S3OutputPath": "s3://my-hub-bucket/"}, + tags=[{"Key": "tag-key-1", "Value": "tag-value-1"}], + ) + + request = { + "HubName": "mock-hub-name", + "HubDescription": "this is my sagemaker hub", + "HubDisplayName": "Mock Hub", + "HubSearchKeywords": ["mock", "hub", "123"], + "S3StorageConfig": {"S3OutputPath": "s3://my-hub-bucket/"}, + "Tags": [{"Key": "tag-key-1", "Value": "tag-value-1"}], + } + + sagemaker_session.sagemaker_client.create_hub.assert_called_with(**request) + + +def test_describe_hub(sagemaker_session): + sagemaker_session.describe_hub( + hub_name="mock-hub-name", + ) + + request = { + "HubName": "mock-hub-name", + } + + sagemaker_session.sagemaker_client.describe_hub.assert_called_with(**request) + + +def test_list_hubs(sagemaker_session): + sagemaker_session.list_hubs( + creation_time_after="08-14-1997 12:00:00", + creation_time_before="01-08-2024 10:25:00", + max_results=25, + max_schema_version="1.0.5", + name_contains="mock-hub", + sort_by="HubName", + sort_order="Ascending", + ) + + request = { + "CreationTimeAfter": "08-14-1997 12:00:00", + "CreationTimeBefore": "01-08-2024 10:25:00", + "MaxResults": 25, + "MaxSchemaVersion": "1.0.5", + "NameContains": "mock-hub", + "SortBy": "HubName", + "SortOrder": "Ascending", + } + + sagemaker_session.sagemaker_client.list_hubs.assert_called_with(**request) + + +def test_list_hub_contents(sagemaker_session): + sagemaker_session.list_hub_contents( + hub_name="mock-hub-123", + hub_content_type="MODELREF", + creation_time_after="08-14-1997 12:00:00", + creation_time_before="01-08/2024 10:25:00", + max_results=25, + max_schema_version="1.0.5", + name_contains="mock-hub", + sort_by="HubName", + sort_order="Ascending", + ) + + request = { + "HubName": "mock-hub-123", + "HubContentType": "MODELREF", + "CreationTimeAfter": "08-14-1997 12:00:00", + "CreationTimeBefore": "01-08/2024 10:25:00", + "MaxResults": 25, + "MaxSchemaVersion": "1.0.5", + "NameContains": "mock-hub", + "SortBy": "HubName", + "SortOrder": "Ascending", + } + + sagemaker_session.sagemaker_client.list_hub_contents.assert_called_with(**request) + + +def test_list_hub_content_versions(sagemaker_session): + sagemaker_session.list_hub_content_versions( + hub_name="mock-hub-123", + hub_content_type="MODELREF", + hub_content_name="mock-hub-content-1", + min_version="1.0.0", + creation_time_after="08-14-1997 12:00:00", + creation_time_before="01-08/2024 10:25:00", + max_results=25, + max_schema_version="1.0.5", + sort_by="HubName", + sort_order="Ascending", + ) + + request = { + "HubName": "mock-hub-123", + "HubContentType": "MODELREF", + "HubContentName": "mock-hub-content-1", + "MinVersion": "1.0.0", + "CreationTimeAfter": "08-14-1997 12:00:00", + "CreationTimeBefore": "01-08/2024 10:25:00", + "MaxResults": 25, + "MaxSchemaVersion": "1.0.5", + "SortBy": "HubName", + "SortOrder": "Ascending", + } + + sagemaker_session.sagemaker_client.list_hub_content_versions.assert_called_with(**request) + + +def test_delete_hub(sagemaker_session): + sagemaker_session.delete_hub( + hub_name="mock-hub-123", + ) + + request = { + "HubName": "mock-hub-123", + } + + sagemaker_session.sagemaker_client.delete_hub.assert_called_with(**request) + + +def test_create_hub_content_reference(sagemaker_session): + sagemaker_session.create_hub_content_reference( + hub_name="mock-hub-name", + source_hub_content_arn=( + "arn:aws:sagemaker:us-east-1:" + "123456789123:" + "hub-content/JumpStartHub/" + "model/mock-hub-content-1" + ), + hub_content_name="mock-hub-content-1", + min_version="1.1.1", + ) + + request = { + "HubName": "mock-hub-name", + "SageMakerPublicHubContentArn": "arn:aws:sagemaker:us-east-1:123456789123:hub-content/JumpStartHub/model/mock-hub-content-1", # noqa: E501 + "HubContentName": "mock-hub-content-1", + "MinVersion": "1.1.1", + } + + sagemaker_session.sagemaker_client.create_hub_content_reference.assert_called_with(**request) + + +def test_delete_hub_content_reference(sagemaker_session): + sagemaker_session.delete_hub_content_reference( + hub_name="mock-hub-name", + hub_content_type="ModelReference", + hub_content_name="mock-hub-content-1", + ) + + request = { + "HubName": "mock-hub-name", + "HubContentType": "ModelReference", + "HubContentName": "mock-hub-content-1", + } + + sagemaker_session.sagemaker_client.delete_hub_content_reference.assert_called_with(**request) diff --git a/tox.ini b/tox.ini index 6e1f9ce956..194e134b36 100644 --- a/tox.ini +++ b/tox.ini @@ -81,7 +81,7 @@ passenv = # Can be used to specify which tests to run, e.g.: tox -- -s commands = python -c "import os; os.system('install-custom-pkgs --install-boto-wheels')" - pip install 'apache-airflow==2.9.1' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.9.1/constraints-3.8.txt" + pip install 'apache-airflow==2.9.2' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.9.2/constraints-3.8.txt" pip install 'torch==2.0.1+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' pip install 'torchvision==0.15.2+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' pip install 'dill>=0.3.8' @@ -92,9 +92,9 @@ depends = {py38,py39,py310,p311}: clean [testenv:runcoverage] -description = run unit tests with coverage +description = run unit tests with coverage commands = - pytest --cov=sagemaker --cov-append {posargs} + pytest --cov=sagemaker --cov-append {posargs} {env:IGNORE_COVERAGE:} coverage report -i --fail-under=86 [testenv:flake8] @@ -194,4 +194,4 @@ commands = # this needs to succeed for tests to display in some IDEs deps = .[test] commands = - pytest --collect-only + pytest --collect-only \ No newline at end of file From f1bc99e92c6524baf692c28c04a36d3d0938e9ad Mon Sep 17 00:00:00 2001 From: Adam Kozdrowicz Date: Wed, 3 Jul 2024 13:53:22 -0400 Subject: [PATCH 28/45] feat: Support Alt Configs for Public & Curated Hub (#1505) * feat: add alt config support for public & curated hub --- src/sagemaker/jumpstart/hub/interfaces.py | 91 +++++++++++++++++++ src/sagemaker/jumpstart/hub/parsers.py | 8 ++ src/sagemaker/jumpstart/types.py | 18 ++-- tests/unit/sagemaker/jumpstart/constants.py | 64 ++++++++++++- .../jumpstart/hub/test_interfaces.py | 75 ++++++++++++++- tests/unit/sagemaker/jumpstart/test_utils.py | 2 - 6 files changed, 246 insertions(+), 12 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/interfaces.py b/src/sagemaker/jumpstart/hub/interfaces.py index 2748409927..d987216872 100644 --- a/src/sagemaker/jumpstart/hub/interfaces.py +++ b/src/sagemaker/jumpstart/hub/interfaces.py @@ -13,14 +13,20 @@ """This module stores types related to SageMaker JumpStart HubAPI requests and responses.""" from __future__ import absolute_import +from enum import Enum import re import json import datetime from typing import Any, Dict, List, Union, Optional +from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.types import ( HubContentType, HubArnExtractedInfo, + JumpStartConfigComponent, + JumpStartConfigRanking, + JumpStartMetadataConfig, + JumpStartMetadataConfigs, JumpStartPredictorSpecs, JumpStartHyperparameter, JumpStartDataHolderType, @@ -34,6 +40,13 @@ ) +class _ComponentType(str, Enum): + """Enum for different component types.""" + + INFERENCE = "Inference" + TRAINING = "Training" + + class HubDataHolderType(JumpStartDataHolderType): """Base class for many Hub API interfaces.""" @@ -456,6 +469,9 @@ class HubModelDocument(HubDataHolderType): "hosting_use_script_uri", "hosting_eula_uri", "hosting_model_package_arn", + "inference_configs", + "inference_config_components", + "inference_config_rankings", "training_artifact_s3_data_type", "training_artifact_compression_type", "training_model_package_artifact_uri", @@ -467,6 +483,9 @@ class HubModelDocument(HubDataHolderType): "training_ecr_uri", "training_metrics", "training_artifact_uri", + "training_configs", + "training_config_components", + "training_config_rankings", "inference_dependencies", "training_dependencies", "default_inference_instance_type", @@ -566,6 +585,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: ) self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri") self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn") + + self.inference_config_rankings = self._get_config_rankings(json_obj) + self.inference_config_components = self._get_config_components(json_obj) + self.inference_configs = self._get_configs(json_obj) + self.default_inference_instance_type: Optional[str] = json_obj.get( "DefaultInferenceInstanceType" ) @@ -667,6 +691,15 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: "TrainingMetrics", None ) self.training_artifact_uri: Optional[str] = json_obj.get("TrainingArtifactUri") + + self.training_config_rankings = self._get_config_rankings( + json_obj, _ComponentType.TRAINING + ) + self.training_config_components = self._get_config_components( + json_obj, _ComponentType.TRAINING + ) + self.training_configs = self._get_configs(json_obj, _ComponentType.TRAINING) + self.training_dependencies: Optional[str] = json_obj.get("TrainingDependencies") self.default_training_instance_type: Optional[str] = json_obj.get( "DefaultTrainingInstanceType" @@ -707,6 +740,64 @@ def get_region(self) -> str: """Returns hub region.""" return self._region + def _get_config_rankings( + self, json_obj: Dict[str, Any], component_type=_ComponentType.INFERENCE + ) -> Optional[Dict[str, JumpStartConfigRanking]]: + """Returns config rankings.""" + config_rankings = json_obj.get(f"{component_type.value}ConfigRankings") + return ( + { + alias: JumpStartConfigRanking(ranking, is_hub_content=True) + for alias, ranking in config_rankings.items() + } + if config_rankings + else None + ) + + def _get_config_components( + self, json_obj: Dict[str, Any], component_type=_ComponentType.INFERENCE + ) -> Optional[Dict[str, JumpStartConfigComponent]]: + """Returns config components.""" + config_components = json_obj.get(f"{component_type.value}ConfigComponents") + return ( + { + alias: JumpStartConfigComponent(alias, config, is_hub_content=True) + for alias, config in config_components.items() + } + if config_components + else None + ) + + def _get_configs( + self, json_obj: Dict[str, Any], component_type=_ComponentType.INFERENCE + ) -> Optional[JumpStartMetadataConfigs]: + """Returns configs.""" + if not (configs := json_obj.get(f"{component_type.value}Configs")): + return None + + configs_dict = {} + for alias, config in configs.items(): + config_components = None + if isinstance(config, dict) and (component_names := config.get("ComponentNames")): + config_components = { + name: getattr(self, f"{component_type.value.lower()}_config_components").get( + name + ) + for name in component_names + } + configs_dict[alias] = JumpStartMetadataConfig( + alias, config, json_obj, config_components, is_hub_content=True + ) + + if component_type == _ComponentType.INFERENCE: + config_rankings = self.inference_config_rankings + scope = JumpStartScriptScope.INFERENCE + else: + config_rankings = self.training_config_rankings + scope = JumpStartScriptScope.TRAINING + + return JumpStartMetadataConfigs(configs_dict, config_rankings, scope) + class HubNotebookDocument(HubDataHolderType): """Data class for notebook type HubContentDocument from session.describe_hub_content().""" diff --git a/src/sagemaker/jumpstart/hub/parsers.py b/src/sagemaker/jumpstart/hub/parsers.py index 8226a380fd..28c2d9b32d 100644 --- a/src/sagemaker/jumpstart/hub/parsers.py +++ b/src/sagemaker/jumpstart/hub/parsers.py @@ -142,6 +142,9 @@ def make_model_specs_from_describe_hub_content_response( hub_model_document.incremental_training_supported ) specs["hosting_ecr_uri"] = hub_model_document.hosting_ecr_uri + specs["inference_configs"] = hub_model_document.inference_configs + specs["inference_config_components"] = hub_model_document.inference_config_components + specs["inference_config_rankings"] = hub_model_document.inference_config_rankings hosting_artifact_bucket, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable hub_model_document.hosting_artifact_uri @@ -233,6 +236,11 @@ def make_model_specs_from_describe_hub_content_response( training_script_key, ) = parse_s3_url(hub_model_document.training_script_uri) specs["training_script_key"] = training_script_key + + specs["training_configs"] = hub_model_document.training_configs + specs["training_config_components"] = hub_model_document.training_config_components + specs["training_config_rankings"] = hub_model_document.training_config_rankings + specs["training_dependencies"] = hub_model_document.training_dependencies specs["default_training_instance_type"] = hub_model_document.default_training_instance_type specs["supported_training_instance_types"] = ( diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 6ed2c4fdb9..ddc7943650 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1169,12 +1169,14 @@ class JumpStartConfigRanking(JumpStartDataHolderType): __slots__ = ["description", "rankings"] - def __init__(self, spec: Optional[Dict[str, Any]]): + def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content=False): """Initializes a JumpStartConfigRanking object. Args: spec (Dict[str, Any]): Dictionary representation of training config ranking. """ + if is_hub_content: + spec = {camel_to_snake(key): val for key, val in spec.items()} self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -1285,7 +1287,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj.get("incremental_training_supported", False) ) if self._is_hub_content: - self.hosting_ecr_uri: Optional[str] = json_obj["hosting_ecr_uri"] + self.hosting_ecr_uri: Optional[str] = json_obj.get("hosting_ecr_uri") self._non_serializable_slots.append("hosting_ecr_specs") else: self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = ( @@ -1491,9 +1493,7 @@ class JumpStartConfigComponent(JumpStartMetadataBaseFields): __slots__ = slots + JumpStartMetadataBaseFields.__slots__ def __init__( - self, - component_name: str, - component: Optional[Dict[str, Any]], + self, component_name: str, component: Optional[Dict[str, Any]], is_hub_content=False ): """Initializes a JumpStartConfigComponent object from its json representation. @@ -1504,8 +1504,10 @@ def __init__( Raises: ValueError: If the component field is invalid. """ - super().__init__(component) + if is_hub_content: + component = walk_and_apply_json(component, camel_to_snake) self.component_name = component_name + super().__init__(component, is_hub_content) self.from_json(component) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -1542,6 +1544,7 @@ def __init__( config: Dict[str, Any], base_fields: Dict[str, Any], config_components: Dict[str, JumpStartConfigComponent], + is_hub_content=False, ): """Initializes a JumpStartMetadataConfig object from its json representation. @@ -1554,6 +1557,9 @@ def __init__( config_components (Dict[str, JumpStartConfigComponent]): The list of components that are used to construct the resolved config. """ + if is_hub_content: + config = walk_and_apply_json(config, camel_to_snake) + base_fields = walk_and_apply_json(base_fields, camel_to_snake) self.base_fields = base_fields self.config_components: Dict[str, JumpStartConfigComponent] = config_components self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = ( diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 1ae489acf8..9bff5cce67 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -8703,7 +8703,17 @@ "variants": {"inf2": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, }, }, - "neuron-budget": {"inference_environment_variables": {"BUDGET": "1234"}}, + "neuron-budget": { + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + } + ], + }, "gpu-inference": { "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference/model/", @@ -9816,6 +9826,58 @@ "DynamicContainerDeploymentSupported": True, "TrainingModelPackageArtifactUri": None, "Dependencies": [], + "InferenceConfigRankings": { + "overall": {"Description": "default", "Rankings": ["variant1"]} + }, + "InferenceConfigs": { + "variant1": { + "ComponentNames": ["variant1"], + "BenchmarkMetrics": { + "ml.g5.12xlarge": [ + {"Name": "latency", "Unit": "sec", "Value": "0.19", "Concurrency": "1"}, + ] + }, + }, + }, + "InferenceConfigComponents": { + "variant1": { + "HostingEcrUri": "123456789012.ecr.us-west-2.amazon.com/repository", + "HostingArtifactUri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration-llama-2-7b/artifacts/variant1/v1.0.0/", # noqa: E501 + "HostingScriptUri": "s3://jumpstart-monarch-test-hub-bucket/monarch-curated-hub-1714579993.88695/curated_models/meta-textgeneration-llama-2-7b/4.0.0/source-directory-tarballs/meta/inference/textgeneration/v1.2.3/sourcedir.tar.gz", # noqa: E501 + "InferenceDependencies": [], + "InferenceEnvironmentVariables": [ + { + "Name": "SAGEMAKER_PROGRAM", + "Type": "text", + "Default": "inference.py", + "Scope": "container", + "RequiredForModelClass": True, + } + ], + "HostingAdditionalDataSources": { + "speculative_decoding": [ + { + "ArtifactVersion": 1, + "ChannelName": "speculative_decoding_channel_1", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/path/1", + }, + }, + { + "ArtifactVersion": 1, + "ChannelName": "speculative_decoding_channel_2", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/path/2", + }, + }, + ] + }, + }, + }, }, "meta-textgeneration-llama-2-70b": { "Url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", diff --git a/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py index c4b95443ec..11798bc854 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py @@ -15,9 +15,13 @@ import pytest import numpy as np from sagemaker.jumpstart.types import ( + JumpStartConfigComponent, + JumpStartConfigRanking, JumpStartHyperparameter, JumpStartInstanceTypeVariants, JumpStartEnvironmentVariable, + JumpStartMetadataConfig, + JumpStartMetadataConfigs, JumpStartPredictorSpecs, JumpStartSerializablePayload, ) @@ -32,9 +36,8 @@ def test_hub_content_document_from_json_obj(): region = "us-west-2" - gemma_model_document = HubModelDocument( - json_obj=HUB_MODEL_DOCUMENT_DICTS["huggingface-llm-gemma-2b-instruct"], region=region - ) + json_obj = HUB_MODEL_DOCUMENT_DICTS["huggingface-llm-gemma-2b-instruct"] + gemma_model_document = HubModelDocument(json_obj=json_obj, region=region) assert gemma_model_document.url == "https://huggingface.co/google/gemma-2b-it" assert gemma_model_document.min_sdk_version == "2.189.0" assert gemma_model_document.training_supported is True @@ -979,3 +982,69 @@ def test_hub_content_document_from_json_obj(): assert gemma_model_document.dynamic_container_deployment_supported is True assert gemma_model_document.training_model_package_artifact_uri is None assert gemma_model_document.dependencies == [] + + inference_config_rankings = { + "overall": JumpStartConfigRanking( + {"Description": "default", "Rankings": ["variant1"]}, is_hub_content=True + ) + } + + inference_config_components = { + "variant1": JumpStartConfigComponent( + "variant1", + { + "HostingEcrUri": "123456789012.ecr.us-west-2.amazon.com/repository", + "HostingArtifactUri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration-llama-2-7b/artifacts/variant1/v1.0.0/", # noqa: E501 + "HostingScriptUri": "s3://jumpstart-monarch-test-hub-bucket/monarch-curated-hub-1714579993.88695/curated_models/meta-textgeneration-llama-2-7b/4.0.0/source-directory-tarballs/meta/inference/textgeneration/v1.2.3/sourcedir.tar.gz", # noqa: E501 + "InferenceDependencies": [], + "InferenceEnvironmentVariables": [ + { + "Name": "SAGEMAKER_PROGRAM", + "Type": "text", + "Default": "inference.py", + "Scope": "container", + "RequiredForModelClass": True, + } + ], + "HostingAdditionalDataSources": { + "speculative_decoding": [ + { + "ArtifactVersion": 1, + "ChannelName": "speculative_decoding_channel_1", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/path/1", + }, + }, + { + "ArtifactVersion": 1, + "ChannelName": "speculative_decoding_channel_2", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/path/2", + }, + }, + ] + }, + }, + is_hub_content=True, + ) + } + + inference_configs_dict = { + "variant1": JumpStartMetadataConfig( + "variant1", + json_obj["InferenceConfigs"]["variant1"], + json_obj, + inference_config_components, + is_hub_content=True, + ) + } + + inference_configs = JumpStartMetadataConfigs(inference_configs_dict, inference_config_rankings) + + assert gemma_model_document.inference_config_rankings == inference_config_rankings + assert gemma_model_document.inference_config_components == inference_config_components + assert gemma_model_document.inference_configs == inference_configs diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 6cb8fbaa14..204f1d2d29 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -2059,6 +2059,4 @@ def test_has_instance_rate_stat(stats, expected): ) def test_deployment_config_response_data(data, expected): out = utils.deployment_config_response_data(data) - - print(out) assert out == expected From 15e26c49140c2614a3c762fc06fb34bd8381f442 Mon Sep 17 00:00:00 2001 From: Jacky Lee Date: Wed, 3 Jul 2024 11:28:34 -0700 Subject: [PATCH 29/45] fix: make telemetry logger persist certain information (#1500) * refactor telemetry logger * refactor * refactor * pylint + UT * add tag * add remove tags * handle tags again * pylint --------- Co-authored-by: Jacky Lee --- src/sagemaker/model.py | 9 ++ .../serve/builder/jumpstart_builder.py | 22 ++++- src/sagemaker/serve/builder/model_builder.py | 17 +++- src/sagemaker/serve/utils/telemetry_logger.py | 94 +++++++++---------- src/sagemaker/utils.py | 27 ++++++ .../serve/utils/test_telemetry_logger.py | 11 +-- tests/unit/test_utils.py | 22 +++++ 7 files changed, 136 insertions(+), 66 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index ce8142e43d..5d7ee5b378 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -74,6 +74,7 @@ Tags, _resolve_routing_config, _validate_new_tags, + remove_tag_with_key, ) from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.predictor_async import AsyncPredictor @@ -426,6 +427,14 @@ def add_tags(self, tags: Tags) -> None: """ self._tags = _validate_new_tags(tags, self._tags) + def remove_tag_with_key(self, key: str) -> None: + """Remove a tag with the given key from the list of tags. + + Args: + key (str): The key of the tag to remove. + """ + self._tags = remove_tag_with_key(key, self._tags) + @classmethod def attach( cls, diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index e051e4340d..962b01f650 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -116,7 +116,9 @@ def __init__(self): self.model_metadata = None self.role_arn = None self.is_fine_tuned = None - self.is_gated = None + self.is_compiled = False + self.is_quantized = False + self.speculative_decoding_draft_model_source = None @abstractmethod def _prepare_for_mode(self): @@ -503,6 +505,18 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None: self.pysdk_model.set_deployment_config(config_name, instance_type) + self.instance_type = instance_type + + # JS-benchmarked models only include SageMaker-provided SD models + if self.pysdk_model.additional_model_data_sources: + self.speculative_decoding_draft_model_source = "sagemaker" + self.pysdk_model.add_tags( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "sagemaker"}, + ) + self.pysdk_model.remove_tag_with_key(Tag.OPTIMIZATION_JOB_NAME) + self.pysdk_model.remove_tag_with_key(Tag.FINE_TUNING_MODEL_PATH) + self.pysdk_model.remove_tag_with_key(Tag.FINE_TUNING_JOB_NAME) + def get_deployment_config(self) -> Optional[Dict[str, Any]]: """Gets the deployment config to apply to the model. @@ -775,10 +789,8 @@ def _is_gated_model(self, model=None) -> bool: s3_uri = s3_uri.get("S3DataSource").get("S3Uri") if s3_uri is None: - self.is_gated = False - else: - self.is_gated = "private" in s3_uri - return self.is_gated + return False + return "private" in s3_uri def _set_additional_model_source( self, diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index d58b0618b7..7b290ebb69 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -23,6 +23,7 @@ from pathlib import Path +from sagemaker.enums import Tag from sagemaker.s3 import S3Downloader from sagemaker import Session @@ -69,6 +70,7 @@ from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model from sagemaker.serve.utils.optimize_utils import ( _generate_optimized_model, + _extract_speculative_draft_model_provider, ) from sagemaker.serve.utils.predictors import _get_local_mode_predictor from sagemaker.serve.utils.hardware_detector import ( @@ -647,11 +649,6 @@ def _handle_mlflow_input(self): mlflow_model_path = self.model_metadata.get(MLFLOW_MODEL_PATH) artifact_path = self._get_artifact_path(mlflow_model_path) if not self._mlflow_metadata_exists(artifact_path): - logger.info( - "MLflow model metadata not detected in %s. ModelBuilder is not " - "handling MLflow model input", - mlflow_model_path, - ) return self._initialize_for_mlflow(artifact_path) @@ -1144,6 +1141,12 @@ def _model_builder_optimize_wrapper( Returns: Model: A deployable ``Model`` object. """ + self.is_compiled = compilation_config is not None + self.is_quantized = quantization_config is not None + self.speculative_decoding_draft_model_source = _extract_speculative_draft_model_provider( + speculative_decoding_config + ) + if quantization_config and compilation_config: raise ValueError("Quantization config and compilation config are mutually exclusive.") @@ -1180,4 +1183,8 @@ def _model_builder_optimize_wrapper( job_status = self.sagemaker_session.wait_for_optimization_job(job_name) return _generate_optimized_model(self.pysdk_model, job_status) + self.pysdk_model.remove_tag_with_key(Tag.OPTIMIZATION_JOB_NAME) + if not speculative_decoding_config: + self.pysdk_model.remove_tag_with_key(Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER) + return self.pysdk_model diff --git a/src/sagemaker/serve/utils/telemetry_logger.py b/src/sagemaker/serve/utils/telemetry_logger.py index 9a74f4b828..0ea6ec3f26 100644 --- a/src/sagemaker/serve/utils/telemetry_logger.py +++ b/src/sagemaker/serve/utils/telemetry_logger.py @@ -94,15 +94,38 @@ def wrapper(self, *args, **kwargs): logger.info(TELEMETRY_OPT_OUT_MESSAGING) response = None caught_ex = None - + status = "1" + failure_reason = None + failure_type = None extra = f"{func_name}" + start_timer = perf_counter() + try: + response = func(self, *args, **kwargs) + except ( + ModelBuilderException, + exceptions.CapacityError, + exceptions.UnexpectedStatusException, + exceptions.AsyncInferenceError, + ) as e: + status = "0" + caught_ex = e + failure_reason = str(e) + failure_type = e.__class__.__name__ + except Exception as e: # pylint: disable=W0703 + raise e + + stop_timer = perf_counter() + elapsed = stop_timer - start_timer + if self.model_server: extra += f"&x-modelServer={MODEL_SERVER_TO_CODE[str(self.model_server)]}" if self.image_uri: image_uri_tail = self.image_uri.split("/")[1] - image_uri_option = _get_image_uri_option(self.image_uri, self._is_custom_image_uri) + image_uri_option = _get_image_uri_option( + self.image_uri, getattr(self, "_is_custom_image_uri", False) + ) if self.image_uri: extra += f"&x-imageTag={image_uri_tail}" @@ -128,63 +151,36 @@ def wrapper(self, *args, **kwargs): if getattr(self, "is_fine_tuned", False): extra += "&x-fineTuned=1" - if getattr(self, "is_gated", False): - extra += "&x-gated=1" - if kwargs.get("compilation_config"): + if getattr(self, "is_compiled", False): extra += "&x-compiled=1" - if kwargs.get("quantization_config"): + if getattr(self, "is_quantized", False): extra += "&x-quantized=1" - if kwargs.get("speculative_decoding_config"): - model_provider = kwargs["speculative_decoding_config"]["ModelProvider"] + if getattr(self, "speculative_decoding_draft_model_source", False): model_provider_enum = ( SpeculativeDecodingDraftModelSource.SAGEMAKER - if model_provider.lower() == "sagemaker" + if self.speculative_decoding_draft_model_source == "sagemaker" else SpeculativeDecodingDraftModelSource.CUSTOM ) model_provider_value = SD_DRAFT_MODEL_SOURCE_TO_CODE[str(model_provider_enum)] extra += f"&x-sdDraftModelSource={model_provider_value}" - start_timer = perf_counter() - try: - response = func(self, *args, **kwargs) - stop_timer = perf_counter() - elapsed = stop_timer - start_timer - extra += f"&x-latency={round(elapsed, 2)}" - if not self.serve_settings.telemetry_opt_out: - _send_telemetry( - "1", - MODE_TO_CODE[str(self.mode)], - self.sagemaker_session, - None, - None, - extra, - ) - except ( - ModelBuilderException, - exceptions.CapacityError, - exceptions.UnexpectedStatusException, - exceptions.AsyncInferenceError, - ) as e: - stop_timer = perf_counter() - elapsed = stop_timer - start_timer - extra += f"&x-latency={round(elapsed, 2)}" - if not self.serve_settings.telemetry_opt_out: - _send_telemetry( - "0", - MODE_TO_CODE[str(self.mode)], - self.sagemaker_session, - str(e), - e.__class__.__name__, - extra, - ) - caught_ex = e - except Exception as e: # pylint: disable=W0703 - caught_ex = e - finally: - if caught_ex: - raise caught_ex - return response # pylint: disable=W0150 + extra += f"&x-latency={round(elapsed, 2)}" + + if not self.serve_settings.telemetry_opt_out: + _send_telemetry( + status, + MODE_TO_CODE[str(self.mode)], + self.sagemaker_session, + failure_reason, + failure_type, + extra, + ) + + if caught_ex: + raise caught_ex + + return response return wrapper diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index d20e72fc1f..834d193ec1 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1873,3 +1873,30 @@ def _validate_new_tags(new_tags: Optional[Tags], curr_tags: Optional[Tags]) -> O curr_tags.append(new_tag) return curr_tags + + +def remove_tag_with_key(key: str, tags: Optional[Tags]) -> Optional[Tags]: + """Remove a tag with the given key from the list of tags. + + Args: + key (str): The key of the tag to remove. + tags (Optional[Tags]): The current list of tags. + + Returns: + Optional[Tags]: The updated list of tags with the tag removed. + """ + if tags is None: + return tags + if isinstance(tags, dict): + tags = [tags] + + updated_tags = [] + for tag in tags: + if tag["Key"] != key: + updated_tags.append(tag) + + if not updated_tags: + return None + if len(updated_tags) == 1: + return updated_tags[0] + return updated_tags diff --git a/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py b/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py index 563e0f8f20..4729efbda4 100644 --- a/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py +++ b/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py @@ -314,17 +314,15 @@ def test_capture_telemetry_decorator_optimize_with_custom_configs(self, mock_sen mock_model_builder.model_server = ModelServer.TORCHSERVE mock_model_builder.sagemaker_session.endpoint_arn = None mock_model_builder.is_fine_tuned = True - mock_model_builder.is_gated = True + mock_model_builder.is_compiled = True + mock_model_builder.is_quantized = True + mock_model_builder.speculative_decoding_draft_model_source = "sagemaker" mock_speculative_decoding_config = MagicMock() mock_config = {"ModelProvider": "sagemaker"} mock_speculative_decoding_config.__getitem__.side_effect = mock_config.__getitem__ - mock_model_builder.mock_optimize( - quantization_config=Mock(), - compilation_config=Mock(), - speculative_decoding_config=mock_speculative_decoding_config, - ) + mock_model_builder.mock_optimize() args = mock_send_telemetry.call_args.args latency = str(args[5]).split("latency=")[1] @@ -333,7 +331,6 @@ def test_capture_telemetry_decorator_optimize_with_custom_configs(self, mock_sen "&x-modelServer=1" f"&x-sdkVersion={SDK_VERSION}" f"&x-fineTuned=1" - f"&x-gated=1" f"&x-compiled=1" f"&x-quantized=1" f"&x-sdDraftModelSource=1" diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 63263a7920..11cbc120b3 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -57,6 +57,7 @@ _resolve_routing_config, tag_exists, _validate_new_tags, + remove_tag_with_key, ) from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger @@ -2124,3 +2125,24 @@ def test_new_add_tags(self): new_tag = {"Key": "project-2", "Value": "my-project-2"} self.assertEqual(_validate_new_tags(new_tag, None), new_tag) + + def test_remove_existing_tag(self): + original_tags = [ + {"Key": "Tag1", "Value": "Value1"}, + {"Key": "Tag2", "Value": "Value2"}, + {"Key": "Tag3", "Value": "Value3"}, + ] + expected_output = [{"Key": "Tag1", "Value": "Value1"}, {"Key": "Tag3", "Value": "Value3"}] + self.assertEqual(remove_tag_with_key("Tag2", original_tags), expected_output) + + def test_remove_non_existent_tag(self): + original_tags = [ + {"Key": "Tag1", "Value": "Value1"}, + {"Key": "Tag2", "Value": "Value2"}, + {"Key": "Tag3", "Value": "Value3"}, + ] + self.assertEqual(remove_tag_with_key("NonExistentTag", original_tags), original_tags) + + def test_remove_only_tag(self): + original_tags = [{"Key": "Tag1", "Value": "Value1"}] + self.assertIsNone(remove_tag_with_key("Tag1", original_tags)) From 26c8696f69250d4a01843e3fe93e3d344cd126d2 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga <54963715+makungaj1@users.noreply.github.com> Date: Wed, 3 Jul 2024 12:58:52 -0700 Subject: [PATCH 30/45] Optimize support for hf models (#1499) * HF support * refactoring * Refactoring * Refactoing * HF Refactoring * Refactoring * UT * Fix UT * Resolving PR comments * HF Token * Resolving PR comments * Fix UT * Fix JS ModelServer deploy wrapper override * Fix tests * fix UT * Resolve PR comments * fix doc --------- Co-authored-by: Jonathan Makunga --- requirements/extras/test_requirements.txt | 1 + src/sagemaker/huggingface/llm_utils.py | 25 +++ src/sagemaker/serve/builder/djl_builder.py | 8 +- .../serve/builder/jumpstart_builder.py | 79 ++++---- src/sagemaker/serve/builder/model_builder.py | 184 ++++++++++++++++-- src/sagemaker/serve/builder/tei_builder.py | 4 +- src/sagemaker/serve/builder/tgi_builder.py | 4 +- .../serve/builder/transformers_builder.py | 4 +- .../serve/mode/sagemaker_endpoint_mode.py | 32 ++- .../serve/model_server/djl_serving/server.py | 60 +++--- .../model_server/multi_model_server/server.py | 59 +++--- .../serve/model_server/tei/server.py | 63 +++--- .../model_server/tensorflow_serving/server.py | 39 ++-- .../serve/model_server/tgi/server.py | 60 +++--- .../serve/model_server/torchserve/server.py | 38 ++-- .../serve/model_server/triton/server.py | 40 ++-- src/sagemaker/serve/utils/optimize_utils.py | 105 +++++++++- .../sagemaker/huggingface/test_llm_utils.py | 27 ++- .../serve/builder/test_model_builder.py | 32 +++ .../serve/model_server/tei/test_server.py | 1 + .../tensorflow_serving/test_tf_server.py | 1 + .../serve/model_server/triton/test_server.py | 1 + .../serve/utils/test_optimize_utils.py | 162 ++++++++++++++- 23 files changed, 804 insertions(+), 225 deletions(-) diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index 60904c51b0..56805ebc7a 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -38,3 +38,4 @@ accelerate>=0.24.1,<=0.27.0 schema==0.7.5 tensorflow>=2.1,<=2.16 mlflow>=2.12.2,<2.13 +huggingface_hub>=0.23.4 diff --git a/src/sagemaker/huggingface/llm_utils.py b/src/sagemaker/huggingface/llm_utils.py index 9927d1d293..c7a1316760 100644 --- a/src/sagemaker/huggingface/llm_utils.py +++ b/src/sagemaker/huggingface/llm_utils.py @@ -13,7 +13,9 @@ """Functions for generating ECR image URIs for pre-built SageMaker Docker images.""" from __future__ import absolute_import +import os from typing import Optional +import importlib.util import urllib.request from urllib.error import HTTPError, URLError @@ -123,3 +125,26 @@ def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] = "Did not find model metadata for the following HuggingFace Model ID %s" % model_id ) return hf_model_metadata_json + + +def download_huggingface_model_metadata( + model_id: str, model_local_path: str, hf_hub_token: Optional[str] = None +) -> None: + """Downloads the HuggingFace Model snapshot via HuggingFace API. + + Args: + model_id (str): The HuggingFace Model ID + model_local_path (str): The local path to save the HuggingFace Model snapshot. + hf_hub_token (str): The HuggingFace Hub Token + + Raises: + ImportError: If huggingface_hub is not installed. + """ + if not importlib.util.find_spec("huggingface_hub"): + raise ImportError("Unable to import huggingface_hub, check if huggingface_hub is installed") + + from huggingface_hub import snapshot_download + + os.makedirs(model_local_path, exist_ok=True) + logger.info("Downloading model %s from Hugging Face Hub to %s", model_id, model_local_path) + snapshot_download(repo_id=model_id, local_dir=model_local_path, token=hf_hub_token) diff --git a/src/sagemaker/serve/builder/djl_builder.py b/src/sagemaker/serve/builder/djl_builder.py index 72437c0fbb..75acd0d1fe 100644 --- a/src/sagemaker/serve/builder/djl_builder.py +++ b/src/sagemaker/serve/builder/djl_builder.py @@ -24,6 +24,7 @@ LocalModelOutOfMemoryException, LocalModelInvocationException, ) +from sagemaker.serve.utils.optimize_utils import _is_optimized from sagemaker.serve.utils.tuning import ( _serial_benchmark, _concurrent_benchmark, @@ -214,9 +215,10 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa del kwargs["role"] # set model_data to uncompressed s3 dict - self.pysdk_model.model_data, env_vars = self._prepare_for_mode() - self.env_vars.update(env_vars) - self.pysdk_model.env.update(self.env_vars) + if not _is_optimized(self.pysdk_model): + self.pysdk_model.model_data, env_vars = self._prepare_for_mode() + self.env_vars.update(env_vars) + self.pysdk_model.env.update(self.env_vars) # if the weights have been cached via local container mode -> set to offline if str(Mode.LOCAL_CONTAINER) in self.modes: diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 962b01f650..ccfe795004 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -42,10 +42,11 @@ _update_environment_variables, _extract_speculative_draft_model_provider, _is_image_compatible_with_optimization_job, - _extracts_and_validates_speculative_model_source, _generate_channel_name, - _generate_additional_model_data_sources, - _is_s3_uri, + _extract_optimization_config_and_env, + _is_optimized, + _custom_speculative_decoding, + SPECULATIVE_DRAFT_MODEL, ) from sagemaker.serve.utils.predictors import ( DjlLocalModePredictor, @@ -121,7 +122,7 @@ def __init__(self): self.speculative_decoding_draft_model_source = None @abstractmethod - def _prepare_for_mode(self): + def _prepare_for_mode(self, **kwargs): """Placeholder docstring""" @abstractmethod @@ -130,6 +131,9 @@ def _get_client_translators(self): def _is_jumpstart_model_id(self) -> bool: """Placeholder docstring""" + if self.model is None: + return False + try: model_uris.retrieve(model_id=self.model, model_version="*", model_scope=_JS_SCOPE) except KeyError: @@ -141,8 +145,9 @@ def _is_jumpstart_model_id(self) -> bool: def _create_pre_trained_js_model(self) -> Type[Model]: """Placeholder docstring""" - pysdk_model = JumpStartModel(self.model, vpc_config=self.vpc_config) - pysdk_model.sagemaker_session = self.sagemaker_session + pysdk_model = JumpStartModel( + self.model, vpc_config=self.vpc_config, sagemaker_session=self.sagemaker_session + ) self._original_deploy = pysdk_model.deploy pysdk_model.deploy = self._js_builder_deploy_wrapper @@ -151,6 +156,7 @@ def _create_pre_trained_js_model(self) -> Type[Model]: @_capture_telemetry("jumpstart.deploy") def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: """Placeholder docstring""" + env = {} if "mode" in kwargs and kwargs.get("mode") != self.mode: overwrite_mode = kwargs.get("mode") # mode overwritten by customer during model.deploy() @@ -167,7 +173,8 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: or not hasattr(self, "prepared_for_tgi") or not hasattr(self, "prepared_for_mms") ): - self.pysdk_model.model_data, env = self._prepare_for_mode() + if not _is_optimized(self.pysdk_model): + self.pysdk_model.model_data, env = self._prepare_for_mode() elif overwrite_mode == Mode.LOCAL_CONTAINER: self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER @@ -198,7 +205,6 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: ) self._prepare_for_mode() - env = {} else: raise ValueError("Mode %s is not supported!" % overwrite_mode) @@ -726,25 +732,17 @@ def _optimize_for_jumpstart( ) model_source = _generate_model_source(self.pysdk_model.model_data, accept_eula) - - optimization_config = {} - if quantization_config: - optimization_config["ModelQuantizationConfig"] = quantization_config - pysdk_model_env_vars = _update_environment_variables( - pysdk_model_env_vars, quantization_config["OverrideEnvironment"] - ) - if compilation_config: - optimization_config["ModelCompilationConfig"] = compilation_config - pysdk_model_env_vars = _update_environment_variables( - pysdk_model_env_vars, compilation_config["OverrideEnvironment"] - ) + optimization_config, env = _extract_optimization_config_and_env( + quantization_config, compilation_config + ) + pysdk_model_env_vars = _update_environment_variables(pysdk_model_env_vars, env) output_config = {"S3OutputLocation": output_path} if kms_key: output_config["KmsKeyId"] = kms_key if not instance_type: - instance_type = self.pysdk_model.deployment_config.get("DeploymentArgs").get( - "InstanceType" + instance_type = self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get( + "InstanceType", _get_nb_instance() ) create_optimization_job_args = { @@ -771,6 +769,10 @@ def _optimize_for_jumpstart( self.pysdk_model.env.update(pysdk_model_env_vars) if accept_eula: self.pysdk_model.accept_eula = accept_eula + if isinstance(self.pysdk_model.model_data, dict): + self.pysdk_model.model_data["S3DataSource"]["ModelAccessConfig"] = { + "AcceptEula": True + } if quantization_config or compilation_config: return create_optimization_job_args @@ -806,7 +808,6 @@ def _set_additional_model_source( if speculative_decoding_config: model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config) channel_name = _generate_channel_name(self.pysdk_model.additional_model_data_sources) - speculative_draft_model = f"/opt/ml/additional-model-data-sources/{channel_name}" if model_provider == "sagemaker": additional_model_data_sources = self.pysdk_model.deployment_config.get( @@ -825,32 +826,18 @@ def _set_additional_model_source( raise ValueError( "Cannot find deployment config compatible for optimization job." ) + + self.pysdk_model.env.update( + {"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}"} + ) + self.pysdk_model.add_tags( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "sagemaker"}, + ) else: - model_source = _extracts_and_validates_speculative_model_source( - speculative_decoding_config + self.pysdk_model = _custom_speculative_decoding( + self.pysdk_model, speculative_decoding_config, accept_eula ) - if _is_s3_uri(model_source): - self.pysdk_model.additional_model_data_sources = ( - _generate_additional_model_data_sources( - model_source, channel_name, accept_eula - ) - ) - else: - speculative_draft_model = model_source - - self.pysdk_model.env = _update_environment_variables( - self.pysdk_model.env, - {"OPTION_SPECULATIVE_DRAFT_MODEL": speculative_draft_model}, - ) - self.pysdk_model.add_tags( - {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": model_provider}, - ) - if accept_eula and isinstance(self.pysdk_model.model_data, dict): - self.pysdk_model.model_data["S3DataSource"]["ModelAccessConfig"] = { - "AcceptEula": True - } - def _find_compatible_deployment_config( self, speculative_decoding_config: Optional[Dict] = None ) -> Optional[Dict[str, Any]]: diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 7b290ebb69..aa56c89182 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -70,6 +70,12 @@ from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model from sagemaker.serve.utils.optimize_utils import ( _generate_optimized_model, + _generate_model_source, + _update_environment_variables, + _extract_optimization_config_and_env, + _is_s3_uri, + _normalize_local_model_path, + _custom_speculative_decoding, _extract_speculative_draft_model_provider, ) from sagemaker.serve.utils.predictors import _get_local_mode_predictor @@ -97,6 +103,7 @@ from sagemaker.workflow.entities import PipelineVariable from sagemaker.huggingface.llm_utils import ( get_huggingface_model_metadata, + download_huggingface_model_metadata, ) logger = logging.getLogger(__name__) @@ -192,7 +199,11 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, new models without task metadata in the Hub, adding unsupported task types will throw an exception. ``MLFLOW_MODEL_PATH`` is available for providing local path or s3 path to MLflow artifacts. However, ``MLFLOW_MODEL_PATH`` is experimental and is not - intended for production use at this moment. + intended for production use at this moment. ``CUSTOM_MODEL_PATH`` is available for + providing local path or s3 path to model artifacts. ``FINE_TUNING_MODEL_PATH`` is + available for providing s3 path to fine-tuned model artifacts. ``FINE_TUNING_JOB_NAME`` + is available for providing fine-tuned job name. Both ``FINE_TUNING_MODEL_PATH`` and + ``FINE_TUNING_JOB_NAME`` are mutually exclusive. """ model_path: Optional[str] = field( @@ -293,9 +304,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, default=None, metadata={ "help": "Define the model metadata to override, currently supports `HF_TASK`, " - "`MLFLOW_MODEL_PATH`, `FINE_TUNING_MODEL_PATH`, and `FINE_TUNING_JOB_NAME`. HF_TASK " - "should be set for new models without task metadata in the Hub, Adding unsupported " - "task types will throw an exception." + "`MLFLOW_MODEL_PATH`, `FINE_TUNING_MODEL_PATH`, `FINE_TUNING_JOB_NAME`, and " + "`CUSTOM_MODEL_PATH`. HF_TASK should be set for new models without task metadata " + "in the Hub, Adding unsupported task types will throw an exception." }, ) @@ -386,8 +397,15 @@ def _get_serve_setting(self): sagemaker_session=self.sagemaker_session, ) - def _prepare_for_mode(self, should_upload_artifacts: bool = False): - """Placeholder docstring""" + def _prepare_for_mode( + self, model_path: Optional[str] = None, should_upload_artifacts: Optional[bool] = False + ): + """Prepare this `Model` for serving. + + Args: + model_path (Optional[str]): Model path + should_upload_artifacts (Optional[bool]): Whether to upload artifacts to S3. + """ # TODO: move mode specific prepare steps under _model_builder_deploy_wrapper self.s3_upload_path = None if self.mode == Mode.SAGEMAKER_ENDPOINT: @@ -398,16 +416,15 @@ def _prepare_for_mode(self, should_upload_artifacts: bool = False): self.s3_upload_path, env_vars_sagemaker = self.modes[ str(Mode.SAGEMAKER_ENDPOINT) ].prepare( - self.model_path, + (model_path or self.model_path), self.secret_key, self.serve_settings.s3_model_data_url, self.sagemaker_session, self.image_uri, getattr(self, "model_hub", None) == ModelHub.JUMPSTART, - should_upload=should_upload_artifacts, + should_upload_artifacts=should_upload_artifacts, ) - if env_vars_sagemaker: - self.env_vars.update(env_vars_sagemaker) + self.env_vars.update(env_vars_sagemaker) return self.s3_upload_path, env_vars_sagemaker if self.mode == Mode.LOCAL_CONTAINER: # init the LocalContainerMode object @@ -822,10 +839,17 @@ def build( # pylint: disable=R0911 self.mode = mode if role_arn: self.role_arn = role_arn - self.sagemaker_session = sagemaker_session or Session() + + if not self.sagemaker_session: + self.sagemaker_session = sagemaker_session or Session() self.sagemaker_session.settings._local_download_dir = self.model_path + # DJL expects `HF_TOKEN` key. This allows backward compatibility + # until we deprecate HUGGING_FACE_HUB_TOKEN. + if self.env_vars.get("HUGGING_FACE_HUB_TOKEN") and not self.env_vars.get("HF_TOKEN"): + self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN") + # https://github.com/boto/botocore/blob/develop/botocore/useragent.py#L258 # decorate to_string() due to # https://github.com/boto/botocore/blob/develop/botocore/client.py#L1014-L1015 @@ -841,7 +865,7 @@ def build( # pylint: disable=R0911 self._build_validations() - if self.model_server: + if not self._is_jumpstart_model_id() and self.model_server: return self._build_for_model_server() if isinstance(self.model, str): @@ -1087,6 +1111,9 @@ def optimize(self, *args, **kwargs) -> Model: Returns: Model: A deployable ``Model`` object. """ + if self.mode != Mode.SAGEMAKER_ENDPOINT: + raise ValueError("Model optimization is only supported in Sagemaker Endpoint Mode.") + # need to get telemetry_opt_out info before telemetry decorator is called self.serve_settings = self._get_serve_setting() @@ -1160,7 +1187,6 @@ def _model_builder_optimize_wrapper( self.build(mode=self.mode, sagemaker_session=self.sagemaker_session) job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}" - input_args = None if self._is_jumpstart_model_id(): input_args = self._optimize_for_jumpstart( output_path=output_path, @@ -1177,6 +1203,21 @@ def _model_builder_optimize_wrapper( kms_key=kms_key, max_runtime_in_sec=max_runtime_in_sec, ) + else: + input_args = self._optimize_for_hf( + output_path=output_path, + instance_type=instance_type, + role_arn=self.role_arn, + tags=tags, + job_name=job_name, + quantization_config=quantization_config, + compilation_config=compilation_config, + speculative_decoding_config=speculative_decoding_config, + env_vars=env_vars, + vpc_config=vpc_config, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + ) if input_args: self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args) @@ -1188,3 +1229,120 @@ def _model_builder_optimize_wrapper( self.pysdk_model.remove_tag_with_key(Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER) return self.pysdk_model + + def _optimize_for_hf( + self, + output_path: str, + instance_type: Optional[str] = None, + role_arn: Optional[str] = None, + tags: Optional[Tags] = None, + job_name: Optional[str] = None, + quantization_config: Optional[Dict] = None, + compilation_config: Optional[Dict] = None, + speculative_decoding_config: Optional[Dict] = None, + env_vars: Optional[Dict] = None, + vpc_config: Optional[Dict] = None, + kms_key: Optional[str] = None, + max_runtime_in_sec: Optional[int] = None, + ) -> Optional[Dict[str, Any]]: + """Runs a model optimization job. + + Args: + output_path (str): Specifies where to store the compiled/quantized model. + instance_type (Optional[str]): Target deployment instance type that + the model is optimized for. + role_arn (Optional[str]): Execution role. Defaults to ``None``. + tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. + job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. + quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``. + compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. + speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. + Defaults to ``None`` + env_vars (Optional[Dict]): Additional environment variables to run the optimization + container. Defaults to ``None``. + vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. + kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading + to S3. Defaults to ``None``. + max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to + ``None``. + + Returns: + Dict[str, Any]: Model optimization job input arguments. + """ + if self.model_server != ModelServer.DJL_SERVING: + logger.info("Overwriting model server to DJL.") + self.model_server = ModelServer.DJL_SERVING + + optimization_env_vars = env_vars + pysdk_model_env_vars = env_vars + + if quantization_config or compilation_config: + self.instance_type = instance_type or self.instance_type + + self._optimize_prepare_for_hf() + model_source = _generate_model_source(self.pysdk_model.model_data, False) + + self.pysdk_model = _custom_speculative_decoding( + self.pysdk_model, speculative_decoding_config, False + ) + + optimization_config, env = _extract_optimization_config_and_env( + quantization_config, compilation_config + ) + pysdk_model_env_vars = _update_environment_variables(pysdk_model_env_vars, env) + + output_config = {"S3OutputLocation": output_path} + if kms_key: + output_config["KmsKeyId"] = kms_key + + create_optimization_job_args = { + "OptimizationJobName": job_name, + "ModelSource": model_source, + "DeploymentInstanceType": self.instance_type, + "OptimizationConfigs": [optimization_config], + "OutputConfig": output_config, + "RoleArn": role_arn, + } + + if optimization_env_vars: + create_optimization_job_args["OptimizationEnvironment"] = optimization_env_vars + if max_runtime_in_sec: + create_optimization_job_args["StoppingCondition"] = { + "MaxRuntimeInSeconds": max_runtime_in_sec + } + if tags: + create_optimization_job_args["Tags"] = tags + if vpc_config: + create_optimization_job_args["VpcConfig"] = vpc_config + + if pysdk_model_env_vars: + self.pysdk_model.env.update(pysdk_model_env_vars) + + return create_optimization_job_args + return None + + def _optimize_prepare_for_hf(self): + """Prepare huggingface model data for optimization.""" + custom_model_path: str = ( + self.model_metadata.get("CUSTOM_MODEL_PATH") if self.model_metadata else None + ) + if _is_s3_uri(custom_model_path): + # Remove slash by the end of s3 uri, as it may lead to / subfolder during upload. + custom_model_path = ( + custom_model_path[:-1] if custom_model_path.endswith("/") else custom_model_path + ) + else: + if not custom_model_path: + custom_model_path = f"/tmp/sagemaker/model-builder/{self.model}/code" + download_huggingface_model_metadata( + self.model, + custom_model_path, + self.env_vars.get("HUGGING_FACE_HUB_TOKEN"), + ) + custom_model_path = _normalize_local_model_path(custom_model_path) + + self.pysdk_model.model_data, env = self._prepare_for_mode( + model_path=custom_model_path, + should_upload_artifacts=True, + ) + self.pysdk_model.env.update(env) diff --git a/src/sagemaker/serve/builder/tei_builder.py b/src/sagemaker/serve/builder/tei_builder.py index e251eb4f81..a1f4567eb9 100644 --- a/src/sagemaker/serve/builder/tei_builder.py +++ b/src/sagemaker/serve/builder/tei_builder.py @@ -25,6 +25,7 @@ _get_nb_instance, ) from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure +from sagemaker.serve.utils.optimize_utils import _is_optimized from sagemaker.serve.utils.predictors import TeiLocalModePredictor from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.mode.function_pointers import Mode @@ -162,7 +163,8 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa self.pysdk_model.role = kwargs.get("role") del kwargs["role"] - self._prepare_for_mode() + if not _is_optimized(self.pysdk_model): + self._prepare_for_mode() # if the weights have been cached via local container mode -> set to offline if str(Mode.LOCAL_CONTAINER) in self.modes: diff --git a/src/sagemaker/serve/builder/tgi_builder.py b/src/sagemaker/serve/builder/tgi_builder.py index e6cbe41c90..558a560a74 100644 --- a/src/sagemaker/serve/builder/tgi_builder.py +++ b/src/sagemaker/serve/builder/tgi_builder.py @@ -25,6 +25,7 @@ LocalModelInvocationException, SkipTuningComboException, ) +from sagemaker.serve.utils.optimize_utils import _is_optimized from sagemaker.serve.utils.tuning import ( _serial_benchmark, _concurrent_benchmark, @@ -201,7 +202,8 @@ def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa self.pysdk_model.role = kwargs.get("role") del kwargs["role"] - self._prepare_for_mode() + if not _is_optimized(self.pysdk_model): + self._prepare_for_mode() # if the weights have been cached via local container mode -> set to offline if str(Mode.LOCAL_CONTAINER) in self.modes: diff --git a/src/sagemaker/serve/builder/transformers_builder.py b/src/sagemaker/serve/builder/transformers_builder.py index dded7bd0bd..e618b54e44 100644 --- a/src/sagemaker/serve/builder/transformers_builder.py +++ b/src/sagemaker/serve/builder/transformers_builder.py @@ -27,6 +27,7 @@ from sagemaker.serve.model_server.multi_model_server.prepare import ( _create_dir_structure, ) +from sagemaker.serve.utils.optimize_utils import _is_optimized from sagemaker.serve.utils.predictors import TransformersLocalModePredictor from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.mode.function_pointers import Mode @@ -223,7 +224,8 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr self.pysdk_model.role = kwargs.get("role") del kwargs["role"] - self._prepare_for_mode() + if not _is_optimized(self.pysdk_model): + self._prepare_for_mode() if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True diff --git a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py index d0022ae74c..6f9bf8307f 100644 --- a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py +++ b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py @@ -59,7 +59,7 @@ def prepare( sagemaker_session: Session = None, image: str = None, jumpstart: bool = False, - should_upload: bool = False, + should_upload_artifacts: bool = False, ): """Placeholder docstring""" try: @@ -78,6 +78,7 @@ def prepare( secret_key=secret_key, s3_model_data_url=s3_model_data_url, image=image, + should_upload_artifacts=True, ) if self.model_server == ModelServer.TRITON: @@ -87,6 +88,7 @@ def prepare( secret_key=secret_key, s3_model_data_url=s3_model_data_url, image=image, + should_upload_artifacts=True, ) if self.model_server == ModelServer.DJL_SERVING: @@ -95,40 +97,50 @@ def prepare( sagemaker_session=sagemaker_session, s3_model_data_url=s3_model_data_url, image=image, + should_upload_artifacts=True, ) - if self.model_server == ModelServer.TGI and should_upload: - upload_artifacts = self._upload_tgi_artifacts( + if self.model_server == ModelServer.TENSORFLOW_SERVING: + upload_artifacts = self._upload_tensorflow_serving_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, + secret_key=secret_key, s3_model_data_url=s3_model_data_url, image=image, - jumpstart=jumpstart, + should_upload_artifacts=True, ) - if self.model_server == ModelServer.MMS and should_upload: - upload_artifacts = self._upload_server_artifacts( + # By default, we do not want to upload artifacts in S3 for the below server. + # In Case of Optimization, artifacts need to be uploaded into s3. + # In that case, `should_upload_artifacts` arg needs to come from + # the caller of prepare. + + if self.model_server == ModelServer.TGI: + upload_artifacts = self._upload_tgi_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, s3_model_data_url=s3_model_data_url, image=image, + jumpstart=jumpstart, + should_upload_artifacts=should_upload_artifacts, ) - if self.model_server == ModelServer.TENSORFLOW_SERVING: - upload_artifacts = self._upload_tensorflow_serving_artifacts( + if self.model_server == ModelServer.MMS: + upload_artifacts = self._upload_server_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, - secret_key=secret_key, s3_model_data_url=s3_model_data_url, image=image, + should_upload_artifacts=should_upload_artifacts, ) - if self.model_server == ModelServer.TEI and should_upload: + if self.model_server == ModelServer.TEI: upload_artifacts = self._tei_serving._upload_tei_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, s3_model_data_url=s3_model_data_url, image=image, + should_upload_artifacts=should_upload_artifacts, ) if upload_artifacts or isinstance(self.model_server, ModelServer): diff --git a/src/sagemaker/serve/model_server/djl_serving/server.py b/src/sagemaker/serve/model_server/djl_serving/server.py index 80214332b0..4ba7dd227d 100644 --- a/src/sagemaker/serve/model_server/djl_serving/server.py +++ b/src/sagemaker/serve/model_server/djl_serving/server.py @@ -12,6 +12,7 @@ from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join from sagemaker.s3 import S3Uploader from sagemaker.local.utils import get_docker_host +from sagemaker.serve.utils.optimize_utils import _is_s3_uri logger = logging.getLogger(__name__) MODE_DIR_BINDING = "/opt/ml/model/" @@ -91,39 +92,48 @@ def _upload_djl_artifacts( s3_model_data_url: str = None, image: str = None, env_vars: dict = None, + should_upload_artifacts: bool = False, ): """Placeholder docstring""" - if s3_model_data_url: - bucket, key_prefix = parse_s3_url(url=s3_model_data_url) - else: - bucket, key_prefix = None, None - - code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) - - bucket, code_key_prefix = determine_bucket_and_prefix( - bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session - ) + model_data_url = None + if _is_s3_uri(model_path): + model_data_url = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) - code_dir = Path(model_path).joinpath("code") + code_dir = Path(model_path).joinpath("code") - s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") + s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") - logger.debug("Uploading DJL Model Resources uncompressed to: %s", s3_location) + logger.debug("Uploading DJL Model Resources uncompressed to: %s", s3_location) - model_data_url = S3Uploader.upload( - str(code_dir), - s3_location, - None, - sagemaker_session, - ) + model_data_url = S3Uploader.upload( + str(code_dir), + s3_location, + None, + sagemaker_session, + ) - model_data = { - "S3DataSource": { - "CompressionType": "None", - "S3DataType": "S3Prefix", - "S3Uri": model_data_url + "/", + model_data = ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": model_data_url + "/", + } } - } + if model_data_url + else None + ) return (model_data, _update_env_vars(env_vars)) diff --git a/src/sagemaker/serve/model_server/multi_model_server/server.py b/src/sagemaker/serve/model_server/multi_model_server/server.py index b78e01f5c3..91d585b4cf 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/server.py +++ b/src/sagemaker/serve/model_server/multi_model_server/server.py @@ -11,6 +11,7 @@ from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join from sagemaker.s3 import S3Uploader from sagemaker.local.utils import get_docker_host +from sagemaker.serve.utils.optimize_utils import _is_s3_uri MODE_DIR_BINDING = "/opt/ml/model/" _DEFAULT_ENV_VARS = {} @@ -84,38 +85,48 @@ def _upload_server_artifacts( s3_model_data_url: str = None, image: str = None, env_vars: dict = None, + should_upload_artifacts: bool = False, ): - if s3_model_data_url: - bucket, key_prefix = parse_s3_url(url=s3_model_data_url) - else: - bucket, key_prefix = None, None - - code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + model_data_url = None + if _is_s3_uri(model_path): + model_data_url = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) - bucket, code_key_prefix = determine_bucket_and_prefix( - bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session - ) + code_dir = Path(model_path).joinpath("code") - code_dir = Path(model_path).joinpath("code") + s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") - s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") + logger.debug("Uploading Multi Model Server Resources uncompressed to: %s", s3_location) - logger.debug("Uploading Multi Model Server Resources uncompressed to: %s", s3_location) + model_data_url = S3Uploader.upload( + str(code_dir), + s3_location, + None, + sagemaker_session, + ) - model_data_url = S3Uploader.upload( - str(code_dir), - s3_location, - None, - sagemaker_session, + model_data = ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": model_data_url + "/", + } + } + if model_data_url + else None ) - model_data = { - "S3DataSource": { - "CompressionType": "None", - "S3DataType": "S3Prefix", - "S3Uri": model_data_url + "/", - } - } return model_data, _update_env_vars(env_vars) diff --git a/src/sagemaker/serve/model_server/tei/server.py b/src/sagemaker/serve/model_server/tei/server.py index 25c27e6dda..54abbea0da 100644 --- a/src/sagemaker/serve/model_server/tei/server.py +++ b/src/sagemaker/serve/model_server/tei/server.py @@ -12,7 +12,7 @@ from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join from sagemaker.s3 import S3Uploader from sagemaker.local.utils import get_docker_host - +from sagemaker.serve.utils.optimize_utils import _is_s3_uri MODE_DIR_BINDING = "/opt/ml/model/" _SHM_SIZE = "2G" @@ -107,6 +107,7 @@ def _upload_tei_artifacts( s3_model_data_url: str = None, image: str = None, env_vars: dict = None, + should_upload_artifacts: bool = False, ): """Uploads the model artifacts to S3. @@ -116,38 +117,48 @@ def _upload_tei_artifacts( s3_model_data_url: S3 model data URL image: Image to use env_vars: Environment variables to set + model_data_s3_path: S3 path to model data + should_upload_artifacts: Whether to upload artifacts """ - if s3_model_data_url: - bucket, key_prefix = parse_s3_url(url=s3_model_data_url) - else: - bucket, key_prefix = None, None - - code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) - - bucket, code_key_prefix = determine_bucket_and_prefix( - bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session - ) + model_data_url = None + if _is_s3_uri(model_path): + model_data_url = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) - code_dir = Path(model_path).joinpath("code") + code_dir = Path(model_path).joinpath("code") - s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") + s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") - logger.debug("Uploading TEI Model Resources uncompressed to: %s", s3_location) + logger.debug("Uploading TEI Model Resources uncompressed to: %s", s3_location) - model_data_url = S3Uploader.upload( - str(code_dir), - s3_location, - None, - sagemaker_session, - ) + model_data_url = S3Uploader.upload( + str(code_dir), + s3_location, + None, + sagemaker_session, + ) - model_data = { - "S3DataSource": { - "CompressionType": "None", - "S3DataType": "S3Prefix", - "S3Uri": model_data_url + "/", + model_data = ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": model_data_url + "/", + } } - } + if model_data_url + else None + ) return (model_data, _update_env_vars(env_vars)) diff --git a/src/sagemaker/serve/model_server/tensorflow_serving/server.py b/src/sagemaker/serve/model_server/tensorflow_serving/server.py index 2392287c61..45931e9afc 100644 --- a/src/sagemaker/serve/model_server/tensorflow_serving/server.py +++ b/src/sagemaker/serve/model_server/tensorflow_serving/server.py @@ -7,6 +7,7 @@ import platform from pathlib import Path from sagemaker.base_predictor import PredictorBase +from sagemaker.serve.utils.optimize_utils import _is_s3_uri from sagemaker.session import Session from sagemaker.serve.utils.exceptions import LocalModelInvocationException from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url @@ -101,6 +102,7 @@ def _upload_tensorflow_serving_artifacts( secret_key: str, s3_model_data_url: str = None, image: str = None, + should_upload_artifacts: bool = False, ): """Uploads the model artifacts to S3. @@ -110,23 +112,30 @@ def _upload_tensorflow_serving_artifacts( secret_key: Secret key to use for authentication s3_model_data_url: S3 model data URL image: Image to use + model_data_s3_path: S3 model data URI """ - if s3_model_data_url: - bucket, key_prefix = parse_s3_url(url=s3_model_data_url) - else: - bucket, key_prefix = None, None - - code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) - - bucket, code_key_prefix = determine_bucket_and_prefix( - bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session - ) + s3_upload_path = None + if _is_s3_uri(model_path): + s3_upload_path = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) - logger.debug( - "Uploading the model resources to bucket=%s, key_prefix=%s.", bucket, code_key_prefix - ) - s3_upload_path = upload(sagemaker_session, model_path, bucket, code_key_prefix) - logger.debug("Model resources uploaded to: %s", s3_upload_path) + logger.debug( + "Uploading the model resources to bucket=%s, key_prefix=%s.", + bucket, + code_key_prefix, + ) + s3_upload_path = upload(sagemaker_session, model_path, bucket, code_key_prefix) + logger.debug("Model resources uploaded to: %s", s3_upload_path) env_vars = { "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", diff --git a/src/sagemaker/serve/model_server/tgi/server.py b/src/sagemaker/serve/model_server/tgi/server.py index 75cf3bd402..4d9686a89c 100644 --- a/src/sagemaker/serve/model_server/tgi/server.py +++ b/src/sagemaker/serve/model_server/tgi/server.py @@ -12,6 +12,7 @@ from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join from sagemaker.s3 import S3Uploader from sagemaker.local.utils import get_docker_host +from sagemaker.serve.utils.optimize_utils import _is_s3_uri MODE_DIR_BINDING = "/opt/ml/model/" _SHM_SIZE = "2G" @@ -111,38 +112,47 @@ def _upload_tgi_artifacts( s3_model_data_url: str = None, image: str = None, env_vars: dict = None, + should_upload_artifacts: bool = False, ): - if s3_model_data_url: - bucket, key_prefix = parse_s3_url(url=s3_model_data_url) - else: - bucket, key_prefix = None, None - - code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) - - bucket, code_key_prefix = determine_bucket_and_prefix( - bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session - ) + model_data_url = None + if _is_s3_uri(model_path): + model_data_url = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) - code_dir = Path(model_path).joinpath("code") + code_dir = Path(model_path).joinpath("code") - s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") + s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") - logger.debug("Uploading TGI Model Resources uncompressed to: %s", s3_location) + logger.debug("Uploading TGI Model Resources uncompressed to: %s", s3_location) - model_data_url = S3Uploader.upload( - str(code_dir), - s3_location, - None, - sagemaker_session, - ) + model_data_url = S3Uploader.upload( + str(code_dir), + s3_location, + None, + sagemaker_session, + ) - model_data = { - "S3DataSource": { - "CompressionType": "None", - "S3DataType": "S3Prefix", - "S3Uri": model_data_url + "/", + model_data = ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": model_data_url + "/", + } } - } + if model_data_url + else None + ) if jumpstart: return (model_data, {}) return (model_data, _update_env_vars(env_vars)) diff --git a/src/sagemaker/serve/model_server/torchserve/server.py b/src/sagemaker/serve/model_server/torchserve/server.py index 5aef136355..74e37cd70b 100644 --- a/src/sagemaker/serve/model_server/torchserve/server.py +++ b/src/sagemaker/serve/model_server/torchserve/server.py @@ -7,6 +7,7 @@ import platform from pathlib import Path from sagemaker.base_predictor import PredictorBase +from sagemaker.serve.utils.optimize_utils import _is_s3_uri from sagemaker.session import Session from sagemaker.serve.utils.exceptions import LocalModelInvocationException from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url @@ -84,24 +85,31 @@ def _upload_torchserve_artifacts( secret_key: str, s3_model_data_url: str = None, image: str = None, + should_upload_artifacts: bool = False, ): """Tar the model artifact and upload to S3 bucket, then prepare for the environment variables""" - if s3_model_data_url: - bucket, key_prefix = parse_s3_url(url=s3_model_data_url) - else: - bucket, key_prefix = None, None - - code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) - - bucket, code_key_prefix = determine_bucket_and_prefix( - bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session - ) + s3_upload_path = None + if _is_s3_uri(model_path): + s3_upload_path = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) - logger.debug( - "Uploading the model resources to bucket=%s, key_prefix=%s.", bucket, code_key_prefix - ) - s3_upload_path = upload(sagemaker_session, model_path, bucket, code_key_prefix) - logger.debug("Model resources uploaded to: %s", s3_upload_path) + logger.debug( + "Uploading the model resources to bucket=%s, key_prefix=%s.", + bucket, + code_key_prefix, + ) + s3_upload_path = upload(sagemaker_session, model_path, bucket, code_key_prefix) + logger.debug("Model resources uploaded to: %s", s3_upload_path) env_vars = { "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", diff --git a/src/sagemaker/serve/model_server/triton/server.py b/src/sagemaker/serve/model_server/triton/server.py index 62dfb4759a..e2f3c20d7a 100644 --- a/src/sagemaker/serve/model_server/triton/server.py +++ b/src/sagemaker/serve/model_server/triton/server.py @@ -9,6 +9,7 @@ from sagemaker import fw_utils from sagemaker import Session from sagemaker.base_predictor import PredictorBase +from sagemaker.serve.utils.optimize_utils import _is_s3_uri from sagemaker.serve.utils.uploader import upload from sagemaker.serve.utils.exceptions import LocalModelInvocationException from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url @@ -115,25 +116,32 @@ def _upload_triton_artifacts( secret_key: str, s3_model_data_url: str = None, image: str = None, + should_upload_artifacts: bool = False, ): """Tar triton artifacts and upload to s3""" - if s3_model_data_url: - bucket, key_prefix = parse_s3_url(url=s3_model_data_url) - else: - bucket, key_prefix = None, None - - code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) - - bucket, code_key_prefix = determine_bucket_and_prefix( - bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session - ) + s3_upload_path = None + if _is_s3_uri(model_path): + s3_upload_path = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) - logger.debug( - "Uploading the model resources to bucket=%s, key_prefix=%s.", bucket, code_key_prefix - ) - model_repository = model_path + "/model_repository" - s3_upload_path = upload(sagemaker_session, model_repository, bucket, code_key_prefix) - logger.debug("Model resources uploaded to: %s", s3_upload_path) + logger.debug( + "Uploading the model resources to bucket=%s, key_prefix=%s.", + bucket, + code_key_prefix, + ) + model_repository = model_path + "/model_repository" + s3_upload_path = upload(sagemaker_session, model_repository, bucket, code_key_prefix) + logger.debug("Model resources uploaded to: %s", s3_upload_path) env_vars = { "SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "model", diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index e4313a4321..853974aae0 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -15,15 +15,17 @@ import re import logging -from typing import Dict, Any, Optional, Union, List +from typing import Dict, Any, Optional, Union, List, Tuple from sagemaker import Model from sagemaker.enums import Tag - logger = logging.getLogger(__name__) +SPECULATIVE_DRAFT_MODEL = "/opt/ml/additional-model-data-sources" + + def _is_image_compatible_with_optimization_job(image_uri: Optional[str]) -> bool: """Checks whether an instance is compatible with an optimization job. @@ -74,6 +76,25 @@ def _generate_optimized_model(pysdk_model: Model, optimization_response: dict) - return pysdk_model +def _is_optimized(pysdk_model: Model) -> bool: + """Checks whether an optimization model is optimized. + + Args: + pysdk_model (Model): A PySDK model. + + Return: + bool: Whether the given model type is optimized. + """ + optimized_tags = [Tag.OPTIMIZATION_JOB_NAME, Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER] + if hasattr(pysdk_model, "_tags") and pysdk_model._tags: + if isinstance(pysdk_model._tags, dict): + return pysdk_model._tags.get("Key") in optimized_tags + for tag in pysdk_model._tags: + if tag.get("Key") in optimized_tags: + return True + return False + + def _generate_model_source( model_data: Optional[Union[Dict[str, Any], str]], accept_eula: Optional[bool] ) -> Optional[Dict[str, Any]]: @@ -224,3 +245,83 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool: return False return re.match("^s3://([^/]+)/?(.*)$", s3_uri) is not None + + +def _extract_optimization_config_and_env( + quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None +) -> Optional[Tuple[Optional[Dict], Optional[Dict]]]: + """Extracts optimization config and environment variables. + + Args: + quantization_config (Optional[Dict]): The quantization config. + compilation_config (Optional[Dict]): The compilation config. + + Returns: + Optional[Tuple[Optional[Dict], Optional[Dict]]]: + The optimization config and environment variables. + """ + if quantization_config: + return {"ModelQuantizationConfig": quantization_config}, quantization_config.get( + "OverrideEnvironment" + ) + if compilation_config: + return {"ModelCompilationConfig": compilation_config}, compilation_config.get( + "OverrideEnvironment" + ) + return None, None + + +def _normalize_local_model_path(local_model_path: Optional[str]) -> Optional[str]: + """Normalizes the local model path. + + Args: + local_model_path (Optional[str]): The local model path. + + Returns: + Optional[str]: The normalized model path. + """ + if local_model_path is None: + return local_model_path + + # Removes /code or /code/ path at the end of local_model_path, + # as it is appended during artifacts upload. + pattern = r"/code/?$" + if re.search(pattern, local_model_path): + return re.sub(pattern, "", local_model_path) + return local_model_path + + +def _custom_speculative_decoding( + model: Model, + speculative_decoding_config: Optional[Dict], + accept_eula: Optional[bool] = False, +) -> Model: + """Modifies the given model for speculative decoding config with custom provider. + + Args: + model (Model): The model. + speculative_decoding_config (Optional[Dict]): The speculative decoding config. + accept_eula (Optional[bool]): Whether to accept eula or not. + """ + + if speculative_decoding_config: + additional_model_source = _extracts_and_validates_speculative_model_source( + speculative_decoding_config + ) + + if _is_s3_uri(additional_model_source): + channel_name = _generate_channel_name(model.additional_model_data_sources) + speculative_draft_model = f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}" + + model.additional_model_data_sources = _generate_additional_model_data_sources( + additional_model_source, channel_name, accept_eula + ) + else: + speculative_draft_model = additional_model_source + + model.env.update({"OPTION_SPECULATIVE_DRAFT_MODEL": speculative_draft_model}) + model.add_tags( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "customer"}, + ) + + return model diff --git a/tests/unit/sagemaker/huggingface/test_llm_utils.py b/tests/unit/sagemaker/huggingface/test_llm_utils.py index 3c4cdde3f6..675a6fd885 100644 --- a/tests/unit/sagemaker/huggingface/test_llm_utils.py +++ b/tests/unit/sagemaker/huggingface/test_llm_utils.py @@ -15,7 +15,10 @@ from unittest import TestCase from urllib.error import HTTPError from unittest.mock import Mock, patch -from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata +from sagemaker.huggingface.llm_utils import ( + get_huggingface_model_metadata, + download_huggingface_model_metadata, +) MOCK_HF_ID = "mock_hf_id" MOCK_HF_HUB_TOKEN = "mock_hf_hub_token" @@ -74,3 +77,25 @@ def test_huggingface_model_metadata_general_exception(self, mock_urllib): f"Did not find model metadata for the following HuggingFace Model ID {MOCK_HF_ID}" ) self.assertEquals(expected_error_msg, str(context.exception)) + + @patch("huggingface_hub.snapshot_download") + def test_download_huggingface_model_metadata(self, mock_snapshot_download): + mock_snapshot_download.side_effect = None + + download_huggingface_model_metadata(MOCK_HF_ID, "local_path", MOCK_HF_HUB_TOKEN) + + mock_snapshot_download.assert_called_once_with( + repo_id=MOCK_HF_ID, local_dir="local_path", token=MOCK_HF_HUB_TOKEN + ) + + @patch("importlib.util.find_spec") + def test_download_huggingface_model_metadata_ex(self, mock_find_spec): + mock_find_spec.side_effect = lambda *args, **kwargs: False + + self.assertRaisesRegex( + ImportError, + "Unable to import huggingface_hub, check if huggingface_hub is installed", + lambda: download_huggingface_model_metadata( + MOCK_HF_ID, "local_path", MOCK_HF_HUB_TOKEN + ), + ) diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 19a06dd5bb..5aea0fc6d4 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -245,6 +245,10 @@ def test_model_server_override_transformers_with_model( mock_build_for_ts.assert_called_once() @patch("os.makedirs", Mock()) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=False, + ) @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") @patch("sagemaker.serve.builder.model_builder.save_pkl") @@ -263,6 +267,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc( mock_save_pkl, mock_prepare_for_torchserve, mock_detect_fw_version, + mock_is_jumpstart_model_id, ): # setup mocks mock_detect_container.side_effect = lambda model, region, instance_type: ( @@ -349,6 +354,10 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc( self.assertEqual(build_result.serve_settings, mock_setting_object) @patch("os.makedirs", Mock()) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=False, + ) @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") @patch("sagemaker.serve.builder.model_builder.save_pkl") @@ -367,6 +376,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc( mock_save_pkl, mock_prepare_for_torchserve, mock_detect_fw_version, + mock_is_jumpstart_model_id, ): # setup mocks mock_detect_container.side_effect = lambda model, region, instance_type: ( @@ -551,6 +561,10 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec( self.assertEqual(build_result.serve_settings, mock_setting_object) @patch("os.makedirs", Mock()) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=False, + ) @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") @patch("sagemaker.serve.builder.model_builder.save_pkl") @@ -569,6 +583,7 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model( mock_save_pkl, mock_prepare_for_torchserve, mock_detect_fw_version, + mock_is_jumpstart_model_id, ): # setup mocks mock_detect_container.side_effect = lambda model, region, instance_type: ( @@ -653,6 +668,10 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model( self.assertEqual("sample agent ModelBuilder", user_agent) @patch("os.makedirs", Mock()) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=False, + ) @patch("sagemaker.serve.builder.model_builder.save_xgboost") @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") @@ -673,6 +692,7 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model( mock_prepare_for_torchserve, mock_detect_fw_version, mock_save_xgb, + mock_is_jumpstart_model_id, ): # setup mocks mock_detect_container.side_effect = lambda model, region, instance_type: ( @@ -1001,6 +1021,10 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo ) @patch("os.makedirs", Mock()) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=False, + ) @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") @patch("sagemaker.serve.builder.model_builder.save_pkl") @@ -1021,6 +1045,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co mock_save_pkl, mock_prepare_for_torchserve, mock_detect_fw_version, + mock_is_jumpstart_model_id, ): # setup mocks mock_detect_fw_version.return_value = framework, version @@ -2221,6 +2246,10 @@ def test_build_mlflow_model_s3_input_tensorflow_serving_local_mode_happy( assert isinstance(predictor, TensorflowServingLocalPredictor) @patch("os.makedirs", Mock()) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=False, + ) @patch("sagemaker.serve.builder.tf_serving_builder.prepare_for_tf_serving") @patch("sagemaker.serve.builder.model_builder.S3Downloader.list") @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @@ -2249,6 +2278,7 @@ def test_build_tensorflow_serving_non_mlflow_case( mock_detect_fw_version, mock_s3_downloader, mock_prepare_for_tf_serving, + mock_is_jumpstart_model_id, ): mock_s3_downloader.return_value = [] mock_detect_container.return_value = mock_image_uri @@ -2298,6 +2328,8 @@ def test_build_tensorflow_serving_non_mlflow_case( mock_session, ) + # builder.build(sagemaker_session=mock_session, role_arn=mock_role_arn, mode=Mode.SAGEMAKER_ENDPOINT) + @pytest.mark.skip(reason="Implementation not completed") @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") diff --git a/tests/unit/sagemaker/serve/model_server/tei/test_server.py b/tests/unit/sagemaker/serve/model_server/tei/test_server.py index 2344a61fbc..cc1226702f 100644 --- a/tests/unit/sagemaker/serve/model_server/tei/test_server.py +++ b/tests/unit/sagemaker/serve/model_server/tei/test_server.py @@ -135,6 +135,7 @@ def test_upload_artifacts_sagemaker_tei_server(self, mock_uploader): sagemaker_session=mock_session, s3_model_data_url=S3_URI, image=TEI_IMAGE, + should_upload_artifacts=True, ) mock_uploader.upload.assert_called_once() diff --git a/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_server.py b/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_server.py index 3d3bac0935..b9cce13dbb 100644 --- a/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_server.py +++ b/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_server.py @@ -92,6 +92,7 @@ def test_upload_artifacts_sagemaker_triton_server(self, mock_upload, mock_platfo secret_key=SECRET_KEY, s3_model_data_url=S3_URI, image=CPU_TF_IMAGE, + should_upload_artifacts=True, ) mock_upload.assert_called_once_with(mock_session, MODEL_PATH, "mock_model_data_uri", ANY) diff --git a/tests/unit/sagemaker/serve/model_server/triton/test_server.py b/tests/unit/sagemaker/serve/model_server/triton/test_server.py index c80c4296e7..3f571424ed 100644 --- a/tests/unit/sagemaker/serve/model_server/triton/test_server.py +++ b/tests/unit/sagemaker/serve/model_server/triton/test_server.py @@ -172,6 +172,7 @@ def test_upload_artifacts_sagemaker_triton_server(self, mock_upload, mock_platfo secret_key=SECRET_KEY, s3_model_data_url=S3_URI, image=GPU_TRITON_IMAGE, + should_upload_artifacts=True, ) mock_upload.assert_called_once_with(mock_session, MODEL_REPO, "mock_model_data_uri", ANY) diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py index f0e18186b7..80a3217c1d 100644 --- a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -12,7 +12,8 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -from unittest.mock import Mock +import unittest +from unittest.mock import Mock, patch import pytest @@ -26,6 +27,10 @@ _is_s3_uri, _generate_additional_model_data_sources, _generate_channel_name, + _extract_optimization_config_and_env, + _normalize_local_model_path, + _is_optimized, + _custom_speculative_decoding, ) mock_optimization_job_output = { @@ -136,6 +141,19 @@ def test_generate_optimized_model(): ) +def test_is_optimized(): + model = Mock() + + model._tags = {"Key": Tag.OPTIMIZATION_JOB_NAME} + assert _is_optimized(model) is True + + model._tags = [{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER}] + assert _is_optimized(model) is True + + model._tags = [{"Key": Tag.FINE_TUNING_MODEL_PATH}] + assert _is_optimized(model) is False + + @pytest.mark.parametrize( "env, new_env, output_env", [ @@ -233,3 +251,145 @@ def test_generate_additional_model_data_sources(): ) def test_is_s3_uri(s3_uri, expected): assert _is_s3_uri(s3_uri) == expected + + +@pytest.mark.parametrize( + "quantization_config, compilation_config, expected_config, expected_env", + [ + ( + None, + { + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + } + }, + { + "ModelCompilationConfig": { + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + } + }, + }, + { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + }, + ), + ( + { + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + } + }, + None, + { + "ModelQuantizationConfig": { + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + } + }, + }, + { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + }, + ), + (None, None, None, None), + ], +) +def test_extract_optimization_config_and_env( + quantization_config, compilation_config, expected_config, expected_env +): + assert _extract_optimization_config_and_env(quantization_config, compilation_config) == ( + expected_config, + expected_env, + ) + + +@pytest.mark.parametrize( + "my_path, expected_path", + [ + ("local/path/llama/code", "local/path/llama"), + ("local/path/llama/code/", "local/path/llama"), + ("local/path/llama/", "local/path/llama/"), + ("local/path/llama", "local/path/llama"), + ], +) +def test_normalize_local_model_path(my_path, expected_path): + assert _normalize_local_model_path(my_path) == expected_path + + +class TestCustomSpeculativeDecodingConfig(unittest.TestCase): + + @patch("sagemaker.model.Model") + def test_with_s3_hf(self, mock_model): + mock_model.env = {} + mock_model.additional_model_data_sources = None + speculative_decoding_config = { + "ModelSource": "s3://bucket/djl-inference-2024-07-02-00-03-32-127/code" + } + + res_model = _custom_speculative_decoding(mock_model, speculative_decoding_config) + + mock_model.add_tags.assert_called_once_with( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "customer"} + ) + + self.assertEqual( + res_model.env, + {"OPTION_SPECULATIVE_DRAFT_MODEL": "/opt/ml/additional-model-data-sources/draft_model"}, + ) + self.assertEqual( + res_model.additional_model_data_sources, + [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "S3Uri": "s3://bucket/djl-inference-2024-07-02-00-03-32-127/code", + "S3DataType": "S3Prefix", + "CompressionType": "None", + }, + } + ], + ) + + @patch("sagemaker.model.Model") + def test_with_s3_js(self, mock_model): + mock_model.env = {} + mock_model.additional_model_data_sources = None + speculative_decoding_config = { + "ModelSource": "s3://bucket/huggingface-pytorch-tgi-inference" + } + + res_model = _custom_speculative_decoding(mock_model, speculative_decoding_config, True) + + self.assertEqual( + res_model.additional_model_data_sources, + [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "S3Uri": "s3://bucket/huggingface-pytorch-tgi-inference", + "S3DataType": "S3Prefix", + "CompressionType": "None", + "ModelAccessConfig": {"ACCEPT_EULA": True}, + }, + } + ], + ) + + @patch("sagemaker.model.Model") + def test_with_non_s3(self, mock_model): + mock_model.env = {} + mock_model.additional_model_data_sources = None + speculative_decoding_config = {"ModelSource": "huggingface-pytorch-tgi-inference"} + + res_model = _custom_speculative_decoding(mock_model, speculative_decoding_config, False) + + self.assertIsNone(res_model.additional_model_data_sources) + self.assertEqual( + res_model.env, + {"OPTION_SPECULATIVE_DRAFT_MODEL": "huggingface-pytorch-tgi-inference"}, + ) + + mock_model.add_tags.assert_called_once_with( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "customer"} + ) From 6687c56d4c8df5bcb055863fb28fa32dd4eb7ecd Mon Sep 17 00:00:00 2001 From: Jonathan Makunga <54963715+makungaj1@users.noreply.github.com> Date: Fri, 5 Jul 2024 09:36:44 -0700 Subject: [PATCH 31/45] Fixing bugs (#1506) * Fixing bugs * Refactoring * Increase coverage * Fix UT * Fix UT * Increase coverage * Fix UT * Refactoring * Fix UT --------- Co-authored-by: Jonathan Makunga --- .../serve/builder/jumpstart_builder.py | 44 +-- src/sagemaker/serve/builder/model_builder.py | 47 +-- src/sagemaker/serve/utils/optimize_utils.py | 6 - .../serve/builder/test_djl_builder.py | 4 + .../serve/builder/test_js_builder.py | 305 ++++++++++++++++++ .../serve/builder/test_model_builder.py | 152 +++++++++ tests/unit/sagemaker/serve/constants.py | 8 +- .../serve/utils/test_optimize_utils.py | 5 - 8 files changed, 515 insertions(+), 56 deletions(-) diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index ccfe795004..88328472a5 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -714,12 +714,9 @@ def _optimize_for_jumpstart( f"Model '{self.model}' requires accepting end-user license agreement (EULA)." ) - optimization_env_vars = env_vars - pysdk_model_env_vars = env_vars - + pysdk_model_env_vars = dict() if compilation_config: - neuron_env = self._get_neuron_model_env_vars(instance_type) - optimization_env_vars = _update_environment_variables(neuron_env, optimization_env_vars) + pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type) if speculative_decoding_config: self._set_additional_model_source(speculative_decoding_config) @@ -730,28 +727,34 @@ def _optimize_for_jumpstart( config_name=deployment_config.get("DeploymentConfigName"), instance_type=deployment_config.get("InstanceType"), ) + pysdk_model_env_vars = self.pysdk_model.env model_source = _generate_model_source(self.pysdk_model.model_data, accept_eula) - optimization_config, env = _extract_optimization_config_and_env( + optimization_env_vars = _update_environment_variables(pysdk_model_env_vars, env_vars) + + optimization_config, override_env = _extract_optimization_config_and_env( quantization_config, compilation_config ) - pysdk_model_env_vars = _update_environment_variables(pysdk_model_env_vars, env) output_config = {"S3OutputLocation": output_path} if kms_key: output_config["KmsKeyId"] = kms_key - if not instance_type: - instance_type = self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get( - "InstanceType", _get_nb_instance() - ) + + deployment_config_instance_type = ( + self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get("InstanceType") + if self.pysdk_model.deployment_config + else None + ) + self.instance_type = instance_type or deployment_config_instance_type or _get_nb_instance() + self.role_arn = role_arn or self.role_arn create_optimization_job_args = { "OptimizationJobName": job_name, "ModelSource": model_source, - "DeploymentInstanceType": instance_type, + "DeploymentInstanceType": self.instance_type, "OptimizationConfigs": [optimization_config], "OutputConfig": output_config, - "RoleArn": role_arn, + "RoleArn": self.role_arn, } if optimization_env_vars: @@ -765,8 +768,6 @@ def _optimize_for_jumpstart( if vpc_config: create_optimization_job_args["VpcConfig"] = vpc_config - if pysdk_model_env_vars: - self.pysdk_model.env.update(pysdk_model_env_vars) if accept_eula: self.pysdk_model.accept_eula = accept_eula if isinstance(self.pysdk_model.model_data, dict): @@ -775,6 +776,9 @@ def _optimize_for_jumpstart( } if quantization_config or compilation_config: + self.pysdk_model.env = _update_environment_variables( + optimization_env_vars, override_env + ) return create_optimization_job_args return None @@ -810,9 +814,13 @@ def _set_additional_model_source( channel_name = _generate_channel_name(self.pysdk_model.additional_model_data_sources) if model_provider == "sagemaker": - additional_model_data_sources = self.pysdk_model.deployment_config.get( - "DeploymentArgs", {} - ).get("AdditionalDataSources") + additional_model_data_sources = ( + self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get( + "AdditionalDataSources" + ) + if self.pysdk_model.deployment_config + else None + ) if additional_model_data_sources is None: deployment_config = self._find_compatible_deployment_config( speculative_decoding_config diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index aa56c89182..9293362e65 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -71,7 +71,6 @@ from sagemaker.serve.utils.optimize_utils import ( _generate_optimized_model, _generate_model_source, - _update_environment_variables, _extract_optimization_config_and_env, _is_s3_uri, _normalize_local_model_path, @@ -840,8 +839,7 @@ def build( # pylint: disable=R0911 if role_arn: self.role_arn = role_arn - if not self.sagemaker_session: - self.sagemaker_session = sagemaker_session or Session() + self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() self.sagemaker_session.settings._local_download_dir = self.model_path @@ -1111,8 +1109,6 @@ def optimize(self, *args, **kwargs) -> Model: Returns: Model: A deployable ``Model`` object. """ - if self.mode != Mode.SAGEMAKER_ENDPOINT: - raise ValueError("Model optimization is only supported in Sagemaker Endpoint Mode.") # need to get telemetry_opt_out info before telemetry decorator is called self.serve_settings = self._get_serve_setting() @@ -1174,6 +1170,9 @@ def _model_builder_optimize_wrapper( speculative_decoding_config ) + if self.mode != Mode.SAGEMAKER_ENDPOINT: + raise ValueError("Model optimization is only supported in Sagemaker Endpoint Mode.") + if quantization_config and compilation_config: raise ValueError("Quantization config and compilation config are mutually exclusive.") @@ -1273,39 +1272,39 @@ def _optimize_for_hf( logger.info("Overwriting model server to DJL.") self.model_server = ModelServer.DJL_SERVING - optimization_env_vars = env_vars - pysdk_model_env_vars = env_vars + self.role_arn = role_arn or self.role_arn + self.instance_type = instance_type or self.instance_type if quantization_config or compilation_config: - self.instance_type = instance_type or self.instance_type + create_optimization_job_args = { + "OptimizationJobName": job_name, + "DeploymentInstanceType": self.instance_type, + "RoleArn": self.role_arn, + } + + if env_vars: + self.pysdk_model.env.update(env_vars) + create_optimization_job_args["OptimizationEnvironment"] = env_vars self._optimize_prepare_for_hf() model_source = _generate_model_source(self.pysdk_model.model_data, False) + create_optimization_job_args["ModelSource"] = model_source self.pysdk_model = _custom_speculative_decoding( self.pysdk_model, speculative_decoding_config, False ) - optimization_config, env = _extract_optimization_config_and_env( + optimization_config, override_env = _extract_optimization_config_and_env( quantization_config, compilation_config ) - pysdk_model_env_vars = _update_environment_variables(pysdk_model_env_vars, env) + create_optimization_job_args["OptimizationConfigs"] = [optimization_config] + self.pysdk_model.env.update(override_env) output_config = {"S3OutputLocation": output_path} if kms_key: output_config["KmsKeyId"] = kms_key + create_optimization_job_args["OutputConfig"] = output_config - create_optimization_job_args = { - "OptimizationJobName": job_name, - "ModelSource": model_source, - "DeploymentInstanceType": self.instance_type, - "OptimizationConfigs": [optimization_config], - "OutputConfig": output_config, - "RoleArn": role_arn, - } - - if optimization_env_vars: - create_optimization_job_args["OptimizationEnvironment"] = optimization_env_vars if max_runtime_in_sec: create_optimization_job_args["StoppingCondition"] = { "MaxRuntimeInSeconds": max_runtime_in_sec @@ -1315,8 +1314,10 @@ def _optimize_for_hf( if vpc_config: create_optimization_job_args["VpcConfig"] = vpc_config - if pysdk_model_env_vars: - self.pysdk_model.env.update(pysdk_model_env_vars) + # HF_MODEL_ID needs not to be present, otherwise, + # HF model artifacts will be re-downloaded during deployment + if "HF_MODEL_ID" in self.pysdk_model.env: + del self.pysdk_model.env["HF_MODEL_ID"] return create_optimization_job_args return None diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index 853974aae0..be4c7fd993 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -54,17 +54,11 @@ def _generate_optimized_model(pysdk_model: Model, optimization_response: dict) - recommended_image_uri = optimization_response.get("OptimizationOutput", {}).get( "RecommendedInferenceImage" ) - optimized_environment = optimization_response.get("OptimizationEnvironment") s3_uri = optimization_response.get("OutputConfig", {}).get("S3OutputLocation") deployment_instance_type = optimization_response.get("DeploymentInstanceType") if recommended_image_uri: pysdk_model.image_uri = recommended_image_uri - if optimized_environment: - if pysdk_model.env: - pysdk_model.env.update(optimized_environment) - else: - pysdk_model.env = optimized_environment if s3_uri: pysdk_model.model_data["S3DataSource"]["S3Uri"] = s3_uri if deployment_instance_type: diff --git a/tests/unit/sagemaker/serve/builder/test_djl_builder.py b/tests/unit/sagemaker/serve/builder/test_djl_builder.py index 7b0c67f326..474403498c 100644 --- a/tests/unit/sagemaker/serve/builder/test_djl_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_djl_builder.py @@ -188,6 +188,7 @@ def test_tune_for_djl_local_container_deep_ping_ex( tuned_model = model.tune() assert tuned_model.env == mock_default_configs + @patch("sagemaker.serve.builder.djl_builder._get_model_config_properties_from_hf") @patch("sagemaker.serve.builder.djl_builder._capture_telemetry", side_effect=None) @patch( "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", @@ -211,7 +212,10 @@ def test_tune_for_djl_local_container_load_ex( mock_get_ram_usage_mb, mock_is_jumpstart_model, mock_telemetry, + mock_get_model_config_properties_from_hf, ): + mock_get_model_config_properties_from_hf.return_value = {} + builder = ModelBuilder( model=mock_model_id, schema_builder=mock_schema_builder, diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index 6c2e03b683..248955c273 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -87,6 +87,74 @@ "/artifacts/inference-prepack/v1.0.0/" ) +mock_optimization_job_response = { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:312206380606:optimization-job" + "/modelbuilderjob-c9b28846f963497ca540010b2aa2ec8d", + "OptimizationJobStatus": "COMPLETED", + "OptimizationStartTime": "", + "OptimizationEndTime": "", + "CreationTime": "", + "LastModifiedTime": "", + "OptimizationJobName": "modelbuilderjob-c9b28846f963497ca540010b2aa2ec8d", + "ModelSource": { + "S3": { + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/" + "meta-textgeneration-llama-3-8b-instruct/artifacts/inference-prepack/v1.1.0/" + } + }, + "OptimizationEnvironment": { + "ENDPOINT_SERVER_TIMEOUT": "3600", + "HF_MODEL_ID": "/opt/ml/model", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "OPTION_DTYPE": "fp16", + "OPTION_MAX_ROLLING_BATCH_SIZE": "4", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + "OPTION_N_POSITIONS": "2048", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + "SAGEMAKER_ENV": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "SAGEMAKER_PROGRAM": "inference.py", + }, + "DeploymentInstanceType": "ml.inf2.48xlarge", + "OptimizationConfigs": [ + { + "ModelCompilationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-neuronx-sdk2.18.2", + "OverrideEnvironment": { + "OPTION_DTYPE": "fp16", + "OPTION_MAX_ROLLING_BATCH_SIZE": "4", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + "OPTION_N_POSITIONS": "2048", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + }, + } + } + ], + "OutputConfig": { + "S3OutputLocation": "s3://dont-delete-ss-jarvis-integ-test-312206380606-us-west-2/" + "code/a75a061aba764f2aa014042bcdc1464b/" + }, + "OptimizationOutput": { + "RecommendedInferenceImage": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "djl-inference:0.28.0-neuronx-sdk2.18.2" + }, + "RoleArn": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20230707T131628", + "StoppingCondition": {"MaxRuntimeInSeconds": 36000}, + "ResponseMetadata": { + "RequestId": "704c7bcd-41e2-4d73-8039-262ff6a3f38b", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": "704c7bcd-41e2-4d73-8039-262ff6a3f38b", + "content-type": "application/x-amz-json-1.1", + "content-length": "1787", + "date": "Thu, 04 Jul 2024 16:55:50 GMT", + }, + "RetryAttempts": 0, + }, +} + class TestJumpStartBuilder(unittest.TestCase): @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @@ -1078,3 +1146,240 @@ def test_fine_tuned_model_with_fine_tuning_job_name( {"key": Tag.FINE_TUNING_MODEL_PATH, "value": mock_fine_tuning_model_path}, ] ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_quantize_for_jumpstart( + self, + mock_serve_settings, + mock_telemetry, + ): + mock_sagemaker_session = Mock() + + mock_pysdk_model = Mock() + mock_pysdk_model.env = {"SAGEMAKER_ENV": "1"} + mock_pysdk_model.model_data = mock_model_data + mock_pysdk_model.image_uri = mock_tgi_image_uri + mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS + mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0] + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species " + "of turtle native to the brackish coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a " + "species of turtle native to the brackish coastal " + "tidal marshes of the east coast." + } + ] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + sagemaker_session=mock_sagemaker_session, + ) + + model_builder.pysdk_model = mock_pysdk_model + + out_put = model_builder._optimize_for_jumpstart( + accept_eula=True, + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + }, + env_vars={ + "OPTION_TENSOR_PARALLEL_DEGREE": "1", + "OPTION_MAX_ROLLING_BATCH_SIZE": "2", + }, + output_path="s3://bucket/code/", + ) + + self.assertIsNotNone(out_put) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model") + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + def test_optimize_compile_for_jumpstart_without_neuron_env( + self, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_is_gated_model, + mock_serve_settings, + mock_telemetry, + ): + mock_sagemaker_session = Mock() + mock_sagemaker_session.wait_for_optimization_job.side_effect = ( + lambda *args: mock_optimization_job_response + ) + + mock_pre_trained_model.return_value = MagicMock() + mock_pre_trained_model.return_value.env = dict() + mock_pre_trained_model.return_value.model_data = mock_model_data + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.return_value = ( + DEPLOYMENT_CONFIGS + ) + mock_pre_trained_model.return_value.deployment_config = DEPLOYMENT_CONFIGS[0] + mock_pre_trained_model.return_value._metadata_configs = None + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species " + "of turtle native to the brackish coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a " + "species of turtle native to the brackish coastal " + "tidal marshes of the east coast." + } + ] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + sagemaker_session=mock_sagemaker_session, + ) + + optimized_model = model_builder.optimize( + accept_eula=True, + instance_type="ml.inf2.48xlarge", + compilation_config={ + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + "OPTION_N_POSITIONS": "2048", + "OPTION_DTYPE": "fp16", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_MAX_ROLLING_BATCH_SIZE": "4", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + } + }, + output_path="s3://bucket/code/", + ) + + self.assertEqual( + optimized_model.image_uri, + mock_optimization_job_response["OptimizationOutput"]["RecommendedInferenceImage"], + ) + self.assertEqual( + optimized_model.model_data["S3DataSource"]["S3Uri"], + mock_optimization_job_response["OutputConfig"]["S3OutputLocation"], + ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel") + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model") + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + def test_optimize_compile_for_jumpstart_with_neuron_env( + self, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_js_model, + mock_is_gated_model, + mock_serve_settings, + mock_telemetry, + ): + mock_sagemaker_session = Mock() + mock_metadata_config = Mock() + mock_sagemaker_session.wait_for_optimization_job.side_effect = ( + lambda *args: mock_optimization_job_response + ) + + mock_metadata_config.resolved_config = { + "supported_inference_instance_types": ["ml.inf2.48xlarge"], + "hosting_neuron_model_id": "neuron_model_id", + } + + mock_js_model.return_value = MagicMock() + mock_js_model.return_value.env = dict() + + mock_pre_trained_model.return_value = MagicMock() + mock_pre_trained_model.return_value.env = dict() + mock_pre_trained_model.return_value.config_name = "config_name" + mock_pre_trained_model.return_value.model_data = mock_model_data + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.return_value = ( + DEPLOYMENT_CONFIGS + ) + mock_pre_trained_model.return_value.deployment_config = DEPLOYMENT_CONFIGS[0] + mock_pre_trained_model.return_value._metadata_configs = { + "config_name": mock_metadata_config + } + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species " + "of turtle native to the brackish coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a " + "species of turtle native to the brackish coastal " + "tidal marshes of the east coast." + } + ] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + sagemaker_session=mock_sagemaker_session, + ) + + optimized_model = model_builder.optimize( + accept_eula=True, + instance_type="ml.inf2.48xlarge", + compilation_config={ + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + "OPTION_N_POSITIONS": "2048", + "OPTION_DTYPE": "fp16", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_MAX_ROLLING_BATCH_SIZE": "4", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + } + }, + output_path="s3://bucket/code/", + ) + + self.assertEqual( + optimized_model.image_uri, + mock_optimization_job_response["OptimizationOutput"]["RecommendedInferenceImage"], + ) + self.assertEqual( + optimized_model.model_data["S3DataSource"]["S3Uri"], + mock_optimization_job_response["OutputConfig"]["S3OutputLocation"], + ) + self.assertEqual(optimized_model.env["OPTION_TENSOR_PARALLEL_DEGREE"], "2") + self.assertEqual(optimized_model.env["OPTION_N_POSITIONS"], "2048") + self.assertEqual(optimized_model.env["OPTION_DTYPE"], "fp16") + self.assertEqual(optimized_model.env["OPTION_ROLLING_BATCH"], "auto") + self.assertEqual(optimized_model.env["OPTION_MAX_ROLLING_BATCH_SIZE"], "4") + self.assertEqual(optimized_model.env["OPTION_NEURON_OPTIMIZE_LEVEL"], "2") diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 5aea0fc6d4..d6798f48c4 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -2595,3 +2595,155 @@ def test_set_tracking_arn_mlflow_not_installed(self): builder.set_tracking_arn, tracking_arn, ) + + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_local_mode(self, mock_get_serve_setting): + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", mode=Mode.LOCAL_CONTAINER + ) + + self.assertRaisesRegex( + ValueError, + "Model optimization is only supported in Sagemaker Endpoint Mode.", + lambda: model_builder.optimize( + quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}} + ), + ) + + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_exclusive_args(self, mock_get_serve_setting): + mock_sagemaker_session = Mock() + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + sagemaker_session=mock_sagemaker_session, + ) + + self.assertRaisesRegex( + ValueError, + "Quantization config and compilation config are mutually exclusive.", + lambda: model_builder.optimize( + quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, + compilation_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, + ), + ) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_for_hf_with_custom_s3_path( + self, + mock_get_serve_setting, + mock_prepare_for_mode, + ): + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/code/code/", + } + }, + {"DTYPE": "bfloat16"}, + ) + + mock_pysdk_model = Mock() + mock_pysdk_model.model_data = None + mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-8B-Instruc"} + + model_builder = ModelBuilder( + model="meta-llama/Meta-Llama-3-8B-Instruct", + env_vars={"HUGGING_FACE_HUB_TOKEN": "token"}, + model_metadata={ + "CUSTOM_MODEL_PATH": "s3://bucket/path/", + }, + ) + + model_builder.pysdk_model = mock_pysdk_model + + out_put = model_builder._optimize_for_hf( + job_name="job_name-123", + instance_type="ml.g5.2xlarge", + role_arn="role-arn", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + }, + output_path="s3://bucket/code/", + ) + + print(out_put) + + self.assertEqual(model_builder.role_arn, "role-arn") + self.assertEqual(model_builder.instance_type, "ml.g5.2xlarge") + self.assertEqual(model_builder.pysdk_model.env["OPTION_QUANTIZE"], "awq") + self.assertEqual( + out_put, + { + "OptimizationJobName": "job_name-123", + "DeploymentInstanceType": "ml.g5.2xlarge", + "RoleArn": "role-arn", + "ModelSource": {"S3": {"S3Uri": "s3://bucket/code/code/"}}, + "OptimizationConfigs": [ + {"ModelQuantizationConfig": {"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}} + ], + "OutputConfig": {"S3OutputLocation": "s3://bucket/code/"}, + }, + ) + + @patch( + "sagemaker.serve.builder.model_builder.download_huggingface_model_metadata", autospec=True + ) + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_for_hf_without_custom_s3_path( + self, + mock_get_serve_setting, + mock_prepare_for_mode, + mock_download_huggingface_model_metadata, + ): + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/code/code/", + } + }, + {"DTYPE": "bfloat16"}, + ) + + mock_pysdk_model = Mock() + mock_pysdk_model.model_data = None + mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-8B-Instruc"} + + model_builder = ModelBuilder( + model="meta-llama/Meta-Llama-3-8B-Instruct", + env_vars={"HUGGING_FACE_HUB_TOKEN": "token"}, + ) + + model_builder.pysdk_model = mock_pysdk_model + + out_put = model_builder._optimize_for_hf( + job_name="job_name-123", + instance_type="ml.g5.2xlarge", + role_arn="role-arn", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + }, + output_path="s3://bucket/code/", + ) + + self.assertEqual(model_builder.role_arn, "role-arn") + self.assertEqual(model_builder.instance_type, "ml.g5.2xlarge") + self.assertEqual(model_builder.pysdk_model.env["OPTION_QUANTIZE"], "awq") + self.assertEqual( + out_put, + { + "OptimizationJobName": "job_name-123", + "DeploymentInstanceType": "ml.g5.2xlarge", + "RoleArn": "role-arn", + "ModelSource": {"S3": {"S3Uri": "s3://bucket/code/code/"}}, + "OptimizationConfigs": [ + {"ModelQuantizationConfig": {"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}} + ], + "OutputConfig": {"S3OutputLocation": "s3://bucket/code/"}, + }, + ) diff --git a/tests/unit/sagemaker/serve/constants.py b/tests/unit/sagemaker/serve/constants.py index 5c40c1bf64..5a4679747b 100644 --- a/tests/unit/sagemaker/serve/constants.py +++ b/tests/unit/sagemaker/serve/constants.py @@ -22,7 +22,7 @@ {"name": "Latency", "value": "100", "unit": "Tokens/S"}, {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, ], - "DeploymentConfig": { + "DeploymentArgs": { "ModelDataDownloadTimeout": 1200, "ContainerStartupHealthCheckTimeout": 1200, "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" @@ -59,7 +59,7 @@ {"name": "Latency", "value": "100", "unit": "Tokens/S"}, {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, ], - "DeploymentConfig": { + "DeploymentArgs": { "ModelDataDownloadTimeout": 1200, "ContainerStartupHealthCheckTimeout": 1200, "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" @@ -96,7 +96,7 @@ {"name": "Latency", "value": "100", "unit": "Tokens/S"}, {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, ], - "DeploymentConfig": { + "DeploymentArgs": { "ModelDataDownloadTimeout": 1200, "ContainerStartupHealthCheckTimeout": 1200, "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" @@ -133,7 +133,7 @@ {"name": "Latency", "value": "100", "unit": "Tokens/S"}, {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, ], - "DeploymentConfig": { + "DeploymentArgs": { "ModelDataDownloadTimeout": 1200, "ContainerStartupHealthCheckTimeout": 1200, "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py index 80a3217c1d..a82e508ee2 100644 --- a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -116,7 +116,6 @@ def test_generate_optimized_model(): "meta-textgeneration-llama-3-8b/artifacts/inference-prepack/v1.0.1/" } } - pysdk_model.env = {"OPTION_QUANTIZE": "awq"} optimized_model = _generate_optimized_model(pysdk_model, mock_optimization_job_output) @@ -124,10 +123,6 @@ def test_generate_optimized_model(): optimized_model.image_uri == mock_optimization_job_output["OptimizationOutput"]["RecommendedInferenceImage"] ) - assert optimized_model.env == { - "OPTION_QUANTIZE": "awq", - **mock_optimization_job_output["OptimizationEnvironment"], - } assert ( optimized_model.model_data["S3DataSource"]["S3Uri"] == mock_optimization_job_output["OutputConfig"]["S3OutputLocation"] From 7993b77ec42b9123dd70a5593faf00f8fb0b48cf Mon Sep 17 00:00:00 2001 From: Jonathan Makunga <54963715+makungaj1@users.noreply.github.com> Date: Sun, 7 Jul 2024 19:33:15 -0700 Subject: [PATCH 32/45] Fix public optimize api signature (#1507) * Fix public optimize api signature * JS Compilation fix * Refactoring * Refactoring --------- Co-authored-by: Jonathan Makunga --- .../serve/builder/jumpstart_builder.py | 24 +++++-- src/sagemaker/serve/builder/model_builder.py | 69 ++++++++++++++----- src/sagemaker/serve/utils/optimize_utils.py | 23 ++++++- .../serve/utils/test_optimize_utils.py | 13 ++++ 4 files changed, 103 insertions(+), 26 deletions(-) diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 88328472a5..07885792d2 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -47,6 +47,7 @@ _is_optimized, _custom_speculative_decoding, SPECULATIVE_DRAFT_MODEL, + _is_inferentia_or_trainium, ) from sagemaker.serve.utils.predictors import ( DjlLocalModePredictor, @@ -714,10 +715,25 @@ def _optimize_for_jumpstart( f"Model '{self.model}' requires accepting end-user license agreement (EULA)." ) + is_compilation = (quantization_config is None) and ( + (compilation_config is not None) or _is_inferentia_or_trainium(instance_type) + ) + pysdk_model_env_vars = dict() - if compilation_config: + if is_compilation: pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type) + optimization_config, override_env = _extract_optimization_config_and_env( + quantization_config, compilation_config + ) + if not optimization_config and is_compilation: + override_env = override_env or pysdk_model_env_vars + optimization_config = { + "ModelCompilationConfig": { + "OverrideEnvironment": override_env, + } + } + if speculative_decoding_config: self._set_additional_model_source(speculative_decoding_config) else: @@ -732,10 +748,6 @@ def _optimize_for_jumpstart( model_source = _generate_model_source(self.pysdk_model.model_data, accept_eula) optimization_env_vars = _update_environment_variables(pysdk_model_env_vars, env_vars) - optimization_config, override_env = _extract_optimization_config_and_env( - quantization_config, compilation_config - ) - output_config = {"S3OutputLocation": output_path} if kms_key: output_config["KmsKeyId"] = kms_key @@ -775,7 +787,7 @@ def _optimize_for_jumpstart( "AcceptEula": True } - if quantization_config or compilation_config: + if quantization_config or is_compilation: self.pysdk_model.env = _update_environment_variables( optimization_env_vars, override_env ) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 9293362e65..11ce087ee4 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -1083,25 +1083,47 @@ def _try_fetch_gpu_info(self): f"Unable to determine single GPU size for instance: [{self.instance_type}]" ) - def optimize(self, *args, **kwargs) -> Model: - """Runs a model optimization job. + def optimize( + self, + output_path: Optional[str] = None, + instance_type: Optional[str] = None, + role_arn: Optional[str] = None, + tags: Optional[Tags] = None, + job_name: Optional[str] = None, + accept_eula: Optional[bool] = None, + quantization_config: Optional[Dict] = None, + compilation_config: Optional[Dict] = None, + speculative_decoding_config: Optional[Dict] = None, + env_vars: Optional[Dict] = None, + vpc_config: Optional[Dict] = None, + kms_key: Optional[str] = None, + max_runtime_in_sec: Optional[int] = 36000, + sagemaker_session: Optional[Session] = None, + ) -> Model: + """Create an optimized deployable ``Model`` instance with ``ModelBuilder``. Args: - instance_type (Optional[str]): Target deployment instance type that the - model is optimized for. - output_path (Optional[str]): Specifies where to store the compiled/quantized model. - role_arn (Optional[str]): Execution role. Defaults to ``None``. + output_path (str): Specifies where to store the compiled/quantized model. + instance_type (str): Target deployment instance type that the model is optimized for. + role_arn (Optional[str]): Execution role arn. Defaults to ``None``. tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``. compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. + speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. + Defaults to ``None`` env_vars (Optional[Dict]): Additional environment variables to run the optimization container. Defaults to ``None``. vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading to S3. Defaults to ``None``. max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to - ``None``. + 36000 seconds. sagemaker_session (Optional[Session]): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the function creates one using the default AWS configuration chain. @@ -1113,7 +1135,22 @@ def optimize(self, *args, **kwargs) -> Model: # need to get telemetry_opt_out info before telemetry decorator is called self.serve_settings = self._get_serve_setting() - return self._model_builder_optimize_wrapper(*args, **kwargs) + return self._model_builder_optimize_wrapper( + output_path=output_path, + instance_type=instance_type, + role_arn=role_arn, + tags=tags, + job_name=job_name, + accept_eula=accept_eula, + quantization_config=quantization_config, + compilation_config=compilation_config, + speculative_decoding_config=speculative_decoding_config, + env_vars=env_vars, + vpc_config=vpc_config, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + sagemaker_session=sagemaker_session, + ) @_capture_telemetry("optimize") def _model_builder_optimize_wrapper( @@ -1178,10 +1215,8 @@ def _model_builder_optimize_wrapper( self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() - if instance_type: - self.instance_type = instance_type - if role_arn: - self.role_arn = role_arn + self.instance_type = instance_type or self.instance_type + self.role_arn = role_arn or self.role_arn self.build(mode=self.mode, sagemaker_session=self.sagemaker_session) job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}" @@ -1266,7 +1301,7 @@ def _optimize_for_hf( ``None``. Returns: - Dict[str, Any]: Model optimization job input arguments. + Optional[Dict[str, Any]]: Model optimization job input arguments. """ if self.model_server != ModelServer.DJL_SERVING: logger.info("Overwriting model server to DJL.") @@ -1275,6 +1310,10 @@ def _optimize_for_hf( self.role_arn = role_arn or self.role_arn self.instance_type = instance_type or self.instance_type + self.pysdk_model = _custom_speculative_decoding( + self.pysdk_model, speculative_decoding_config, False + ) + if quantization_config or compilation_config: create_optimization_job_args = { "OptimizationJobName": job_name, @@ -1290,10 +1329,6 @@ def _optimize_for_hf( model_source = _generate_model_source(self.pysdk_model.model_data, False) create_optimization_job_args["ModelSource"] = model_source - self.pysdk_model = _custom_speculative_decoding( - self.pysdk_model, speculative_decoding_config, False - ) - optimization_config, override_env = _extract_optimization_config_and_env( quantization_config, compilation_config ) diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index be4c7fd993..83978e252a 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -26,6 +26,23 @@ SPECULATIVE_DRAFT_MODEL = "/opt/ml/additional-model-data-sources" +def _is_inferentia_or_trainium(instance_type: Optional[str]) -> bool: + """Checks whether an instance is compatible with Inferentia. + + Args: + instance_type (str): The instance type used for the compilation job. + + Returns: + bool: Whether the given instance type is Inferentia or Trainium. + """ + if isinstance(instance_type, str): + match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) + if match: + if match[1].startswith("inf") or match[1].startswith("trn"): + return True + return False + + def _is_image_compatible_with_optimization_job(image_uri: Optional[str]) -> bool: """Checks whether an instance is compatible with an optimization job. @@ -169,11 +186,11 @@ def _extracts_and_validates_speculative_model_source( Raises: ValueError: If model source is none. """ - s3_uri: str = speculative_decoding_config.get("ModelSource") + model_source: str = speculative_decoding_config.get("ModelSource") - if not s3_uri: + if not model_source: raise ValueError("ModelSource must be provided in speculative decoding config.") - return s3_uri + return model_source def _generate_channel_name(additional_model_data_sources: Optional[List[Dict]]) -> str: diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py index a82e508ee2..bdd59b0497 100644 --- a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -31,6 +31,7 @@ _normalize_local_model_path, _is_optimized, _custom_speculative_decoding, + _is_inferentia_or_trainium, ) mock_optimization_job_output = { @@ -81,6 +82,18 @@ } +@pytest.mark.parametrize( + "instance, expected", + [ + ("ml.trn1.2xlarge", True), + ("ml.inf2.xlarge", True), + ("ml.c7gd.4xlarge", False), + ], +) +def test_is_inferentia_or_trainium(instance, expected): + assert _is_inferentia_or_trainium(instance) == expected + + @pytest.mark.parametrize( "image_uri, expected", [ From d997612e62ec454cc42bc9ae390e707f730082d4 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Mon, 8 Jul 2024 00:54:28 -0700 Subject: [PATCH 33/45] Refactoring --- VERSION | 2 +- src/sagemaker/jumpstart/factory/estimator.py | 1 - src/sagemaker/jumpstart/factory/model.py | 1 - src/sagemaker/jumpstart/hub/interfaces.py | 2 +- src/sagemaker/jumpstart/types.py | 3 ++- src/sagemaker/jumpstart/utils.py | 15 ------------- src/sagemaker/serve/builder/model_builder.py | 14 ------------- tests/unit/sagemaker/jumpstart/constants.py | 21 ------------------- .../sagemaker/jumpstart/model/test_model.py | 2 +- 9 files changed, 5 insertions(+), 56 deletions(-) diff --git a/VERSION b/VERSION index 5fa95d3edc..07acbaddb8 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.224.5.dev0 \ No newline at end of file +2.224.5.dev0 diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 2735893751..8540f53ca4 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -62,7 +62,6 @@ from sagemaker.jumpstart.utils import ( add_hub_content_arn_tags, add_jumpstart_model_info_tags, - add_jumpstart_model_id_version_tags, get_eula_message, get_default_jumpstart_session_with_user_agent_suffix, update_dict_if_key_not_present, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index bce66a2062..e759adec5e 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -49,7 +49,6 @@ from sagemaker.jumpstart.utils import ( add_hub_content_arn_tags, add_jumpstart_model_info_tags, - add_jumpstart_model_id_version_tags, get_default_jumpstart_session_with_user_agent_suffix, get_neo_content_bucket, update_dict_if_key_not_present, diff --git a/src/sagemaker/jumpstart/hub/interfaces.py b/src/sagemaker/jumpstart/hub/interfaces.py index c3ca307444..d987216872 100644 --- a/src/sagemaker/jumpstart/hub/interfaces.py +++ b/src/sagemaker/jumpstart/hub/interfaces.py @@ -585,7 +585,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: ) self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri") self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn") - + self.inference_config_rankings = self._get_config_rankings(json_obj) self.inference_config_components = self._get_config_components(json_obj) self.inference_configs = self._get_configs(json_obj) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 92591c7599..fb4c157a67 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1975,7 +1975,7 @@ def extract_region_from_arn(arn: str) -> Optional[str]: if match: hub_region = match.group(2) return hub_region - + match = re.match(HUB_ARN_REGEX, arn) if match: hub_region = match.group(2) @@ -1983,6 +1983,7 @@ def extract_region_from_arn(arn: str) -> Optional[str]: return hub_region + class JumpStartCachedContentValue(JumpStartDataHolderType): """Data class for the s3 cached content values.""" diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 41316845cb..7a00efa8e1 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -446,21 +446,6 @@ def add_hub_content_arn_tags( return tags -def add_hub_content_arn_tags( - tags: Optional[List[TagsDict]], - hub_arn: str, -) -> Optional[List[TagsDict]]: - """Adds custom Hub arn tag to JumpStart related resources.""" - - tags = add_single_jumpstart_tag( - hub_arn, - enums.JumpStartTag.HUB_CONTENT_ARN, - tags, - is_uri=False, - ) - return tags - - def add_jumpstart_uri_tags( tags: Optional[List[TagsDict]] = None, inference_model_uri: Optional[Union[str, dict]] = None, diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index eea56950c8..01b2b96f68 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -309,20 +309,6 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, }, ) - def _build_validations(self): - """Validations needed for model server overrides, or auto-detection or fallback""" - if self.mode == Mode.IN_PROCESS: - raise ValueError("IN_PROCESS mode is not supported yet!") - - if self.inference_spec and self.model: - raise ValueError("Can only set one of the following: model, inference_spec.") - - if self.image_uri and not is_1p_image_uri(self.image_uri) and self.model_server is None: - raise ValueError( - "Model_server must be set when non-first-party image_uri is set. " - + "Supported model servers: %s" % supported_model_servers - ) - def _save_model_inference_spec(self): """Placeholder docstring""" # check if path exists and create if not diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 21970e59eb..9117b2d26d 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -9805,27 +9805,6 @@ "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, }, }, - "ContextualHelp": { - "HubFormatTrainData": [ - "A train and an optional validation directories. Each directory contains a CSV/JSON/TXT. ", - "- For CSV/JSON files, the text data is used from the column called 'text' or the first column if no column called 'text' is found", # noqa: E501 - "- The number of files under train and validation (if provided) should equal to one, respectively.", - " [Learn how to setup an AWS S3 bucket.](https://docs.aws.amazon.com/AmazonS3/latest/dev/UsingBucket.html)", # noqa: E501 - ], - "HubDefaultTrainData": [ - "Dataset: [SEC](https://www.sec.gov/edgar/searchedgar/companysearch)", - "SEC filing contains regulatory documents that companies and issuers of securities must submit to the Securities and Exchange Commission (SEC) on a regular basis.", # noqa: E501 - "License: [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/legalcode)", - ], - }, - "ModelDataDownloadTimeout": 1200, - "ContainerStartupHealthCheckTimeout": 1200, - "EncryptInterContainerTraffic": True, - "DisableOutputCompression": True, - "MaxRuntimeInSeconds": 360000, - "DynamicContainerDeploymentSupported": True, - "TrainingModelPackageArtifactUri": None, - "Dependencies": [], "InferenceConfigRankings": { "overall": {"Description": "default", "Rankings": ["variant1"]} }, diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index f0970cc8c5..af2a61abb9 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -1492,7 +1492,7 @@ def test_attach( model_version="some-version", inference_component_name="some-ic-name", ) - + mock_get_model_info_from_endpoint.assert_not_called() @mock.patch( From 7152db247bc5323f8ebad6521f4ede4a5927a439 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Mon, 8 Jul 2024 09:34:41 -0700 Subject: [PATCH 34/45] Integration tests --- src/sagemaker/jumpstart/enums.py | 2 + .../jumpstart/model/test_jumpstart_model.py | 43 +++++++++++ .../sagemaker/serve/test_serve_js_happy.py | 72 +++++++++++++++++++ 3 files changed, 117 insertions(+) diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index 5446276c3a..a83964e394 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -93,8 +93,10 @@ class JumpStartTag(str, Enum): MODEL_ID = "sagemaker-sdk:jumpstart-model-id" MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version" MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type" + INFERENCE_CONFIG_NAME = "sagemaker-sdk:jumpstart-inference-config-name" TRAINING_CONFIG_NAME = "sagemaker-sdk:jumpstart-training-config-name" + HUB_CONTENT_ARN = "sagemaker-sdk:hub-content-arn" diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 6bc0a5c996..1563ccf2d3 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -11,7 +11,10 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import + +import io import os +import sys import time from unittest import mock @@ -349,3 +352,43 @@ def test_register_gated_jumpstart_model(setup): predictor.delete_predictor() assert response is not None + + +def test_jumpstart_model_with_deployment_configs(setup): + model_id = "meta-textgeneration-llama-2-7b-f" + + model = JumpStartModel( + model_id=model_id, + model_version="*", + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + + captured_output = io.StringIO() + sys.stdout = captured_output + model.display_benchmark_metrics() + sys.stdout = sys.__stdout__ + assert captured_output.getvalue() is not None + + configs = model.list_deployment_configs() + assert len(configs) > 0 + + model.set_deployment_config( + configs[0]["ConfigName"], + "ml.g5.2xlarge", + ) + assert model.config_name == configs[0]["ConfigName"] + + predictor = model.deploy( + accept_eula=True, + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + payload = { + "inputs": "some-payload", + "parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6}, + } + + response = predictor.predict(payload, custom_attributes="accept_eula=true") + + assert response is not None diff --git a/tests/integ/sagemaker/serve/test_serve_js_happy.py b/tests/integ/sagemaker/serve/test_serve_js_happy.py index ad0527fcc0..e97f130a59 100644 --- a/tests/integ/sagemaker/serve/test_serve_js_happy.py +++ b/tests/integ/sagemaker/serve/test_serve_js_happy.py @@ -12,6 +12,9 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import io +import sys + import pytest from sagemaker.serve.builder.model_builder import ModelBuilder @@ -54,6 +57,19 @@ def happy_model_builder(sagemaker_session): ) +@pytest.fixture +def meta_textgeneration_llama_2_7b_f_schema(): + prompt = "Hello, I'm a language model," + response = "Hello, I'm a language model, and I'm here to help you with your English." + sample_input = {"inputs": prompt} + sample_output = [{"generated_text": response}] + + return SchemaBuilder( + sample_input=sample_input, + sample_output=sample_output, + ) + + @pytest.fixture def happy_mms_model_builder(sagemaker_session): iam_client = sagemaker_session.boto_session.client("iam") @@ -125,3 +141,59 @@ def test_happy_mms_sagemaker_endpoint(happy_mms_model_builder, gpu_instance_type ) if caught_ex: raise caught_ex + + +@pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="The goal of these test are to test the serving components of our feature", +) +def test_js_model_with_deployment_configs( + meta_textgeneration_llama_2_7b_f_schema, + sagemaker_session, +): + logger.info("Running in SAGEMAKER_ENDPOINT mode...") + caught_ex = None + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-2-7b-f", + schema_builder=meta_textgeneration_llama_2_7b_f_schema, + ) + configs = model_builder.list_deployment_configs() + + assert len(configs) > 0 + + captured_output = io.StringIO() + sys.stdout = captured_output + model_builder.display_benchmark_metrics() + sys.stdout = sys.__stdout__ + assert captured_output.getvalue() is not None + + model_builder.set_deployment_config( + configs[0]["ConfigName"], + "ml.g5.2xlarge", + ) + model = model_builder.build(role_arn=role_arn, sagemaker_session=sagemaker_session) + assert model.config_name == configs[0]["ConfigName"] + assert model_builder.get_deployment_config() is not None + + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + try: + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") + predictor = model.deploy(accept_eula=True) + logger.info("Endpoint successfully deployed.") + + updated_sample_input = model_builder.schema_builder.sample_input + + predictor.predict(updated_sample_input) + except Exception as e: + caught_ex = e + finally: + cleanup_model_resources( + sagemaker_session=sagemaker_session, + model_name=model.name, + endpoint_name=model.endpoint_name, + ) + if caught_ex: + raise caught_ex From 0bd6aa86e7857db6045d94a97545b7f76df1adc4 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Mon, 8 Jul 2024 11:48:29 -0700 Subject: [PATCH 35/45] Skip Alt Config integ tests as metadata aren't fully deployed. --- requirements/extras/test_requirements.txt | 2 +- .../integ/sagemaker/jumpstart/model/test_jumpstart_model.py | 6 +++++- tests/integ/sagemaker/serve/test_serve_js_happy.py | 6 +++--- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index 1266bd6fef..56805ebc7a 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -38,4 +38,4 @@ accelerate>=0.24.1,<=0.27.0 schema==0.7.5 tensorflow>=2.1,<=2.16 mlflow>=2.12.2,<2.13 -huggingface_hub>=0.23.4 \ No newline at end of file +huggingface_hub>=0.23.4 diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 1563ccf2d3..5ee0abd41f 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -354,8 +354,12 @@ def test_register_gated_jumpstart_model(setup): assert response is not None +@pytest.mark.skipif( + True, + reason="Only enable after metadata is fully deployed.", +) def test_jumpstart_model_with_deployment_configs(setup): - model_id = "meta-textgeneration-llama-2-7b-f" + model_id = "meta-textgeneration-llama-2-13b" model = JumpStartModel( model_id=model_id, diff --git a/tests/integ/sagemaker/serve/test_serve_js_happy.py b/tests/integ/sagemaker/serve/test_serve_js_happy.py index e97f130a59..807a5ad691 100644 --- a/tests/integ/sagemaker/serve/test_serve_js_happy.py +++ b/tests/integ/sagemaker/serve/test_serve_js_happy.py @@ -144,8 +144,8 @@ def test_happy_mms_sagemaker_endpoint(happy_mms_model_builder, gpu_instance_type @pytest.mark.skipif( - PYTHON_VERSION_IS_NOT_310, - reason="The goal of these test are to test the serving components of our feature", + True, + reason="Only enable after metadata is fully deployed.", ) def test_js_model_with_deployment_configs( meta_textgeneration_llama_2_7b_f_schema, @@ -157,7 +157,7 @@ def test_js_model_with_deployment_configs( role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] model_builder = ModelBuilder( - model="meta-textgeneration-llama-2-7b-f", + model="meta-textgeneration-llama-2-13b", schema_builder=meta_textgeneration_llama_2_7b_f_schema, ) configs = model_builder.list_deployment_configs() From 1e14343b53ebeae1164643d1f5563595b8c323eb Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Mon, 8 Jul 2024 17:31:58 -0700 Subject: [PATCH 36/45] Fix metric column name --- src/sagemaker/jumpstart/utils.py | 11 +++++++---- tests/unit/sagemaker/jumpstart/test_utils.py | 10 +++++----- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 7a00efa8e1..a78a39ede2 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1220,6 +1220,8 @@ def get_metrics_from_deployment_configs( if not deployment_configs: return {} + print("deployment_configs: {}".format(deployment_configs)) + data = {"Instance Type": [], "Config Name": [], "Concurrent Users": []} instance_rate_data = {} for index, deployment_config in enumerate(deployment_configs): @@ -1256,7 +1258,7 @@ def get_metrics_from_deployment_configs( instance_rate_data[instance_rate_column_name].append(instance_type_rate.value) for metric in metrics: - column_name = _normalize_benchmark_metric_column_name(metric.name) + column_name = _normalize_benchmark_metric_column_name(metric.name, metric.unit) data[column_name] = data.get(column_name, []) data[column_name].append(metric.value) @@ -1264,18 +1266,19 @@ def get_metrics_from_deployment_configs( return data -def _normalize_benchmark_metric_column_name(name: str) -> str: +def _normalize_benchmark_metric_column_name(name: str, unit: str) -> str: """Normalizes benchmark metric column name. Args: name (str): Name of the metric. + unit (str): Unit of the metric. Returns: str: Normalized metric column name. """ if "latency" in name.lower(): - name = "Latency for each user (TTFT in ms)" + name = f"Latency, TTFT (P50 in {unit.lower()})" elif "throughput" in name.lower(): - name = "Throughput per user (token/seconds)" + name = f"Throughput (P50 in {unit.lower()}/user)" return name diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 204f1d2d29..533483a497 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1979,14 +1979,14 @@ def test__normalize_benchmark_metrics(): @pytest.mark.parametrize( - "name, expected", + "name, unit, expected", [ - ("latency", "Latency for each user (TTFT in ms)"), - ("throughput", "Throughput per user (token/seconds)"), + ("latency", "sec", "Latency, TTFT (P50 in sec)"), + ("throughput", "tokens/sec", "Throughput (P50 in tokens/sec/user)"), ], ) -def test__normalize_benchmark_metric_column_name(name, expected): - out = utils._normalize_benchmark_metric_column_name(name) +def test_normalize_benchmark_metric_column_name(name, unit, expected): + out = utils._normalize_benchmark_metric_column_name(name, unit) assert out == expected From 9409031a09738cbdcaa57ec4f2235706f8eed32a Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Mon, 8 Jul 2024 19:33:29 -0700 Subject: [PATCH 37/45] Refactoring --- src/sagemaker/jumpstart/utils.py | 2 -- src/sagemaker/utils.py | 4 ++-- tests/unit/test_utils.py | 10 +++++----- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index a78a39ede2..9905d66f3a 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1220,8 +1220,6 @@ def get_metrics_from_deployment_configs( if not deployment_configs: return {} - print("deployment_configs: {}".format(deployment_configs)) - data = {"Instance Type": [], "Config Name": [], "Concurrent Users": []} instance_rate_data = {} for index, deployment_config in enumerate(deployment_configs): diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 834d193ec1..18e604691d 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1791,9 +1791,9 @@ def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Optional[Dict[ if value is not None: value = str(round(float(value), 3)) return { - "unit": f"{currency}/{price.get('unit', 'Hrs')}", + "unit": f"{currency}/Hr", "value": value, - "name": "Instance Rate", + "name": "On-demand Instance Rate", } return None diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 11cbc120b3..deb295e6e1 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1929,7 +1929,7 @@ def test_deep_override_skip_keys(self): } ] }, - {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.9"}, + {"name": "On-demand Instance Rate", "unit": "USD/Hr", "value": "0.9"}, ), ( "ml.t4g.nano", @@ -1947,7 +1947,7 @@ def test_deep_override_skip_keys(self): '"termAttributes": {}}}}}' ] }, - {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.008"}, + {"name": "On-demand Instance Rate", "unit": "USD/Hr", "value": "0.008"}, ), ( "ml.t4g.nano", @@ -1965,7 +1965,7 @@ def test_deep_override_skip_keys(self): '"termAttributes": {}}}}}' ] }, - {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.008"}, + {"name": "On-demand Instance Rate", "unit": "USD/Hr", "value": "0.008"}, ), ( "ml.t4g.nano", @@ -1983,7 +1983,7 @@ def test_deep_override_skip_keys(self): '"termAttributes": {}}}}}' ] }, - {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.008"}, + {"name": "On-demand Instance Rate", "unit": "USD/Hr", "value": "0.008"}, ), ], ) @@ -2024,7 +2024,7 @@ def test_get_instance_rate_per_hour( } } }, - {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.9"}, + {"name": "On-demand Instance Rate", "unit": "USD/Hr", "value": "0.9"}, ), ], ) From ba3d49c0a92773b811bc571aef06f67d74e1fb0f Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Tue, 9 Jul 2024 11:54:39 -0700 Subject: [PATCH 38/45] Display API --- src/sagemaker/jumpstart/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 9905d66f3a..83425d62b3 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1236,6 +1236,7 @@ def get_metrics_from_deployment_configs( instance_type_to_display = ( f"{current_instance_type} (Default)" if index == 0 + and concurrent_user and int(concurrent_user) == 1 and current_instance_type == deployment_config.deployment_args.default_instance_type @@ -1295,7 +1296,7 @@ def _normalize_benchmark_metrics( instance_type_rate = None concurrent_users = {} for current_instance_type_metric in benchmark_metric_stats: - if current_instance_type_metric.name.lower() == "instance rate": + if "instance rate" in current_instance_type_metric.name.lower(): instance_type_rate = current_instance_type_metric elif current_instance_type_metric.concurrency not in concurrent_users: concurrent_users[current_instance_type_metric.concurrency] = [ From 2edb2e62865bd50b1fa2aa6be3fc73201dd7e844 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Tue, 9 Jul 2024 12:24:32 -0700 Subject: [PATCH 39/45] Relax set deployment error handling --- src/sagemaker/jumpstart/factory/model.py | 6 +----- tests/unit/sagemaker/jumpstart/model/test_model.py | 6 ------ 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index e759adec5e..188b4786d7 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -674,11 +674,7 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta resolved_config = specs.inference_configs.configs[kwargs.config_name].resolved_config supported_instance_types = resolved_config.get("supported_inference_instance_types", []) if kwargs.instance_type not in supported_instance_types: - raise ValueError( - f"Instance type {kwargs.instance_type} " - f"is not supported for config {kwargs.config_name}." - ) - + JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type) return kwargs diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index af2a61abb9..56eaa0b660 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -1941,12 +1941,6 @@ def test_model_set_deployment_config_incompatible_instance_type_or_name( mock_get_model_specs.reset_mock() mock_model_deploy.reset_mock() mock_get_model_specs.side_effect = get_prototype_spec_with_configs - with pytest.raises(ValueError) as error: - model.set_deployment_config("neuron-inference", "ml.inf2.32xlarge") - assert ( - "Instance type ml.inf2.32xlarge is not supported for config neuron-inference." - in str(error) - ) with pytest.raises(ValueError) as error: model.set_deployment_config("neuron-inference-unknown-name", "ml.inf2.32xlarge") From 4dca186fe473e2a7cde3d32ffd0438f2ce627fe1 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Tue, 9 Jul 2024 12:49:23 -0700 Subject: [PATCH 40/45] Override region for draft model data source --- src/sagemaker/jumpstart/factory/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 188b4786d7..61fcff242f 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -697,7 +697,7 @@ def _add_additional_model_data_sources_to_kwargs( # Append speculative decoding data source from metadata speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources() for data_source in speculative_decoding_data_sources: - data_source.s3_data_source.set_bucket(get_neo_content_bucket()) + data_source.s3_data_source.set_bucket(get_neo_content_bucket(region=kwargs.region)) api_shape_additional_model_data_sources = ( [ camel_case_to_pascal_case(data_source.to_json()) From 44741199a1c5480bd642012a384c8ce872cee096 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Tue, 9 Jul 2024 13:14:05 -0700 Subject: [PATCH 41/45] use latest boto3 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 9242e69cfd..b9486bbf18 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ def read_requirements(filename): # Declare minimal set for installation required_packages = [ "attrs>=23.1.0,<24", - "boto3>=1.33.3,<2.0", + "boto3>=1.34.142,<2.0", "cloudpickle==2.2.1", "google-pasta", "numpy>=1.9.0,<2.0", From a7d1bae1063229d65a10c5847c895c8b77a7e138 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Tue, 9 Jul 2024 13:45:52 -0700 Subject: [PATCH 42/45] EBS Volue --- src/sagemaker/utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 18e604691d..45509f65f6 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1454,10 +1454,15 @@ def volume_size_supported(instance_type: str) -> bool: if len(parts) != 2: raise ValueError(f"Failed to parse instance type '{instance_type}'") - # Any instance type with a "d" in the instance family (i.e. c5d, p4d, etc) + g5 - # does not support attaching an EBS volume. + # Any instance type with a "d" in the instance family (i.e. c5d, p4d, etc) + # + g5 or g6 or p5 does not support attaching an EBS volume. family = parts[0] - return "d" not in family and not family.startswith("g5") + return ( + "d" not in family + and not family.startswith("g5") + and not family.startswith("g6") + and not family.startswith("p5") + ) except Exception as e: raise ValueError(f"Failed to parse instance type '{instance_type}': {str(e)}") From 661a4156786aecb903cc8a2b0d0ef8fca8709405 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Tue, 9 Jul 2024 14:03:17 -0700 Subject: [PATCH 43/45] model tags --- src/sagemaker/serve/utils/optimize_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index 83978e252a..35a937407e 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -332,7 +332,7 @@ def _custom_speculative_decoding( model.env.update({"OPTION_SPECULATIVE_DRAFT_MODEL": speculative_draft_model}) model.add_tags( - {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "customer"}, + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "custom"}, ) return model From 59edbfb24141c15499207f178ad4b5431089c466 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Tue, 9 Jul 2024 14:43:47 -0700 Subject: [PATCH 44/45] UT --- tests/unit/sagemaker/serve/utils/test_optimize_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py index bdd59b0497..3caec9c334 100644 --- a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -399,5 +399,5 @@ def test_with_non_s3(self, mock_model): ) mock_model.add_tags.assert_called_once_with( - {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "customer"} + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "custom"} ) From b85e2c32510af64f905e5543a0a5fb9c97d1b98b Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Tue, 9 Jul 2024 15:13:22 -0700 Subject: [PATCH 45/45] FIX UT --- tests/unit/sagemaker/serve/utils/test_optimize_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py index 3caec9c334..712382f068 100644 --- a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -338,7 +338,7 @@ def test_with_s3_hf(self, mock_model): res_model = _custom_speculative_decoding(mock_model, speculative_decoding_config) mock_model.add_tags.assert_called_once_with( - {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "customer"} + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "custom"} ) self.assertEqual(