Skip to content

feat: fixing instance to config auto resolution support #5251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 112 additions & 2 deletions src/sagemaker/jumpstart/artifacts/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,61 @@ def _retrieve_default_environment_variables(
sagemaker_session=sagemaker_session,
)

# Auto-detect config_name based on instance_type, even if a default config was provided
auto_detected_config_name = config_name

# For any instance type, check all available configs to find the best match
if instance_type:
from sagemaker.utils import get_instance_type_family
instance_type_family = get_instance_type_family(instance_type)

# Get all available configs to check
temp_model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=script,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=None, # Get default config first
model_type=model_type,
)

if temp_model_specs.inference_configs:
# Get config rankings to prioritize correctly
config_rankings = []
if hasattr(temp_model_specs, 'inference_config_rankings') and temp_model_specs.inference_config_rankings:
overall_rankings = temp_model_specs.inference_config_rankings.get('overall')
if overall_rankings and hasattr(overall_rankings, 'rankings'):
config_rankings = overall_rankings.rankings

# Check configs in ranking priority order (highest to lowest priority)
matching_configs = []
for config_name_candidate, config in temp_model_specs.inference_configs.configs.items():
config_resolved = config.resolved_config

if 'hosting_instance_type_variants' in config_resolved and config_resolved['hosting_instance_type_variants']:
from sagemaker.jumpstart.types import JumpStartInstanceTypeVariants
variants_dict = config_resolved['hosting_instance_type_variants']
variants = JumpStartInstanceTypeVariants(variants_dict)

# Check if this config specifically supports this instance type or family
if (variants.variants and
(instance_type in variants.variants or instance_type_family in variants.variants)):
matching_configs.append(config_name_candidate)

# Select the highest priority matching config based on rankings
if matching_configs and config_rankings:
for ranked_config in config_rankings:
if ranked_config in matching_configs:
auto_detected_config_name = ranked_config
break
elif matching_configs:
# Fallback to first match if no rankings available
auto_detected_config_name = matching_configs[0]

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
Expand All @@ -94,7 +149,7 @@ def _retrieve_default_environment_variables(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
config_name=auto_detected_config_name,
model_type=model_type,
)

Expand Down Expand Up @@ -225,6 +280,61 @@ def _retrieve_gated_model_uri_env_var_value(
sagemaker_session=sagemaker_session,
)

# Auto-detect config_name based on instance_type, even if a default config was provided
auto_detected_config_name = config_name

# For any instance type, check all available configs to find the best match
if instance_type:
from sagemaker.utils import get_instance_type_family
instance_type_family = get_instance_type_family(instance_type)

# Get all available configs to check
temp_model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=JumpStartScriptScope.TRAINING,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=None, # Get default config first
model_type=model_type,
)

if temp_model_specs.inference_configs:
# Get config rankings to prioritize correctly
config_rankings = []
if hasattr(temp_model_specs, 'inference_config_rankings') and temp_model_specs.inference_config_rankings:
overall_rankings = temp_model_specs.inference_config_rankings.get('overall')
if overall_rankings and hasattr(overall_rankings, 'rankings'):
config_rankings = overall_rankings.rankings

# Check configs in ranking priority order (highest to lowest priority)
matching_configs = []
for config_name_candidate, config in temp_model_specs.inference_configs.configs.items():
config_resolved = config.resolved_config

if 'hosting_instance_type_variants' in config_resolved and config_resolved['hosting_instance_type_variants']:
from sagemaker.jumpstart.types import JumpStartInstanceTypeVariants
variants_dict = config_resolved['hosting_instance_type_variants']
variants = JumpStartInstanceTypeVariants(variants_dict)

# Check if this config specifically supports this instance type or family
if (variants.variants and
(instance_type in variants.variants or instance_type_family in variants.variants)):
matching_configs.append(config_name_candidate)

