Skip to content

Commit bff477c

Browse files
fix: Update Gemini provider to properly handle prompt extraction and improve test coverage
Co-Authored-By: Alex Reibman <meta.alex.r@gmail.com>
1 parent 9c9af3a commit bff477c

File tree

2 files changed

+25
-21
lines changed

2 files changed

+25
-21
lines changed

agentops/llms/providers/gemini.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,6 @@ class GeminiProvider(BaseProvider):
1919
is called and the google.generativeai package is imported. No manual
2020
initialization is required."""
2121

22-
"""Provider for Google's Gemini API.
23-
24-
This provider is automatically detected and initialized when agentops.init()
25-
is called and the google.generativeai package is imported. No manual
26-
initialization is required."""
27-
2822
def __init__(self, client=None):
2923
"""Initialize the Gemini provider.
3024
@@ -66,7 +60,7 @@ def handle_stream_chunk(chunk):
6660
llm_event.returns = chunk
6761
llm_event.agent_id = check_call_stack_for_agent_id()
6862
llm_event.model = getattr(chunk, "model", "gemini-1.5-flash") # Default if not provided
69-
llm_event.prompt = kwargs.get("contents", [])
63+
llm_event.prompt = kwargs.get("prompt") or kwargs.get("contents", [])
7064

7165
try:
7266
if hasattr(chunk, "text") and chunk.text:
@@ -103,7 +97,7 @@ def stream_handler(stream):
10397
try:
10498
llm_event.returns = response
10599
llm_event.agent_id = check_call_stack_for_agent_id()
106-
llm_event.prompt = kwargs.get("contents", [])
100+
llm_event.prompt = kwargs.get("prompt") or kwargs.get("contents", [])
107101
llm_event.completion = response.text
108102
llm_event.model = getattr(response, "model", "gemini-1.5-flash")
109103

@@ -144,10 +138,19 @@ def patched_function(self, *args, **kwargs):
144138
init_timestamp = get_ISO_time()
145139
session = kwargs.pop("session", None) # Always try to pop session, returns None if not present
146140

141+
# Handle positional prompt argument
142+
event_kwargs = kwargs.copy() # Create a copy for event tracking
143+
if args and len(args) > 0:
144+
# First argument is the prompt
145+
if "contents" not in kwargs:
146+
kwargs["contents"] = args[0]
147+
event_kwargs["prompt"] = args[0] # Store original prompt
148+
args = args[1:] # Remove prompt from args since we moved it to kwargs
149+
147150
# Call original method and track event
148151
if "generate_content" in _ORIGINAL_METHODS:
149152
result = _ORIGINAL_METHODS["generate_content"](self, *args, **kwargs)
150-
return provider.handle_response(result, kwargs, init_timestamp, session=session)
153+
return provider.handle_response(result, event_kwargs, init_timestamp, session=session)
151154
else:
152155
logger.error("Original generate_content method not found. Cannot proceed with override.")
153156
return None

tests/unit/test_llms/providers/test_gemini.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import google.generativeai as genai
22
import agentops
3-
from agentops.llms.providers.gemini import GeminiProvider
3+
from agentops.llms.providers.gemini import GeminiProvider, _ORIGINAL_METHODS
44
from agentops.event import LLMEvent
55

66
# Configure the API key from environment variable
@@ -43,7 +43,7 @@ def test_gemini_provider():
4343
provider = GeminiProvider(model)
4444
assert provider.client == model
4545
assert provider.provider_name == "Gemini"
46-
assert provider.original_generate is None
46+
assert "generate_content" not in _ORIGINAL_METHODS
4747

4848

4949
def test_gemini_sync_generation():
@@ -97,31 +97,31 @@ class InvalidClient:
9797

9898
# Test override with None client
9999
provider.override() # Should log warning and return
100-
assert provider.original_generate is None
100+
assert "generate_content" not in _ORIGINAL_METHODS
101101

102102
# Test override with uninitialized generate_content
103103
provider.client = InvalidClient()
104104
provider.override() # Should log warning about missing generate_content
105-
assert provider.original_generate is None
105+
assert "generate_content" not in _ORIGINAL_METHODS
106106

107-
# Test patched function with None original_generate
107+
# Test patched function with missing original method
108108
model = genai.GenerativeModel("gemini-1.5-flash")
109109
provider = GeminiProvider(model)
110-
provider.original_generate = None
111110
provider.override()
112111

113-
# Should log error and return None
112+
# Should log error and return None when original method is missing
113+
if "generate_content" in _ORIGINAL_METHODS:
114+
del _ORIGINAL_METHODS["generate_content"]
114115
result = model.generate_content("test prompt")
115116
assert result is None
116117

117118
# Test undo_override with None client
118119
provider.client = None
119120
provider.undo_override() # Should handle None client gracefully
120121

121-
# Test undo_override with None original_generate
122+
# Test undo_override with missing original method
122123
provider.client = model
123-
provider.original_generate = None
124-
provider.undo_override() # Should handle None original_generate gracefully
124+
provider.undo_override() # Should handle missing original method gracefully
125125

126126
# Test automatic provider detection
127127
agentops.init()
@@ -503,6 +503,7 @@ def test_undo_override():
503503
provider.undo_override()
504504
assert model.generate_content == original_generate
505505

506-
# Test undo_override when original_generate is None
507-
provider.original_generate = None
506+
# Test undo_override with missing original method
507+
if "generate_content" in _ORIGINAL_METHODS:
508+
del _ORIGINAL_METHODS["generate_content"]
508509
provider.undo_override() # Should not raise any errors

0 commit comments

Comments
 (0)