Skip to content

Commit 33fe72a

Browse files
jsondaicopybara-github
authored andcommitted
feat: GenAI Client(evals) - Add support for inference_configs in create_evaluation_run.
PiperOrigin-RevId: 856324409
1 parent b1b900e commit 33fe72a

File tree

2 files changed

+111
-12
lines changed

2 files changed

+111
-12
lines changed

tests/unit/vertexai/genai/replays/test_create_evaluation_run.py

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
import pytest
2121

2222
GCS_DEST = "gs://lakeyk-limited-bucket/eval_run_output"
23-
UNIVERSAL_AR_METRIC = types.EvaluationRunMetric(
24-
metric="universal_ar_v1",
23+
GENERAL_QUALITY_METRIC = types.EvaluationRunMetric(
24+
metric="general_quality_v1",
2525
metric_config=types.UnifiedMetric(
2626
predefined_metric_spec=types.PredefinedMetricSpec(
27-
metric_spec_name="universal_ar_v1",
27+
metric_spec_name="general_quality_v1",
2828
)
2929
),
3030
)
@@ -71,7 +71,7 @@ def test_create_eval_run_data_source_evaluation_set(client):
7171
),
7272
dest=GCS_DEST,
7373
metrics=[
74-
UNIVERSAL_AR_METRIC,
74+
GENERAL_QUALITY_METRIC,
7575
types.RubricMetric.FINAL_RESPONSE_QUALITY,
7676
LLM_METRIC,
7777
],
@@ -94,7 +94,7 @@ def test_create_eval_run_data_source_evaluation_set(client):
9494
output_config=genai_types.OutputConfig(
9595
gcs_destination=genai_types.GcsDestination(output_uri_prefix=GCS_DEST)
9696
),
97-
metrics=[UNIVERSAL_AR_METRIC, FINAL_RESPONSE_QUALITY_METRIC, LLM_METRIC],
97+
metrics=[GENERAL_QUALITY_METRIC, FINAL_RESPONSE_QUALITY_METRIC, LLM_METRIC],
9898
)
9999
assert evaluation_run.inference_configs[
100100
"agent-1"
@@ -131,7 +131,7 @@ def test_create_eval_run_data_source_bigquery_request_set(client):
131131
),
132132
labels={"label1": "value1"},
133133
dest=GCS_DEST,
134-
metrics=[UNIVERSAL_AR_METRIC],
134+
metrics=[GENERAL_QUALITY_METRIC],
135135
)
136136
assert isinstance(evaluation_run, types.EvaluationRun)
137137
assert evaluation_run.display_name == "test5"
@@ -152,7 +152,7 @@ def test_create_eval_run_data_source_bigquery_request_set(client):
152152
output_config=genai_types.OutputConfig(
153153
gcs_destination=genai_types.GcsDestination(output_uri_prefix=GCS_DEST)
154154
),
155-
metrics=[UNIVERSAL_AR_METRIC],
155+
metrics=[GENERAL_QUALITY_METRIC],
156156
)
157157
assert evaluation_run.inference_configs is None
158158
assert evaluation_run.labels == {
@@ -161,6 +161,43 @@ def test_create_eval_run_data_source_bigquery_request_set(client):
161161
assert evaluation_run.error is None
162162

163163

164+
def test_create_eval_run_with_inference_configs(client):
165+
"""Tests that create_evaluation_run() creates a correctly structured EvaluationRun with inference_configs."""
166+
client._api_client._http_options.api_version = "v1beta1"
167+
inference_config = types.EvaluationRunInferenceConfig(
168+
model="projects/503583131166/locations/us-central1/publishers/google/models/gemini-2.5-flash"
169+
)
170+
evaluation_run = client.evals.create_evaluation_run(
171+
name="test_inference_config",
172+
display_name="test_inference_config",
173+
dataset=types.EvaluationRunDataSource(
174+
evaluation_set="projects/503583131166/locations/us-central1/evaluationSets/6619939608513740800"
175+
),
176+
dest=GCS_DEST,
177+
metrics=[GENERAL_QUALITY_METRIC],
178+
inference_configs={"model_1": inference_config},
179+
labels={"label1": "value1"},
180+
)
181+
assert isinstance(evaluation_run, types.EvaluationRun)
182+
assert evaluation_run.display_name == "test_inference_config"
183+
assert evaluation_run.state == types.EvaluationRunState.PENDING
184+
assert isinstance(evaluation_run.data_source, types.EvaluationRunDataSource)
185+
assert evaluation_run.data_source.evaluation_set == (
186+
"projects/503583131166/locations/us-central1/evaluationSets/6619939608513740800"
187+
)
188+
assert evaluation_run.evaluation_config == types.EvaluationRunConfig(
189+
output_config=genai_types.OutputConfig(
190+
gcs_destination=genai_types.GcsDestination(output_uri_prefix=GCS_DEST)
191+
),
192+
metrics=[GENERAL_QUALITY_METRIC],
193+
)
194+
assert evaluation_run.inference_configs["model_1"] == inference_config
195+
assert evaluation_run.labels == {
196+
"label1": "value1",
197+
}
198+
assert evaluation_run.error is None
199+
200+
164201
# Test fails in replay mode because of UUID generation mismatch.
165202
# def test_create_eval_run_data_source_evaluation_dataset(client):
166203
# """Tests that create_evaluation_run() creates a correctly structured EvaluationRun with EvaluationDataset."""
@@ -217,7 +254,7 @@ def test_create_eval_run_data_source_bigquery_request_set(client):
217254
# eval_dataset_df=input_df,
218255
# ),
219256
# dest=GCS_DEST,
220-
# metrics=[UNIVERSAL_AR_METRIC],
257+
# metrics=[GENERAL_QUALITY_METRIC],
221258
# )
222259
# assert isinstance(evaluation_run, types.EvaluationRun)
223260
# assert evaluation_run.display_name == "test6"
@@ -278,7 +315,7 @@ async def test_create_eval_run_async(client):
278315
)
279316
),
280317
dest=GCS_DEST,
281-
metrics=[UNIVERSAL_AR_METRIC],
318+
metrics=[GENERAL_QUALITY_METRIC],
282319
)
283320
assert isinstance(evaluation_run, types.EvaluationRun)
284321
assert evaluation_run.display_name == "test8"
@@ -295,7 +332,7 @@ async def test_create_eval_run_async(client):
295332
output_config=genai_types.OutputConfig(
296333
gcs_destination=genai_types.GcsDestination(output_uri_prefix=GCS_DEST)
297334
),
298-
metrics=[UNIVERSAL_AR_METRIC],
335+
metrics=[GENERAL_QUALITY_METRIC],
299336
)
300337
assert evaluation_run.error is None
301338
assert evaluation_run.inference_configs is None
@@ -304,6 +341,44 @@ async def test_create_eval_run_async(client):
304341
assert evaluation_run.error is None
305342