# Select the highest priority matching config based on rankings
if matching_configs and config_rankings:
for ranked_config in config_rankings:
if ranked_config in matching_configs:
auto_detected_config_name = ranked_config
break
elif matching_configs:
# Fallback to first match if no rankings available
auto_detected_config_name = matching_configs[0]

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
Expand All @@ -234,7 +344,7 @@ def _retrieve_gated_model_uri_env_var_value(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
config_name=auto_detected_config_name,
model_type=model_type,
)

Expand Down
78 changes: 77 additions & 1 deletion src/sagemaker/jumpstart/artifacts/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.utils import get_instance_type_family
from sagemaker.session import Session


Expand Down Expand Up @@ -88,6 +89,60 @@ def _retrieve_image_uri(
sagemaker_session=sagemaker_session,
)

# Auto-detect config_name based on instance_type, even if a default config was provided
auto_detected_config_name = config_name

# For any instance type, check all available configs to find the best match
if instance_type:
instance_type_family = get_instance_type_family(instance_type)

# Get all available configs to check
temp_model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=image_scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=None, # Get default config first
model_type=model_type,
)

if temp_model_specs.inference_configs:
# Get config rankings to prioritize correctly
config_rankings = []
if hasattr(temp_model_specs, 'inference_config_rankings') and temp_model_specs.inference_config_rankings:
overall_rankings = temp_model_specs.inference_config_rankings.get('overall')
if overall_rankings and hasattr(overall_rankings, 'rankings'):
config_rankings = overall_rankings.rankings

# Check configs in ranking priority order (highest to lowest priority)
matching_configs = []
for config_name_candidate, config in temp_model_specs.inference_configs.configs.items():
config_resolved = config.resolved_config

if 'hosting_instance_type_variants' in config_resolved and config_resolved['hosting_instance_type_variants']:
from sagemaker.jumpstart.types import JumpStartInstanceTypeVariants
variants_dict = config_resolved['hosting_instance_type_variants']
variants = JumpStartInstanceTypeVariants(variants_dict)

# Check if this config specifically supports this instance type or family
if (variants.variants and
(instance_type in variants.variants or instance_type_family in variants.variants)):
matching_configs.append(config_name_candidate)

# Select the highest priority matching config based on rankings
if matching_configs and config_rankings:
for ranked_config in config_rankings:
if ranked_config in matching_configs:
auto_detected_config_name = ranked_config
break
elif matching_configs:
# Fallback to first match if no rankings available
auto_detected_config_name = matching_configs[0]

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
Expand All @@ -97,7 +152,7 @@ def _retrieve_image_uri(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
config_name=auto_detected_config_name,
model_type=model_type,
)

Expand All @@ -109,6 +164,27 @@ def _retrieve_image_uri(
)
if image_uri is not None:
return image_uri

# If the default config doesn't have the instance type, try other configs
if model_specs.inference_configs and instance_type:
instance_type_family = get_instance_type_family(instance_type)

# Try to find a config that supports this instance type
for config_name, config in model_specs.inference_configs.configs.items():
resolved_config = config.resolved_config

if 'hosting_instance_type_variants' in resolved_config and resolved_config['hosting_instance_type_variants']:
from sagemaker.jumpstart.types import JumpStartInstanceTypeVariants
variants_dict = resolved_config['hosting_instance_type_variants']
variants = JumpStartInstanceTypeVariants(variants_dict)

# Check if this config supports the instance type or instance type family
if (variants.variants and
(instance_type in variants.variants or instance_type_family in variants.variants)):
image_uri = variants.get_image_uri(instance_type=instance_type, region=region)
if image_uri is not None:
return image_uri

if hub_arn:
ecr_uri = model_specs.hosting_ecr_uri
return ecr_uri
Expand Down
57 changes: 56 additions & 1 deletion src/sagemaker/jumpstart/artifacts/model_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,61 @@ def _retrieve_model_uri(
sagemaker_session=sagemaker_session,
)

# Auto-detect config_name based on instance_type, even if a default config was provided
auto_detected_config_name = config_name

# For any instance type, check all available configs to find the best match
if instance_type:
from sagemaker.utils import get_instance_type_family
instance_type_family = get_instance_type_family(instance_type)

