Skip to content

Commit 4edd6a2

Browse files
jsondaicopybara-github
authored andcommitted
fix: GenAI client(evals) - Fix TypeError in _build_generate_content_config
PiperOrigin-RevId: 853003335
1 parent 4b42c76 commit 4edd6a2

File tree

1 file changed

+47
-21
lines changed

1 file changed

+47
-21
lines changed

vertexai/_genai/_evals_common.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
#
1515
"""Common utilities for evals."""
16+
1617
import asyncio
1718
import base64
1819
import collections
@@ -24,8 +25,8 @@
2425
import os
2526
import threading
2627
import time
27-
import uuid
2828
from typing import Any, Callable, Literal, Optional, Union
29+
import uuid
2930

3031
from google.api_core import exceptions as api_exceptions
3132
import vertexai
@@ -41,7 +42,6 @@
4142
from . import _evals_metric_loaders
4243
from . import _evals_utils
4344
from . import _gcs_utils
44-
4545
from . import evals
4646
from . import types
4747

@@ -90,7 +90,8 @@ def _get_api_client_with_location(
9090
return api_client
9191

9292
logger.info(
93-
"Model endpoint location set to %s, overriding client location %s for this API call.",
93+
"Model endpoint location set to %s, overriding client location %s for"
94+
" this API call.",
9495
location,
9596
api_client.location,
9697
)
@@ -208,7 +209,7 @@ def _generate_content_with_retry(
208209

209210

210211
def _build_generate_content_config(
211-
request_dict: dict[str, Any],
212+
request_dict: Union[dict[str, Any], str],
212213
global_config: Optional[genai_types.GenerateContentConfig] = None,
213214
) -> genai_types.GenerateContentConfig:
214215
"""Builds a GenerateContentConfig from the request dictionary or provided config."""
@@ -220,6 +221,9 @@ def _build_generate_content_config(
220221
else:
221222
merged_config_dict = {}
222223

224+
if not isinstance(request_dict, dict):
225+
return genai_types.GenerateContentConfig(**merged_config_dict)
226+
223227
for key in [
224228
"system_instruction",
225229
"tools",
@@ -264,7 +268,11 @@ def _execute_inference_concurrently(
264268
agent_engine: Optional[Union[str, types.AgentEngine]] = None,
265269
agent: Optional[LlmAgent] = None,
266270
) -> list[
267-
Union[genai_types.GenerateContentResponse, dict[str, Any], list[dict[str, Any]]]
271+
Union[
272+
genai_types.GenerateContentResponse,
273+
dict[str, Any],
274+
list[dict[str, Any]],
275+
]
268276
]:
269277
"""Internal helper to run inference with concurrency."""
270278
logger.info(
@@ -384,7 +392,11 @@ def _run_gemini_inference(
384392
prompt_dataset: pd.DataFrame,
385393
config: Optional[genai_types.GenerateContentConfig] = None,
386394
) -> list[
387-
Union[genai_types.GenerateContentResponse, dict[str, Any], list[dict[str, Any]]]
395+
Union[
396+
genai_types.GenerateContentResponse,
397+
dict[str, Any],
398+
list[dict[str, Any]],
399+
]
388400
]:
389401
"""Internal helper to run inference using Gemini model with concurrency."""
390402
return _execute_inference_concurrently(
@@ -410,7 +422,9 @@ def _run_custom_inference(
410422
)
411423

412424

413-
def _convert_prompt_row_to_litellm_messages(row: pd.Series) -> list[dict[str, Any]]:
425+
def _convert_prompt_row_to_litellm_messages(
426+
row: pd.Series,
427+
) -> list[dict[str, Any]]:
414428
"""Converts a DataFrame row into LiteLLM's messages format by detecting the input schema."""
415429
messages: list[dict[str, Any]] = []
416430
row_dict = row.to_dict()
@@ -442,10 +456,10 @@ def _convert_prompt_row_to_litellm_messages(row: pd.Series) -> list[dict[str, An
442456
return [{"role": USER_AUTHOR, "content": row_dict["prompt"]}]
443457

444458
raise ValueError(
445-
"Could not determine prompt/messages format from input row. "
446-
"Expected OpenAI request body with a 'messages' key, or a 'request' key"
447-
" with OpenAI request body, or Gemini request body with a 'contents'"
448-
f" key, or a 'prompt' key with a raw string. Found keys: {list(row_dict.keys())}"
459+
"Could not determine prompt/messages format from input row. Expected"
460+
" OpenAI request body with a 'messages' key, or a 'request' key with"
461+
" OpenAI request body, or Gemini request body with a 'contents' key, or"
462+
f" a 'prompt' key with a raw string. Found keys: {list(row_dict.keys())}"
449463
)
450464

451465

@@ -466,7 +480,8 @@ def _run_litellm_inference(
466480
) -> list[Optional[dict[str, Any]]]:
467481
"""Runs inference using LiteLLM with concurrency."""
468482
logger.info(
469-
"Generating responses for %d prompts using LiteLLM for third party model: %s",
483+
"Generating responses for %d prompts using LiteLLM for third party"
484+
" model: %s",
470485
len(prompt_dataset),
471486
model,
472487
)
@@ -776,8 +791,8 @@ def _execute_inference(
776791
777792
Args:
778793
api_client: The API client.
779-
src: The source of the dataset. Can be a string (path to a local file,
780-
a GCS path, or a BigQuery table) or a Pandas DataFrame.
794+
src: The source of the dataset. Can be a string (path to a local file, a
795+
GCS path, or a BigQuery table) or a Pandas DataFrame.
781796
model: The model to use for inference. Can be a callable function or a
782797
string representing a model.
783798
agent_engine: The agent engine to use for inference. Can be a resource
@@ -847,8 +862,8 @@ def _execute_inference(
847862
raise TypeError(
848863
f"Unsupported agent_engine type: {type(agent_engine)}. Expecting a"
849864
" string (agent engine resource name in"
850-
" 'projects/{project_id}/locations/{location_id}/reasoningEngines/{reasoning_engine_id}' format)"
851-
" or a types.AgentEngine instance."
865+
" 'projects/{project_id}/locations/{location_id}/reasoningEngines/{reasoning_engine_id}'"
866+
" format) or a types.AgentEngine instance."
852867
)
853868
if (
854869
_evals_constant.INTERMEDIATE_EVENTS in prompt_dataset.columns
@@ -1193,7 +1208,8 @@ def _execute_evaluation( # type: ignore[no-untyped-def]
11931208
validated_agent_info = agent_info
11941209
else:
11951210
raise TypeError(
1196-
f"agent_info values must be of type types.evals.AgentInfo or dict, but got {type(agent_info)}'"
1211+
"agent_info values must be of type types.evals.AgentInfo or dict,"
1212+
f" but got {type(agent_info)}'"
11971213
)
11981214

11991215
processed_eval_dataset, num_response_candidates = _resolve_dataset_inputs(
@@ -1323,8 +1339,9 @@ def _run_agent_internal(
13231339
processed_responses
13241340
) != len(processed_intermediate_events):
13251341
raise RuntimeError(
1326-
"Critical prompt/response/intermediate_events count mismatch: %d prompts vs %d vs %d"
1327-
" responses. This indicates an issue in response collection."
1342+
"Critical prompt/response/intermediate_events count mismatch: %d"
1343+
" prompts vs %d vs %d responses. This indicates an issue in response"
1344+
" collection."
13281345
% (
13291346
len(prompt_dataset),
13301347
len(processed_responses),
@@ -1354,7 +1371,11 @@ def _run_agent(
13541371
agent: Optional[LlmAgent],
13551372
prompt_dataset: pd.DataFrame,
13561373
) -> list[
1357-
Union[list[dict[str, Any]], dict[str, Any], genai_types.GenerateContentResponse]
1374+
Union[
1375+
list[dict[str, Any]],
1376+
dict[str, Any],
1377+
genai_types.GenerateContentResponse,
1378+
]
13581379
]:
13591380
"""Internal helper to run inference using Gemini model with concurrency."""
13601381
if agent_engine:
@@ -1645,12 +1666,15 @@ def _convert_request_to_dataset_row(
16451666
return dict_row
16461667

16471668

1648-
def _transform_dataframe(rows: list[dict[str, Any]]) -> list[types.EvaluationDataset]:
1669+
def _transform_dataframe(
1670+
rows: list[dict[str, Any]],
1671+
) -> list[types.EvaluationDataset]:
16491672
"""Transforms rows to a list of EvaluationDatasets.
16501673
16511674
Args:
16521675
rows: A list of rows, each row is a dictionary of candidate name to response
16531676
text.
1677+
16541678
Returns:
16551679
A list of EvaluationDatasets, one for each candidate.
16561680
"""
@@ -1677,6 +1701,7 @@ def _get_eval_cases_eval_dfs_from_eval_items(
16771701
Args:
16781702
api_client: The API client.
16791703
evaluation_set_name: The name of the evaluation set.
1704+
16801705
Returns:
16811706
A tuple of two lists:
16821707
- eval_case_results: A list of EvalCaseResults, one for each evaluation
@@ -1748,6 +1773,7 @@ def _get_eval_result_from_eval_items(
17481773
Args:
17491774
results: The EvaluationRunResults object.
17501775
eval_items: The list of EvaluationItems.
1776+
17511777
Returns:
17521778
An EvaluationResult object.
17531779
"""

0 commit comments

Comments
 (0)