1313# limitations under the License.
1414#
1515"""Common utilities for evals."""
16+
1617import asyncio
1718import base64
1819import collections
2425import os
2526import threading
2627import time
27- import uuid
2828from typing import Any , Callable , Literal , Optional , Union
29+ import uuid
2930
3031from google .api_core import exceptions as api_exceptions
3132import vertexai
4142from . import _evals_metric_loaders
4243from . import _evals_utils
4344from . import _gcs_utils
44-
4545from . import evals
4646from . import types
4747
@@ -90,7 +90,8 @@ def _get_api_client_with_location(
9090 return api_client
9191
9292 logger .info (
93- "Model endpoint location set to %s, overriding client location %s for this API call." ,
93+ "Model endpoint location set to %s, overriding client location %s for"
94+ " this API call." ,
9495 location ,
9596 api_client .location ,
9697 )
@@ -208,7 +209,7 @@ def _generate_content_with_retry(
208209
209210
210211def _build_generate_content_config (
211- request_dict : dict [str , Any ],
212+ request_dict : Union [ dict [str , Any ], str ],
212213 global_config : Optional [genai_types .GenerateContentConfig ] = None ,
213214) -> genai_types .GenerateContentConfig :
214215 """Builds a GenerateContentConfig from the request dictionary or provided config."""
@@ -220,6 +221,9 @@ def _build_generate_content_config(
220221 else :
221222 merged_config_dict = {}
222223
224+ if not isinstance (request_dict , dict ):
225+ return genai_types .GenerateContentConfig (** merged_config_dict )
226+
223227 for key in [
224228 "system_instruction" ,
225229 "tools" ,
@@ -264,7 +268,11 @@ def _execute_inference_concurrently(
264268 agent_engine : Optional [Union [str , types .AgentEngine ]] = None ,
265269 agent : Optional [LlmAgent ] = None ,
266270) -> list [
267- Union [genai_types .GenerateContentResponse , dict [str , Any ], list [dict [str , Any ]]]
271+ Union [
272+ genai_types .GenerateContentResponse ,
273+ dict [str , Any ],
274+ list [dict [str , Any ]],
275+ ]
268276]:
269277 """Internal helper to run inference with concurrency."""
270278 logger .info (
@@ -384,7 +392,11 @@ def _run_gemini_inference(
384392 prompt_dataset : pd .DataFrame ,
385393 config : Optional [genai_types .GenerateContentConfig ] = None ,
386394) -> list [
387- Union [genai_types .GenerateContentResponse , dict [str , Any ], list [dict [str , Any ]]]
395+ Union [
396+ genai_types .GenerateContentResponse ,
397+ dict [str , Any ],
398+ list [dict [str , Any ]],
399+ ]
388400]:
389401 """Internal helper to run inference using Gemini model with concurrency."""
390402 return _execute_inference_concurrently (
@@ -410,7 +422,9 @@ def _run_custom_inference(
410422 )
411423
412424
413- def _convert_prompt_row_to_litellm_messages (row : pd .Series ) -> list [dict [str , Any ]]:
425+ def _convert_prompt_row_to_litellm_messages (
426+ row : pd .Series ,
427+ ) -> list [dict [str , Any ]]:
414428 """Converts a DataFrame row into LiteLLM's messages format by detecting the input schema."""
415429 messages : list [dict [str , Any ]] = []
416430 row_dict = row .to_dict ()
@@ -442,10 +456,10 @@ def _convert_prompt_row_to_litellm_messages(row: pd.Series) -> list[dict[str, An
442456 return [{"role" : USER_AUTHOR , "content" : row_dict ["prompt" ]}]
443457
444458 raise ValueError (
445- "Could not determine prompt/messages format from input row. "
446- "Expected OpenAI request body with a 'messages' key, or a 'request' key"
447- " with OpenAI request body, or Gemini request body with a 'contents'"
448- f" key, or a 'prompt' key with a raw string. Found keys: { list (row_dict .keys ())} "
459+ "Could not determine prompt/messages format from input row. Expected "
460+ " OpenAI request body with a 'messages' key, or a 'request' key with "
461+ " OpenAI request body, or Gemini request body with a 'contents' key, or "
462+ f" a 'prompt' key with a raw string. Found keys: { list (row_dict .keys ())} "
449463 )
450464
451465
@@ -466,7 +480,8 @@ def _run_litellm_inference(
466480) -> list [Optional [dict [str , Any ]]]:
467481 """Runs inference using LiteLLM with concurrency."""
468482 logger .info (
469- "Generating responses for %d prompts using LiteLLM for third party model: %s" ,
483+ "Generating responses for %d prompts using LiteLLM for third party"
484+ " model: %s" ,
470485 len (prompt_dataset ),
471486 model ,
472487 )
@@ -776,8 +791,8 @@ def _execute_inference(
776791
777792 Args:
778793 api_client: The API client.
779- src: The source of the dataset. Can be a string (path to a local file,
780- a GCS path, or a BigQuery table) or a Pandas DataFrame.
794+ src: The source of the dataset. Can be a string (path to a local file, a
795+ GCS path, or a BigQuery table) or a Pandas DataFrame.
781796 model: The model to use for inference. Can be a callable function or a
782797 string representing a model.
783798 agent_engine: The agent engine to use for inference. Can be a resource
@@ -847,8 +862,8 @@ def _execute_inference(
847862 raise TypeError (
848863 f"Unsupported agent_engine type: { type (agent_engine )} . Expecting a"
849864 " string (agent engine resource name in"
850- " 'projects/{project_id}/locations/{location_id}/reasoningEngines/{reasoning_engine_id}' format) "
851- " or a types.AgentEngine instance."
865+ " 'projects/{project_id}/locations/{location_id}/reasoningEngines/{reasoning_engine_id}'"
866+ " format) or a types.AgentEngine instance."
852867 )
853868 if (
854869 _evals_constant .INTERMEDIATE_EVENTS in prompt_dataset .columns
@@ -1193,7 +1208,8 @@ def _execute_evaluation( # type: ignore[no-untyped-def]
11931208 validated_agent_info = agent_info
11941209 else :
11951210 raise TypeError (
1196- f"agent_info values must be of type types.evals.AgentInfo or dict, but got { type (agent_info )} '"
1211+ "agent_info values must be of type types.evals.AgentInfo or dict,"
1212+ f" but got { type (agent_info )} '"
11971213 )
11981214
11991215 processed_eval_dataset , num_response_candidates = _resolve_dataset_inputs (
@@ -1323,8 +1339,9 @@ def _run_agent_internal(
13231339 processed_responses
13241340 ) != len (processed_intermediate_events ):
13251341 raise RuntimeError (
1326- "Critical prompt/response/intermediate_events count mismatch: %d prompts vs %d vs %d"
1327- " responses. This indicates an issue in response collection."
1342+ "Critical prompt/response/intermediate_events count mismatch: %d"
1343+ " prompts vs %d vs %d responses. This indicates an issue in response"
1344+ " collection."
13281345 % (
13291346 len (prompt_dataset ),
13301347 len (processed_responses ),
@@ -1354,7 +1371,11 @@ def _run_agent(
13541371 agent : Optional [LlmAgent ],
13551372 prompt_dataset : pd .DataFrame ,
13561373) -> list [
1357- Union [list [dict [str , Any ]], dict [str , Any ], genai_types .GenerateContentResponse ]
1374+ Union [
1375+ list [dict [str , Any ]],
1376+ dict [str , Any ],
1377+ genai_types .GenerateContentResponse ,
1378+ ]
13581379]:
13591380 """Internal helper to run inference using Gemini model with concurrency."""
13601381 if agent_engine :
@@ -1645,12 +1666,15 @@ def _convert_request_to_dataset_row(
16451666 return dict_row
16461667
16471668
1648- def _transform_dataframe (rows : list [dict [str , Any ]]) -> list [types .EvaluationDataset ]:
1669+ def _transform_dataframe (
1670+ rows : list [dict [str , Any ]],
1671+ ) -> list [types .EvaluationDataset ]:
16491672 """Transforms rows to a list of EvaluationDatasets.
16501673
16511674 Args:
16521675 rows: A list of rows, each row is a dictionary of candidate name to response
16531676 text.
1677+
16541678 Returns:
16551679 A list of EvaluationDatasets, one for each candidate.
16561680 """
@@ -1677,6 +1701,7 @@ def _get_eval_cases_eval_dfs_from_eval_items(
16771701 Args:
16781702 api_client: The API client.
16791703 evaluation_set_name: The name of the evaluation set.
1704+
16801705 Returns:
16811706 A tuple of two lists:
16821707 - eval_case_results: A list of EvalCaseResults, one for each evaluation
@@ -1748,6 +1773,7 @@ def _get_eval_result_from_eval_items(
17481773 Args:
17491774 results: The EvaluationRunResults object.
17501775 eval_items: The list of EvaluationItems.
1776+
17511777 Returns:
17521778 An EvaluationResult object.
17531779 """
0 commit comments