Skip to content

Commit 60e497d

Browse files
Fix/google genai api (#124)
* exported namespace management, for use in other libraries. * prep_namespace exposed for use in other libraries. * merge conflict in uv.lock * Updated genai api. * Fixing uv.lock * black --------- Co-authored-by: Niki van Stein <n.van.stein@liacs.leidenuniv.nl>
1 parent 3792738 commit 60e497d

File tree

4 files changed

+2109
-1891
lines changed

4 files changed

+2109
-1891
lines changed

llamea/llm.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from abc import ABC, abstractmethod
1212

1313
try:
14-
import google.generativeai as genai
14+
from google import genai
15+
from google.genai import types
1516
except ModuleNotFoundError: # pragma: no cover - optional dependency
1617
genai = None
1718

@@ -372,22 +373,21 @@ def __init__(self, api_key, model="gemini-2.0-flash", **kwargs):
372373
"""
373374
if genai is None: # pragma: no cover - optional dependency
374375
raise ImportError(
375-
"google-generativeai is required to use Gemini_LLM. Install the 'google-generativeai' package."
376+
"google-genai is required to use Gemini_LLM. Install the 'google-genai' package."
376377
)
377378
super().__init__(api_key, model, None, **kwargs)
378-
genai.configure(api_key=api_key)
379-
generation_config = {
379+
380+
self.generation_config = {
381+
"system_instruction": "You are a computer scientist and excellent Python programmer.",
380382
"temperature": 1,
381383
"top_p": 0.95,
382384
"top_k": 64,
383385
"max_output_tokens": 65536,
384386
"response_mime_type": "text/plain",
385387
}
386-
387-
self.client = genai.GenerativeModel(
388-
model_name=self.model, # "gemini-1.5-flash","gemini-2.0-flash",
389-
generation_config=generation_config,
390-
system_instruction="You are a computer scientist and excellent Python programmer.",
388+
self.api_key = api_key
389+
self.client = genai.Client(
390+
api_key=api_key,
391391
)
392392

393393
def __getstate__(self):
@@ -399,21 +399,29 @@ def __getstate__(self):
399399
def __setstate__(self, state):
400400
"""Restore from a pickled state."""
401401
self.__dict__.update(state) # put back the simple stuff
402-
generation_config = {
403-
"temperature": 1,
404-
"top_p": 0.95,
405-
"top_k": 64,
406-
"max_output_tokens": 65536,
407-
"response_mime_type": "text/plain",
408-
}
409402

410-
self.client = genai.GenerativeModel(
411-
model_name=self.model, # "gemini-1.5-flash","gemini-2.0-flash",
412-
generation_config=generation_config,
413-
system_instruction="You are a computer scientist and excellent Python programmer.",
414-
)
403+
self.client = genai.Client(
404+
api_key=self.api_key
405+
) # expecting implicit pull for env var GOOGLE_API_KEY, too risky to pickle.
415406

416-
def query(self, session_messages, max_retries: int = 5, default_delay: int = 10):
407+
def __deepcopy__(self, memo):
408+
cls = self.__class__
409+
new = cls.__new__(cls)
410+
memo[id(self)] = new
411+
for k, v in self.__dict__.items():
412+
if k == "client":
413+
continue
414+
setattr(new, k, copy.deepcopy(v, memo))
415+
new.client = genai.Client(api_key=new.api_key)
416+
return new
417+
418+
def query(
419+
self,
420+
session_messages: list[dict[str, str]],
421+
max_retries: int = 5,
422+
default_delay: int = 10,
423+
**kwargs,
424+
):
417425
"""
418426
Sends the conversation history to Gemini, retrying on 429 ResourceExhausted exceptions.
419427
@@ -433,14 +441,18 @@ def query(self, session_messages, max_retries: int = 5, default_delay: int = 10)
433441
attempt = 0
434442
while True:
435443
try:
436-
chat = self.client.start_chat(history=history)
444+
config = self.generation_config.copy()
445+
config.update(**kwargs)
446+
chat = self.client.chats.create(
447+
model=self.model, history=history, config=config
448+
)
437449
response = chat.send_message(last)
438450
return response.text
439451

440452
except Exception as err:
441453
attempt += 1
442454
if attempt > max_retries:
443-
raise # bubble out after N tries
455+
raise err # bubble out after N tries
444456

445457
# Prefer the structured retry_delay field if present
446458
delay = getattr(err, "retry_delay", None)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ dependencies = [
1313
"ollama>=0.5.3,<1",
1414
"jsonlines>=4.0.0,<5",
1515
"configspace>=1.2.0,<2",
16-
"google-generativeai>=0.8.1,<0.9",
16+
"google-genai>=1,<2",
1717
"joblib>=1.4.2,<2",
1818
"lizard>=1.17.13,<2",
1919
"networkx>=3.4.2,<4",

tests/test_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,13 +219,13 @@ def test_gemini_llm_retries_then_succeeds(monkeypatch):
219219
chat_ok.send_message.return_value = type("R", (), {"text": "OK-DONE"})
220220

221221
fake_client = MagicMock()
222-
fake_client.start_chat.side_effect = [chat_fail, chat_ok]
222+
fake_client.chats.create.side_effect = [chat_fail, chat_ok]
223223
llm.client = fake_client
224224

225225
reply = llm.query([{"role": "user", "content": "hello"}], max_retries=3)
226226

227227
assert reply == "OK-DONE"
228-
assert fake_client.start_chat.call_count == 2 # 1 failure + 1 success
228+
assert fake_client.chats.create.call_count == 2 # 1 failure + 1 success
229229
slept.assert_called_once_with(3) # 2 s + 1 s safety buffer
230230

231231

@@ -240,7 +240,7 @@ def test_gemini_llm_gives_up_after_max_retries(monkeypatch):
240240
chat_fail.send_message.side_effect = _resource_exhausted(1)
241241

242242
fake_client = MagicMock()
243-
fake_client.start_chat.return_value = chat_fail
243+
fake_client.chats.create.return_value = chat_fail
244244
llm.client = fake_client
245245

246246
with pytest.raises(Exception):

0 commit comments

Comments
 (0)