Skip to content

Commit fef17bd

Browse files
committed
Merge remote-tracking branch 'origin/fix-pr-checks' into fix-pr-checks
2 parents c2c72c6 + 10467f9 commit fef17bd

File tree

4 files changed

+81
-101
lines changed

4 files changed

+81
-101
lines changed

sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py

Lines changed: 30 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@ class _Benchmark(str, Enum):
3535
MATH = "math"
3636
STRONG_REJECT = "strong_reject"
3737
IFEVAL = "ifeval"
38-
GEN_QA = "gen_qa"
3938
MMMU = "mmmu"
4039
LLM_JUDGE = "llm_judge"
41-
INFERENCE_ONLY = "inference_only"
4240

4341

4442
# Internal benchmark configuration mapping - using plain dictionaries
@@ -138,14 +136,6 @@ class _Benchmark(str, Enum):
138136
"subtask_available": False,
139137
"subtasks": None
140138
},
141-
_Benchmark.GEN_QA: {
142-
"modality": "Multi-Modal (image)",
143-
"description": "Custom Dataset Evaluation – Lets you supply your own dataset for benchmarking, comparing model outputs to reference answers with metrics such as ROUGE and BLEU. gen_qa supports image inference for models which have multimodal support.",
144-
"metrics": ["all"],
145-
"strategy": "gen_qa",
146-
"subtask_available": False,
147-
"subtasks": None
148-
},
149139
_Benchmark.MMMU: {
150140
"modality": "Multi-Modal",
151141
"description": "Massive Multidiscipline Multimodal Understanding (MMMU) – College-level benchmark comprising multiple-choice and open-ended questions from 30 disciplines.",
@@ -171,14 +161,6 @@ class _Benchmark(str, Enum):
171161
"subtask_available": False,
172162
"subtasks": None
173163
},
174-
_Benchmark.INFERENCE_ONLY: {
175-
"modality": "Text",
176-
"description": "Lets you supply your own dataset to generate inference responses which can be used with the llm_judge task. No metrics are computed for this task.",
177-
"metrics": ["N/A"],
178-
"strategy": "--",
179-
"subtask_available": False,
180-
"subtasks": None
181-
},
182164
}
183165

184166

