Skip to content

Commit f8fd56d

Browse files
test: Add comprehensive test coverage for Gemini provider session handling and event recording
Co-Authored-By: Alex Reibman <meta.alex.r@gmail.com>
1 parent bff477c commit f8fd56d

File tree

4 files changed

+128
-81
lines changed

4 files changed

+128
-81
lines changed

agentops/llms/providers/gemini.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional, Generator, Any, Dict, Union
22

33
from agentops.llms.providers.base import BaseProvider
4-
from agentops.event import LLMEvent
4+
from agentops.event import LLMEvent, ErrorEvent
55
from agentops.session import Session
66
from agentops.helpers import get_ISO_time, check_call_stack_for_agent_id
77
from agentops.log_config import logger
@@ -44,8 +44,7 @@ def handle_response(
4444
For streaming responses: A generator yielding response chunks
4545
4646
Note:
47-
Token counts are not currently provided by the Gemini API.
48-
Future versions may add token counting functionality.
47+
Token counts are extracted from usage_metadata if available.
4948
"""
5049
llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)
5150
if session is not None:
@@ -56,13 +55,14 @@ def handle_response(
5655
accumulated_text = [] # Use list to accumulate text chunks
5756

5857
def handle_stream_chunk(chunk):
59-
if llm_event.returns is None:
60-
llm_event.returns = chunk
61-
llm_event.agent_id = check_call_stack_for_agent_id()
62-
llm_event.model = getattr(chunk, "model", "gemini-1.5-flash") # Default if not provided
63-
llm_event.prompt = kwargs.get("prompt") or kwargs.get("contents", [])
64-
58+
nonlocal llm_event
6559
try:
60+
if llm_event.returns is None:
61+
llm_event.returns = chunk
62+
llm_event.agent_id = check_call_stack_for_agent_id()
63+
llm_event.model = getattr(chunk, "model", "gemini-1.5-flash")
64+
llm_event.prompt = kwargs.get("prompt", kwargs.get("contents", []))
65+
6666
if hasattr(chunk, "text") and chunk.text:
6767
accumulated_text.append(chunk.text)
6868

@@ -79,25 +79,31 @@ def handle_stream_chunk(chunk):
7979
self._safe_record(session, llm_event)
8080

8181
except Exception as e:
82+
if session is not None:
83+
self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e))
8284
logger.warning(
83-
f"Unable to parse chunk for Gemini LLM call. Skipping upload to AgentOps\n"
84-
f"Error: {str(e)}\n"
85+
f"Unable to parse chunk for Gemini LLM call. Error: {str(e)}\n"
8586
f"Chunk: {chunk}\n"
8687
f"kwargs: {kwargs}\n"
8788
)
8889

8990
def stream_handler(stream):
90-
for chunk in stream:
91-
handle_stream_chunk(chunk)
92-
yield chunk
91+
try:
92+
for chunk in stream:
93+
handle_stream_chunk(chunk)
94+
yield chunk
95+
except Exception as e:
96+
if session is not None:
97+
self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e))
98+
raise # Re-raise after recording error
9399

94100
return stream_handler(response)
95101

96102
# For synchronous responses
97103
try:
98104
llm_event.returns = response
99105
llm_event.agent_id = check_call_stack_for_agent_id()
100-
llm_event.prompt = kwargs.get("prompt") or kwargs.get("contents", [])
106+
llm_event.prompt = kwargs.get("prompt", kwargs.get("contents", []))
101107
llm_event.completion = response.text
102108
llm_event.model = getattr(response, "model", "gemini-1.5-flash")
103109

@@ -110,9 +116,10 @@ def stream_handler(stream):
110116
llm_event.end_timestamp = get_ISO_time()
111117
self._safe_record(session, llm_event)
112118
except Exception as e:
119+
if session is not None:
120+
self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e))
113121
logger.warning(
114-
f"Unable to parse response for Gemini LLM call. Skipping upload to AgentOps\n"
115-
f"Error: {str(e)}\n"
122+
f"Unable to parse response for Gemini LLM call. Error: {str(e)}\n"
116123
f"Response: {response}\n"
117124
f"kwargs: {kwargs}\n"
118125
)
@@ -136,24 +143,33 @@ def override(self):
136143

137144
def patched_function(self, *args, **kwargs):
138145
init_timestamp = get_ISO_time()
139-
session = kwargs.pop("session", None) # Always try to pop session, returns None if not present
146+
147+
# Extract and remove session from kwargs if present
148+
session = kwargs.pop("session", None)
140149

141150
# Handle positional prompt argument
142151
event_kwargs = kwargs.copy() # Create a copy for event tracking
143152
if args and len(args) > 0:
144153
# First argument is the prompt
154+
prompt = args[0]
145155
if "contents" not in kwargs:
146-
kwargs["contents"] = args[0]
147-
event_kwargs["prompt"] = args[0] # Store original prompt
156+
kwargs["contents"] = prompt
157+
event_kwargs["prompt"] = prompt # Store original prompt for event tracking
148158
args = args[1:] # Remove prompt from args since we moved it to kwargs
149159

150160
# Call original method and track event
151-
if "generate_content" in _ORIGINAL_METHODS:
152-
result = _ORIGINAL_METHODS["generate_content"](self, *args, **kwargs)
153-
return provider.handle_response(result, event_kwargs, init_timestamp, session=session)
154-
else:
155-
logger.error("Original generate_content method not found. Cannot proceed with override.")
156-
return None
161+
try:
162+
if "generate_content" in _ORIGINAL_METHODS:
163+
result = _ORIGINAL_METHODS["generate_content"](self, *args, **kwargs)
164+
return provider.handle_response(result, event_kwargs, init_timestamp, session=session)
165+
else:
166+
logger.error("Original generate_content method not found. Cannot proceed with override.")
167+
return None
168+
except Exception as e:
169+
logger.error(f"Error in Gemini generate_content: {str(e)}")
170+
if session is not None:
171+
provider._safe_record(session, ErrorEvent(exception=e))
172+
raise # Re-raise the exception after recording
157173

158174
# Override the method at class level
159175
genai.GenerativeModel.generate_content = patched_function

examples/gemini_examples/create_notebook.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,14 @@
3030

3131
init = """\
3232
# Initialize AgentOps and Gemini model
33-
ao_client = agentops.init()
33+
agentops.init() # Provider detection happens automatically
3434
model = genai.GenerativeModel("gemini-1.5-flash")"""
3535

3636
sync_test = """\
3737
# Test synchronous generation
3838
print("Testing synchronous generation:")
3939
response = model.generate_content(
40-
"What are the three laws of robotics?",
41-
session=ao_client
40+
"What are the three laws of robotics?"
4241
)
4342
print(response.text)"""
4443

@@ -47,8 +46,7 @@
4746
print("\\nTesting streaming generation:")
4847
response = model.generate_content(
4948
"Explain the concept of machine learning in simple terms.",
50-
stream=True,
51-
session=ao_client
49+
stream=True
5250
)
5351
5452
for chunk in response:
@@ -58,8 +56,7 @@
5856
# Test another synchronous generation
5957
print("\\nTesting another synchronous generation:")
6058
response = model.generate_content(
61-
"What is the difference between supervised and unsupervised learning?",
62-
session=ao_client
59+
"What is the difference between supervised and unsupervised learning?"
6360
)
6461
print(response.text)"""
6562

examples/gemini_examples/test_notebook.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838

3939
# Initialize AgentOps and Gemini model
40-
ao_client = agentops.init()
40+
agentops.init() # Provider detection happens automatically
4141
model = genai.GenerativeModel("gemini-1.5-flash")
4242

4343

@@ -46,7 +46,7 @@
4646

4747
# Test synchronous generation
4848
print("Testing synchronous generation:")
49-
response = model.generate_content("What are the three laws of robotics?", session=ao_client)
49+
response = model.generate_content("What are the three laws of robotics?")
5050
print(response.text)
5151

5252

@@ -56,7 +56,7 @@
5656
# Test streaming generation
5757
print("\nTesting streaming generation:")
5858
response = model.generate_content(
59-
"Explain the concept of machine learning in simple terms.", stream=True, session=ao_client
59+
"Explain the concept of machine learning in simple terms.", stream=True
6060
)
6161

6262
for chunk in response:
@@ -66,7 +66,7 @@
6666
# Test another synchronous generation
6767
print("\nTesting another synchronous generation:")
6868
response = model.generate_content(
69-
"What is the difference between supervised and unsupervised learning?", session=ao_client
69+
"What is the difference between supervised and unsupervised learning?"
7070
)
7171
print(response.text)
7272

0 commit comments

Comments
 (0)