Skip to content

Commit f43e8a0

Browse files
authored
Merge branch 'master' into fix-pr-checks
2 parents be2a47f + 60574e5 commit f43e8a0

36 files changed

+1850
-14610
lines changed

sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,13 +352,18 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
352352
elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
353353
recipe = next((r for r in recipes_with_template if not r.get("Peft")), None)
354354

355-
if recipe and recipe.get("SmtjOverrideParamsS3Uri"):
355+
if not recipe:
356+
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}")
357+
358+
elif recipe and recipe.get("SmtjOverrideParamsS3Uri"):
356359
s3_uri = recipe["SmtjOverrideParamsS3Uri"]
357360
s3 = boto3.client("s3")
358361
bucket, key = s3_uri.replace("s3://", "").split("/", 1)
359362
obj = s3.get_object(Bucket=bucket, Key=key)
360363
options_dict = json.loads(obj["Body"].read())
361364
return FineTuningOptions(options_dict), model_arn, is_gated_model
365+
else:
366+
return FineTuningOptions({}), model_arn, is_gated_model
362367

363368
except Exception as e:
364369
logger.error("Exception getting fine-tuning options: %s", e)
@@ -598,6 +603,9 @@ def _create_output_config(sagemaker_session,s3_output_path=None, kms_key_id=None
598603
# Use default S3 output path if none provided
599604
if s3_output_path is None:
600605
s3_output_path = _get_default_s3_output_path(sagemaker_session)
606+
607+
# Validate S3 path exists
608+
_validate_s3_path_exists(s3_output_path, sagemaker_session)
601609

602610
return OutputDataConfig(
603611
s3_output_path=s3_output_path,
@@ -682,3 +690,43 @@ def _validate_eula_for_gated_model(model, accept_eula, is_gated_model):
682690
)
683691

684692
return accept_eula
693+
694+
695+
def _validate_s3_path_exists(s3_path: str, sagemaker_session):
696+
"""Validate if S3 path exists and is accessible."""
697+
if not s3_path.startswith("s3://"):
698+
raise ValueError(f"Invalid S3 path format: {s3_path}")
699+
700+
# Parse S3 URI
701+
s3_parts = s3_path.replace("s3://", "").split("/", 1)
702+
bucket_name = s3_parts[0]
703+
prefix = s3_parts[1] if len(s3_parts) > 1 else ""
704+
705+
s3_client = sagemaker_session.boto_session.client('s3')
706+
707+
try:
708+
# Check if bucket exists and is accessible
709+
s3_client.head_bucket(Bucket=bucket_name)
710+
711+
# If prefix is provided, check if it exists
712+
if prefix:
713+
response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix, MaxKeys=1)
714+
if 'Contents' not in response:
715+
raise ValueError(f"S3 prefix '{prefix}' does not exist in bucket '{bucket_name}'")
716+
717+
except Exception as e:
718+
if "NoSuchBucket" in str(e):
719+
raise ValueError(f"S3 bucket '{bucket_name}' does not exist or is not accessible")
720+
raise ValueError(f"Failed to validate S3 path '{s3_path}': {str(e)}")
721+
722+
723+
def _validate_hyperparameter_values(hyperparameters: dict):
724+
"""Validate hyperparameter values for allowed characters."""
725+
import re
726+
allowed_chars = r"^[a-zA-Z0-9/_.:,\-\s'\"\[\]]*$"
727+
for key, value in hyperparameters.items():
728+
if isinstance(value, str) and not re.match(allowed_chars, value):
729+
raise ValueError(
730+
f"Hyperparameter '{key}' value '{value}' contains invalid characters. "
731+
f"Only a-z, A-Z, 0-9, /, _, ., :, \\, -, space, ', \", [, ] and , are allowed."
732+
)

sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py

