Skip to content

Commit 6f7d1a3

Browse files
jsondaicopybara-github
authored andcommitted
feat: GenAI Client(evals) - Support CustomCodeExecution metric in Vertex Gen AI Eval Service
PiperOrigin-RevId: 839489822
1 parent 89db338 commit 6f7d1a3

File tree

6 files changed

+520
-2
lines changed

6 files changed

+520
-2
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
16+
17+
from tests.unit.vertexai.genai.replays import pytest_helper
18+
from vertexai._genai import types
19+
import pandas as pd
20+
21+
22+
def test_custom_code_execution(client):
23+
"""Tests that custom code execution metric produces a correctly structured EvaluationResult."""
24+
25+
code_snippet = """
26+
def evaluate(instance):
27+
if instance['response'] == instance['reference']:
28+
return 1.0
29+
return 0.0
30+
"""
31+
32+
custom_metric = types.Metric(
33+
name="my_custom_code_metric",
34+
remote_custom_function=code_snippet,
35+
)
36+
37+
prompts_df = pd.DataFrame(
38+
{
39+
"prompt": ["What is 2+2?", "What is 3+3?"],
40+
"response": ["4", "5"],
41+
"reference": ["4", "6"],
42+
}
43+
)
44+
45+
eval_dataset = types.EvaluationDataset(
46+
eval_dataset_df=prompts_df,
47+
candidate_name="test_model",
48+
)
49+
50+
evaluation_result = client.evals.evaluate(
51+
dataset=eval_dataset,
52+
metrics=[custom_metric],
53+
)
54+
55+
assert isinstance(evaluation_result, types.EvaluationResult)
56+
57+
assert evaluation_result.summary_metrics is not None
58+
assert evaluation_result.summary_metrics
59+
for summary in evaluation_result.summary_metrics:
60+
assert isinstance(summary, types.AggregatedMetricResult)
61+
assert summary.metric_name == "my_custom_code_metric"
62+
63+
assert evaluation_result.eval_case_results is not None
64+
assert evaluation_result.eval_case_results
65+
for case_result in evaluation_result.eval_case_results:
66+
assert isinstance(case_result, types.EvalCaseResult)
67+
assert case_result.eval_case_index is not None
68+
assert case_result.response_candidate_results is not None
69+
70+
71+
def test_custom_code_execution_batch_evaluate(client):
72+
"""Tests that batch_evaluate() works with custom code execution metric."""
73+
74+
code_snippet = """
75+
def evaluate(instance):
76+
if instance['response'] == instance['reference']:
77+
return 1.0
78+
return 0.0
79+
"""
80+
81+
custom_metric = types.Metric(
82+
name="my_custom_code_metric",
83+
remote_custom_function=code_snippet,
84+
)
85+
86+
eval_dataset = types.EvaluationDataset(
87+
gcs_source=types.GcsSource(
88+
uris=["gs://genai-eval-sdk-replay-test/test_data/inference_results.jsonl"]
89+
),
90+
)
91+
92+
evaluation_result = client.evals.batch_evaluate(
93+
dataset=eval_dataset,
94+
metrics=[custom_metric],
95+
dest="gs://genai-eval-sdk-replay-test/test_data/batch_eval_output",
96+
)
97+
98+
assert evaluation_result is not None
99+
100+
101+
pytestmark = pytest_helper.setup(
102+
file=__file__,
103+
globals_for_file=globals(),
104+
test_method="evals.evaluate",
105+
)

vertexai/_genai/_evals_metric_handlers.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,7 +1099,148 @@ def aggregate(
10991099
)
11001100

11011101

