Skip to content

Commit 55abc74

Browse files
committed
Merge remote-tracking branch 'upstream/master' into hmac-key-fix-v3
2 parents 14417bb + 60574e5 commit 55abc74

39 files changed

+1996
-14788
lines changed

sagemaker-train/src/sagemaker/ai_registry/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def _create_lambda_function(cls, name: str, source_file: str, role: Optional[str
381381

382382
# Create Lambda function
383383
lambda_client = boto3.client("lambda")
384-
function_name = f"SageMaker-evaluator-{name}"
384+
function_name = f"SageMaker-evaluator-{name}-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
385385
handler_name = f"{os.path.splitext(os.path.basename(source_file))[0]}.lambda_handler"
386386

387387
try:

sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -343,36 +343,27 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
343343
recipes_with_template = [r for r in matching_recipes if r.get("SmtjRecipeTemplateS3Uri")]
344344

345345
if not recipes_with_template:
346-
raise ValueError(f"No recipes found with SmtjRecipeTemplateS3Uri for technique: {customization_technique}")
347-
348-
# If multiple recipes, filter by training_type (peft key)
349-
if len(recipes_with_template) > 1:
350-
351-
if isinstance(training_type, TrainingType) and training_type == TrainingType.LORA:
352-
# Filter recipes that have peft key for LORA
353-
lora_recipes = [r for r in recipes_with_template if r.get("Peft")]
354-
if lora_recipes:
355-
recipes_with_template = lora_recipes
356-
elif len(recipes_with_template) > 1:
357-
raise ValueError(f"Multiple recipes found for LORA training but none have peft key")
358-
elif isinstance(training_type, TrainingType) and training_type == TrainingType.FULL:
359-
# For FULL training, if multiple recipes exist, throw error
360-
if len(recipes_with_template) > 1:
361-
raise ValueError(f"Multiple recipes found for FULL training - cannot determine which to use")
362-
363-
# If still multiple recipes after filtering, throw error
364-
if len(recipes_with_template) > 1:
365-
raise ValueError(f"Multiple recipes found after filtering - cannot determine which to use")
366-
367-
recipe = recipes_with_template[0]
368-
369-
if recipe and recipe.get("SmtjOverrideParamsS3Uri"):
346+
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}")
347+
348+
# Select recipe based on training type
349+
recipe = None
350+
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
351+
recipe = next((r for r in recipes_with_template if r.get("Peft")), None)
352+
elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
353+
recipe = next((r for r in recipes_with_template if not r.get("Peft")), None)
354+
355+
if not recipe:
356+
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}")
357+
358+
elif recipe and recipe.get("SmtjOverrideParamsS3Uri"):
370359
s3_uri = recipe["SmtjOverrideParamsS3Uri"]
371360
s3 = boto3.client("s3")
372361
bucket, key = s3_uri.replace("s3://", "").split("/", 1)
373362
obj = s3.get_object(Bucket=bucket, Key=key)
374363
options_dict = json.loads(obj["Body"].read())
375364
return FineTuningOptions(options_dict), model_arn, is_gated_model
365+
else:
366+
return FineTuningOptions({}), model_arn, is_gated_model
376367