Lines changed: 33 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ def __init__(self, sagemaker_session=None):
5858
5959
Args:
6060
sagemaker_session: SageMaker session to use for API calls.
61-
If None, will be created with beta endpoint if configured.
61+
If None, will be created with endpoint if configured.
6262
"""
6363
self.sagemaker_session = sagemaker_session
64-
self._beta_endpoint = os.environ.get('SAGEMAKER_ENDPOINT')
64+
self._endpoint = os.environ.get('SAGEMAKER_ENDPOINT')
6565

6666
def resolve_model_info(
6767
self,
@@ -188,8 +188,21 @@ def _resolve_model_package_object(self, model_package: 'ModelPackage') -> _Model
188188
base_model_name = hub_content_name
189189
if hasattr(container.base_model, 'hub_content_arn'):
190190
base_model_arn = container.base_model.hub_content_arn
191+
192+
# If hub_content_arn is not present, construct it from hub_content_name and version
193+
if not base_model_arn and hasattr(container.base_model, 'hub_content_version'):
194+
hub_content_version = container.base_model.hub_content_version
195+
model_pkg_arn = getattr(model_package, 'model_package_arn', None)
196+
197+
if hub_content_name and hub_content_version and model_pkg_arn:
198+
# Extract region from model package ARN
199+
arn_parts = model_pkg_arn.split(':')
200+
if len(arn_parts) >= 4:
201+
region = arn_parts[3]
202+
# Construct hub content ARN for SageMaker public hub
203+
base_model_arn = f"arn:aws:sagemaker:{region}:aws:hub-content/SageMakerPublicHub/Model/{hub_content_name}/{hub_content_version}"
191204

192-
# If we couldn't extract base model ARN, this is not a supported model package
205+
# If we couldn't extract or construct base model ARN, this is not a supported model package
193206
if not base_model_arn:
194207
raise ValueError(
195208
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
@@ -234,82 +247,23 @@ def _resolve_model_package_arn(self, model_package_arn: str) -> _ModelInfo:
234247
# Validate ARN format
235248
self._validate_model_package_arn(model_package_arn)
236249

237-
# TODO: Switch to sagemaker_core ModelPackage.get() once the bug is fixed
238-
# Currently, ModelPackage.get() has a Pydantic validation issue where
239-
# the transform() function doesn't include model_package_name in the response,
240-
# causing: "1 validation error for ModelPackage - model_package_name: Field required"
241-
# Using boto3 directly as a workaround.
242-
243-
# Use the sagemaker client from the session (which has the correct endpoint configured)
244-
sm_client = session.sagemaker_client if hasattr(session, 'sagemaker_client') else session.boto_session.client('sagemaker')
245-
response = sm_client.describe_model_package(ModelPackageName=model_package_arn)
246-
247-
# Extract base model info from response
248-
base_model_name = None
249-
base_model_arn = None
250-
hub_content_name = None
250+
# Use sagemaker.core ModelPackage.get() to retrieve model package information
251+
from sagemaker.core.resources import ModelPackage
251252

252-
# Check inference specification
253-
if 'InferenceSpecification' not in response:
254-
raise ValueError(
255-
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
256-
f"The provided model package (ARN: {model_package_arn}) "
257-
f"does not have an inference_specification."
258-
)
253+
import logging
254+
logger = logging.getLogger(__name__)
259255

260-
inf_spec = response['InferenceSpecification']
261-
if 'Containers' not in inf_spec or len(inf_spec['Containers']) == 0:
262-
raise ValueError(
263-
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
264-
f"The provided model package (ARN: {model_package_arn}) "
265-
f"does not have any containers in its inference_specification."
266-
)
267-
268-
container = inf_spec['Containers'][0]
269-
270-
# Extract base model info
271-
if 'BaseModel' not in container:
272-
raise ValueError(
273-
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
274-
f"The provided model package (ARN: {model_package_arn}) "
275-
f"does not have base_model metadata in its inference_specification.containers[0]. "
276-
f"Please ensure the model was created using SageMaker's fine-tuning capabilities."
277-
)
278-
279-
base_model_info = container['BaseModel']
280-
hub_content_name = base_model_info.get('HubContentName')
281-
hub_content_version = base_model_info.get('HubContentVersion')
282-
base_model_arn = base_model_info.get('HubContentArn')
283-
284-
# If HubContentArn is None, construct it from HubContentName and version
285-
# This handles cases where the API doesn't return the full ARN
286-
if not base_model_arn and hub_content_name and hub_content_version:
287-
# Extract region from model_package_arn
288-
arn_parts = model_package_arn.split(':')
289-
if len(arn_parts) >= 4:
290-
region = arn_parts[3]
291-
# Construct hub content ARN for SageMaker public hub
292-
base_model_arn = f"arn:aws:sagemaker:{region}:aws:hub-content/SageMakerPublicHub/Model/{hub_content_name}/{hub_content_version}"
293-
294-
if not base_model_arn:
295-
raise ValueError(
296-
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
297-
f"The provided model package (ARN: {model_package_arn}) "
298-
f"does not have base_model metadata with HubContentArn or sufficient information to construct it. "
299-
f"Please ensure the model was created using SageMaker's fine-tuning capabilities."
300-
)
256+
# Get the model package using sagemaker.core
257+
model_package = ModelPackage.get(
258+
model_package_name=model_package_arn,
259+
session=session.boto_session,
260+
region=session.boto_session.region_name
261+
)
301262

302-
# Use hub_content_name as base_model_name
303-
base_model_name = hub_content_name if hub_content_name else response.get('ModelPackageGroupName', 'unknown')
263+
logger.info(f"Retrieved ModelPackage in region: {session.boto_session.region_name}")
304264

305-
return _ModelInfo(
306-
base_model_name=base_model_name,
307-
base_model_arn=base_model_arn,
308-
source_model_package_arn=model_package_arn,
309-
model_type=_ModelType.FINE_TUNED,
310-
hub_content_name=hub_content_name,
311-
additional_metadata={}
312-
)
265+
# Now use the existing _resolve_model_package_object method to extract base model info
266+
return self._resolve_model_package_object(model_package)
313267

314268
except ValueError:
315269
# Re-raise ValueError as-is (our custom error messages)
@@ -342,7 +296,7 @@ def _validate_model_package_arn(self, arn: str) -> bool:
342296

343297
def _get_session(self):
344298
"""
345-
Get or create SageMaker session with beta endpoint support.
299+
Get or create SageMaker session with endpoint support.
346300
347301
Returns:
348302
SageMaker session
@@ -352,12 +306,11 @@ def _get_session(self):
352306

