83
83
84
84
85
85
class ChatGroq (BaseChatModel ):
86
- """Groq Chat large language models API.
86
+ r """Groq Chat large language models API.
87
87
88
88
To use, you should have the
89
89
environment variable ``GROQ_API_KEY`` set with your API key.
@@ -412,7 +412,8 @@ def build_extra(cls, values: dict[str, Any]) -> Any:
412
412
extra = values .get ("model_kwargs" , {})
413
413
for field_name in list (values ):
414
414
if field_name in extra :
415
- raise ValueError (f"Found { field_name } supplied twice." )
415
+ msg = f"Found { field_name } supplied twice."
416
+ raise ValueError (msg )
416
417
if field_name not in all_required_field_names :
417
418
warnings .warn (
418
419
f"""WARNING! { field_name } is not default parameter.
@@ -423,10 +424,11 @@ def build_extra(cls, values: dict[str, Any]) -> Any:
423
424
424
425
invalid_model_kwargs = all_required_field_names .intersection (extra .keys ())
425
426
if invalid_model_kwargs :
426
- raise ValueError (
427
+ msg = (
427
428
f"Parameters { invalid_model_kwargs } should be specified explicitly. "
428
429
f"Instead they were passed in as part of `model_kwargs` parameter."
429
430
)
431
+ raise ValueError (msg )
430
432
431
433
values ["model_kwargs" ] = extra
432
434
return values
@@ -435,9 +437,11 @@ def build_extra(cls, values: dict[str, Any]) -> Any:
435
437
def validate_environment (self ) -> Self :
436
438
"""Validate that api key and python package exists in environment."""
437
439
if self .n < 1 :
438
- raise ValueError ("n must be at least 1." )
440
+ msg = "n must be at least 1."
441
+ raise ValueError (msg )
439
442
if self .n > 1 and self .streaming :
440
- raise ValueError ("n must be 1 when streaming." )
443
+ msg = "n must be 1 when streaming."
444
+ raise ValueError (msg )
441
445
if self .temperature == 0 :
442
446
self .temperature = 1e-8
443
447
@@ -470,10 +474,11 @@ def validate_environment(self) -> Self:
470
474
** client_params , ** async_specific
471
475
).chat .completions
472
476
except ImportError as exc :
473
- raise ImportError (
477
+ msg = (
474
478
"Could not import groq python package. "
475
479
"Please install it with `pip install groq`."
476
- ) from exc
480
+ )
481
+ raise ImportError (msg ) from exc
477
482
return self
478
483
479
484
#
@@ -680,7 +685,7 @@ def _default_params(self) -> dict[str, Any]:
680
685
return params
681
686
682
687
def _create_chat_result (
683
- self , response : Union [ dict , BaseModel ] , params : dict
688
+ self , response : dict | BaseModel , params : dict
684
689
) -> ChatResult :
685
690
generations = []
686
691
if not isinstance (response , dict ):
@@ -698,7 +703,7 @@ def _create_chat_result(
698
703
"total_tokens" , input_tokens + output_tokens
699
704
),
700
705
}
701
- generation_info = dict ( finish_reason = res .get ("finish_reason" ))
706
+ generation_info = { " finish_reason" : res .get ("finish_reason" )}
702
707
if "logprobs" in res :
703
708
generation_info ["logprobs" ] = res ["logprobs" ]
704
709
gen = ChatGeneration (
@@ -755,7 +760,7 @@ def bind_functions(
755
760
self ,
756
761
functions : Sequence [Union [dict [str , Any ], type [BaseModel ], Callable , BaseTool ]],
757
762
function_call : Optional [
758
- Union [_FunctionCall , str , Literal ["auto" , "none" ]]
763
+ Union [_FunctionCall , str , Literal ["auto" , "none" ]] # noqa: PYI051
759
764
] = None ,
760
765
** kwargs : Any ,
761
766
) -> Runnable [LanguageModelInput , BaseMessage ]:
@@ -777,8 +782,8 @@ def bind_functions(
777
782
(if any).
778
783
**kwargs: Any additional parameters to pass to
779
784
:meth:`~langchain_groq.chat_models.ChatGroq.bind`.
780
- """
781
785
786
+ """
782
787
formatted_functions = [convert_to_openai_function (fn ) for fn in functions ]
783
788
if function_call is not None :
784
789
function_call = (
@@ -788,18 +793,20 @@ def bind_functions(
788
793
else function_call
789
794
)
790
795
if isinstance (function_call , dict ) and len (formatted_functions ) != 1 :
791
- raise ValueError (
796
+ msg = (
792
797
"When specifying `function_call`, you must provide exactly one "
793
798
"function."
794
799
)
800
+ raise ValueError (msg )
795
801
if (
796
802
isinstance (function_call , dict )
797
803
and formatted_functions [0 ]["name" ] != function_call ["name" ]
798
804
):
799
- raise ValueError (
805
+ msg = (
800
806
f"Function call { function_call } was specified, but the only "
801
807
f"provided function was { formatted_functions [0 ]['name' ]} ."
802
808
)
809
+ raise ValueError (msg )
803
810
kwargs = {** kwargs , "function_call" : function_call }
804
811
return super ().bind (
805
812
functions = formatted_functions ,
@@ -811,7 +818,7 @@ def bind_tools(
811
818
tools : Sequence [Union [dict [str , Any ], type [BaseModel ], Callable , BaseTool ]],
812
819
* ,
813
820
tool_choice : Optional [
814
- Union [dict , str , Literal ["auto" , "any" , "none" ], bool ]
821
+ Union [dict , str , Literal ["auto" , "any" , "none" ], bool ] # noqa: PYI051
815
822
] = None ,
816
823
** kwargs : Any ,
817
824
) -> Runnable [LanguageModelInput , BaseMessage ]:
@@ -829,8 +836,8 @@ def bind_tools(
829
836
{"type": "function", "function": {"name": <<tool_name>>}}.
830
837
**kwargs: Any additional parameters to pass to the
831
838
:class:`~langchain.runnable.Runnable` constructor.
832
- """
833
839
840
+ """
834
841
formatted_tools = [convert_to_openai_tool (tool ) for tool in tools ]
835
842
if tool_choice is not None and tool_choice :
836
843
if tool_choice == "any" :
@@ -841,10 +848,11 @@ def bind_tools(
841
848
tool_choice = {"type" : "function" , "function" : {"name" : tool_choice }}
842
849
if isinstance (tool_choice , bool ):
843
850
if len (tools ) > 1 :
844
- raise ValueError (
851
+ msg = (
845
852
"tool_choice can only be True when there is one tool. Received "
846
853
f"{ len (tools )} tools."
847
854
)
855
+ raise ValueError (msg )
848
856
tool_name = formatted_tools [0 ]["function" ]["name" ]
849
857
tool_choice = {
850
858
"type" : "function" ,
@@ -861,8 +869,8 @@ def with_structured_output(
861
869
method : Literal ["function_calling" , "json_mode" ] = "function_calling" ,
862
870
include_raw : bool = False ,
863
871
** kwargs : Any ,
864
- ) -> Runnable [LanguageModelInput , Union [ dict , BaseModel ] ]:
865
- """Model wrapper that returns outputs formatted to match the given schema.
872
+ ) -> Runnable [LanguageModelInput , dict | BaseModel ]:
873
+ r """Model wrapper that returns outputs formatted to match the given schema.
866
874
867
875
Args:
868
876
schema:
@@ -895,6 +903,9 @@ def with_structured_output(
895
903
response will be returned. If an error occurs during output parsing it
896
904
will be caught and returned as well. The final output is always a dict
897
905
with keys "raw", "parsed", and "parsing_error".
906
+ kwargs:
907
+ Any additional parameters to pass to the
908
+ :class:`~langchain.runnable.Runnable` constructor.
898
909
899
910
Returns:
900
911
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
@@ -1075,21 +1086,24 @@ class AnswerWithJustification(BaseModel):
1075
1086
# },
1076
1087
# 'parsing_error': None
1077
1088
# }
1089
+
1078
1090
""" # noqa: E501
1079
1091
_ = kwargs .pop ("strict" , None )
1080
1092
if kwargs :
1081
- raise ValueError (f"Received unsupported arguments { kwargs } " )
1093
+ msg = f"Received unsupported arguments { kwargs } "
1094
+ raise ValueError (msg )
1082
1095
is_pydantic_schema = _is_pydantic_class (schema )
1083
1096
if method == "json_schema" :
1084
1097
# Some applications require that incompatible parameters (e.g., unsupported
1085
1098
# methods) be handled.
1086
1099
method = "function_calling"
1087
1100
if method == "function_calling" :
1088
1101
if schema is None :
1089
- raise ValueError (
1102
+ msg = (
1090
1103
"schema must be specified when method is 'function_calling'. "
1091
1104
"Received None."
1092
1105
)
1106
+ raise ValueError (msg )
1093
1107
formatted_tool = convert_to_openai_tool (schema )
1094
1108
tool_name = formatted_tool ["function" ]["name" ]
1095
1109
llm = self .bind_tools (
@@ -1123,10 +1137,11 @@ class AnswerWithJustification(BaseModel):
1123
1137
else JsonOutputParser ()
1124
1138
)
1125
1139
else :
1126
- raise ValueError (
1140
+ msg = (
1127
1141
f"Unrecognized method argument. Expected one of 'function_calling' or "
1128
1142
f"'json_mode'. Received: '{ method } '"
1129
1143
)
1144
+ raise ValueError (msg )
1130
1145
1131
1146
if include_raw :
1132
1147
parser_assign = RunnablePassthrough .assign (
@@ -1137,8 +1152,7 @@ class AnswerWithJustification(BaseModel):
1137
1152
[parser_none ], exception_key = "parsing_error"
1138
1153
)
1139
1154
return RunnableMap (raw = llm ) | parser_with_fallback
1140
- else :
1141
- return llm | output_parser
1155
+ return llm | output_parser
1142
1156
1143
1157
1144
1158
def _is_pydantic_class (obj : Any ) -> bool :
@@ -1160,6 +1174,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
1160
1174
1161
1175
Returns:
1162
1176
The dictionary.
1177
+
1163
1178
"""
1164
1179
message_dict : dict [str , Any ]
1165
1180
if isinstance (message , ChatMessage ):
@@ -1200,7 +1215,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
1200
1215
"tool_call_id" : message .tool_call_id ,
1201
1216
}
1202
1217
else :
1203
- raise TypeError (f"Got unknown type { message } " )
1218
+ msg = f"Got unknown type { message } "
1219
+ raise TypeError (msg )
1204
1220
if "name" in message .additional_kwargs :
1205
1221
message_dict ["name" ] = message .additional_kwargs ["name" ]
1206
1222
return message_dict
@@ -1224,7 +1240,7 @@ def _convert_chunk_to_message_chunk(
1224
1240
1225
1241
if role == "user" or default_class == HumanMessageChunk :
1226
1242
return HumanMessageChunk (content = content )
1227
- elif role == "assistant" or default_class == AIMessageChunk :
1243
+ if role == "assistant" or default_class == AIMessageChunk :
1228
1244
if reasoning := _dict .get ("reasoning" ):
1229
1245
additional_kwargs ["reasoning_content" ] = reasoning
1230
1246
if usage := (chunk .get ("x_groq" ) or {}).get ("usage" ):
@@ -1242,16 +1258,15 @@ def _convert_chunk_to_message_chunk(
1242
1258
additional_kwargs = additional_kwargs ,
1243
1259
usage_metadata = usage_metadata , # type: ignore[arg-type]
1244
1260
)
1245
- elif role == "system" or default_class == SystemMessageChunk :
1261
+ if role == "system" or default_class == SystemMessageChunk :
1246
1262
return SystemMessageChunk (content = content )
1247
- elif role == "function" or default_class == FunctionMessageChunk :
1263
+ if role == "function" or default_class == FunctionMessageChunk :
1248
1264
return FunctionMessageChunk (content = content , name = _dict ["name" ])
1249
- elif role == "tool" or default_class == ToolMessageChunk :
1265
+ if role == "tool" or default_class == ToolMessageChunk :
1250
1266
return ToolMessageChunk (content = content , tool_call_id = _dict ["tool_call_id" ])
1251
- elif role or default_class == ChatMessageChunk :
1267
+ if role or default_class == ChatMessageChunk :
1252
1268
return ChatMessageChunk (content = content , role = role )
1253
- else :
1254
- return default_class (content = content ) # type: ignore
1269
+ return default_class (content = content ) # type: ignore[call-arg]
1255
1270
1256
1271
1257
1272
def _convert_dict_to_message (_dict : Mapping [str , Any ]) -> BaseMessage :
@@ -1262,12 +1277,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
1262
1277
1263
1278
Returns:
1264
1279
The LangChain message.
1280
+
1265
1281
"""
1266
1282
id_ = _dict .get ("id" )
1267
1283
role = _dict .get ("role" )
1268
1284
if role == "user" :
1269
1285
return HumanMessage (content = _dict .get ("content" , "" ))
1270
- elif role == "assistant" :
1286
+ if role == "assistant" :
1271
1287
content = _dict .get ("content" , "" ) or ""
1272
1288
additional_kwargs : dict = {}
1273
1289
if reasoning := _dict .get ("reasoning" ):
@@ -1292,11 +1308,11 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
1292
1308
tool_calls = tool_calls ,
1293
1309
invalid_tool_calls = invalid_tool_calls ,
1294
1310
)
1295
- elif role == "system" :
1311
+ if role == "system" :
1296
1312
return SystemMessage (content = _dict .get ("content" , "" ))
1297
- elif role == "function" :
1313
+ if role == "function" :
1298
1314
return FunctionMessage (content = _dict .get ("content" , "" ), name = _dict .get ("name" )) # type: ignore[arg-type]
1299
- elif role == "tool" :
1315
+ if role == "tool" :
1300
1316
additional_kwargs = {}
1301
1317
if "name" in _dict :
1302
1318
additional_kwargs ["name" ] = _dict ["name" ]
@@ -1305,8 +1321,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
1305
1321
tool_call_id = _dict .get ("tool_call_id" ),
1306
1322
additional_kwargs = additional_kwargs ,
1307
1323
)
1308
- else :
1309
- return ChatMessage (content = _dict .get ("content" , "" ), role = role ) # type: ignore[arg-type]
1324
+ return ChatMessage (content = _dict .get ("content" , "" ), role = role ) # type: ignore[arg-type]
1310
1325
1311
1326
1312
1327
def _lc_tool_call_to_groq_tool_call (tool_call : ToolCall ) -> dict :
0 commit comments