377368
except Exception as e:
378369
logger.error("Exception getting fine-tuning options: %s", e)
@@ -612,6 +603,9 @@ def _create_output_config(sagemaker_session,s3_output_path=None, kms_key_id=None
612603
# Use default S3 output path if none provided
613604
if s3_output_path is None:
614605
s3_output_path = _get_default_s3_output_path(sagemaker_session)
606+
607+
# Validate S3 path exists
608+
_validate_s3_path_exists(s3_output_path, sagemaker_session)
615609

616610
return OutputDataConfig(
617611
s3_output_path=s3_output_path,
@@ -696,3 +690,43 @@ def _validate_eula_for_gated_model(model, accept_eula, is_gated_model):
696690
)
697691

698692
return accept_eula
693+
694+
695+
def _validate_s3_path_exists(s3_path: str, sagemaker_session):
696+
"""Validate if S3 path exists and is accessible."""
697+
if not s3_path.startswith("s3://"):
698+
raise ValueError(f"Invalid S3 path format: {s3_path}")
699+
700+
# Parse S3 URI
701+
s3_parts = s3_path.replace("s3://", "").split("/", 1)
702+
bucket_name = s3_parts[0]
703+
prefix = s3_parts[1] if len(s3_parts) > 1 else ""
704+
705+
s3_client = sagemaker_session.boto_session.client('s3')
706+
707+
try:
708+
# Check if bucket exists and is accessible
709+
s3_client.head_bucket(Bucket=bucket_name)
710+
711+
# If prefix is provided, check if it exists
712+
if prefix:
713+
response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix, MaxKeys=1)
714+
if 'Contents' not in response:
715+
raise ValueError(f"S3 prefix '{prefix}' does not exist in bucket '{bucket_name}'")
716+
717+
except Exception as e:
718+
if "NoSuchBucket" in str(e):
719+
raise ValueError(f"S3 bucket '{bucket_name}' does not exist or is not accessible")
720+
raise ValueError(f"Failed to validate S3 path '{s3_path}': {str(e)}")
721+
722+
723+
def _validate_hyperparameter_values(hyperparameters: dict):
724+
"""Validate hyperparameter values for allowed characters."""
725+
import re
726+
allowed_chars = r"^[a-zA-Z0-9/_.:,\-\s'\"\[\]]*$"
727+
for key, value in hyperparameters.items():
728+
if isinstance(value, str) and not re.match(allowed_chars, value):
729+
raise ValueError(
730+
f"Hyperparameter '{key}' value '{value}' contains invalid characters. "
731+
f"Only a-z, A-Z, 0-9, /, _, ., :, \\, -, space, ', \", [, ] and , are allowed."
732+
)

sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py