353307
from sagemaker.core.helper.session_helper import Session
354308

355-
# Check for beta endpoint in environment variable
356-
if self._beta_endpoint:
309+
# Check for endpoint in environment variable
310+
if self._endpoint:
357311
sm_client = boto3.client(
358312
'sagemaker',
359-
endpoint_url=self._beta_endpoint,
360-
region_name=os.environ.get('AWS_REGION', 'us-west-2')
313+
endpoint_url=self._endpoint
361314
)
362315
return Session(sagemaker_client=sm_client)
363316

sagemaker-train/src/sagemaker/train/dpo_trainer.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
_create_serverless_config,
1818
_create_mlflow_config,
1919
_create_model_package_config,
20-
_validate_eula_for_gated_model
20+
_validate_eula_for_gated_model,
21+
_validate_hyperparameter_values
2122
)
2223
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
2324
from sagemaker.core.telemetry.constants import Feature
@@ -137,8 +138,38 @@ def __init__(
137138

138139
))
139140

141+
# Process hyperparameters
142+
self._process_hyperparameters()
143+
140144
# Validate and set EULA acceptance
141145
self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model)
146+
147+
def _process_hyperparameters(self):
148+
"""Remove hyperparameter keys that are handled by constructor inputs."""
149+
if self.hyperparameters:
150+
# Remove keys that are handled by constructor inputs
151+
if hasattr(self.hyperparameters, 'data_path'):
152+
delattr(self.hyperparameters, 'data_path')
153+
self.hyperparameters._specs.pop('data_path', None)
154+
if hasattr(self.hyperparameters, 'output_path'):
155+
delattr(self.hyperparameters, 'output_path')
156+
self.hyperparameters._specs.pop('output_path', None)
157+
if hasattr(self.hyperparameters, 'data_s3_path'):
158+
delattr(self.hyperparameters, 'data_s3_path')
159+
self.hyperparameters._specs.pop('data_s3_path', None)
160+
if hasattr(self.hyperparameters, 'output_s3_path'):
161+
delattr(self.hyperparameters, 'output_s3_path')
162+
self.hyperparameters._specs.pop('output_s3_path', None)
163+
if hasattr(self.hyperparameters, 'training_data_name'):
164+
delattr(self.hyperparameters, 'training_data_name')
165+
self.hyperparameters._specs.pop('training_data_name', None)
166+
if hasattr(self.hyperparameters, 'validation_data_name'):
167+
delattr(self.hyperparameters, 'validation_data_name')
168+
self.hyperparameters._specs.pop('validation_data_name', None)
169+
if hasattr(self.hyperparameters, 'validation_data_path'):
170+
delattr(self.hyperparameters, 'validation_data_path')
171+
self.hyperparameters._specs.pop('validation_data_path', None)
172+
142173
@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="DPOTrainer.train")
143174
def train(self,
144175
training_dataset: Optional[Union[str, DataSet]] = None,
@@ -198,6 +229,7 @@ def train(self,
198229
)
199230

200231
final_hyperparameters = self.hyperparameters.to_dict()
232+
_validate_hyperparameter_values(final_hyperparameters)
201233

202234
model_package_config = _create_model_package_config(
203235
model_package_group_name=self.model_package_group_name,

sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,8 @@ def _get_or_create_artifact_arn(self, source_uri: str, region: str) -> str:
546546
properties['HubContentArn'] = source_uri
547547
else:
548548
properties['SourceUri'] = source_uri
549+
550+
_logger.info(f"source_uri: {source_uri}, region: {region}, properties: {properties}")
549551

550552
# Create artifact using Artifact.create()
551553
artifact = Artifact.create(

sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,10 @@ def _get_custom_scorer_template_additions(self, evaluator_config: dict) -> dict:
308308
'evaluator_arn': evaluator_config['evaluator_arn'],
309309
}
310310

311+
# Add lambda_type for Nova models
312+
if is_nova:
313+
custom_scorer_context['lambda_type'] = 'rft'
314+
311315
# Add preset_reward_function if present
312316
if evaluator_config['preset_reward_function']:
313317
custom_scorer_context['preset_reward_function'] = evaluator_config['preset_reward_function']

sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,8 @@
632632
"task": "{{ task }}",
633633
"strategy": "{{ strategy }}"{% if metric is defined %},
634634
"metric": "{{ metric }}"{% elif evaluation_metric is defined %},
635-
"evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if max_new_tokens is defined %},
635+
"evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if lambda_type is defined %},
636+
"lambda_type": "{{ lambda_type }}"{% endif %}{% if max_new_tokens is defined %},
636637
"max_new_tokens": "{{ max_new_tokens }}"{% endif %}{% if temperature is defined %},
637638
"temperature": "{{ temperature }}"{% endif %}{% if top_k is defined %},
638639
"top_k": "{{ top_k }}"{% endif %}{% if top_p is defined %},
@@ -694,7 +695,8 @@
694695
"task": "{{ task }}",
695696
"strategy": "{{ strategy }}"{% if metric is defined %},
696697
"metric": "{{ metric }}"{% elif evaluation_metric is defined %},
697-
"evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if max_new_tokens is defined %},
698+
"evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if lambda_type is defined %},
699+
"lambda_type": "{{ lambda_type }}"{% endif %}{% if max_new_tokens is defined %},
698700
"max_new_tokens": "{{ max_new_tokens }}"{% endif %}{% if temperature is defined %},
699701
"temperature": "{{ temperature }}"{% endif %}{% if top_k is defined %},
700702
"top_k": "{{ top_k }}"{% endif %}{% if top_p is defined %},
@@ -872,7 +874,8 @@
872874
"task": "{{ task }}",
873875
"strategy": "{{ strategy }}"{% if metric is defined %},
874876
"metric": "{{ metric }}"{% elif evaluation_metric is defined %},
875-
"evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if max_new_tokens is defined %},
877+
"evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if lambda_type is defined %},
878+
"lambda_type": "{{ lambda_type }}"{% endif %}{% if max_new_tokens is defined %},
876879
"max_new_tokens": "{{ max_new_tokens }}"{% endif %}{% if temperature is defined %},
877880
"temperature": "{{ temperature }}"{% endif %}{% if top_k is defined %},
878881
"top_k": "{{ top_k }}"{% endif %}{% if top_p is defined %},

0 commit comments

Comments
 (0)