-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: optimization technique related validations. #4921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 16 commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
7ec16e6
Enable quantization and compilation in the same optimization job via …
cf70f59
Require EULA acceptance when using a gated 1p draft model via ModelBu…
fcb5092
add accept_draft_model_eula to JumpStartModel when deployment config …
9489b8d
add map of valid optimization combinations
5512c26
Add ModelBuilder support for JumpStart-provided draft models.
c94a78b
Tweak draft model EULA validations and messaging. Remove redundant de…
d10c475
Add "Auto" speculative decoding ModelProvider option; add validations…
8fb27a0
Fix JumpStartModel.AdditionalModelDataSource model access config assi…
779f6d6
move the accept eula configurations into deploy flow
gwang111 aef3a90
Merge branch 'master' into QuicksilverV2
gwang111 b7b15b8
move the accept eula configurations into deploy flow
gwang111 748ea4b
Use correct bucket for SM/JS draft models and minor formatting/valida…
a7feb54
Remove obsolete docstring.
694b4f2
remove references to accept_draft_model_eula
gwang111 7b6aef1
renaming of eula fn and error msg
gwang111 ce47be5
Merge branch 'master' into QuicksilverV2
gwang111 1f75072
fix: pin testing deps (#4925)
benieric 277e0b1
Revert "change: add TGI 2.4.0 image uri (#4922)" (#4926)
Captainia 8f0083b
fix naming and messaging
gwang111 8b73f34
ModelBuilder speculative decoding UTs and minor fixes.
c06aef0
Merge branch 'master' into QuicksilverV2
gwang111 09a54dc
Fix set union.
3b147cd
add UTs for JumpStart deployment
gwang111 65cb5b3
fix formatting issues
gwang111 4d1e12b
address validation comments
gwang111 bf706ad
fix doc strings
gwang111 f121eb0
Add TRTLLM compilation + speculative decoding validation.
9148e70
address nits
gwang111 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
# language governing permissions and limitations under the License. | ||
"""This module contains utilities related to SageMaker JumpStart.""" | ||
from __future__ import absolute_import | ||
|
||
from copy import copy | ||
import logging | ||
import os | ||
|
@@ -22,6 +23,7 @@ | |
from botocore.exceptions import ClientError | ||
from packaging.version import Version | ||
import botocore | ||
from sagemaker_core.shapes import ModelAccessConfig | ||
import sagemaker | ||
from sagemaker.config.config_schema import ( | ||
MODEL_ENABLE_NETWORK_ISOLATION_PATH, | ||
|
@@ -55,6 +57,7 @@ | |
TagsDict, | ||
get_instance_rate_per_hour, | ||
get_domain_for_region, | ||
camel_case_to_pascal_case, | ||
) | ||
from sagemaker.workflow import is_pipeline_variable | ||
from sagemaker.user_agent import get_user_agent_extra_suffix | ||
|
@@ -555,11 +558,18 @@ def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str: | |
"""Returns EULA message to display if one is available, else empty string.""" | ||
if model_specs.hosting_eula_key is None: | ||
return "" | ||
return format_eula_message_template( | ||
model_id=model_specs.model_id, region=region, hosting_eula_key=model_specs.hosting_eula_key | ||
) | ||
|
||
|
||
def format_eula_message_template(model_id: str, region: str, hosting_eula_key: str): | ||
gwang111 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Returns a formatted EULA message.""" | ||
return ( | ||
f"Model '{model_specs.model_id}' requires accepting end-user license agreement (EULA). " | ||
f"Model '{model_id}' requires accepting end-user license agreement (EULA). " | ||
f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}." | ||
f"{get_domain_for_region(region)}" | ||
f"/{model_specs.hosting_eula_key} for terms of use." | ||
f"/{hosting_eula_key} for terms of use." | ||
) | ||
|
||
|
||
|
@@ -1525,3 +1535,63 @@ def wrapped_f(*args, **kwargs): | |
if _func is None: | ||
return wrapper_cache | ||
return wrapper_cache(_func) | ||
|
||
gwang111 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _add_model_access_configs_to_model_data_sources( | ||
model_data_sources: List[Dict[str, any]], | ||
model_access_configs: Dict[str, ModelAccessConfig], | ||
model_id: str, | ||
region: str, | ||
): | ||
"""Sets AcceptEula to True for gated speculative decoding models""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this docstring is misleading, describe what the code is doing please. Something like: """Iterate over the accept EULA configs to ensure all channels are matched, return accepted EULA config list.
Raise:
ValueError if at least one channel that requires EULA acceptance as not passed.
""" |
||
|
||
if not model_data_sources: | ||
return model_data_sources | ||
|
||
acked_model_data_sources = [] | ||
for model_data_source in model_data_sources: | ||
hosting_eula_key = model_data_source.get("HostingEulaKey") | ||
if hosting_eula_key: | ||
if not model_access_configs or not model_access_configs.get(model_id): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens if the user passed |
||
eula_message_template = ( | ||
"{model_source}{base_eula_message}{model_access_configs_message}" | ||
) | ||
model_access_config_entry = ( | ||
'"{model_id}":ModelAccessConfig(accept_eula=True)'.format(model_id=model_id) | ||
) | ||
raise ValueError( | ||
eula_message_template.format( | ||
model_source="Additional " if model_data_source.get("ChannelName") else "", | ||
base_eula_message=format_eula_message_template( | ||
model_id=model_id, region=region, hosting_eula_key=hosting_eula_key | ||
), | ||
model_access_configs_message=( | ||
" Please add a ModelAccessConfig entry:" | ||
gwang111 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f" {model_access_config_entry} " | ||
"to model_access_configs to acknowledge the EULA." | ||
gwang111 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
), | ||
) | ||
) | ||
acked_model_data_source = model_data_source.copy() | ||
acked_model_data_source.pop("HostingEulaKey") | ||
acked_model_data_source["S3DataSource"]["ModelAccessConfig"] = ( | ||
camel_case_to_pascal_case(model_access_configs.get(model_id).model_dump()) | ||
) | ||
acked_model_data_sources.append(acked_model_data_source) | ||
else: | ||
acked_model_data_sources.append(model_data_source) | ||
return acked_model_data_sources | ||
|
||
|
||
def get_draft_model_content_bucket(provider: Dict, region: str) -> str: | ||
"""Returns the correct content bucket for a 1p draft model.""" | ||
neo_bucket = get_neo_content_bucket(region=region) | ||
if not provider: | ||
return neo_bucket | ||
provider_name = provider.get("name", "") | ||
if provider_name == "JumpStart": | ||
classification = provider.get("classification", "ungated") | ||
if classification == "gated": | ||
return get_jumpstart_gated_content_bucket(region=region) | ||
return get_jumpstart_content_bucket(region=region) | ||
return neo_bucket |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.