@@ -375,6 +375,21 @@ class Joke(BaseModel):
375
375
"""Number of chat completions to generate for each prompt."""
376
376
max_tokens : Optional [int ] = None
377
377
"""Maximum number of tokens to generate."""
378
+ service_tier : Literal ["on_demand" , "flex" , "auto" ] = Field (default = "on_demand" )
379
+ """Optional parameter that you can include to specify the service tier you'd like to
380
+ use for requests.
381
+
382
+ - ``'on_demand'``: Default.
383
+ - ``'flex'``: On-demand processing when capacity is available, with rapid timeouts
384
+ if resources are constrained. Provides balance between performance and reliability
385
+ for workloads that don't require guaranteed processing.
386
+ - ``'auto'``: Uses on-demand rate limits, then falls back to ``'flex'`` if those
387
+ limits are exceeded
388
+
389
+ See the `Groq documentation
390
+ <https://console.groq.com/docs/flex-processing>`__ for more details and a list of
391
+ service tiers and descriptions.
392
+ """
378
393
default_headers : Union [Mapping [str , str ], None ] = None
379
394
default_query : Union [Mapping [str , object ], None ] = None
380
395
# Configure a custom httpx client. See the
@@ -534,7 +549,7 @@ def _generate(
534
549
** kwargs ,
535
550
}
536
551
response = self .client .create (messages = message_dicts , ** params )
537
- return self ._create_chat_result (response )
552
+ return self ._create_chat_result (response , params )
538
553
539
554
async def _agenerate (
540
555
self ,
@@ -555,7 +570,7 @@ async def _agenerate(
555
570
** kwargs ,
556
571
}
557
572
response = await self .async_client .create (messages = message_dicts , ** params )
558
- return self ._create_chat_result (response )
573
+ return self ._create_chat_result (response , params )
559
574
560
575
def _stream (
561
576
self ,
@@ -582,6 +597,8 @@ def _stream(
582
597
generation_info ["model_name" ] = self .model_name
583
598
if system_fingerprint := chunk .get ("system_fingerprint" ):
584
599
generation_info ["system_fingerprint" ] = system_fingerprint
600
+ service_tier = params .get ("service_tier" ) or self .service_tier
601
+ generation_info ["service_tier" ] = service_tier
585
602
logprobs = choice .get ("logprobs" )
586
603
if logprobs :
587
604
generation_info ["logprobs" ] = logprobs
@@ -623,6 +640,8 @@ async def _astream(
623
640
generation_info ["model_name" ] = self .model_name
624
641
if system_fingerprint := chunk .get ("system_fingerprint" ):
625
642
generation_info ["system_fingerprint" ] = system_fingerprint
643
+ service_tier = params .get ("service_tier" ) or self .service_tier
644
+ generation_info ["service_tier" ] = service_tier
626
645
logprobs = choice .get ("logprobs" )
627
646
if logprobs :
628
647
generation_info ["logprobs" ] = logprobs
@@ -653,13 +672,16 @@ def _default_params(self) -> dict[str, Any]:
653
672
"stop" : self .stop ,
654
673
"reasoning_format" : self .reasoning_format ,
655
674
"reasoning_effort" : self .reasoning_effort ,
675
+ "service_tier" : self .service_tier ,
656
676
** self .model_kwargs ,
657
677
}
658
678
if self .max_tokens is not None :
659
679
params ["max_tokens" ] = self .max_tokens
660
680
return params
661
681
662
- def _create_chat_result (self , response : Union [dict , BaseModel ]) -> ChatResult :
682
+ def _create_chat_result (
683
+ self , response : Union [dict , BaseModel ], params : dict
684
+ ) -> ChatResult :
663
685
generations = []
664
686
if not isinstance (response , dict ):
665
687
response = response .model_dump ()
@@ -689,6 +711,7 @@ def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
689
711
"model_name" : self .model_name ,
690
712
"system_fingerprint" : response .get ("system_fingerprint" , "" ),
691
713
}
714
+ llm_output ["service_tier" ] = params .get ("service_tier" ) or self .service_tier
692
715
return ChatResult (generations = generations , llm_output = llm_output )
693
716
694
717
def _create_message_dicts (
@@ -719,6 +742,8 @@ def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
719
742
combined = {"token_usage" : overall_token_usage , "model_name" : self .model_name }
720
743
if system_fingerprint :
721
744
combined ["system_fingerprint" ] = system_fingerprint
745
+ if self .service_tier :
746
+ combined ["service_tier" ] = self .service_tier
722
747
return combined
723
748
724
749
@deprecated (
0 commit comments