@@ -278,10 +260,6 @@ class BenchMarkEvaluator(BaseEvaluator):
278260
Optional. If not provided, the system will attempt to resolve it using the default
279261
MLflow app experience (checks domain match, account default, or creates a new app).
280262
Format: arn:aws:sagemaker:region:account:mlflow-tracking-server/name
281-
dataset (Union[str, Any]): Evaluation dataset. Required. Accepts:
282-
- S3 URI (str): e.g., 's3://bucket/path/dataset.jsonl'
283-
- Dataset ARN (str): e.g., 'arn:aws:sagemaker:...:hub-content/AIRegistry/DataSet/...'
284-
- DataSet object: sagemaker.ai_registry.dataset.DataSet instance (ARN inferred automatically)
285263
evaluate_base_model (bool): Whether to evaluate the base model in addition to the custom
286264
model. Set to False to skip base model evaluation and only evaluate the custom model.
287265
Defaults to True (evaluates both models).
@@ -309,7 +287,6 @@ class BenchMarkEvaluator(BaseEvaluator):
309287
benchmark=Benchmark.MMLU,
310288
subtasks=["abstract_algebra", "anatomy", "astronomy"],
311289
model="llama3-2-1b-instruct",
312-
dataset="s3://bucket/eval-data.jsonl",
313290
s3_output_path="s3://bucket/outputs/",
314291
mlflow_resource_arn="arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/my-server"
315292
)
@@ -383,15 +360,21 @@ def _validate_subtasks(cls, v, values):
383360
f"Subtask list cannot be empty for benchmark '{benchmark.value}'. "
384361
f"Provide at least one subtask or use 'ALL'."
385362
)
386-
363+
if len(v) > 1 :
364+
raise ValueError(
365+
f"Currently only one subtask is supported for benchmark '{benchmark.value}'. "
366+
f"Provide only one subtask or use 'ALL'."
367+
)
368+
369+
# TODO : Should support list of subtasks.
387370
# Validate each subtask in the list
388371
for subtask in v:
389372
if not isinstance(subtask, str):
390373
raise ValueError(
391374
f"All subtasks in the list must be strings. "
392375
f"Found {type(subtask).__name__}: {subtask}"
393376
)
394-
377+
395378
# Validate against available subtasks if defined
396379
if config.get("subtasks") and subtask not in config["subtasks"]:
397380
raise ValueError(
@@ -525,23 +508,32 @@ def _resolve_subtask_for_evaluation(self, subtask: Optional[Union[str, List[str]
525508
"""
526509
# Use provided subtask or fall back to constructor subtasks
527510
eval_subtask = subtask if subtask is not None else self.subtasks
528-
511+
512+
if eval_subtask is None or eval_subtask.upper() == "ALL":
513+
#TODO : Check All Vs None subtask for evaluation
514+
return None
515+
529516
# Validate the subtask
530517
config = _BENCHMARK_CONFIG.get(self.benchmark)
531518
if config and config.get("subtask_available"):
532-
if isinstance(eval_subtask, list):
533-
for st in eval_subtask:
534-
if config.get("subtasks") and st not in config["subtasks"] and st.upper() != "ALL":
535-
raise ValueError(
536-
f"Invalid subtask '{st}' for benchmark '{self.benchmark.value}'. "
537-
f"Available subtasks: {', '.join(config['subtasks'])}"
538-
)
539-
elif isinstance(eval_subtask, str):
519+
if isinstance(eval_subtask, str):
540520
if eval_subtask.upper() != "ALL" and config.get("subtasks") and eval_subtask not in config["subtasks"]:
541521
raise ValueError(
542522
f"Invalid subtask '{eval_subtask}' for benchmark '{self.benchmark.value}'. "
543523
f"Available subtasks: {', '.join(config['subtasks'])}"
544524
)
525+
elif isinstance(eval_subtask, list):
526+
if len(eval_subtask) == 0:
527+
raise ValueError(
528+
f"Subtask list cannot be empty for benchmark '{self.benchmark.value}'. "
529+
f"Provide at least one subtask or use 'ALL'."
530+
)
531+
if len(eval_subtask) > 1:
532+
raise ValueError(
533+
f"Currently only one subtask is supported for benchmark '{self.benchmark.value}'. "
534+
f"Provide only one subtask or use 'ALL'."
535+
)
536+
545537

546538
return eval_subtask
547539

@@ -571,10 +563,12 @@ def _get_benchmark_template_additions(self, eval_subtask: Optional[Union[str, Li
571563
'task': self.benchmark.value,
572564
'strategy': config["strategy"],
573565
metric_key: config["metrics"][0] if config.get("metrics") else 'accuracy',
574-
'subtask': eval_subtask if isinstance(eval_subtask, str) else ','.join(eval_subtask) if eval_subtask else '',
575566
'evaluate_base_model': self.evaluate_base_model,
576567
}
577568

569+
if isinstance(eval_subtask, str):
570+
benchmark_context['subtask'] = eval_subtask
571+
578572
# Add all configured hyperparameters
579573
for key in configured_params.keys():
580574
benchmark_context[key] = configured_params[key]
@@ -602,7 +596,6 @@ def evaluate(self, subtask: Optional[Union[str, List[str]]] = None) -> Evaluatio
602596
benchmark=Benchmark.MMLU,
603597
subtasks="ALL",
604598
model="llama3-2-1b-instruct",
605-
dataset="s3://bucket/data.jsonl",
606599
s3_output_path="s3://bucket/outputs/"
607600
)
608601
@@ -643,9 +636,7 @@ def evaluate(self, subtask: Optional[Union[str, List[str]]] = None) -> Evaluatio
643636
model_package_group_arn=model_package_group_arn,
644637
resolved_model_artifact_arn=artifacts['resolved_model_artifact_arn']
645638
)
646-
647-
# Add dataset URI
648-
template_context['dataset_uri'] = self.dataset
639+
649640

650641
# Add benchmark-specific template additions
651642
benchmark_additions = self._get_benchmark_template_additions(eval_subtask, config)

sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@
129129
{% if kms_key_id %},
130130
"KmsKeyId": "{{ kms_key_id }}"
131131
{% endif %}
132-
},
132+
}{% if dataset_uri %},
133133
"InputDataConfig": [
134134
{
135135
"ChannelName": "train",
@@ -144,7 +144,7 @@
144144
}
145145
}{% endif %}
146146
}
147-
]{% if vpc_config %},
147+
]{% endif %}{% if vpc_config %},
148148
"VpcConfig": {
149149
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
150150
"Subnets": {{ vpc_subnets | tojson }}
@@ -191,7 +191,7 @@
191191
{% if kms_key_id %},
192192
"KmsKeyId": "{{ kms_key_id }}"
193193
{% endif %}
194-
},
194+
}{% if dataset_uri %},
195195
"InputDataConfig": [
196196
{
197197
"ChannelName": "train",
@@ -206,7 +206,7 @@
206206
}
207207
}{% endif %}
208208
}
209-
]{% if vpc_config %},
209+
]{% endif %}{% if vpc_config %},
210210
"VpcConfig": {
211211
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
212212
"Subnets": {{ vpc_subnets | tojson }}
@@ -358,7 +358,7 @@
358358
{% if kms_key_id %},
359359
"KmsKeyId": "{{ kms_key_id }}"
360360
{% endif %}
361-
},
361+
}{% if dataset_uri %},
362362
"InputDataConfig": [
363363
{
364364
"ChannelName": "train",
@@ -373,7 +373,7 @@
373373
}
374374
}{% endif %}
375375
}
376-
]{% if vpc_config %},
376+
]{% endif %}{% if vpc_config %},
377377
"VpcConfig": {
378378
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
379379
"Subnets": {{ vpc_subnets | tojson }}
@@ -500,7 +500,7 @@
500500
{% if kms_key_id %},
501501
"KmsKeyId": "{{ kms_key_id }}"
502502
{% endif %}
503-
},
503+
}{% if dataset_uri %},
504504
"InputDataConfig": [
505505
{
506506
"ChannelName": "train",
@@ -515,7 +515,7 @@
515515
}
516516
}{% endif %}
517517
}
518-
]{% if vpc_config %},
518+
]{% endif %}{% if vpc_config %},
519519
"VpcConfig": {
520520
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
521521
"Subnets": {{ vpc_subnets | tojson }}
@@ -650,7 +650,7 @@
650650
{% if kms_key_id %},
651651
"KmsKeyId": "{{ kms_key_id }}"
652652
{% endif %}
653-
},
653+
}{% if dataset_uri %},
654654
"InputDataConfig": [
655655
{
656656
"ChannelName": "train",
@@ -665,7 +665,7 @@
665665
}
666666
}{% endif %}
667667
}
668-
]{% if vpc_config %},
668+
]{% endif %}{% if vpc_config %},
669669
"VpcConfig": {
670670
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
671671
"Subnets": {{ vpc_subnets | tojson }}
@@ -713,7 +713,7 @@
713713
{% if kms_key_id %},
714714
"KmsKeyId": "{{ kms_key_id }}"
715715
{% endif %}
716-
},
716+
}{% if dataset_uri %},
717717
"InputDataConfig": [
718718
{
719719
"ChannelName": "train",
@@ -728,7 +728,7 @@
728728
}
729729
}{% endif %}
730730
}
731-
]{% if vpc_config %},
731+
]{% endif %}{% if vpc_config %},
732732
"VpcConfig": {
733733
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
734734
"Subnets": {{ vpc_subnets | tojson }}
@@ -892,7 +892,7 @@
892892
{% if kms_key_id %},
893893
"KmsKeyId": "{{ kms_key_id }}"
894894
{% endif %}
895-
},
895+
}{% if dataset_uri %},
896896
"InputDataConfig": [
897897
{
898898
"ChannelName": "train",
@@ -907,7 +907,7 @@
907907
}
908908
}{% endif %}
909909
}
910-
]{% if vpc_config %},
910+
]{% endif %}{% if vpc_config %},
911911
"VpcConfig": {
912912
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
913913
"Subnets": {{ vpc_subnets | tojson }}
@@ -1032,7 +1032,7 @@
10321032
"ModelPackageConfig": {
10331033
"ModelPackageGroupArn": "{{ model_package_group_arn }}",
10341034
"SourceModelPackageArn": "{{ source_model_package_arn }}"
1035-
},
1035+
}{% if dataset_uri %},
10361036
"InputDataConfig": [
10371037
{
10381038
"ChannelName": "train",
@@ -1047,7 +1047,7 @@
10471047
}
10481048
}{% endif %}
10491049
}
1050-
]{% if vpc_config %},
1050+
]{% endif %}{% if vpc_config %},
10511051
"VpcConfig": {
10521052
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
10531053
"Subnets": {{ vpc_subnets | tojson }}
@@ -1086,7 +1086,7 @@
10861086
"ModelPackageConfig": {
10871087
"ModelPackageGroupArn": "{{ model_package_group_arn }}",
10881088
"SourceModelPackageArn": "{{ source_model_package_arn }}"
1089-
},
1089+
}{% if dataset_uri %},
10901090
"InputDataConfig": [
10911091
{
10921092
"ChannelName": "train",
@@ -1101,7 +1101,7 @@
11011101
}
11021102
}{% endif %}
11031103
}
1104-
]{% if vpc_config %},
1104+
]{% endif %}{% if vpc_config %},
11051105
"VpcConfig": {
11061106
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
11071107
"Subnets": {{ vpc_subnets | tojson }}

0 commit comments

Comments
 (0)