1313
1414from pydantic import BaseModel , validator
1515
16- from sagemaker .core .resources import ModelPackageGroup
16+ from sagemaker .core .resources import ModelPackageGroup , ModelPackage
1717from sagemaker .core .shapes import VpcConfig
1818
1919if TYPE_CHECKING :
2020 from sagemaker .core .helper .session_helper import Session
2121
22+ from sagemaker .train .base_trainer import BaseTrainer
2223# Module-level logger
2324_logger = logging .getLogger (__name__ )
2425
@@ -53,6 +54,7 @@ class BaseEvaluator(BaseModel):
5354 - JumpStart model ID (str): e.g., 'llama3-2-1b-instruct'
5455 - ModelPackage object: A fine-tuned model package
5556 - ModelPackage ARN (str): e.g., 'arn:aws:sagemaker:region:account:model-package/name/version'
57+ - BaseTrainer object: A completed training job (i.e., it must have _latest_training_job with output_model_package_arn populated)
5658 base_eval_name (Optional[str]): Optional base name for evaluation jobs. This name is used
5759 as the PipelineExecutionDisplayName when creating the SageMaker pipeline execution.
5860 The actual display name will be "{base_eval_name}-{timestamp}". This parameter can
@@ -86,7 +88,7 @@ class BaseEvaluator(BaseModel):
8688
8789 region : Optional [str ] = None
8890 sagemaker_session : Optional [Any ] = None
89- model : Union [str , Any ]
91+ model : Union [str , BaseTrainer , ModelPackage ]
9092 base_eval_name : Optional [str ] = None
9193 s3_output_path : str
9294 mlflow_resource_arn : Optional [str ] = None
@@ -278,7 +280,7 @@ def _validate_mlflow_arn_format(cls, v: Optional[str]) -> Optional[str]:
278280 return v
279281
280282 @validator ('model' )
281- def _resolve_model_info (cls , v : Union [str , Any ], values : dict ) -> Union [str , Any ]:
283+ def _resolve_model_info (cls , v : Union [str , BaseTrainer , ModelPackage ], values : dict ) -> Union [str , Any ]:
282284 """Resolve model information from various input types.
283285
284286 This validator uses the common model resolution utility to extract:
@@ -289,7 +291,7 @@ def _resolve_model_info(cls, v: Union[str, Any], values: dict) -> Union[str, Any
289291 The resolved information is stored in private attributes for use by subclasses.
290292
291293 Args:
292- v (Union[str, Any ]): Model identifier (JumpStart ID, ModelPackage, or ARN ).
294+ v (Union[str, BaseTrainer, ModelPackage ]): Model identifier (JumpStart ID, ModelPackage, ARN, or BaseTrainer ).
293295 values (dict): Dictionary of already-validated fields.
294296
295297 Returns:
0 commit comments