Lines changed: 33 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ def __init__(self, sagemaker_session=None):
5858
5959
Args:
6060
sagemaker_session: SageMaker session to use for API calls.
61-
If None, will be created with beta endpoint if configured.
61+
If None, will be created with endpoint if configured.
6262
"""
6363
self.sagemaker_session = sagemaker_session
64-
self._beta_endpoint = os.environ.get('SAGEMAKER_ENDPOINT')
64+
self._endpoint = os.environ.get('SAGEMAKER_ENDPOINT')
6565

6666
def resolve_model_info(
6767
self,
@@ -188,8 +188,21 @@ def _resolve_model_package_object(self, model_package: 'ModelPackage') -> _Model
188188
base_model_name = hub_content_name
189189
if hasattr(container.base_model, 'hub_content_arn'):
190190
base_model_arn = container.base_model.hub_content_arn
191+
192+
# If hub_content_arn is not present, construct it from hub_content_name and version
193+
if not base_model_arn and hasattr(container.base_model, 'hub_content_version'):
194+
hub_content_version = container.base_model.hub_content_version
195+
model_pkg_arn = getattr(model_package, 'model_package_arn', None)
196+
197+
if hub_content_name and hub_content_version and model_pkg_arn:
198+
# Extract region from model package ARN
199+
arn_parts = model_pkg_arn.split(':')
200+
if len(arn_parts) >= 4:
201+
region = arn_parts[3]
202+
# Construct hub content ARN for SageMaker public hub
203+
base_model_arn = f"arn:aws:sagemaker:{region}:aws:hub-content/SageMakerPublicHub/Model/{hub_content_name}/{hub_content_version}"
191204

192-
# If we couldn't extract base model ARN, this is not a supported model package
205+
# If we couldn't extract or construct base model ARN, this is not a supported model package
193206
if not base_model_arn:
194207
raise ValueError(
195208
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
@@ -234,82 +247,23 @@ def _resolve_model_package_arn(self, model_package_arn: str) -> _ModelInfo:
234247
# Validate ARN format
235248
self._validate_model_package_arn(model_package_arn)
236249

237-
# TODO: Switch to sagemaker_core ModelPackage.get() once the bug is fixed
238-
# Currently, ModelPackage.get() has a Pydantic validation issue where
239-
# the transform() function doesn't include model_package_name in the response,
240-
# causing: "1 validation error for ModelPackage - model_package_name: Field required"
241-
# Using boto3 directly as a workaround.
242-
243-
# Use the sagemaker client from the session (which has the correct endpoint configured)
244-
sm_client = session.sagemaker_client if hasattr(session, 'sagemaker_client') else session.boto_session.client('sagemaker')
245-
response = sm_client.describe_model_package(ModelPackageName=model_package_arn)
246-
247-
# Extract base model info from response
248-
base_model_name = None
249-
base_model_arn = None
250-
hub_content_name = None
250+
# Use sagemaker.core ModelPackage.get() to retrieve model package information
251+
from sagemaker.core.resources import ModelPackage
251252

252-
# Check inference specification
253-
if 'InferenceSpecification' not in response:
254-
raise ValueError(
255-
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
256-
f"The provided model package (ARN: {model_package_arn}) "
257-
f"does not have an inference_specification."
258-
)
253+
import logging
254+
logger = logging.getLogger(__name__)
259255

260-
inf_spec = response['InferenceSpecification']
261-
if 'Containers' not in inf_spec or len(inf_spec['Containers']) == 0:
262-
raise ValueError(
263-
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
264-
f"The provided model package (ARN: {model_package_arn}) "
265-
f"does not have any containers in its inference_specification."
266-
)
267-
268-
container = inf_spec['Containers'][0]
269-
270-
# Extract base model info
271-
if 'BaseModel' not in container:
272-
raise ValueError(
273-
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
274-
f"The provided model package (ARN: {model_package_arn}) "
275-
f"does not have base_model metadata in its inference_specification.containers[0]. "
276-
f"Please ensure the model was created using SageMaker's fine-tuning capabilities."
277-
)
278-
279-
base_model_info = container['BaseModel']
280-
hub_content_name = base_model_info.get('HubContentName')
281-
hub_content_version = base_model_info.get('HubContentVersion')
282-
base_model_arn = base_model_info.get('HubContentArn')
283-
284-
# If HubContentArn is None, construct it from HubContentName and version
285-
# This handles cases where the API doesn't return the full ARN
286-
if not base_model_arn and hub_content_name and hub_content_version:
287-
# Extract region from model_package_arn
288-
arn_parts = model_package_arn.split(':')
289-
if len(arn_parts) >= 4:
290-
region = arn_parts[3]
291-
# Construct hub content ARN for SageMaker public hub
292-
base_model_arn = f"arn:aws:sagemaker:{region}:aws:hub-content/SageMakerPublicHub/Model/{hub_content_name}/{hub_content_version}"
293-
294-
if not base_model_arn:
295-
raise ValueError(
296-
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
297-
f"The provided model package (ARN: {model_package_arn}) "
298-
f"does not have base_model metadata with HubContentArn or sufficient information to construct it. "
299-
f"Please ensure the model was created using SageMaker's fine-tuning capabilities."
300-
)
256+
# Get the model package using sagemaker.core
257+
model_package = ModelPackage.get(
258+
model_package_name=model_package_arn,
259+
session=session.boto_session,
260+
region=session.boto_session.region_name
261+
)
301262

302-
# Use hub_content_name as base_model_name
303-
base_model_name = hub_content_name if hub_content_name else response.get('ModelPackageGroupName', 'unknown')
263+
logger.info(f"Retrieved ModelPackage in region: {session.boto_session.region_name}")
304264

305-
return _ModelInfo(
306-
base_model_name=base_model_name,
307-
base_model_arn=base_model_arn,
308-
source_model_package_arn=model_package_arn,
309-
model_type=_ModelType.FINE_TUNED,
310-
hub_content_name=hub_content_name,
311-
additional_metadata={}
312-
)
265+
# Now use the existing _resolve_model_package_object method to extract base model info
266+
return self._resolve_model_package_object(model_package)
313267

314268
except ValueError:
315269
# Re-raise ValueError as-is (our custom error messages)
@@ -342,7 +296,7 @@ def _validate_model_package_arn(self, arn: str) -> bool:
342296

343297
def _get_session(self):
344298
"""
345-
Get or create SageMaker session with beta endpoint support.
299+
Get or create SageMaker session with endpoint support.
346300
347301
Returns:
348302
SageMaker session
@@ -352,12 +306,11 @@ def _get_session(self):
352306

353307
from sagemaker.core.helper.session_helper import Session
354308

355-
# Check for beta endpoint in environment variable
356-
if self._beta_endpoint:
309+
# Check for endpoint in environment variable
310+
if self._endpoint:
357311
sm_client = boto3.client(
358312
'sagemaker',
359-
endpoint_url=self._beta_endpoint,
360-
region_name=os.environ.get('AWS_REGION', 'us-west-2')
313+
endpoint_url=self._endpoint
361314
)
362315
return Session(sagemaker_client=sm_client)
363316

0 commit comments

Comments
 (0)