306343

344+
@pytest.mark.asyncio
345+
async def test_create_eval_run_async_with_inference_configs(client):
346+
"""Tests that create_evaluation_run() creates a correctly structured EvaluationRun with inference_configs asynchronously."""
347+
client._api_client._http_options.api_version = "v1beta1"
348+
inference_config = types.EvaluationRunInferenceConfig(
349+
model="projects/503583131166/locations/us-central1/publishers/google/models/gemini-2.5-flash"
350+
)
351+
evaluation_run = await client.aio.evals.create_evaluation_run(
352+
name="test_inference_config_async",
353+
display_name="test_inference_config_async",
354+
dataset=types.EvaluationRunDataSource(
355+
evaluation_set="projects/503583131166/locations/us-central1/evaluationSets/6619939608513740800"
356+
),
357+
dest=GCS_DEST,
358+
metrics=[GENERAL_QUALITY_METRIC],
359+
inference_configs={"model_1": inference_config},
360+
labels={"label1": "value1"},
361+
)
362+
assert isinstance(evaluation_run, types.EvaluationRun)
363+
assert evaluation_run.display_name == "test_inference_config_async"
364+
assert evaluation_run.state == types.EvaluationRunState.PENDING
365+
assert isinstance(evaluation_run.data_source, types.EvaluationRunDataSource)
366+
assert evaluation_run.data_source.evaluation_set == (
367+
"projects/503583131166/locations/us-central1/evaluationSets/6619939608513740800"
368+
)
369+
assert evaluation_run.evaluation_config == types.EvaluationRunConfig(
370+
output_config=genai_types.OutputConfig(
371+
gcs_destination=genai_types.GcsDestination(output_uri_prefix=GCS_DEST)
372+
),
373+
metrics=[GENERAL_QUALITY_METRIC],
374+
)
375+
assert evaluation_run.inference_configs["model_1"] == inference_config
376+
assert evaluation_run.labels == {
377+
"label1": "value1",
378+
}
379+
assert evaluation_run.error is None
380+
381+
307382
pytestmark = pytest_helper.setup(
308383
file=__file__,
309384
globals_for_file=globals(),

vertexai/_genai/evals.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,6 +1581,9 @@ def create_evaluation_run(
15811581
name: Optional[str] = None,
15821582
display_name: Optional[str] = None,
15831583
agent_info: Optional[types.evals.AgentInfoOrDict] = None,
1584+
inference_configs: Optional[
1585+
dict[str, types.EvaluationRunInferenceConfigOrDict]
1586+
] = None,
15841587
labels: Optional[dict[str, str]] = None,
15851588
config: Optional[types.CreateEvaluationRunConfigOrDict] = None,
15861589
) -> types.EvaluationRun:
@@ -1593,12 +1596,21 @@ def create_evaluation_run(
15931596
name: The name of the evaluation run.
15941597
display_name: The display name of the evaluation run.
15951598
agent_info: The agent info to evaluate.
1599+
inference_configs: The candidate to inference config map for the evaluation run.
1600+
The key is the candidate name, and the value is the inference config.
1601+
If provided, agent_info must be None.
1602+
Example:
1603+
{"candidate-1": types.EvaluationRunInferenceConfig(model="gemini-2.5-flash")}
15961604
labels: The labels to apply to the evaluation run.
15971605
config: The configuration for the evaluation run.
15981606
15991607
Returns:
16001608
The created evaluation run.
16011609
"""
1610+
if agent_info and inference_configs:
1611+
raise ValueError(
1612+
"At most one of agent_info or inference_configs can be provided."
1613+
)
16021614
if agent_info and isinstance(agent_info, dict):
16031615
agent_info = types.evals.AgentInfo.model_validate(agent_info)
16041616
if type(dataset).__name__ == "EvaluationDataset":
@@ -1630,8 +1642,8 @@ def create_evaluation_run(
16301642
evaluation_config = types.EvaluationRunConfig(
16311643
output_config=output_config, metrics=resolved_metrics
16321644
)
1633-
inference_configs = {}
16341645
if agent_info:
1646+
inference_configs = {}
16351647
inference_configs[agent_info.name] = types.EvaluationRunInferenceConfig(
16361648
agent_config=types.EvaluationRunAgentConfig(
16371649
developer_instruction=genai_types.Content(
@@ -2429,6 +2441,9 @@ async def create_evaluation_run(
24292441
name: Optional[str] = None,
24302442
display_name: Optional[str] = None,
24312443
agent_info: Optional[types.evals.AgentInfo] = None,
2444+
inference_configs: Optional[
2445+
dict[str, types.EvaluationRunInferenceConfigOrDict]
2446+
] = None,
24322447
labels: Optional[dict[str, str]] = None,
24332448
config: Optional[types.CreateEvaluationRunConfigOrDict] = None,
24342449
) -> types.EvaluationRun:
@@ -2441,12 +2456,21 @@ async def create_evaluation_run(
24412456
name: The name of the evaluation run.
24422457
display_name: The display name of the evaluation run.
24432458
agent_info: The agent info to evaluate.
2459+
inference_configs: The candidate to inference config map for the evaluation run.
2460+
The key is the candidate name, and the value is the inference config.
2461+
If provided, agent_info must be None.
2462+
Example:
2463+
{"candidate-1": types.EvaluationRunInferenceConfig(model="gemini-2.5-flash")}
24442464
labels: The labels to apply to the evaluation run.
24452465
config: The configuration for the evaluation run.
24462466
24472467
Returns:
24482468
The created evaluation run.
24492469
"""
2470+
if agent_info and inference_configs:
2471+
raise ValueError(
2472+
"At most one of agent_info or inference_configs can be provided."
2473+
)
24502474
if agent_info and isinstance(agent_info, dict):
24512475
agent_info = types.evals.AgentInfo.model_validate(agent_info)
24522476
if type(dataset).__name__ == "EvaluationDataset":
@@ -2477,8 +2501,8 @@ async def create_evaluation_run(
24772501
evaluation_config = types.EvaluationRunConfig(
24782502
output_config=output_config, metrics=resolved_metrics
24792503
)
2480-
inference_configs = {}
24812504
if agent_info:
2505+
inference_configs = {}
24822506
inference_configs[agent_info.name] = types.EvaluationRunInferenceConfig(
24832507
agent_config=types.EvaluationRunAgentConfig(
24842508
developer_instruction=genai_types.Content(

0 commit comments

Comments
 (0)