Skip to content

Commit 066221f

Browse files
authored
Merge branch 'aws:master' into master-mc-bug-fixes
2 parents 8562a22 + f32f615 commit 066221f

29 files changed

+1128
-14668
lines changed

sagemaker-train/src/sagemaker/ai_registry/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def _create_lambda_function(cls, name: str, source_file: str, role: Optional[str
381381

382382
# Create Lambda function
383383
lambda_client = boto3.client("lambda")
384-
function_name = f"SageMaker-evaluator-{name}"
384+
function_name = f"SageMaker-evaluator-{name}-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
385385
handler_name = f"{os.path.splitext(os.path.basename(source_file))[0]}.lambda_handler"
386386

387387
try:

sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py

Lines changed: 33 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ def __init__(self, sagemaker_session=None):
5858
5959
Args:
6060
sagemaker_session: SageMaker session to use for API calls.
61-
If None, will be created with beta endpoint if configured.
61+
If None, will be created with endpoint if configured.
6262
"""
6363
self.sagemaker_session = sagemaker_session
64-
self._beta_endpoint = os.environ.get('SAGEMAKER_ENDPOINT')
64+
self._endpoint = os.environ.get('SAGEMAKER_ENDPOINT')
6565

6666
def resolve_model_info(
6767
self,
@@ -188,8 +188,21 @@ def _resolve_model_package_object(self, model_package: 'ModelPackage') -> _Model
188188
base_model_name = hub_content_name
189189
if hasattr(container.base_model, 'hub_content_arn'):
190190
base_model_arn = container.base_model.hub_content_arn
191+
192+
# If hub_content_arn is not present, construct it from hub_content_name and version
193+
if not base_model_arn and hasattr(container.base_model, 'hub_content_version'):
194+
hub_content_version = container.base_model.hub_content_version
195+
model_pkg_arn = getattr(model_package, 'model_package_arn', None)
196+
197+
if hub_content_name and hub_content_version and model_pkg_arn:
198+
# Extract region from model package ARN
199+
arn_parts = model_pkg_arn.split(':')
200+
if len(arn_parts) >= 4:
201+
region = arn_parts[3]
202+
# Construct hub content ARN for SageMaker public hub
203+
base_model_arn = f"arn:aws:sagemaker:{region}:aws:hub-content/SageMakerPublicHub/Model/{hub_content_name}/{hub_content_version}"
191204

192-
# If we couldn't extract base model ARN, this is not a supported model package
205+
# If we couldn't extract or construct base model ARN, this is not a supported model package
193206
if not base_model_arn:
194207
raise ValueError(
195208
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
@@ -234,82 +247,23 @@ def _resolve_model_package_arn(self, model_package_arn: str) -> _ModelInfo:
234247
# Validate ARN format
235248
self._validate_model_package_arn(model_package_arn)
236249

237-
# TODO: Switch to sagemaker_core ModelPackage.get() once the bug is fixed
238-
# Currently, ModelPackage.get() has a Pydantic validation issue where
239-
# the transform() function doesn't include model_package_name in the response,
240-
# causing: "1 validation error for ModelPackage - model_package_name: Field required"
241-
# Using boto3 directly as a workaround.
242-
243-
# Use the sagemaker client from the session (which has the correct endpoint configured)
244-
sm_client = session.sagemaker_client if hasattr(session, 'sagemaker_client') else session.boto_session.client('sagemaker')
245-
response = sm_client.describe_model_package(ModelPackageName=model_package_arn)
246-
247-
# Extract base model info from response
248-
base_model_name = None
249-
base_model_arn = None
250-
hub_content_name = None
250+
# Use sagemaker.core ModelPackage.get() to retrieve model package information
251+
from sagemaker.core.resources import ModelPackage
251252

252-
# Check inference specification
253-
if 'InferenceSpecification' not in response:
254-
raise ValueError(
255-
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
256-
f"The provided model package (ARN: {model_package_arn}) "
257-
f"does not have an inference_specification."
258-
)
253+
import logging
254+
logger = logging.getLogger(__name__)
259255

260-
inf_spec = response['InferenceSpecification']
261-
if 'Containers' not in inf_spec or len(inf_spec['Containers']) == 0:
262-
raise ValueError(
263-
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
264-
f"The provided model package (ARN: {model_package_arn}) "
265-
f"does not have any containers in its inference_specification."
266-
)
267-
268-
container = inf_spec['Containers'][0]
269-
270-
# Extract base model info
271-
if 'BaseModel' not in container:
272-
raise ValueError(
273-
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
274-
f"The provided model package (ARN: {model_package_arn}) "
275-
f"does not have base_model metadata in its inference_specification.containers[0]. "
276-
f"Please ensure the model was created using SageMaker's fine-tuning capabilities."
277-
)
278-
279-
base_model_info = container['BaseModel']
280-
hub_content_name = base_model_info.get('HubContentName')
281-
hub_content_version = base_model_info.get('HubContentVersion')
282-
base_model_arn = base_model_info.get('HubContentArn')
283-
284-
# If HubContentArn is None, construct it from HubContentName and version
285-
# This handles cases where the API doesn't return the full ARN
286-
if not base_model_arn and hub_content_name and hub_content_version:
287-
# Extract region from model_package_arn
288-
arn_parts = model_package_arn.split(':')
289-
if len(arn_parts) >= 4:
290-
region = arn_parts[3]
291-
# Construct hub content ARN for SageMaker public hub
292-
base_model_arn = f"arn:aws:sagemaker:{region}:aws:hub-content/SageMakerPublicHub/Model/{hub_content_name}/{hub_content_version}"
293-
294-
if not base_model_arn:
295-
raise ValueError(
296-
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
297-
f"The provided model package (ARN: {model_package_arn}) "
298-
f"does not have base_model metadata with HubContentArn or sufficient information to construct it. "
299-
f"Please ensure the model was created using SageMaker's fine-tuning capabilities."
300-
)
256+
# Get the model package using sagemaker.core
257+
model_package = ModelPackage.get(
258+
model_package_name=model_package_arn,
259+
session=session.boto_session,
260+
region=session.boto_session.region_name
261+
)
301262

302-
# Use hub_content_name as base_model_name
303-
base_model_name = hub_content_name if hub_content_name else response.get('ModelPackageGroupName', 'unknown')
263+
logger.info(f"Retrieved ModelPackage in region: {session.boto_session.region_name}")
304264

305-
return _ModelInfo(
306-
base_model_name=base_model_name,
307-
base_model_arn=base_model_arn,
308-
source_model_package_arn=model_package_arn,
309-
model_type=_ModelType.FINE_TUNED,
310-
hub_content_name=hub_content_name,
311-
additional_metadata={}
312-
)
265+
# Now use the existing _resolve_model_package_object method to extract base model info
266+
return self._resolve_model_package_object(model_package)
313267

314268
except ValueError:
315269
# Re-raise ValueError as-is (our custom error messages)
@@ -342,7 +296,7 @@ def _validate_model_package_arn(self, arn: str) -> bool:
342296

343297
def _get_session(self):
344298
"""
345-
Get or create SageMaker session with beta endpoint support.
299+
Get or create SageMaker session with endpoint support.
346300
347301
Returns:
348302
SageMaker session
@@ -352,12 +306,11 @@ def _get_session(self):
352306

353307
from sagemaker.core.helper.session_helper import Session
354308

355-
# Check for beta endpoint in environment variable
356-
if self._beta_endpoint:
309+
# Check for endpoint in environment variable
310+
if self._endpoint:
357311
sm_client = boto3.client(
358312
'sagemaker',
359-
endpoint_url=self._beta_endpoint,
360-
region_name=os.environ.get('AWS_REGION', 'us-west-2')
313+
endpoint_url=self._endpoint
361314
)
362315
return Session(sagemaker_client=sm_client)
363316

0 commit comments

Comments
 (0)