Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 47 additions & 21 deletions vertexai/_genai/_evals_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
#
"""Common utilities for evals."""

import asyncio
import base64
import collections
Expand All @@ -24,8 +25,8 @@
import os
import threading
import time
import uuid
from typing import Any, Callable, Literal, Optional, Union
import uuid

from google.api_core import exceptions as api_exceptions
import vertexai
Expand All @@ -41,7 +42,6 @@
from . import _evals_metric_loaders
from . import _evals_utils
from . import _gcs_utils

from . import evals
from . import types

Expand Down Expand Up @@ -90,7 +90,8 @@ def _get_api_client_with_location(
return api_client

logger.info(
"Model endpoint location set to %s, overriding client location %s for this API call.",
"Model endpoint location set to %s, overriding client location %s for"
" this API call.",
location,
api_client.location,
)
Expand Down Expand Up @@ -208,7 +209,7 @@ def _generate_content_with_retry(


def _build_generate_content_config(
request_dict: dict[str, Any],
request_dict: Union[dict[str, Any], str],
global_config: Optional[genai_types.GenerateContentConfig] = None,
) -> genai_types.GenerateContentConfig:
"""Builds a GenerateContentConfig from the request dictionary or provided config."""
Expand All @@ -220,6 +221,9 @@ def _build_generate_content_config(
else:
merged_config_dict = {}

if not isinstance(request_dict, dict):
return genai_types.GenerateContentConfig(**merged_config_dict)

for key in [
"system_instruction",
"tools",
Expand Down Expand Up @@ -264,7 +268,11 @@ def _execute_inference_concurrently(
agent_engine: Optional[Union[str, types.AgentEngine]] = None,
agent: Optional[LlmAgent] = None,
) -> list[
Union[genai_types.GenerateContentResponse, dict[str, Any], list[dict[str, Any]]]
Union[
genai_types.GenerateContentResponse,
dict[str, Any],
list[dict[str, Any]],
]
]:
"""Internal helper to run inference with concurrency."""
logger.info(
Expand Down Expand Up @@ -384,7 +392,11 @@ def _run_gemini_inference(
prompt_dataset: pd.DataFrame,
config: Optional[genai_types.GenerateContentConfig] = None,
) -> list[
Union[genai_types.GenerateContentResponse, dict[str, Any], list[dict[str, Any]]]
Union[
genai_types.GenerateContentResponse,
dict[str, Any],
list[dict[str, Any]],
]
]:
"""Internal helper to run inference using Gemini model with concurrency."""
return _execute_inference_concurrently(
Expand All @@ -410,7 +422,9 @@ def _run_custom_inference(
)


def _convert_prompt_row_to_litellm_messages(row: pd.Series) -> list[dict[str, Any]]:
def _convert_prompt_row_to_litellm_messages(
row: pd.Series,
) -> list[dict[str, Any]]:
"""Converts a DataFrame row into LiteLLM's messages format by detecting the input schema."""
messages: list[dict[str, Any]] = []
row_dict = row.to_dict()
Expand Down Expand Up @@ -442,10 +456,10 @@ def _convert_prompt_row_to_litellm_messages(row: pd.Series) -> list[dict[str, An
return [{"role": USER_AUTHOR, "content": row_dict["prompt"]}]

raise ValueError(
"Could not determine prompt/messages format from input row. "
"Expected OpenAI request body with a 'messages' key, or a 'request' key"
" with OpenAI request body, or Gemini request body with a 'contents'"
f" key, or a 'prompt' key with a raw string. Found keys: {list(row_dict.keys())}"
"Could not determine prompt/messages format from input row. Expected"
" OpenAI request body with a 'messages' key, or a 'request' key with"
" OpenAI request body, or Gemini request body with a 'contents' key, or"
f" a 'prompt' key with a raw string. Found keys: {list(row_dict.keys())}"
)


Expand All @@ -466,7 +480,8 @@ def _run_litellm_inference(
) -> list[Optional[dict[str, Any]]]:
"""Runs inference using LiteLLM with concurrency."""
logger.info(
"Generating responses for %d prompts using LiteLLM for third party model: %s",
"Generating responses for %d prompts using LiteLLM for third party"
" model: %s",
len(prompt_dataset),
model,
)
Expand Down Expand Up @@ -776,8 +791,8 @@ def _execute_inference(

Args:
api_client: The API client.
src: The source of the dataset. Can be a string (path to a local file,
a GCS path, or a BigQuery table) or a Pandas DataFrame.
src: The source of the dataset. Can be a string (path to a local file, a
GCS path, or a BigQuery table) or a Pandas DataFrame.
model: The model to use for inference. Can be a callable function or a
string representing a model.
agent_engine: The agent engine to use for inference. Can be a resource
Expand Down Expand Up @@ -847,8 +862,8 @@ def _execute_inference(
raise TypeError(
f"Unsupported agent_engine type: {type(agent_engine)}. Expecting a"
" string (agent engine resource name in"
" 'projects/{project_id}/locations/{location_id}/reasoningEngines/{reasoning_engine_id}' format)"
" or a types.AgentEngine instance."
" 'projects/{project_id}/locations/{location_id}/reasoningEngines/{reasoning_engine_id}'"
" format) or a types.AgentEngine instance."
)
if (
_evals_constant.INTERMEDIATE_EVENTS in prompt_dataset.columns
Expand Down Expand Up @@ -1193,7 +1208,8 @@ def _execute_evaluation( # type: ignore[no-untyped-def]
validated_agent_info = agent_info
else:
raise TypeError(
f"agent_info values must be of type types.evals.AgentInfo or dict, but got {type(agent_info)}'"
"agent_info values must be of type types.evals.AgentInfo or dict,"
f" but got {type(agent_info)}'"
)

processed_eval_dataset, num_response_candidates = _resolve_dataset_inputs(
Expand Down Expand Up @@ -1323,8 +1339,9 @@ def _run_agent_internal(
processed_responses
) != len(processed_intermediate_events):
raise RuntimeError(
"Critical prompt/response/intermediate_events count mismatch: %d prompts vs %d vs %d"
" responses. This indicates an issue in response collection."
"Critical prompt/response/intermediate_events count mismatch: %d"
" prompts vs %d vs %d responses. This indicates an issue in response"
" collection."
% (
len(prompt_dataset),
len(processed_responses),
Expand Down Expand Up @@ -1354,7 +1371,11 @@ def _run_agent(
agent: Optional[LlmAgent],
prompt_dataset: pd.DataFrame,
) -> list[
Union[list[dict[str, Any]], dict[str, Any], genai_types.GenerateContentResponse]
Union[
list[dict[str, Any]],
dict[str, Any],
genai_types.GenerateContentResponse,
]
]:
"""Internal helper to run inference using Gemini model with concurrency."""
if agent_engine:
Expand Down Expand Up @@ -1645,12 +1666,15 @@ def _convert_request_to_dataset_row(
return dict_row


def _transform_dataframe(rows: list[dict[str, Any]]) -> list[types.EvaluationDataset]:
def _transform_dataframe(
rows: list[dict[str, Any]],
) -> list[types.EvaluationDataset]:
"""Transforms rows to a list of EvaluationDatasets.

Args:
rows: A list of rows, each row is a dictionary of candidate name to response
text.

Returns:
A list of EvaluationDatasets, one for each candidate.
"""
Expand All @@ -1677,6 +1701,7 @@ def _get_eval_cases_eval_dfs_from_eval_items(
Args:
api_client: The API client.
evaluation_set_name: The name of the evaluation set.

Returns:
A tuple of two lists:
- eval_case_results: A list of EvalCaseResults, one for each evaluation
Expand Down Expand Up @@ -1748,6 +1773,7 @@ def _get_eval_result_from_eval_items(
Args:
results: The EvaluationRunResults object.
eval_items: The list of EvaluationItems.

Returns:
An EvaluationResult object.
"""
Expand Down
Loading