Skip to content

Commit d22e5ed

Browse files
authored
chore(genai): update model names in tests (#1187)
Use modern names as some previous ones are deprecated. Also parameterize some tests
1 parent e221d19 commit d22e5ed

File tree

9 files changed

+87
-55
lines changed

9 files changed

+87
-55
lines changed

libs/genai/tests/integration_tests/test_callbacks.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from typing import Any
22

3+
import pytest
34
from langchain_core.callbacks import BaseCallbackHandler
45
from langchain_core.outputs import LLMResult
56
from langchain_core.prompts import PromptTemplate
67

78
from langchain_google_genai import ChatGoogleGenerativeAI
89

10+
model_names = ["gemini-2.5-flash"]
11+
912

1013
class StreamingLLMCallbackHandler(BaseCallbackHandler):
1114
def __init__(self, **kwargs: Any) -> None:
@@ -20,10 +23,14 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
2023
self.generations.append(response.generations[0][0].text)
2124

2225

23-
def test_streaming_callback() -> None:
26+
@pytest.mark.parametrize(
27+
"model_name",
28+
model_names,
29+
)
30+
def test_streaming_callback(model_name: str) -> None:
2431
prompt_template = "Tell me details about the Company {name} with 2 bullet point?"
2532
cb = StreamingLLMCallbackHandler()
26-
llm = ChatGoogleGenerativeAI(model="models/gemini-2.0-flash-001", callbacks=[cb])
33+
llm = ChatGoogleGenerativeAI(model=model_name, callbacks=[cb])
2734
llm_chain = PromptTemplate.from_template(prompt_template) | llm
2835
for _t in llm_chain.stream({"name": "Google"}):
2936
pass

libs/genai/tests/integration_tests/test_function_call.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44

5+
import pytest
56
from langchain_core.messages import AIMessage
67
from langchain_core.tools import tool
78
from pydantic import BaseModel
@@ -10,8 +11,14 @@
1011
ChatGoogleGenerativeAI,
1112
)
1213

14+
model_names = ["gemini-2.5-flash"]
1315

14-
def test_function_call() -> None:
16+
17+
@pytest.mark.parametrize(
18+
"model_name",
19+
model_names,
20+
)
21+
def test_function_call(model_name: str) -> None:
1522
functions = [
1623
{
1724
"name": "get_weather",
@@ -29,9 +36,7 @@ def test_function_call() -> None:
2936
},
3037
}
3138
]
32-
llm = ChatGoogleGenerativeAI(model="models/gemini-2.0-flash-001").bind(
33-
functions=functions
34-
)
39+
llm = ChatGoogleGenerativeAI(model=model_name).bind(functions=functions)
3540
res = llm.invoke("what weather is today in san francisco?")
3641
assert res
3742
assert res.additional_kwargs
@@ -43,15 +48,17 @@ def test_function_call() -> None:
4348
assert "location" in arguments
4449

4550

46-
def test_tool_call() -> None:
51+
@pytest.mark.parametrize(
52+
"model_name",
53+
model_names,
54+
)
55+
def test_tool_call(model_name: str) -> None:
4756
@tool
4857
def search_tool(query: str) -> str:
4958
"""Searches the web for `query` and returns the result."""
5059
raise NotImplementedError
5160

52-
llm = ChatGoogleGenerativeAI(model="models/gemini-2.0-flash-001").bind(
53-
functions=[search_tool]
54-
)
61+
llm = ChatGoogleGenerativeAI(model=model_name).bind(functions=[search_tool])
5562
response = llm.invoke("weather in san francisco")
5663
assert isinstance(response, AIMessage)
5764
assert isinstance(response.content, str)
@@ -70,10 +77,12 @@ class MyModel(BaseModel):
7077
age: int
7178

7279

