Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 32 additions & 49 deletions sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,8 @@ class _Benchmark(str, Enum):
MATH = "math"
STRONG_REJECT = "strong_reject"
IFEVAL = "ifeval"
GEN_QA = "gen_qa"
MMMU = "mmmu"
LLM_JUDGE = "llm_judge"
INFERENCE_ONLY = "inference_only"


# Internal benchmark configuration mapping - using plain dictionaries
Expand Down Expand Up @@ -138,14 +136,6 @@ class _Benchmark(str, Enum):
"subtask_available": False,
"subtasks": None
},
_Benchmark.GEN_QA: {
"modality": "Multi-Modal (image)",
"description": "Custom Dataset Evaluation – Lets you supply your own dataset for benchmarking, comparing model outputs to reference answers with metrics such as ROUGE and BLEU. gen_qa supports image inference for models which have multimodal support.",
"metrics": ["all"],
"strategy": "gen_qa",
"subtask_available": False,
"subtasks": None
},
_Benchmark.MMMU: {
"modality": "Multi-Modal",
"description": "Massive Multidiscipline Multimodal Understanding (MMMU) – College-level benchmark comprising multiple-choice and open-ended questions from 30 disciplines.",
Expand All @@ -171,14 +161,6 @@ class _Benchmark(str, Enum):
"subtask_available": False,
"subtasks": None
},
_Benchmark.INFERENCE_ONLY: {
"modality": "Text",
"description": "Lets you supply your own dataset to generate inference responses which can be used with the llm_judge task. No metrics are computed for this task.",
"metrics": ["N/A"],
"strategy": "--",
"subtask_available": False,
"subtasks": None
},
}


Expand Down Expand Up @@ -278,10 +260,6 @@ class BenchMarkEvaluator(BaseEvaluator):
Optional. If not provided, the system will attempt to resolve it using the default
MLflow app experience (checks domain match, account default, or creates a new app).
Format: arn:aws:sagemaker:region:account:mlflow-tracking-server/name
dataset (Union[str, Any]): Evaluation dataset. Required. Accepts:
- S3 URI (str): e.g., 's3://bucket/path/dataset.jsonl'
- Dataset ARN (str): e.g., 'arn:aws:sagemaker:...:hub-content/AIRegistry/DataSet/...'
- DataSet object: sagemaker.ai_registry.dataset.DataSet instance (ARN inferred automatically)
evaluate_base_model (bool): Whether to evaluate the base model in addition to the custom
model. Set to False to skip base model evaluation and only evaluate the custom model.
Defaults to True (evaluates both models).
Expand Down Expand Up @@ -309,7 +287,6 @@ class BenchMarkEvaluator(BaseEvaluator):
benchmark=Benchmark.MMLU,
subtasks=["abstract_algebra", "anatomy", "astronomy"],
model="llama3-2-1b-instruct",
dataset="s3://bucket/eval-data.jsonl",
s3_output_path="s3://bucket/outputs/",
mlflow_resource_arn="arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/my-server"
)
Expand All @@ -327,16 +304,8 @@ class BenchMarkEvaluator(BaseEvaluator):
_hyperparameters: Optional[Any] = None

# Template-required fields
dataset: Union[str, Any]
evaluate_base_model: bool = True

@validator('dataset', pre=True)
def _resolve_dataset(cls, v):
"""Resolve dataset to string (S3 URI or ARN) and validate format.

Uses BaseEvaluator's common validation logic to avoid code duplication.
"""
return BaseEvaluator._validate_and_resolve_dataset(v)
evaluate_base_model: bool = False


@validator('benchmark')
def _validate_benchmark_model_compatibility(cls, v, values):
Expand Down Expand Up @@ -385,15 +354,21 @@ def _validate_subtasks(cls, v, values):
f"Subtask list cannot be empty for benchmark '{benchmark.value}'. "
f"Provide at least one subtask or use 'ALL'."
)

if len(v) > 1 :
raise ValueError(
f"Currently only one subtask is supported for benchmark '{benchmark.value}'. "
f"Provide only one subtask or use 'ALL'."
)

