Skip to content

Commit 8099646

Browse files
feat: Add Gemini support (#15)
1 parent 6b962df commit 8099646

File tree

6 files changed

+151
-38
lines changed

6 files changed

+151
-38
lines changed

models/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ async def _generate_text(self, message: dict, run_params: dict, error_tracker: E
301301
if self.inference_type in (
302302
constants.INFERENCE_SERVER_VLLM_CHAT_COMPLETION,
303303
constants.OPENAI_CHAT_COMPLETION,
304+
constants.GEMINI_CHAT_COMPLETION,
304305
):
305306
for i in range(num_chunks):
306307
start = i * max_samples
@@ -381,6 +382,7 @@ async def _generate_text(self, message: dict, run_params: dict, error_tracker: E
381382
if self.inference_type in (
382383
constants.INFERENCE_SERVER_VLLM_CHAT_COMPLETION,
383384
constants.OPENAI_CHAT_COMPLETION,
385+
constants.GEMINI_CHAT_COMPLETION,
384386
):
385387
# Cut to first 30s, then process as chat completion
386388
if audio_array is not None and len(audio_array) > 0:
@@ -472,6 +474,7 @@ async def _handle_multi_turn(self, message: dict, run_params: dict, error_tracke
472474
if self.inference_type not in (
473475
constants.INFERENCE_SERVER_VLLM_CHAT_COMPLETION,
474476
constants.OPENAI_CHAT_COMPLETION,
477+
constants.GEMINI_CHAT_COMPLETION,
475478
):
476479
raise ValueError("Multi-turn conversations only supported for chat completion inference types")
477480

models/request_resp_handler.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import inspect
66

77
import httpx
8+
from google.auth import default
9+
from google.auth.transport.requests import Request
810
from openai import AsyncAzureOpenAI, AsyncOpenAI
911
from models.model_response import ModelResponse, ErrorTracker
1012
from utils import constants
@@ -20,6 +22,8 @@ def __init__(self, inference_type: str, model_info: dict, generation_params: dic
2022
self.model_info = model_info
2123
self.api = model_info.get("url")
2224
self.auth = model_info.get("auth_token", "")
25+
self.location = model_info.get("location", "")
26+
self.project_id = model_info.get("project_id", "")
2327
self.api_version = model_info.get("api_version", "")
2428
self.client = None
2529
self.timeout = timeout
@@ -153,6 +157,25 @@ def set_client(self, verify_ssl: bool, timeout: int):
153157
http_client=httpx.AsyncClient(verify=verify_ssl),
154158
)
155159
)
160+
elif self.inference_type == constants.GEMINI_CHAT_COMPLETION:
161+
# Gemini endpoints
162+
163+
# Set an API host for Gemini on Vertex AI
164+
api_host = "aiplatform.googleapis.com"
165+
if self.location != "global":
166+
api_host = f"{self.location}-aiplatform.googleapis.com"
167+
168+
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
169+
credentials.refresh(Request())
170+
171+
self.client = AsyncOpenAI(
172+
base_url=f"https://{api_host}/v1/projects/{self.project_id}/locations/{self.location}/endpoints/openapi",
173+
api_key=credentials.token,
174+
timeout=timeout,
175+
max_retries=0,
176+
default_headers={"Connection": "close"},
177+
http_client=httpx.AsyncClient(verify=verify_ssl),
178+
)
156179

157180
def validated_safe_generation_params(self, generation_params):
158181
"""Validate and sanitize generation parameters for the OpenAI API client.
@@ -187,19 +210,27 @@ async def request_server(self, msg_body, tools=None, error_tracker: ErrorTracker
187210
2. Any exception is wrapped in a `ModelResponse` with ``response_code = 500``.
188211
"""
189212
model_name: str | None = self.model_info.get("model")
213+
reasoning_effort = self.model_info.get("reasoning_effort", None)
190214
if tools:
191215
tools = self.convert_to_tool(tools)
192216

193217
start_time = time.time()
194218
# Re-create a fresh client for this request to avoid closed-loop issues
195219
self.set_client(verify_ssl=True, timeout=self.timeout)
196220
try:
197-
if self.inference_type == constants.OPENAI_CHAT_COMPLETION or self.inference_type == constants.INFERENCE_SERVER_VLLM_CHAT_COMPLETION:
221+
if self.inference_type in (constants.OPENAI_CHAT_COMPLETION, constants.INFERENCE_SERVER_VLLM_CHAT_COMPLETION, constants.GEMINI_CHAT_COMPLETION):
198222
# openai chat completions, vllm chat completions
199223
self.generation_params = self.validated_safe_generation_params(self.generation_params)
200-
prediction = await self.client.chat.completions.create(
201-
model=model_name, messages=msg_body, tools=tools, **self.generation_params
202-
)
224+
225+
if reasoning_effort:
226+
prediction = await self.client.chat.completions.create(
227+
model=model_name, messages=msg_body, tools=tools, reasoning_effort=reasoning_effort, **self.generation_params
228+
)
229+
else:
230+
prediction = await self.client.chat.completions.create(
231+
model=model_name, messages=msg_body, tools=tools, **self.generation_params
232+
)
233+
203234
raw_response: str = self._extract_response_data(prediction)
204235
llm_response: str = raw_response['choices'][0]['message']['content'] or " "
205236

requirements.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,8 @@ nest_asyncio==1.6.0
5151
immutabledict==4.2.1
5252

5353
# Debugging
54-
debugpy==1.8.0 # VSCode debugging support
54+
debugpy==1.8.0 # VSCode debugging support
55+
56+
# Gemimini libraries
57+
google-auth==2.40.3
58+
google-genai==1.38.0

sample_config.yaml

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,64 +25,75 @@ filter:
2525
num_samples: 100 # number of samples to run(remove for all)
2626
length_filter: [0.0, 30.0] #optional - filters for only audio samples in this length(seconds) - only supported for general and callhome preprocessors
2727

28-
judge_properties:
28+
judge_settings:
2929
judge_concurrency: 8 #judge call(optional)
30-
judge_model: "gpt-4o-mini" #optional
31-
judge_type: "openai" # mandatory (vllm or openai)
32-
judge_api_version: "${API_VERSION}" # optional(needed for openai)
33-
judge_api_endpoint: "${ENDPOINT_URL}" # mandatory
34-
judge_api_key: "${AUTH_TOKEN}" # mandatory
30+
judge_model: gpt-4o-mini #optional
31+
judge_type: openai # mandatory (vllm or openai)
32+
judge_api_version: ${API_VERSION} # optional(needed for openai)
33+
judge_api_endpoint: ${ENDPOINT_URL} # mandatory
34+
judge_api_key: ${AUTH_TOKEN} # mandatory
3535
judge_temperature: 0.1 # optional
36-
judge_prompt_model_override: "gpt-4o-mini-enhanced" # optional
36+
judge_prompt_model_override: gpt-4o-mini-enhanced # optional
3737

3838
logging:
3939
log_file: "audiobench.log" # Path to the main log file
4040

41-
4241
models:
43-
- name: "gpt-4o-mini-audio-preview-1" # mandatory - must be unique
44-
inference_type: "openai" # mandatory - you can use vllm(vllm), openai(openai), (chat completion) or audio transcription endpoint(transcription)
45-
url: "${ENDPOINT_URL}" # mandatory - endpoint url
42+
- name: gpt-4o-mini-audio-preview-1 # must be unique
43+
inference_type: openai # you can use vllm, openai, gemini or transcription
44+
url: ${ENDPOINT_URL} # endpoint url
4645
delay: 100
4746
retry_attempts: 8
4847
timeout: 30
49-
model: "gpt-4o-mini-audio-preview" # mandatory - only needed for vllm
50-
auth_token: "${AUTH_TOKEN}"
51-
api_version: "${API_VERSION}"
48+
model: gpt-4o-mini-audio-preview
49+
auth_token: ${AUTH_TOKEN}
50+
api_version: ${API_VERSION}
5251
batch_size: 300 # Optional - batch eval size
5352
chunk_size: 30 # Optional - max audio length in seconds fed to model
5453

55-
- name: "gpt-4o-mini-audio-preview-2" # mandatory - must be unique
56-
inference_type: "openai" # mandatory - you can use vllm(vllm), openai(openai), (chat completion) or audio transcription endpoint(transcription)
57-
url: "${ENDPOINT_URL}" # mandatory - endpoint url
54+
- name: gpt-4o-mini-audio-preview-2 # must be unique
55+
inference_type: openai # you can use vllm, openai, gemini or transcription
56+
url: ${ENDPOINT_URL} # endpoint url
5857
delay: 100
5958
retry_attempts: 8
6059
timeout: 30
61-
model: "gpt-4o-mini-audio-preview" # mandatory - only needed for vllm
62-
auth_token: "${AUTH_TOKEN}"
63-
api_version: "${API_VERSION}"
64-
batch_size: 100 # Optional - batch eval size
60+
model: gpt-4o-mini-audio-preview
61+
auth_token: ${AUTH_TOKEN}
62+
api_version: ${API_VERSION}
63+
batch_size: 300 # Optional - batch eval size
6564
chunk_size: 30 # Optional - max audio length in seconds fed to model
6665

67-
- name: "qwen-2.5-omni"
68-
inference_type: "vllm" # mandatory - you can use vllm(vllm), openai(openai), (chat completion) or audio transcription endpoint(transcription)
69-
url: "${ENDPOINT_URL}" # mandatory - endpoint url
66+
- name: gemini-2.5-flash # must be unique
67+
inference_type: gemini # you can use vllm, openai, gemini or transcription
68+
location: ${GOOGLE_CLOUD_LOCATION} # GCP Vertex AI configureation
69+
project_id: ${GOOGLE_CLOUD_PROJECT} # GCP Vertex AI configureation
70+
reasoning_effort: medium # Optional - Reasoning effort for supported reasoning models like gemini-2.5-flash, gpt-5,...
71+
delay: 100
72+
retry_attempts: 5
73+
timeout: 300
74+
model: google/gemini-2.5-flash
75+
batch_size: 100 # Optional - batch eval size
76+
chunk_size: 30240 # Optional - max audio length in seconds fed to model
77+
78+
- name: qwen-2.5-omni # must be unique
79+
inference_type: vllm # you can use vllm, openai, gemini or transcription
80+
url: ${ENDPOINT_URL} # endpoint url
7081
delay: 100
7182
retry_attempts: 8
7283
timeout: 30
73-
model: "qwen-2.5-omni" # mandatory - only needed for vllm
74-
auth_token: "${AUTH_TOKEN}"
84+
model: qwen-2.5-omni
85+
auth_token: ${AUTH_TOKEN}
7586
batch_size: 200 # Optional - batch eval size
7687
chunk_size: 40 # Optional - max audio length in seconds fed to model
7788

78-
- name: "whisper-large-3"
79-
inference_type: "vllm" # mandatory - you can use vllm(vllm), openai(openai), (chat completion) or audio transcription endpoint(transcription)
80-
url: "${ENDPOINT_URL}" # mandatory - endpoint url
89+
- name: whisper-large-3 # must be unique
90+
inference_type: transcription # you can use vllm, openai, gemini or transcription
91+
url: ${ENDPOINT_URL} # endpoint url
8192
delay: 100
8293
retry_attempts: 8
8394
timeout: 30
84-
model: "whisper-large-3" # mandatory - only needed for vllm
85-
auth_token: "${AUTH_TOKEN}"
95+
model: whisper-large-3
96+
auth_token: ${AUTH_TOKEN}
8697
batch_size: 100 # Optional - batch eval size
8798
chunk_size: 30 # Optional - max audio length in seconds fed to model
8899

utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
# Inference server types
2222
INFERENCE_SERVER_VLLM_CHAT_COMPLETION = 'vllm'
2323
OPENAI_CHAT_COMPLETION = 'openai'
24+
GEMINI_CHAT_COMPLETION = 'gemini'
2425
TRANSCRIPTION = 'transcription'
2526

2627
# WER/CER metrics constants

utils/util.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
import importlib
3-
import json
3+
import re
44
import logging
55
import os
66
import statistics
@@ -219,7 +219,7 @@ def _validate_models(config: Dict) -> None:
219219
ValueError: If the models section is invalid
220220
"""
221221
def validate_required_fields(info: Dict, index: int) -> None:
222-
required_fields = ['name', 'model', 'inference_type', 'url']
222+
required_fields = ['name', 'model', 'inference_type']
223223
for field in required_fields:
224224
if not info.get(field) or not isinstance(info[field], str) or not info[field].strip():
225225
raise ValueError(f"Model {index}: '{field}' must be a non-empty string")
@@ -411,9 +411,68 @@ def setup_logging(log_file: str):
411411
# Set httpx logger to WARNING level to reduce noise
412412
logging.getLogger("httpx").setLevel(logging.WARNING)
413413

414+
def _replace_env_vars(value):
415+
"""
416+
Replace environment variables in strings.
417+
Supports ${ENV_VAR} and $ENV_VAR syntax.
418+
419+
Args:
420+
value: String value that may contain environment variables
421+
422+
Returns:
423+
String with environment variables substituted
424+
"""
425+
if not isinstance(value, str):
426+
return value
427+
428+
# Replace ${VAR} format
429+
pattern1 = re.compile(r'\${([^}^{]+)}')
430+
matches = pattern1.findall(value)
431+
if matches:
432+
for match in matches:
433+
env_var = os.environ.get(match)
434+
if env_var is not None:
435+
value = value.replace(f"${{{match}}}", env_var)
436+
else:
437+
logger.warning(f"Environment variable '{match}' not found when processing config")
438+
439+
# Replace $VAR format
440+
pattern2 = re.compile(r'(?<!\\)\$([a-zA-Z0-9_]+)')
441+
matches = pattern2.findall(value)
442+
if matches:
443+
for match in matches:
444+
env_var = os.environ.get(match)
445+
if env_var is not None:
446+
value = value.replace(f"${match}", env_var)
447+
else:
448+
logger.warning(f"Environment variable '{match}' not found when processing config")
449+
450+
return value
451+
452+
def _process_nested_env_vars(data):
453+
"""
454+
Process all values in a nested dictionary/list structure,
455+
replacing environment variables in string values.
456+
457+
Args:
458+
data: Dict, list, or scalar value
459+
460+
Returns:
461+
Data with environment variables substituted in string values
462+
"""
463+
if isinstance(data, dict):
464+
return {k: _process_nested_env_vars(v) for k, v in data.items()}
465+
elif isinstance(data, list):
466+
return [_process_nested_env_vars(item) for item in data]
467+
elif isinstance(data, str):
468+
return _replace_env_vars(data)
469+
else:
470+
return data
471+
414472
def read_config(cfg_path: str):
415473
"""
416474
Read configuration file and set up logging.
475+
Supports environment variable substitution in the format ${ENV_VAR} or $ENV_VAR.
417476
418477
Args:
419478
cfg_path: Path to configuration file
@@ -424,6 +483,10 @@ def read_config(cfg_path: str):
424483
# Set up logging
425484
with open(cfg_path, encoding='utf-8') as f:
426485
raw_cfg = yaml.safe_load(f)
486+
487+
# Process environment variables in the config
488+
raw_cfg = _process_nested_env_vars(raw_cfg)
489+
427490
log_file = raw_cfg.get("logging", {}).get("log_file", "default.log")
428491
setup_logging(log_file)
429492

0 commit comments

Comments
 (0)