1102+
class CustomCodeExecutionMetricHandler(MetricHandler):
1103+
"""Metric handler for custom code execution metrics."""
1104+
1105+
def __init__(self, module: "evals.Evals", metric: types.Metric):
1106+
super().__init__(module=module, metric=metric)
1107+
1108+
if not self.metric.remote_custom_function:
1109+
raise ValueError(
1110+
f"CustomCodeExecutionMetricHandler for '{self.metric.name}' needs "
1111+
" Metric.remote_custom_function to be set."
1112+
)
1113+
1114+
def _build_request_payload(
1115+
self, eval_case: types.EvalCase, response_index: int
1116+
) -> dict[str, Any]:
1117+
"""Builds the request parameters for evaluate instances request."""
1118+
if not eval_case.responses or response_index >= len(eval_case.responses):
1119+
raise IndexError(f"response_index {response_index} is out of bounds.")
1120+
1121+
response_content = eval_case.responses[response_index].response
1122+
if not response_content:
1123+
raise ValueError(
1124+
f"Response content missing for candidate {response_index}."
1125+
)
1126+
1127+
reference_instance_data = None
1128+
if eval_case.reference:
1129+
reference_instance_data = PredefinedMetricHandler._content_to_instance_data(
1130+
eval_case.reference.response
1131+
)
1132+
1133+
prompt_instance_data = PredefinedMetricHandler._content_to_instance_data(
1134+
eval_case.prompt
1135+
)
1136+
1137+
instance_payload = types.EvaluationInstance(
1138+
prompt=prompt_instance_data,
1139+
response=PredefinedMetricHandler._content_to_instance_data(
1140+
response_content
1141+
),
1142+
reference=reference_instance_data,
1143+
)
1144+
1145+
return {
1146+
"instance": instance_payload,
1147+
}
1148+
1149+
@override
1150+
def get_metric_result(
1151+
self, eval_case: types.EvalCase, response_index: int
1152+
) -> types.EvalCaseMetricResult:
1153+
"""Processes a single evaluation case for a specific custom code execution metric."""
1154+
metric_name = self.metric.name
1155+
try:
1156+
payload = self._build_request_payload(eval_case, response_index)
1157+
for attempt in range(_MAX_RETRIES):
1158+
try:
1159+
api_response = self.module._evaluate_instances(
1160+
metrics=[self.metric],
1161+
instance=payload.get("instance"),
1162+
)
1163+
break
1164+
except genai_errors.ClientError as e:
1165+
if e.code == 429:
1166+
logger.warning(
1167+
"Resource Exhausted error on attempt %d/%d: %s. Retrying in %s"
1168+
" seconds...",
1169+
attempt + 1,
1170+
_MAX_RETRIES,
1171+
e,
1172+
2**attempt,
1173+
)
1174+
if attempt == _MAX_RETRIES - 1:
1175+
return types.EvalCaseMetricResult(
1176+
metric_name=metric_name,
1177+
error_message=f"Resource exhausted after {_MAX_RETRIES} retries: {e}",
1178+
)
1179+
time.sleep(2**attempt)
1180+
else:
1181+
raise e
1182+
1183+
if (
1184+
api_response
1185+
and hasattr(api_response, "metric_results")
1186+
and api_response.metric_results
1187+
):
1188+
result_data = api_response.metric_results[0]
1189+
1190+
error_message = None
1191+
if result_data.error and getattr(result_data.error, "code"):
1192+
error_message = f"Error in metric result: {result_data.error}"
1193+
return types.EvalCaseMetricResult(
1194+
metric_name=metric_name,
1195+
score=result_data.score,
1196+
explanation=result_data.explanation,
1197+
error_message=error_message,
1198+
)
1199+
else:
1200+
logger.error(
1201+
"Metric results missing in API response for metric '%s'."
1202+
" API response: %s",
1203+
metric_name,
1204+
(
1205+
api_response.model_dump_json(exclude_none=True)
1206+
if api_response
1207+
else "None"
1208+
),
1209+
)
1210+
return types.EvalCaseMetricResult(
1211+
metric_name=metric_name,
1212+
error_message="Metric results missing in API response.",
1213+
)
1214+
except Exception as e: # pylint: disable=broad-exception-caught
1215+
logger.error(
1216+
"Error processing metric %s for case %s: %s",
1217+
metric_name,
1218+
eval_case.eval_case_id,
1219+
e,
1220+
exc_info=True,
1221+
)
1222+
return types.EvalCaseMetricResult(
1223+
metric_name=metric_name, error_message=str(e)
1224+
)
1225+
1226+
@override
1227+
def aggregate(
1228+
self, eval_case_metric_results: list[types.EvalCaseMetricResult]
1229+
) -> types.AggregatedMetricResult:
1230+
"""Aggregates the metric results for a custom code execution metric."""
1231+
logger.debug(
1232+
"Aggregating results for custom code execution metric: %s", self.metric.name
1233+
)
1234+
return _default_aggregate_scores(
1235+
self.metric.name, eval_case_metric_results, calculate_pass_rate=True
1236+
)
1237+
1238+
11021239
_METRIC_HANDLER_MAPPING = [
1240+
(
1241+
lambda m: hasattr(m, "remote_custom_function") and m.remote_custom_function,
1242+
CustomCodeExecutionMetricHandler,
1243+
),
11031244
(
11041245
lambda m: m.custom_function and isinstance(m.custom_function, Callable),
11051246
CustomMetricHandler,
@@ -1125,6 +1266,7 @@ def aggregate(
11251266
TranslationMetricHandler,
11261267
LLMMetricHandler,
11271268
CustomMetricHandler,
1269+
CustomCodeExecutionMetricHandler,
11281270
PredefinedMetricHandler,
11291271
)
11301272

vertexai/_genai/_transformers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ def t_metrics(
6060
"metric_spec_name": metric_name,
6161
"metric_spec_parameters": metric.metric_spec_parameters,
6262
}
63+
# Custom Code Execution Metric
64+
elif (
65+
hasattr(metric, "remote_custom_function") and metric.remote_custom_function
66+
):
67+
metric_payload_item["custom_code_execution_spec"] = {
68+
"evaluation_function": metric.remote_custom_function
69+
}
6370
# Pointwise metrics
6471
elif hasattr(metric, "prompt_template") and metric.prompt_template:
6572
pointwise_spec = {"metric_prompt_template": metric.prompt_template}

0 commit comments

Comments
 (0)