24
24
FakeCallbackHandlerWithChatStart ,
25
25
)
26
26
27
- MODEL_NAME = "llama-3.3-70b-versatile"
27
+ DEFAULT_MODEL_NAME = "openai/gpt-oss-20b"
28
+ REASONING_MODEL_NAME = "deepseek-r1-distill-llama-70b"
28
29
29
30
30
31
#
34
35
def test_invoke () -> None :
35
36
"""Test Chat wrapper."""
36
37
chat = ChatGroq (
37
- model = MODEL_NAME ,
38
+ model = DEFAULT_MODEL_NAME ,
38
39
temperature = 0.7 ,
39
40
base_url = None ,
40
41
groq_proxy = None ,
@@ -55,7 +56,7 @@ def test_invoke() -> None:
55
56
@pytest .mark .scheduled
56
57
async def test_ainvoke () -> None :
57
58
"""Test ainvoke tokens from ChatGroq."""
58
- chat = ChatGroq (model = MODEL_NAME , max_tokens = 10 )
59
+ chat = ChatGroq (model = DEFAULT_MODEL_NAME , max_tokens = 10 )
59
60
60
61
result = await chat .ainvoke ("Welcome to the Groqetship!" , config = {"tags" : ["foo" ]})
61
62
assert isinstance (result , BaseMessage )
@@ -65,7 +66,7 @@ async def test_ainvoke() -> None:
65
66
@pytest .mark .scheduled
66
67
def test_batch () -> None :
67
68
"""Test batch tokens from ChatGroq."""
68
- chat = ChatGroq (model = MODEL_NAME , max_tokens = 10 )
69
+ chat = ChatGroq (model = DEFAULT_MODEL_NAME , max_tokens = 10 )
69
70
70
71
result = chat .batch (["Hello!" , "Welcome to the Groqetship!" ])
71
72
for token in result :
@@ -76,7 +77,7 @@ def test_batch() -> None:
76
77
@pytest .mark .scheduled
77
78
async def test_abatch () -> None :
78
79
"""Test abatch tokens from ChatGroq."""
79
- chat = ChatGroq (model = MODEL_NAME , max_tokens = 10 )
80
+ chat = ChatGroq (model = DEFAULT_MODEL_NAME , max_tokens = 10 )
80
81
81
82
result = await chat .abatch (["Hello!" , "Welcome to the Groqetship!" ])
82
83
for token in result :
@@ -87,7 +88,7 @@ async def test_abatch() -> None:
87
88
@pytest .mark .scheduled
88
89
async def test_stream () -> None :
89
90
"""Test streaming tokens from Groq."""
90
- chat = ChatGroq (model = MODEL_NAME , max_tokens = 10 )
91
+ chat = ChatGroq (model = DEFAULT_MODEL_NAME , max_tokens = 10 )
91
92
92
93
for token in chat .stream ("Welcome to the Groqetship!" ):
93
94
assert isinstance (token , BaseMessageChunk )
@@ -97,7 +98,7 @@ async def test_stream() -> None:
97
98
@pytest .mark .scheduled
98
99
async def test_astream () -> None :
99
100
"""Test streaming tokens from Groq."""
100
- chat = ChatGroq (model = MODEL_NAME , max_tokens = 10 )
101
+ chat = ChatGroq (model = DEFAULT_MODEL_NAME , max_tokens = 10 )
101
102
102
103
full : Optional [BaseMessageChunk ] = None
103
104
chunks_with_token_counts = 0
@@ -136,7 +137,7 @@ async def test_astream() -> None:
136
137
def test_generate () -> None :
137
138
"""Test sync generate."""
138
139
n = 1
139
- chat = ChatGroq (model = MODEL_NAME , max_tokens = 10 )
140
+ chat = ChatGroq (model = DEFAULT_MODEL_NAME , max_tokens = 10 )
140
141
message = HumanMessage (content = "Hello" , n = 1 )
141
142
response = chat .generate ([[message ], [message ]])
142
143
assert isinstance (response , LLMResult )
@@ -155,7 +156,7 @@ def test_generate() -> None:
155
156
async def test_agenerate () -> None :
156
157
"""Test async generation."""
157
158
n = 1
158
- chat = ChatGroq (model = MODEL_NAME , max_tokens = 10 , n = 1 )
159
+ chat = ChatGroq (model = DEFAULT_MODEL_NAME , max_tokens = 10 , n = 1 )
159
160
message = HumanMessage (content = "Hello" )
160
161
response = await chat .agenerate ([[message ], [message ]])
161
162
assert isinstance (response , LLMResult )
@@ -178,7 +179,7 @@ def test_invoke_streaming() -> None:
178
179
"""Test that streaming correctly invokes on_llm_new_token callback."""
179
180
callback_handler = FakeCallbackHandler ()
180
181
chat = ChatGroq (
181
- model = MODEL_NAME ,
182
+ model = DEFAULT_MODEL_NAME ,
182
183
max_tokens = 2 ,
183
184
streaming = True ,
184
185
temperature = 0 ,
@@ -195,7 +196,7 @@ async def test_agenerate_streaming() -> None:
195
196
"""Test that streaming correctly invokes on_llm_new_token callback."""
196
197
callback_handler = FakeCallbackHandlerWithChatStart ()
197
198
chat = ChatGroq (
198
- model = MODEL_NAME ,
199
+ model = DEFAULT_MODEL_NAME ,
199
200
max_tokens = 10 ,
200
201
streaming = True ,
201
202
temperature = 0 ,
@@ -222,7 +223,7 @@ async def test_agenerate_streaming() -> None:
222
223
def test_reasoning_output_invoke () -> None :
223
224
"""Test reasoning output from ChatGroq with invoke."""
224
225
chat = ChatGroq (
225
- model = "deepseek-r1-distill-llama-70b" ,
226
+ model = REASONING_MODEL_NAME ,
226
227
reasoning_format = "parsed" ,
227
228
)
228
229
message = [
@@ -241,7 +242,7 @@ def test_reasoning_output_invoke() -> None:
241
242
def test_reasoning_output_stream () -> None :
242
243
"""Test reasoning output from ChatGroq with stream."""
243
244
chat = ChatGroq (
244
- model = "deepseek-r1-distill-llama-70b" ,
245
+ model = REASONING_MODEL_NAME ,
245
246
reasoning_format = "parsed" ,
246
247
)
247
248
message = [
@@ -300,7 +301,7 @@ def on_llm_end(
300
301
301
302
callback = _FakeCallback ()
302
303
chat = ChatGroq (
303
- model = MODEL_NAME ,
304
+ model = "llama-3.1-8b-instant" , # Use a model that properly streams content
304
305
max_tokens = 2 ,
305
306
temperature = 0 ,
306
307
callbacks = [callback ],
@@ -314,7 +315,7 @@ def on_llm_end(
314
315
315
316
def test_system_message () -> None :
316
317
"""Test ChatGroq wrapper with system message."""
317
- chat = ChatGroq (model = MODEL_NAME , max_tokens = 10 )
318
+ chat = ChatGroq (model = DEFAULT_MODEL_NAME , max_tokens = 10 )
318
319
system_message = SystemMessage (content = "You are to chat with the user." )
319
320
human_message = HumanMessage (content = "Hello" )
320
321
response = chat .invoke ([system_message , human_message ])
@@ -324,15 +325,15 @@ def test_system_message() -> None:
324
325
325
326
def test_tool_choice () -> None :
326
327
"""Test that tool choice is respected."""
327
- llm = ChatGroq (model = MODEL_NAME )
328
+ llm = ChatGroq (model = DEFAULT_MODEL_NAME )
328
329
329
330
class MyTool (BaseModel ):
330
331
name : str
331
332
age : int
332
333
333
334
with_tool = llm .bind_tools ([MyTool ], tool_choice = "MyTool" )
334
335
335
- resp = with_tool .invoke ("Who was the 27 year old named Erick?" )
336
+ resp = with_tool .invoke ("Who was the 27 year old named Erick? Use the tool. " )
336
337
assert isinstance (resp , AIMessage )
337
338
assert resp .content == "" # should just be tool call
338
339
tool_calls = resp .additional_kwargs ["tool_calls" ]
@@ -354,15 +355,15 @@ class MyTool(BaseModel):
354
355
355
356
def test_tool_choice_bool () -> None :
356
357
"""Test that tool choice is respected just passing in True."""
357
- llm = ChatGroq (model = MODEL_NAME )
358
+ llm = ChatGroq (model = DEFAULT_MODEL_NAME )
358
359
359
360
class MyTool (BaseModel ):
360
361
name : str
361
362
age : int
362
363
363
364
with_tool = llm .bind_tools ([MyTool ], tool_choice = True )
364
365
365
- resp = with_tool .invoke ("Who was the 27 year old named Erick?" )
366
+ resp = with_tool .invoke ("Who was the 27 year old named Erick? Use the tool. " )
366
367
assert isinstance (resp , AIMessage )
367
368
assert resp .content == "" # should just be tool call
368
369
tool_calls = resp .additional_kwargs ["tool_calls" ]
@@ -379,7 +380,7 @@ class MyTool(BaseModel):
379
380
@pytest .mark .xfail (reason = "Groq tool_choice doesn't currently force a tool call" )
380
381
def test_streaming_tool_call () -> None :
381
382
"""Test that tool choice is respected."""
382
- llm = ChatGroq (model = MODEL_NAME )
383
+ llm = ChatGroq (model = DEFAULT_MODEL_NAME )
383
384
384
385
class MyTool (BaseModel ):
385
386
name : str
@@ -417,7 +418,7 @@ class MyTool(BaseModel):
417
418
@pytest .mark .xfail (reason = "Groq tool_choice doesn't currently force a tool call" )
418
419
async def test_astreaming_tool_call () -> None :
419
420
"""Test that tool choice is respected."""
420
- llm = ChatGroq (model = MODEL_NAME )
421
+ llm = ChatGroq (model = DEFAULT_MODEL_NAME )
421
422
422
423
class MyTool (BaseModel ):
423
424
name : str
@@ -462,7 +463,9 @@ class Joke(BaseModel):
462
463
setup : str = Field (description = "question to set up a joke" )
463
464
punchline : str = Field (description = "answer to resolve the joke" )
464
465
465
- chat = ChatGroq (model = MODEL_NAME ).with_structured_output (Joke , method = "json_mode" )
466
+ chat = ChatGroq (model = DEFAULT_MODEL_NAME ).with_structured_output (
467
+ Joke , method = "json_mode"
468
+ )
466
469
result = chat .invoke (
467
470
"Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys"
468
471
)
@@ -476,38 +479,38 @@ def test_setting_service_tier_class() -> None:
476
479
message = HumanMessage (content = "Welcome to the Groqetship" )
477
480
478
481
# Initialization
479
- chat = ChatGroq (model = MODEL_NAME , service_tier = "auto" )
482
+ chat = ChatGroq (model = DEFAULT_MODEL_NAME , service_tier = "auto" )
480
483
assert chat .service_tier == "auto"
481
484
response = chat .invoke ([message ])
482
485
assert isinstance (response , BaseMessage )
483
486
assert isinstance (response .content , str )
484
487
assert response .response_metadata .get ("service_tier" ) == "auto"
485
488
486
- chat = ChatGroq (model = MODEL_NAME , service_tier = "flex" )
489
+ chat = ChatGroq (model = DEFAULT_MODEL_NAME , service_tier = "flex" )
487
490
assert chat .service_tier == "flex"
488
491
response = chat .invoke ([message ])
489
492
assert response .response_metadata .get ("service_tier" ) == "flex"
490
493
491
- chat = ChatGroq (model = MODEL_NAME , service_tier = "on_demand" )
494
+ chat = ChatGroq (model = DEFAULT_MODEL_NAME , service_tier = "on_demand" )
492
495
assert chat .service_tier == "on_demand"
493
496
response = chat .invoke ([message ])
494
497
assert response .response_metadata .get ("service_tier" ) == "on_demand"
495
498
496
- chat = ChatGroq (model = MODEL_NAME )
499
+ chat = ChatGroq (model = DEFAULT_MODEL_NAME )
497
500
assert chat .service_tier == "on_demand"
498
501
response = chat .invoke ([message ])
499
502
assert response .response_metadata .get ("service_tier" ) == "on_demand"
500
503
501
504
with pytest .raises (ValueError ):
502
- ChatGroq (model = MODEL_NAME , service_tier = None ) # type: ignore[arg-type]
505
+ ChatGroq (model = DEFAULT_MODEL_NAME , service_tier = None ) # type: ignore[arg-type]
503
506
with pytest .raises (ValueError ):
504
- ChatGroq (model = MODEL_NAME , service_tier = "invalid" ) # type: ignore[arg-type]
507
+ ChatGroq (model = DEFAULT_MODEL_NAME , service_tier = "invalid" ) # type: ignore[arg-type]
505
508
506
509
507
510
def test_setting_service_tier_request () -> None :
508
511
"""Test setting service tier defined at request level."""
509
512
message = HumanMessage (content = "Welcome to the Groqetship" )
510
- chat = ChatGroq (model = MODEL_NAME )
513
+ chat = ChatGroq (model = DEFAULT_MODEL_NAME )
511
514
512
515
response = chat .invoke (
513
516
[message ],
@@ -537,7 +540,7 @@ def test_setting_service_tier_request() -> None:
537
540
538
541
# If an `invoke` call is made with no service tier, we fall back to the class level
539
542
# setting
540
- chat = ChatGroq (model = MODEL_NAME , service_tier = "auto" )
543
+ chat = ChatGroq (model = DEFAULT_MODEL_NAME , service_tier = "auto" )
541
544
response = chat .invoke (
542
545
[message ],
543
546
)
@@ -564,15 +567,15 @@ def test_setting_service_tier_request() -> None:
564
567
565
568
def test_setting_service_tier_streaming () -> None :
566
569
"""Test service tier settings for streaming calls."""
567
- chat = ChatGroq (model = MODEL_NAME , service_tier = "flex" )
570
+ chat = ChatGroq (model = DEFAULT_MODEL_NAME , service_tier = "flex" )
568
571
chunks = list (chat .stream ("Why is the sky blue?" , service_tier = "auto" ))
569
572
570
573
assert chunks [- 1 ].response_metadata .get ("service_tier" ) == "auto"
571
574
572
575
573
576
async def test_setting_service_tier_request_async () -> None :
574
577
"""Test async setting of service tier at the request level."""
575
- chat = ChatGroq (model = MODEL_NAME , service_tier = "flex" )
578
+ chat = ChatGroq (model = DEFAULT_MODEL_NAME , service_tier = "flex" )
576
579
response = await chat .ainvoke ("Hello!" , service_tier = "on_demand" )
577
580
578
581
assert response .response_metadata .get ("service_tier" ) == "on_demand"
0 commit comments