73-
def test_pydantic_call() -> None:
74-
llm = ChatGoogleGenerativeAI(model="models/gemini-2.0-flash-001").bind(
75-
functions=[MyModel]
76-
)
80+
@pytest.mark.parametrize(
81+
"model_name",
82+
model_names,
83+
)
84+
def test_pydantic_call(model_name: str) -> None:
85+
llm = ChatGoogleGenerativeAI(model=model_name).bind(functions=[MyModel])
7786
response = llm.invoke("my name is Erick and I am 27 years old")
7887
assert isinstance(response, AIMessage)
7988
assert isinstance(response.content, str)

libs/genai/tests/integration_tests/test_llms.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from langchain_google_genai import GoogleGenerativeAI, HarmBlockThreshold, HarmCategory
1212

13-
model_names = ["gemini-1.5-flash-latest"]
13+
model_names = ["gemini-2.5-flash"]
1414

1515

1616
@pytest.mark.parametrize(
@@ -45,27 +45,43 @@ def test_google_generativeai_generate(model_name: str) -> None:
4545
assert len(generation_info.get("usage_metadata", {})) > 0
4646

4747

48-
async def test_google_generativeai_agenerate() -> None:
49-
llm = GoogleGenerativeAI(temperature=0, model="models/gemini-2.0-flash-001")
48+
@pytest.mark.parametrize(
49+
"model_name",
50+
model_names,
51+
)
52+
async def test_google_generativeai_agenerate(model_name: str) -> None:
53+
llm = GoogleGenerativeAI(temperature=0, model=model_name)
5054
output = await llm.agenerate(["Please say foo:"])
5155
assert isinstance(output, LLMResult)
5256

5357

54-
def test_generativeai_stream() -> None:
55-
llm = GoogleGenerativeAI(temperature=0, model="gemini-1.5-flash-latest")
58+
@pytest.mark.parametrize(
59+
"model_name",
60+
model_names,
61+
)
62+
def test_generativeai_stream(model_name: str) -> None:
63+
llm = GoogleGenerativeAI(temperature=0, model=model_name)
5664
outputs = list(llm.stream("Please say foo:"))
5765
assert isinstance(outputs[0], str)
5866

5967

60-
def test_generativeai_get_num_tokens_gemini() -> None:
61-
llm = GoogleGenerativeAI(temperature=0, model="gemini-1.5-flash-latest")
68+
@pytest.mark.parametrize(
69+
"model_name",
70+
model_names,
71+
)
72+
def test_generativeai_get_num_tokens_gemini(model_name: str) -> None:
73+
llm = GoogleGenerativeAI(temperature=0, model=model_name)
6274
output = llm.get_num_tokens("How are you?")
6375
assert output == 4
6476

6577

66-
def test_safety_settings_gemini() -> None:
78+
@pytest.mark.parametrize(
79+
"model_name",
80+
model_names,
81+
)
82+
def test_safety_settings_gemini(model_name: str) -> None:
6783
# test with blocked prompt
68-
llm = GoogleGenerativeAI(temperature=0, model="gemini-1.5-flash-latest")
84+
llm = GoogleGenerativeAI(temperature=0, model=model_name)
6985
output = llm.generate(prompts=["how to make a bomb?"])
7086
assert isinstance(output, LLMResult)
7187
assert len(output.generations[0]) > 0
@@ -88,7 +104,7 @@ def test_safety_settings_gemini() -> None:
88104

89105
# test with safety filters on instantiation
90106
llm = GoogleGenerativeAI(
91-
model="gemini-1.5-flash-latest",
107+
model=model_name,
92108
safety_settings=safety_settings,
93109
temperature=0,
94110
)

libs/genai/tests/integration_tests/test_standard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def chat_model_class(self) -> type[BaseChatModel]:
2424
@property
2525
def chat_model_params(self) -> dict:
2626
return {
27-
"model": "models/gemini-2.5-flash",
27+
"model": "gemini-2.5-flash",
2828
"rate_limiter": rate_limiter,
2929
}
3030

@@ -107,7 +107,7 @@ def chat_model_class(self) -> type[BaseChatModel]:
107107
@property
108108
def chat_model_params(self) -> dict:
109109
return {
110-
"model": "models/gemini-2.5-pro",
110+
"model": "gemini-2.5-pro",
111111
"rate_limiter": rate_limiter,
112112
}
113113

libs/genai/tests/integration_tests/test_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_multiple_tools() -> None:
2525
tools = [check_weather, check_live_traffic, check_tennis_score]
2626

2727
model = ChatGoogleGenerativeAI(
28-
model="gemini-2.0-flash-001",
28+
model="gemini-2.5-flash",
2929
)
3030

3131
model_with_tools = model.bind_tools(tools)

libs/genai/tests/unit_tests/__snapshots__/test_standard.ambr

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
}),
1919
'max_output_tokens': 100,
2020
'max_retries': 2,
21-
'model': 'models/gemini-1.5-pro-001',
21+
'model': 'models/gemini-2.5-flash',
2222
'n': 1,
2323
'stop': list([
2424
]),
@@ -49,7 +49,7 @@
4949
}),
5050
'max_output_tokens': 100,
5151
'max_retries': 2,
52-
'model': 'models/gemini-1.0-pro-001',
52+
'model': 'models/gemini-2.5-flash',
5353
'n': 1,
5454
'stop': list([
5555
]),

