-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: optimization technique related validations. #4921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
7ec16e6
cf70f59
fcb5092
9489b8d
5512c26
c94a78b
d10c475
8fb27a0
779f6d6
aef3a90
b7b15b8
748ea4b
a7feb54
694b4f2
7b6aef1
ce47be5
1f75072
277e0b1
8f0083b
8b73f34
c06aef0
09a54dc
3b147cd
65cb5b3
4d1e12b
bf706ad
f121eb0
9148e70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -73,9 +73,8 @@ def _deployment_config_contains_draft_model(deployment_config: Optional[Dict]) - | |
return False | ||
deployment_args = deployment_config.get("DeploymentArgs", {}) | ||
additional_data_sources = deployment_args.get("AdditionalDataSources") | ||
if not additional_data_sources: | ||
return False | ||
return additional_data_sources.get("speculative_decoding", False) | ||
|
||
return "speculative_decoding" in additional_data_sources if additional_data_sources else False | ||
|
||
|
||
def _is_draft_model_jumpstart_provided(deployment_config: Optional[Dict]) -> bool: | ||
|
@@ -207,15 +206,15 @@ def _extract_speculative_draft_model_provider( | |
if speculative_decoding_config is None: | ||
return None | ||
|
||
if speculative_decoding_config.get("ModelProvider").lower() == "jumpstart": | ||
model_provider = speculative_decoding_config.get("ModelProvider", "").lower() | ||
|
||
if model_provider == "jumpstart": | ||
return "jumpstart" | ||
|
||
if speculative_decoding_config.get( | ||
"ModelProvider" | ||
).lower() == "custom" or speculative_decoding_config.get("ModelSource"): | ||
if model_provider == "custom" or speculative_decoding_config.get("ModelSource"): | ||
return "custom" | ||
|
||
if speculative_decoding_config.get("ModelProvider").lower() == "sagemaker": | ||
if model_provider == "sagemaker": | ||
return "sagemaker" | ||
|
||
return "auto" | ||
|
@@ -238,7 +237,7 @@ def _extract_additional_model_data_source_s3_uri( | |
): | ||
return None | ||
|
||
return additional_model_data_source.get("S3DataSource").get("S3Uri", None) | ||
return additional_model_data_source.get("S3DataSource").get("S3Uri") | ||
|
||
|
||
def _extract_deployment_config_additional_model_data_source_s3_uri( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Duplicate of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Deployment config uses Pascal case while the PySDK model will use snake case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ack. ToDo: We need to find a different way of closing these differences. |
||
|
@@ -272,7 +271,7 @@ def _is_draft_model_gated( | |
Returns: | ||
bool: Whether the draft model is gated or not. | ||
""" | ||
return draft_model_config.get("hosting_eula_key", None) | ||
return "hosting_eula_key" in draft_model_config if draft_model_config else False | ||
|
||
|
||
def _extracts_and_validates_speculative_model_source( | ||
|
@@ -371,7 +370,7 @@ def _extract_optimization_config_and_env( | |
compilation_config (Optional[Dict]): The compilation config. | ||
|
||
Returns: | ||
Optional[Tuple[Optional[Dict], Optional[Dict]]]: | ||
Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]: | ||
The optimization config and environment variables. | ||
""" | ||
optimization_config = {} | ||
|
@@ -388,7 +387,7 @@ def _extract_optimization_config_and_env( | |
if compilation_config is not None: | ||
optimization_config["ModelCompilationConfig"] = compilation_config | ||
|
||
# Return both dicts and environment variable if either is present | ||
# Return optimization config dict and environment variables if either is present | ||
if optimization_config: | ||
return optimization_config, quantization_override_env, compilation_override_env | ||
|
||
|
This file was deleted.
Uh oh!
There was an error while loading. Please reload this page.