# Get all available configs to check
temp_model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=model_scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=None, # Get default config first
model_type=model_type,
)

if temp_model_specs.inference_configs:
# Get config rankings to prioritize correctly
config_rankings = []
if hasattr(temp_model_specs, 'inference_config_rankings') and temp_model_specs.inference_config_rankings:
overall_rankings = temp_model_specs.inference_config_rankings.get('overall')
if overall_rankings and hasattr(overall_rankings, 'rankings'):
config_rankings = overall_rankings.rankings

# Check configs in ranking priority order (highest to lowest priority)
matching_configs = []
for config_name_candidate, config in temp_model_specs.inference_configs.configs.items():
config_resolved = config.resolved_config

if 'hosting_instance_type_variants' in config_resolved and config_resolved['hosting_instance_type_variants']:
from sagemaker.jumpstart.types import JumpStartInstanceTypeVariants
variants_dict = config_resolved['hosting_instance_type_variants']
variants = JumpStartInstanceTypeVariants(variants_dict)

# Check if this config specifically supports this instance type or family
if (variants.variants and
(instance_type in variants.variants or instance_type_family in variants.variants)):
matching_configs.append(config_name_candidate)

# Select the highest priority matching config based on rankings
if matching_configs and config_rankings:
for ranked_config in config_rankings:
if ranked_config in matching_configs:
auto_detected_config_name = ranked_config
break
elif matching_configs:
# Fallback to first match if no rankings available
auto_detected_config_name = matching_configs[0]

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
Expand All @@ -153,7 +208,7 @@ def _retrieve_model_uri(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
config_name=auto_detected_config_name,
model_type=model_type,
)

Expand Down
57 changes: 56 additions & 1 deletion src/sagemaker/jumpstart/artifacts/resource_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,61 @@ def _retrieve_default_resources(
sagemaker_session=sagemaker_session,
)

# Auto-detect config_name based on instance_type, even if a default config was provided
auto_detected_config_name = config_name

# For any instance type, check all available configs to find the best match
if instance_type:
from sagemaker.utils import get_instance_type_family
instance_type_family = get_instance_type_family(instance_type)

# Get all available configs to check
temp_model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
model_type=model_type,
sagemaker_session=sagemaker_session,
config_name=None, # Get default config first
)

if temp_model_specs.inference_configs:
# Get config rankings to prioritize correctly
config_rankings = []
if hasattr(temp_model_specs, 'inference_config_rankings') and temp_model_specs.inference_config_rankings:
overall_rankings = temp_model_specs.inference_config_rankings.get('overall')
if overall_rankings and hasattr(overall_rankings, 'rankings'):
config_rankings = overall_rankings.rankings

# Check configs in ranking priority order (highest to lowest priority)
matching_configs = []
for config_name_candidate, config in temp_model_specs.inference_configs.configs.items():
config_resolved = config.resolved_config

if 'hosting_instance_type_variants' in config_resolved and config_resolved['hosting_instance_type_variants']:
from sagemaker.jumpstart.types import JumpStartInstanceTypeVariants
variants_dict = config_resolved['hosting_instance_type_variants']
variants = JumpStartInstanceTypeVariants(variants_dict)

# Check if this config specifically supports this instance type or family
if (variants.variants and
(instance_type in variants.variants or instance_type_family in variants.variants)):
matching_configs.append(config_name_candidate)

# Select the highest priority matching config based on rankings
if matching_configs and config_rankings:
for ranked_config in config_rankings:
if ranked_config in matching_configs:
auto_detected_config_name = ranked_config
break
elif matching_configs:
# Fallback to first match if no rankings available
auto_detected_config_name = matching_configs[0]

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
Expand All @@ -108,7 +163,7 @@ def _retrieve_default_resources(
tolerate_deprecated_model=tolerate_deprecated_model,
model_type=model_type,
sagemaker_session=sagemaker_session,
config_name=config_name,
config_name=auto_detected_config_name,
)

if scope == JumpStartScriptScope.INFERENCE:
Expand Down
Loading