Skip to content

Commit 911b0b6

Browse files
authored
groq: Add service tier option to ChatGroq (#31801)
- Allows users to select a [flex processing](https://console.groq.com/docs/flex-processing) service tier
1 parent 10ec5c8 commit 911b0b6

File tree

3 files changed

+137
-3
lines changed

3 files changed

+137
-3
lines changed

libs/partners/groq/langchain_groq/chat_models.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,21 @@ class Joke(BaseModel):
375375
"""Number of chat completions to generate for each prompt."""
376376
max_tokens: Optional[int] = None
377377
"""Maximum number of tokens to generate."""
378+
service_tier: Literal["on_demand", "flex", "auto"] = Field(default="on_demand")
379+
"""Optional parameter that you can include to specify the service tier you'd like to
380+
use for requests.
381+
382+
- ``'on_demand'``: Default.
383+
- ``'flex'``: On-demand processing when capacity is available, with rapid timeouts
384+
if resources are constrained. Provides balance between performance and reliability
385+
for workloads that don't require guaranteed processing.
386+
- ``'auto'``: Uses on-demand rate limits, then falls back to ``'flex'`` if those
387+
limits are exceeded
388+
389+
See the `Groq documentation
390+
<https://console.groq.com/docs/flex-processing>`__ for more details and a list of
391+
service tiers and descriptions.
392+
"""
378393
default_headers: Union[Mapping[str, str], None] = None
379394
default_query: Union[Mapping[str, object], None] = None
380395
# Configure a custom httpx client. See the
@@ -534,7 +549,7 @@ def _generate(
534549
**kwargs,
535550
}
536551
response = self.client.create(messages=message_dicts, **params)
537-
return self._create_chat_result(response)
552+
return self._create_chat_result(response, params)
538553

539554
async def _agenerate(
540555
self,
@@ -555,7 +570,7 @@ async def _agenerate(
555570
**kwargs,
556571
}
557572
response = await self.async_client.create(messages=message_dicts, **params)
558-
return self._create_chat_result(response)
573+
return self._create_chat_result(response, params)
559574

560575
def _stream(
561576
self,
@@ -582,6 +597,8 @@ def _stream(
582597
generation_info["model_name"] = self.model_name
583598
if system_fingerprint := chunk.get("system_fingerprint"):
584599
generation_info["system_fingerprint"] = system_fingerprint
600+
service_tier = params.get("service_tier") or self.service_tier
601+
generation_info["service_tier"] = service_tier
585602
logprobs = choice.get("logprobs")
586603
if logprobs:
587604
generation_info["logprobs"] = logprobs
@@ -623,6 +640,8 @@ async def _astream(
623640
generation_info["model_name"] = self.model_name
624641
if system_fingerprint := chunk.get("system_fingerprint"):
625642
generation_info["system_fingerprint"] = system_fingerprint
643+
service_tier = params.get("service_tier") or self.service_tier
644+
generation_info["service_tier"] = service_tier
626645
logprobs = choice.get("logprobs")
627646
if logprobs:
628647
generation_info["logprobs"] = logprobs
@@ -653,13 +672,16 @@ def _default_params(self) -> dict[str, Any]:
653672
"stop": self.stop,
654673
"reasoning_format": self.reasoning_format,
655674
"reasoning_effort": self.reasoning_effort,
675+
"service_tier": self.service_tier,
656676
**self.model_kwargs,
657677
}
658678
if self.max_tokens is not None:
659679
params["max_tokens"] = self.max_tokens
660680
return params
661681

662-
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
682+
def _create_chat_result(
683+
self, response: Union[dict, BaseModel], params: dict
684+
) -> ChatResult:
663685
generations = []
664686
if not isinstance(response, dict):
665687
response = response.model_dump()
@@ -689,6 +711,7 @@ def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
689711
"model_name": self.model_name,
690712
"system_fingerprint": response.get("system_fingerprint", ""),
691713
}
714+
llm_output["service_tier"] = params.get("service_tier") or self.service_tier
692715
return ChatResult(generations=generations, llm_output=llm_output)
693716

694717
def _create_message_dicts(
@@ -719,6 +742,8 @@ def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
719742
combined = {"token_usage": overall_token_usage, "model_name": self.model_name}
720743
if system_fingerprint:
721744
combined["system_fingerprint"] = system_fingerprint
745+
if self.service_tier:
746+
combined["service_tier"] = self.service_tier
722747
return combined
723748

724749
@deprecated(

libs/partners/groq/tests/integration_tests/test_chat_models.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Optional, cast
55

66
import pytest
7+
from groq import BadRequestError
78
from langchain_core.messages import (
89
AIMessage,
910
AIMessageChunk,
@@ -467,6 +468,113 @@ class Joke(BaseModel):
467468
assert len(result.punchline) != 0
468469

469470

471+
def test_setting_service_tier_class() -> None:
472+
"""Test setting service tier defined at ChatGroq level."""
473+
message = HumanMessage(content="Welcome to the Groqetship")
474+
475+
# Initialization
476+
chat = ChatGroq(model=MODEL_NAME, service_tier="auto")
477+
assert chat.service_tier == "auto"
478+
response = chat.invoke([message])
479+
assert isinstance(response, BaseMessage)
480+
assert isinstance(response.content, str)
481+
assert response.response_metadata.get("service_tier") == "auto"
482+
483+
chat = ChatGroq(model=MODEL_NAME, service_tier="flex")
484+
assert chat.service_tier == "flex"
485+
response = chat.invoke([message])
486+
assert response.response_metadata.get("service_tier") == "flex"
487+
488+
chat = ChatGroq(model=MODEL_NAME, service_tier="on_demand")
489+
assert chat.service_tier == "on_demand"
490+
response = chat.invoke([message])
491+
assert response.response_metadata.get("service_tier") == "on_demand"
492+
493+
chat = ChatGroq(model=MODEL_NAME)
494+
assert chat.service_tier == "on_demand"
495+
response = chat.invoke([message])
496+
assert response.response_metadata.get("service_tier") == "on_demand"
497+
498+
with pytest.raises(ValueError):
499+
ChatGroq(model=MODEL_NAME, service_tier=None) # type: ignore
500+
with pytest.raises(ValueError):
501+
ChatGroq(model=MODEL_NAME, service_tier="invalid") # type: ignore
502+
503+
504+
def test_setting_service_tier_request() -> None:
505+
"""Test setting service tier defined at request level."""
506+
message = HumanMessage(content="Welcome to the Groqetship")
507+
chat = ChatGroq(model=MODEL_NAME)
508+
509+
response = chat.invoke(
510+
[message],
511+
service_tier="auto",
512+
)
513+
assert isinstance(response, BaseMessage)
514+
assert isinstance(response.content, str)
515+
assert response.response_metadata.get("service_tier") == "auto"
516+
517+
response = chat.invoke(
518+
[message],
519+
service_tier="flex",
520+
)
521+
assert response.response_metadata.get("service_tier") == "flex"
522+
523+
response = chat.invoke(
524+
[message],
525+
service_tier="on_demand",
526+
)
527+
assert response.response_metadata.get("service_tier") == "on_demand"
528+
529+
assert chat.service_tier == "on_demand"
530+
response = chat.invoke(
531+
[message],
532+
)
533+
assert response.response_metadata.get("service_tier") == "on_demand"
534+
535+
# If an `invoke` call is made with no service tier, we fall back to the class level
536+
# setting
537+
chat = ChatGroq(model=MODEL_NAME, service_tier="auto")
538+
response = chat.invoke(
539+
[message],
540+
)
541+
assert response.response_metadata.get("service_tier") == "auto"
542+
543+
response = chat.invoke(
544+
[message],
545+
service_tier="on_demand",
546+
)
547+
assert response.response_metadata.get("service_tier") == "on_demand"
548+
549+
with pytest.raises(BadRequestError):
550+
response = chat.invoke(
551+
[message],
552+
service_tier="invalid",
553+
)
554+
555+
response = chat.invoke(
556+
[message],
557+
service_tier=None,
558+
)
559+
assert response.response_metadata.get("service_tier") == "auto"
560+
561+
562+
def test_setting_service_tier_streaming() -> None:
563+
"""Test service tier settings for streaming calls."""
564+
chat = ChatGroq(model=MODEL_NAME, service_tier="flex")
565+
chunks = list(chat.stream("Why is the sky blue?", service_tier="auto"))
566+
567+
assert chunks[-1].response_metadata.get("service_tier") == "auto"
568+
569+
570+
async def test_setting_service_tier_request_async() -> None:
571+
"""Test async setting of service tier at the request level."""
572+
chat = ChatGroq(model=MODEL_NAME, service_tier="flex")
573+
response = await chat.ainvoke("Hello!", service_tier="on_demand")
574+
575+
assert response.response_metadata.get("service_tier") == "on_demand"
576+
577+
470578
# Groq does not currently support N > 1
471579
# @pytest.mark.scheduled
472580
# def test_chat_multiple_completions() -> None:

libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
'model_name': 'llama-3.1-8b-instant',
2020
'n': 1,
2121
'request_timeout': 60.0,
22+
'service_tier': 'on_demand',
2223
'stop': list([
2324
]),
2425
'temperature': 1e-08,

0 commit comments

Comments
 (0)