# TODO : Should support list of subtasks.
# Validate each subtask in the list
for subtask in v:
if not isinstance(subtask, str):
raise ValueError(
f"All subtasks in the list must be strings. "
f"Found {type(subtask).__name__}: {subtask}"
)

# Validate against available subtasks if defined
if config.get("subtasks") and subtask not in config["subtasks"]:
raise ValueError(
Expand Down Expand Up @@ -527,23 +502,32 @@ def _resolve_subtask_for_evaluation(self, subtask: Optional[Union[str, List[str]
"""
# Use provided subtask or fall back to constructor subtasks
eval_subtask = subtask if subtask is not None else self.subtasks


if eval_subtask is None or eval_subtask.upper() == "ALL":
#TODO : Check All Vs None subtask for evaluation
return None

# Validate the subtask
config = _BENCHMARK_CONFIG.get(self.benchmark)
if config and config.get("subtask_available"):
if isinstance(eval_subtask, list):
for st in eval_subtask:
if config.get("subtasks") and st not in config["subtasks"] and st.upper() != "ALL":
raise ValueError(
f"Invalid subtask '{st}' for benchmark '{self.benchmark.value}'. "
f"Available subtasks: {', '.join(config['subtasks'])}"
)
elif isinstance(eval_subtask, str):
if isinstance(eval_subtask, str):
if eval_subtask.upper() != "ALL" and config.get("subtasks") and eval_subtask not in config["subtasks"]:
raise ValueError(
f"Invalid subtask '{eval_subtask}' for benchmark '{self.benchmark.value}'. "
f"Available subtasks: {', '.join(config['subtasks'])}"
)
elif isinstance(eval_subtask, list):
if len(eval_subtask) == 0:
raise ValueError(
f"Subtask list cannot be empty for benchmark '{self.benchmark.value}'. "
f"Provide at least one subtask or use 'ALL'."
)
if len(eval_subtask) > 1:
raise ValueError(
f"Currently only one subtask is supported for benchmark '{self.benchmark.value}'. "
f"Provide only one subtask or use 'ALL'."
)


return eval_subtask

Expand Down Expand Up @@ -573,10 +557,12 @@ def _get_benchmark_template_additions(self, eval_subtask: Optional[Union[str, Li
'task': self.benchmark.value,
'strategy': config["strategy"],
metric_key: config["metrics"][0] if config.get("metrics") else 'accuracy',
'subtask': eval_subtask if isinstance(eval_subtask, str) else ','.join(eval_subtask) if eval_subtask else '',
'evaluate_base_model': self.evaluate_base_model,
}

if isinstance(eval_subtask, str):
benchmark_context['subtask'] = eval_subtask

# Add all configured hyperparameters
for key in configured_params.keys():
benchmark_context[key] = configured_params[key]
Expand Down Expand Up @@ -604,7 +590,6 @@ def evaluate(self, subtask: Optional[Union[str, List[str]]] = None) -> Evaluatio
benchmark=Benchmark.MMLU,
subtasks="ALL",
model="llama3-2-1b-instruct",
dataset="s3://bucket/data.jsonl",
s3_output_path="s3://bucket/outputs/"
)

Expand Down Expand Up @@ -645,9 +630,7 @@ def evaluate(self, subtask: Optional[Union[str, List[str]]] = None) -> Evaluatio
model_package_group_arn=model_package_group_arn,
resolved_model_artifact_arn=artifacts['resolved_model_artifact_arn']
)

# Add dataset URI
template_context['dataset_uri'] = self.dataset


# Add benchmark-specific template additions
benchmark_additions = self._get_benchmark_template_additions(eval_subtask, config)
Expand Down
36 changes: 18 additions & 18 deletions sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@
{% if kms_key_id %},
"KmsKeyId": "{{ kms_key_id }}"
{% endif %}
},
}{% if dataset_uri %},
"InputDataConfig": [
{
"ChannelName": "train",
Expand All @@ -144,7 +144,7 @@
}
}{% endif %}
}
]{% if vpc_config %},
]{% endif %}{% if vpc_config %},
"VpcConfig": {
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
"Subnets": {{ vpc_subnets | tojson }}
Expand Down Expand Up @@ -191,7 +191,7 @@
{% if kms_key_id %},
"KmsKeyId": "{{ kms_key_id }}"
{% endif %}
},
}{% if dataset_uri %},
"InputDataConfig": [
{
"ChannelName": "train",
Expand All @@ -206,7 +206,7 @@
}
}{% endif %}
}
]{% if vpc_config %},
]{% endif %}{% if vpc_config %},
"VpcConfig": {
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
"Subnets": {{ vpc_subnets | tojson }}
Expand Down Expand Up @@ -358,7 +358,7 @@
{% if kms_key_id %},
"KmsKeyId": "{{ kms_key_id }}"
{% endif %}
},
}{% if dataset_uri %},
"InputDataConfig": [
{
"ChannelName": "train",
Expand All @@ -373,7 +373,7 @@
}
}{% endif %}
}
]{% if vpc_config %},
]{% endif %}{% if vpc_config %},
"VpcConfig": {
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
"Subnets": {{ vpc_subnets | tojson }}
Expand Down Expand Up @@ -500,7 +500,7 @@
{% if kms_key_id %},
"KmsKeyId": "{{ kms_key_id }}"
{% endif %}
},
}{% if dataset_uri %},
"InputDataConfig": [
{
"ChannelName": "train",
Expand All @@ -515,7 +515,7 @@
}
}{% endif %}
}
]{% if vpc_config %},
]{% endif %}{% if vpc_config %},
"VpcConfig": {
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
"Subnets": {{ vpc_subnets | tojson }}
Expand Down Expand Up @@ -650,7 +650,7 @@
{% if kms_key_id %},
"KmsKeyId": "{{ kms_key_id }}"
{% endif %}
},
}{% if dataset_uri %},
"InputDataConfig": [
{
"ChannelName": "train",
Expand All @@ -665,7 +665,7 @@
}
}{% endif %}
}
]{% if vpc_config %},
]{% endif %}{% if vpc_config %},
"VpcConfig": {
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
"Subnets": {{ vpc_subnets | tojson }}
Expand Down Expand Up @@ -713,7 +713,7 @@
{% if kms_key_id %},
"KmsKeyId": "{{ kms_key_id }}"
{% endif %}
},
}{% if dataset_uri %},
"InputDataConfig": [
{
"ChannelName": "train",
Expand All @@ -728,7 +728,7 @@
}
}{% endif %}
}
]{% if vpc_config %},
]{% endif %}{% if vpc_config %},
"VpcConfig": {
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
"Subnets": {{ vpc_subnets | tojson }}
Expand Down Expand Up @@ -892,7 +892,7 @@
{% if kms_key_id %},
"KmsKeyId": "{{ kms_key_id }}"
{% endif %}
},
}{% if dataset_uri %},
"InputDataConfig": [
{
"ChannelName": "train",
Expand All @@ -907,7 +907,7 @@
}
}{% endif %}
}
]{% if vpc_config %},
]{% endif %}{% if vpc_config %},
"VpcConfig": {
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
"Subnets": {{ vpc_subnets | tojson }}
Expand Down Expand Up @@ -1032,7 +1032,7 @@
"ModelPackageConfig": {
"ModelPackageGroupArn": "{{ model_package_group_arn }}",
"SourceModelPackageArn": "{{ source_model_package_arn }}"
},
}{% if dataset_uri %},
"InputDataConfig": [
{
"ChannelName": "train",
Expand All @@ -1047,7 +1047,7 @@
}
}{% endif %}
}
]{% if vpc_config %},
]{% endif %}{% if vpc_config %},
"VpcConfig": {
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
"Subnets": {{ vpc_subnets | tojson }}
Expand Down Expand Up @@ -1086,7 +1086,7 @@
"ModelPackageConfig": {
"ModelPackageGroupArn": "{{ model_package_group_arn }}",
"SourceModelPackageArn": "{{ source_model_package_arn }}"
},
}{% if dataset_uri %},
"InputDataConfig": [
{
"ChannelName": "train",
Expand All @@ -1101,7 +1101,7 @@
}
}{% endif %}
}
]{% if vpc_config %},
]{% endif %}{% if vpc_config %},
"VpcConfig": {
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
"Subnets": {{ vpc_subnets | tojson }}
Expand Down
Loading
Loading