Skip to content

Commit 419c173

Browse files
authored
feat(groq): openai-oss (#32411)
use new openai-oss for integration tests, set module-level testing model names and improve robustness of tool tests
1 parent 4011257 commit 419c173

File tree

1 file changed

+35
-32
lines changed

1 file changed

+35
-32
lines changed

libs/partners/groq/tests/integration_tests/test_chat_models.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
FakeCallbackHandlerWithChatStart,
2525
)
2626

27-
MODEL_NAME = "llama-3.3-70b-versatile"
27+
DEFAULT_MODEL_NAME = "openai/gpt-oss-20b"
28+
REASONING_MODEL_NAME = "deepseek-r1-distill-llama-70b"
2829

2930

3031
#
@@ -34,7 +35,7 @@
3435
def test_invoke() -> None:
3536
"""Test Chat wrapper."""
3637
chat = ChatGroq(
37-
model=MODEL_NAME,
38+
model=DEFAULT_MODEL_NAME,
3839
temperature=0.7,
3940
base_url=None,
4041
groq_proxy=None,
@@ -55,7 +56,7 @@ def test_invoke() -> None:
5556
@pytest.mark.scheduled
5657
async def test_ainvoke() -> None:
5758
"""Test ainvoke tokens from ChatGroq."""
58-
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
59+
chat = ChatGroq(model=DEFAULT_MODEL_NAME, max_tokens=10)
5960

6061
result = await chat.ainvoke("Welcome to the Groqetship!", config={"tags": ["foo"]})
6162
assert isinstance(result, BaseMessage)
@@ -65,7 +66,7 @@ async def test_ainvoke() -> None:
6566
@pytest.mark.scheduled
6667
def test_batch() -> None:
6768
"""Test batch tokens from ChatGroq."""
68-
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
69+
chat = ChatGroq(model=DEFAULT_MODEL_NAME, max_tokens=10)
6970

7071
result = chat.batch(["Hello!", "Welcome to the Groqetship!"])
7172
for token in result:
@@ -76,7 +77,7 @@ def test_batch() -> None:
7677
@pytest.mark.scheduled
7778
async def test_abatch() -> None:
7879
"""Test abatch tokens from ChatGroq."""
79-
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
80+
chat = ChatGroq(model=DEFAULT_MODEL_NAME, max_tokens=10)
8081

8182
result = await chat.abatch(["Hello!", "Welcome to the Groqetship!"])
8283
for token in result:
@@ -87,7 +88,7 @@ async def test_abatch() -> None:
8788
@pytest.mark.scheduled
8889
async def test_stream() -> None:
8990
"""Test streaming tokens from Groq."""
90-
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
91+
chat = ChatGroq(model=DEFAULT_MODEL_NAME, max_tokens=10)
9192

9293
for token in chat.stream("Welcome to the Groqetship!"):
9394
assert isinstance(token, BaseMessageChunk)
@@ -97,7 +98,7 @@ async def test_stream() -> None:
9798
@pytest.mark.scheduled
9899
async def test_astream() -> None:
99100
"""Test streaming tokens from Groq."""
100-
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
101+
chat = ChatGroq(model=DEFAULT_MODEL_NAME, max_tokens=10)
101102

102103
full: Optional[BaseMessageChunk] = None
103104
chunks_with_token_counts = 0
@@ -136,7 +137,7 @@ async def test_astream() -> None:
136137
def test_generate() -> None:
137138
"""Test sync generate."""
138139
n = 1
139-
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
140+
chat = ChatGroq(model=DEFAULT_MODEL_NAME, max_tokens=10)
140141
message = HumanMessage(content="Hello", n=1)
141142
response = chat.generate([[message], [message]])
142143
assert isinstance(response, LLMResult)
@@ -155,7 +156,7 @@ def test_generate() -> None:
155156
async def test_agenerate() -> None:
156157
"""Test async generation."""
157158
n = 1
158-
chat = ChatGroq(model=MODEL_NAME, max_tokens=10, n=1)
159+
chat = ChatGroq(model=DEFAULT_MODEL_NAME, max_tokens=10, n=1)
159160
message = HumanMessage(content="Hello")
160161
response = await chat.agenerate([[message], [message]])
161162
assert isinstance(response, LLMResult)
@@ -178,7 +179,7 @@ def test_invoke_streaming() -> None:
178179
"""Test that streaming correctly invokes on_llm_new_token callback."""
179180
callback_handler = FakeCallbackHandler()
180181
chat = ChatGroq(
181-
model=MODEL_NAME,
182+
model=DEFAULT_MODEL_NAME,
182183
max_tokens=2,
183184
streaming=True,
184185
temperature=0,
@@ -195,7 +196,7 @@ async def test_agenerate_streaming() -> None:
195196
"""Test that streaming correctly invokes on_llm_new_token callback."""
196197
callback_handler = FakeCallbackHandlerWithChatStart()
197198
chat = ChatGroq(
198-
model=MODEL_NAME,
199+
model=DEFAULT_MODEL_NAME,
199200
max_tokens=10,
200201
streaming=True,
201202
temperature=0,
@@ -222,7 +223,7 @@ async def test_agenerate_streaming() -> None:
222223
def test_reasoning_output_invoke() -> None:
223224
"""Test reasoning output from ChatGroq with invoke."""
224225
chat = ChatGroq(
225-
model="deepseek-r1-distill-llama-70b",
226+
model=REASONING_MODEL_NAME,
226227
reasoning_format="parsed",
227228
)
228229
message = [
@@ -241,7 +242,7 @@ def test_reasoning_output_invoke() -> None:
241242
def test_reasoning_output_stream() -> None:
242243
"""Test reasoning output from ChatGroq with stream."""
243244
chat = ChatGroq(
244-
model="deepseek-r1-distill-llama-70b",
245+
model=REASONING_MODEL_NAME,
245246
reasoning_format="parsed",
246247
)
247248
message = [
@@ -300,7 +301,7 @@ def on_llm_end(
300301

301302
callback = _FakeCallback()
302303
chat = ChatGroq(
303-
model=MODEL_NAME,
304+
model="llama-3.1-8b-instant", # Use a model that properly streams content
304305
max_tokens=2,
305306
temperature=0,
306307
callbacks=[callback],
@@ -314,7 +315,7 @@ def on_llm_end(
314315

315316
def test_system_message() -> None:
316317
"""Test ChatGroq wrapper with system message."""
317-
chat = ChatGroq(model=MODEL_NAME, max_tokens=10)
318+
chat = ChatGroq(model=DEFAULT_MODEL_NAME, max_tokens=10)
318319
system_message = SystemMessage(content="You are to chat with the user.")
319320
human_message = HumanMessage(content="Hello")
320321
response = chat.invoke([system_message, human_message])
@@ -324,15 +325,15 @@ def test_system_message() -> None:
324325

325326
def test_tool_choice() -> None:
326327
"""Test that tool choice is respected."""
327-
llm = ChatGroq(model=MODEL_NAME)
328+
llm = ChatGroq(model=DEFAULT_MODEL_NAME)
328329

329330
class MyTool(BaseModel):
330331
name: str
331332
age: int
332333

333334
with_tool = llm.bind_tools([MyTool], tool_choice="MyTool")
334335

335-
resp = with_tool.invoke("Who was the 27 year old named Erick?")
336+
resp = with_tool.invoke("Who was the 27 year old named Erick? Use the tool.")
336337
assert isinstance(resp, AIMessage)
337338
assert resp.content == "" # should just be tool call
338339
tool_calls = resp.additional_kwargs["tool_calls"]
@@ -354,15 +355,15 @@ class MyTool(BaseModel):
354355

355356
def test_tool_choice_bool() -> None:
356357
"""Test that tool choice is respected just passing in True."""
357-
llm = ChatGroq(model=MODEL_NAME)
358+
llm = ChatGroq(model=DEFAULT_MODEL_NAME)
358359

359360
class MyTool(BaseModel):
360361
name: str
361362
age: int
362363

363364
with_tool = llm.bind_tools([MyTool], tool_choice=True)
364365

365-
resp = with_tool.invoke("Who was the 27 year old named Erick?")
366+
resp = with_tool.invoke("Who was the 27 year old named Erick? Use the tool.")
366367
assert isinstance(resp, AIMessage)
367368
assert resp.content == "" # should just be tool call
368369
tool_calls = resp.additional_kwargs["tool_calls"]
@@ -379,7 +380,7 @@ class MyTool(BaseModel):
379380
@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call")
380381
def test_streaming_tool_call() -> None:
381382
"""Test that tool choice is respected."""
382-
llm = ChatGroq(model=MODEL_NAME)
383+
llm = ChatGroq(model=DEFAULT_MODEL_NAME)
383384

384385
class MyTool(BaseModel):
385386
name: str
@@ -417,7 +418,7 @@ class MyTool(BaseModel):
417418
@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call")
418419
async def test_astreaming_tool_call() -> None:
419420
"""Test that tool choice is respected."""
420-
llm = ChatGroq(model=MODEL_NAME)
421+
llm = ChatGroq(model=DEFAULT_MODEL_NAME)
421422

422423
class MyTool(BaseModel):
423424
name: str
@@ -462,7 +463,9 @@ class Joke(BaseModel):
462463
setup: str = Field(description="question to set up a joke")
463464
punchline: str = Field(description="answer to resolve the joke")
464465

465-
chat = ChatGroq(model=MODEL_NAME).with_structured_output(Joke, method="json_mode")
466+
chat = ChatGroq(model=DEFAULT_MODEL_NAME).with_structured_output(
467+
Joke, method="json_mode"
468+
)
466469
result = chat.invoke(
467470
"Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys"
468471
)
@@ -476,38 +479,38 @@ def test_setting_service_tier_class() -> None:
476479
message = HumanMessage(content="Welcome to the Groqetship")
477480

478481
# Initialization
479-
chat = ChatGroq(model=MODEL_NAME, service_tier="auto")
482+
chat = ChatGroq(model=DEFAULT_MODEL_NAME, service_tier="auto")
480483
assert chat.service_tier == "auto"
481484
response = chat.invoke([message])
482485
assert isinstance(response, BaseMessage)
483486
assert isinstance(response.content, str)
484487
assert response.response_metadata.get("service_tier") == "auto"
485488

486-
chat = ChatGroq(model=MODEL_NAME, service_tier="flex")
489+
chat = ChatGroq(model=DEFAULT_MODEL_NAME, service_tier="flex")
487490
assert chat.service_tier == "flex"
488491
response = chat.invoke([message])
489492
assert response.response_metadata.get("service_tier") == "flex"
490493

491-
chat = ChatGroq(model=MODEL_NAME, service_tier="on_demand")
494+
chat = ChatGroq(model=DEFAULT_MODEL_NAME, service_tier="on_demand")
492495
assert chat.service_tier == "on_demand"
493496
response = chat.invoke([message])
494497
assert response.response_metadata.get("service_tier") == "on_demand"
495498

496-
chat = ChatGroq(model=MODEL_NAME)
499+
chat = ChatGroq(model=DEFAULT_MODEL_NAME)
497500
assert chat.service_tier == "on_demand"
498501
response = chat.invoke([message])
499502
assert response.response_metadata.get("service_tier") == "on_demand"
500503

501504
with pytest.raises(ValueError):
502-
ChatGroq(model=MODEL_NAME, service_tier=None) # type: ignore[arg-type]
505+
ChatGroq(model=DEFAULT_MODEL_NAME, service_tier=None) # type: ignore[arg-type]
503506
with pytest.raises(ValueError):
504-
ChatGroq(model=MODEL_NAME, service_tier="invalid") # type: ignore[arg-type]
507+
ChatGroq(model=DEFAULT_MODEL_NAME, service_tier="invalid") # type: ignore[arg-type]
505508

506509

507510
def test_setting_service_tier_request() -> None:
508511
"""Test setting service tier defined at request level."""
509512
message = HumanMessage(content="Welcome to the Groqetship")
510-
chat = ChatGroq(model=MODEL_NAME)
513+
chat = ChatGroq(model=DEFAULT_MODEL_NAME)
511514

512515
response = chat.invoke(
513516
[message],
@@ -537,7 +540,7 @@ def test_setting_service_tier_request() -> None:
537540

538541
# If an `invoke` call is made with no service tier, we fall back to the class level
539542
# setting
540-
chat = ChatGroq(model=MODEL_NAME, service_tier="auto")
543+
chat = ChatGroq(model=DEFAULT_MODEL_NAME, service_tier="auto")
541544
response = chat.invoke(
542545
[message],
543546
)
@@ -564,15 +567,15 @@ def test_setting_service_tier_request() -> None:
564567

565568
def test_setting_service_tier_streaming() -> None:
566569
"""Test service tier settings for streaming calls."""
567-
chat = ChatGroq(model=MODEL_NAME, service_tier="flex")
570+
chat = ChatGroq(model=DEFAULT_MODEL_NAME, service_tier="flex")
568571
chunks = list(chat.stream("Why is the sky blue?", service_tier="auto"))
569572

570573
assert chunks[-1].response_metadata.get("service_tier") == "auto"
571574

572575

573576
async def test_setting_service_tier_request_async() -> None:
574577
"""Test async setting of service tier at the request level."""
575-
chat = ChatGroq(model=MODEL_NAME, service_tier="flex")
578+
chat = ChatGroq(model=DEFAULT_MODEL_NAME, service_tier="flex")
576579
response = await chat.ainvoke("Hello!", service_tier="on_demand")
577580

578581
assert response.response_metadata.get("service_tier") == "on_demand"

0 commit comments

Comments
 (0)