libs/genai/tests/unit_tests/test_chat_models.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
def test_integration_initialization() -> None:
4545
"""Test chat model initialization."""
4646
llm = ChatGoogleGenerativeAI(
47-
model="gemini-nano",
47+
model="gemini-2.5-flash",
4848
google_api_key=SecretStr("..."),
4949
top_k=2,
5050
top_p=1,
@@ -54,27 +54,27 @@ def test_integration_initialization() -> None:
5454
ls_params = llm._get_ls_params()
5555
assert ls_params == {
5656
"ls_provider": "google_genai",
57-
"ls_model_name": "gemini-nano",
57+
"ls_model_name": "gemini-2.5-flash",
5858
"ls_model_type": "chat",
5959
"ls_temperature": 0.7,
6060
}
6161

6262
llm = ChatGoogleGenerativeAI(
63-
model="gemini-nano",
63+
model="gemini-2.5-flash",
6464
google_api_key=SecretStr("..."),
6565
max_output_tokens=10,
6666
)
6767
ls_params = llm._get_ls_params()
6868
assert ls_params == {
6969
"ls_provider": "google_genai",
70-
"ls_model_name": "gemini-nano",
70+
"ls_model_name": "gemini-2.5-flash",
7171
"ls_model_type": "chat",
7272
"ls_temperature": 0.7,
7373
"ls_max_tokens": 10,
7474
}
7575

7676
ChatGoogleGenerativeAI(
77-
model="gemini-nano",
77+
model="gemini-2.5-flash",
7878
api_key=SecretStr("..."),
7979
top_k=2,
8080
top_p=1,
@@ -86,13 +86,13 @@ def test_integration_initialization() -> None:
8686
with warnings.catch_warnings():
8787
warnings.simplefilter("ignore", UserWarning)
8888
llm = ChatGoogleGenerativeAI(
89-
model="gemini-nano",
89+
model="gemini-2.5-flash",
9090
google_api_key=SecretStr("..."),
9191
safety_setting={
9292
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_LOW_AND_ABOVE"
9393
}, # Invalid arg
9494
)
95-
assert llm.model == "models/gemini-nano"
95+
assert llm.model == "models/gemini-2.5-flash"
9696
mock_warning.assert_called_once()
9797
call_args = mock_warning.call_args[0][0]
9898
assert "Unexpected argument 'safety_setting'" in call_args
@@ -105,14 +105,14 @@ def test_initialization_inside_threadpool() -> None:
105105
with ThreadPoolExecutor() as executor:
106106
executor.submit(
107107
ChatGoogleGenerativeAI,
108-
model="gemini-nano",
108+
model="gemini-2.5-flash",
109109
google_api_key=SecretStr("secret-api-key"),
110110
).result()
111111

112112

113113
def test_initalization_without_async() -> None:
114114
chat = ChatGoogleGenerativeAI(
115-
model="gemini-nano",
115+
model="gemini-2.5-flash",
116116
google_api_key=SecretStr("secret-api-key"),
117117
)
118118
assert chat.async_client is None
@@ -121,7 +121,7 @@ def test_initalization_without_async() -> None:
121121
def test_initialization_with_async() -> None:
122122
async def initialize_chat_with_async_client() -> ChatGoogleGenerativeAI:
123123
model = ChatGoogleGenerativeAI(
124-
model="gemini-nano",
124+
model="gemini-2.5-flash",
125125
google_api_key=SecretStr("secret-api-key"),
126126
)
127127
_ = model.async_client
@@ -133,7 +133,7 @@ async def initialize_chat_with_async_client() -> ChatGoogleGenerativeAI:
133133

134134
def test_api_key_is_string() -> None:
135135
chat = ChatGoogleGenerativeAI(
136-
model="gemini-nano",
136+
model="gemini-2.5-flash",
137137
google_api_key=SecretStr("secret-api-key"),
138138
)
139139
assert isinstance(chat.google_api_key, SecretStr)
@@ -143,7 +143,7 @@ def test_api_key_masked_when_passed_via_constructor(
143143
capsys: pytest.CaptureFixture,
144144
) -> None:
145145
chat = ChatGoogleGenerativeAI(
146-
model="gemini-nano",
146+
model="gemini-2.5-flash",
147147
google_api_key=SecretStr("secret-api-key"),
148148
)
149149
print(chat.google_api_key, end="") # noqa: T201
@@ -349,7 +349,7 @@ def test_additional_headers_support(headers: Optional[dict[str, str]]) -> None:
349349
mock_client,
350350
):
351351
chat = ChatGoogleGenerativeAI(
352-
model="gemini-pro",
352+
model="gemini-2.5-flash",
353353
google_api_key=param_secret_api_key,
354354
client_options=param_client_options,
355355
transport=param_transport,
@@ -387,7 +387,7 @@ def test_default_metadata_field_alias() -> None:
387387
# error
388388
# This is the main issue: LangSmith Playground passes None to default_metadata_input
389389
chat1 = ChatGoogleGenerativeAI(
390-
model="gemini-pro",
390+
model="gemini-2.5-flash",
391391
google_api_key=SecretStr("test-key"),
392392
default_metadata_input=None,
393393
)
@@ -398,7 +398,7 @@ def test_default_metadata_field_alias() -> None:
398398
# Test with empty list for default_metadata_input (should not cause validation
399399
# error)
400400
chat2 = ChatGoogleGenerativeAI(
401-
model="gemini-pro",
401+
model="gemini-2.5-flash",
402402
google_api_key=SecretStr("test-key"),
403403
default_metadata_input=[],
404404
)
@@ -407,7 +407,7 @@ def test_default_metadata_field_alias() -> None:
407407

408408
# Test with tuple for default_metadata_input (should not cause validation error)
409409
chat3 = ChatGoogleGenerativeAI(
410-
model="gemini-pro",
410+
model="gemini-2.5-flash",
411411
google_api_key=SecretStr("test-key"),
412412
default_metadata_input=[("X-Test", "test")],
413413
)
@@ -716,7 +716,7 @@ def test_parse_response_candidate(raw_candidate: dict, expected: AIMessage) -> N
716716

717717

718718
def test_serialize() -> None:
719-
llm = ChatGoogleGenerativeAI(model="gemini-pro-1.5", google_api_key="test-key")
719+
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key="test-key")
720720
serialized = dumps(llm)
721721
llm_loaded = loads(
722722
serialized,

0 commit comments

Comments
 (0)