2
2
3
3
from __future__ import annotations
4
4
5
+ import contextlib
5
6
import json
6
7
import logging
7
8
from collections .abc import AsyncIterator , Iterator , Mapping , Sequence
16
17
cast ,
17
18
)
18
19
19
- from fireworks .client import AsyncFireworks , Fireworks # type: ignore
20
+ from fireworks .client import AsyncFireworks , Fireworks # type: ignore[import-untyped]
20
21
from langchain_core ._api import deprecated
21
22
from langchain_core .callbacks import (
22
23
AsyncCallbackManagerForLLMRun ,
@@ -94,11 +95,12 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
94
95
95
96
Returns:
96
97
The LangChain message.
98
+
97
99
"""
98
100
role = _dict .get ("role" )
99
101
if role == "user" :
100
102
return HumanMessage (content = _dict .get ("content" , "" ))
101
- elif role == "assistant" :
103
+ if role == "assistant" :
102
104
# Fix for azure
103
105
# Also Fireworks returns None for tool invocations
104
106
content = _dict .get ("content" , "" ) or ""
@@ -122,13 +124,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
122
124
tool_calls = tool_calls ,
123
125
invalid_tool_calls = invalid_tool_calls ,
124
126
)
125
- elif role == "system" :
127
+ if role == "system" :
126
128
return SystemMessage (content = _dict .get ("content" , "" ))
127
- elif role == "function" :
129
+ if role == "function" :
128
130
return FunctionMessage (
129
131
content = _dict .get ("content" , "" ), name = _dict .get ("name" , "" )
130
132
)
131
- elif role == "tool" :
133
+ if role == "tool" :
132
134
additional_kwargs = {}
133
135
if "name" in _dict :
134
136
additional_kwargs ["name" ] = _dict ["name" ]
@@ -137,8 +139,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
137
139
tool_call_id = _dict .get ("tool_call_id" , "" ),
138
140
additional_kwargs = additional_kwargs ,
139
141
)
140
- else :
141
- return ChatMessage (content = _dict .get ("content" , "" ), role = role or "" )
142
+ return ChatMessage (content = _dict .get ("content" , "" ), role = role or "" )
142
143
143
144
144
145
def _convert_message_to_dict (message : BaseMessage ) -> dict :
@@ -149,6 +150,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
149
150
150
151
Returns:
151
152
The dictionary.
153
+
152
154
"""
153
155
message_dict : dict [str , Any ]
154
156
if isinstance (message , ChatMessage ):
@@ -191,7 +193,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
191
193
"tool_call_id" : message .tool_call_id ,
192
194
}
193
195
else :
194
- raise TypeError (f"Got unknown type { message } " )
196
+ msg = f"Got unknown type { message } "
197
+ raise TypeError (msg )
195
198
if "name" in message .additional_kwargs :
196
199
message_dict ["name" ] = message .additional_kwargs ["name" ]
197
200
return message_dict
@@ -214,7 +217,7 @@ def _convert_chunk_to_message_chunk(
214
217
if raw_tool_calls := _dict .get ("tool_calls" ):
215
218
additional_kwargs ["tool_calls" ] = raw_tool_calls
216
219
for rtc in raw_tool_calls :
217
- try :
220
+ with contextlib . suppress ( KeyError ) :
218
221
tool_call_chunks .append (
219
222
create_tool_call_chunk (
220
223
name = rtc ["function" ].get ("name" ),
@@ -223,11 +226,9 @@ def _convert_chunk_to_message_chunk(
223
226
index = rtc .get ("index" ),
224
227
)
225
228
)
226
- except KeyError :
227
- pass
228
229
if role == "user" or default_class == HumanMessageChunk :
229
230
return HumanMessageChunk (content = content )
230
- elif role == "assistant" or default_class == AIMessageChunk :
231
+ if role == "assistant" or default_class == AIMessageChunk :
231
232
if usage := chunk .get ("usage" ):
232
233
input_tokens = usage .get ("prompt_tokens" , 0 )
233
234
output_tokens = usage .get ("completion_tokens" , 0 )
@@ -244,16 +245,15 @@ def _convert_chunk_to_message_chunk(
244
245
tool_call_chunks = tool_call_chunks ,
245
246
usage_metadata = usage_metadata , # type: ignore[arg-type]
246
247
)
247
- elif role == "system" or default_class == SystemMessageChunk :
248
+ if role == "system" or default_class == SystemMessageChunk :
248
249
return SystemMessageChunk (content = content )
249
- elif role == "function" or default_class == FunctionMessageChunk :
250
+ if role == "function" or default_class == FunctionMessageChunk :
250
251
return FunctionMessageChunk (content = content , name = _dict ["name" ])
251
- elif role == "tool" or default_class == ToolMessageChunk :
252
+ if role == "tool" or default_class == ToolMessageChunk :
252
253
return ToolMessageChunk (content = content , tool_call_id = _dict ["tool_call_id" ])
253
- elif role or default_class == ChatMessageChunk :
254
+ if role or default_class == ChatMessageChunk :
254
255
return ChatMessageChunk (content = content , role = role )
255
- else :
256
- return default_class (content = content ) # type: ignore
256
+ return default_class (content = content ) # type: ignore[call-arg]
257
257
258
258
259
259
class _FunctionCall (TypedDict ):
@@ -280,6 +280,7 @@ class ChatFireworks(BaseChatModel):
280
280
from langchain_fireworks.chat_models import ChatFireworks
281
281
fireworks = ChatFireworks(
282
282
model_name="accounts/fireworks/models/llama-v3p1-8b-instruct")
283
+
283
284
"""
284
285
285
286
@property
@@ -326,14 +327,14 @@ def is_lc_serializable(cls) -> bool:
326
327
),
327
328
)
328
329
"""Fireworks API key.
329
-
330
+
330
331
Automatically read from env variable ``FIREWORKS_API_KEY`` if not provided.
331
332
"""
332
333
333
334
fireworks_api_base : Optional [str ] = Field (
334
335
alias = "base_url" , default_factory = from_env ("FIREWORKS_API_BASE" , default = None )
335
336
)
336
- """Base URL path for API requests, leave blank if not using a proxy or service
337
+ """Base URL path for API requests, leave blank if not using a proxy or service
337
338
emulator."""
338
339
request_timeout : Union [float , tuple [float , float ], Any , None ] = Field (
339
340
default = None , alias = "timeout"
@@ -358,16 +359,17 @@ def is_lc_serializable(cls) -> bool:
358
359
def build_extra (cls , values : dict [str , Any ]) -> Any :
359
360
"""Build extra kwargs from additional params that were passed in."""
360
361
all_required_field_names = get_pydantic_field_names (cls )
361
- values = _build_model_kwargs (values , all_required_field_names )
362
- return values
362
+ return _build_model_kwargs (values , all_required_field_names )
363
363
364
364
@model_validator (mode = "after" )
365
365
def validate_environment (self ) -> Self :
366
366
"""Validate that api key and python package exists in environment."""
367
367
if self .n < 1 :
368
- raise ValueError ("n must be at least 1." )
368
+ msg = "n must be at least 1."
369
+ raise ValueError (msg )
369
370
if self .n > 1 and self .streaming :
370
- raise ValueError ("n must be 1 when streaming." )
371
+ msg = "n must be 1 when streaming."
372
+ raise ValueError (msg )
371
373
372
374
client_params = {
373
375
"api_key" : (
@@ -522,7 +524,7 @@ def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
522
524
"output_tokens" : token_usage .get ("completion_tokens" , 0 ),
523
525
"total_tokens" : token_usage .get ("total_tokens" , 0 ),
524
526
}
525
- generation_info = dict ( finish_reason = res .get ("finish_reason" ))
527
+ generation_info = { " finish_reason" : res .get ("finish_reason" )}
526
528
if "logprobs" in res :
527
529
generation_info ["logprobs" ] = res ["logprobs" ]
528
530
gen = ChatGeneration (
@@ -628,7 +630,7 @@ def bind_functions(
628
630
self ,
629
631
functions : Sequence [Union [dict [str , Any ], type [BaseModel ], Callable , BaseTool ]],
630
632
function_call : Optional [
631
- Union [_FunctionCall , str , Literal ["auto" , "none" ]]
633
+ Union [_FunctionCall , str , Literal ["auto" , "none" ]] # noqa: PYI051
632
634
] = None ,
633
635
** kwargs : Any ,
634
636
) -> Runnable [LanguageModelInput , BaseMessage ]:
@@ -651,8 +653,8 @@ def bind_functions(
651
653
(if any).
652
654
**kwargs: Any additional parameters to pass to the
653
655
:class:`~langchain.runnable.Runnable` constructor.
654
- """
655
656
657
+ """
656
658
formatted_functions = [convert_to_openai_function (fn ) for fn in functions ]
657
659
if function_call is not None :
658
660
function_call = (
@@ -662,18 +664,20 @@ def bind_functions(
662
664
else function_call
663
665
)
664
666
if isinstance (function_call , dict ) and len (formatted_functions ) != 1 :
665
- raise ValueError (
667
+ msg = (
666
668
"When specifying `function_call`, you must provide exactly one "
667
669
"function."
668
670
)
671
+ raise ValueError (msg )
669
672
if (
670
673
isinstance (function_call , dict )
671
674
and formatted_functions [0 ]["name" ] != function_call ["name" ]
672
675
):
673
- raise ValueError (
676
+ msg = (
674
677
f"Function call { function_call } was specified, but the only "
675
678
f"provided function was { formatted_functions [0 ]['name' ]} ."
676
679
)
680
+ raise ValueError (msg )
677
681
kwargs = {** kwargs , "function_call" : function_call }
678
682
return super ().bind (
679
683
functions = formatted_functions ,
@@ -685,7 +689,7 @@ def bind_tools(
685
689
tools : Sequence [Union [dict [str , Any ], type [BaseModel ], Callable , BaseTool ]],
686
690
* ,
687
691
tool_choice : Optional [
688
- Union [dict , str , Literal ["auto" , "any" , "none" ], bool ]
692
+ Union [dict , str , Literal ["auto" , "any" , "none" ], bool ] # noqa: PYI051
689
693
] = None ,
690
694
** kwargs : Any ,
691
695
) -> Runnable [LanguageModelInput , BaseMessage ]:
@@ -705,8 +709,8 @@ def bind_tools(
705
709
``{"type": "function", "function": {"name": <<tool_name>>}}``.
706
710
**kwargs: Any additional parameters to pass to
707
711
:meth:`~langchain_fireworks.chat_models.ChatFireworks.bind`
708
- """
709
712
713
+ """
710
714
formatted_tools = [convert_to_openai_tool (tool ) for tool in tools ]
711
715
if tool_choice is not None and tool_choice :
712
716
if isinstance (tool_choice , str ) and (
@@ -715,10 +719,11 @@ def bind_tools(
715
719
tool_choice = {"type" : "function" , "function" : {"name" : tool_choice }}
716
720
if isinstance (tool_choice , bool ):
717
721
if len (tools ) > 1 :
718
- raise ValueError (
722
+ msg = (
719
723
"tool_choice can only be True when there is one tool. Received "
720
724
f"{ len (tools )} tools."
721
725
)
726
+ raise ValueError (msg )
722
727
tool_name = formatted_tools [0 ]["function" ]["name" ]
723
728
tool_choice = {
724
729
"type" : "function" ,
@@ -779,6 +784,9 @@ def with_structured_output(
779
784
will be caught and returned as well. The final output is always a dict
780
785
with keys "raw", "parsed", and "parsing_error".
781
786
787
+ kwargs:
788
+ Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor.
789
+
782
790
Returns:
783
791
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
784
792
@@ -964,17 +972,20 @@ class AnswerWithJustification(BaseModel):
964
972
# },
965
973
# 'parsing_error': None
966
974
# }
975
+
967
976
""" # noqa: E501
968
977
_ = kwargs .pop ("strict" , None )
969
978
if kwargs :
970
- raise ValueError (f"Received unsupported arguments { kwargs } " )
979
+ msg = f"Received unsupported arguments { kwargs } "
980
+ raise ValueError (msg )
971
981
is_pydantic_schema = _is_pydantic_class (schema )
972
982
if method == "function_calling" :
973
983
if schema is None :
974
- raise ValueError (
984
+ msg = (
975
985
"schema must be specified when method is 'function_calling'. "
976
986
"Received None."
977
987
)
988
+ raise ValueError (msg )
978
989
formatted_tool = convert_to_openai_tool (schema )
979
990
tool_name = formatted_tool ["function" ]["name" ]
980
991
llm = self .bind_tools (
@@ -996,10 +1007,11 @@ class AnswerWithJustification(BaseModel):
996
1007
)
997
1008
elif method == "json_schema" :
998
1009
if schema is None :
999
- raise ValueError (
1010
+ msg = (
1000
1011
"schema must be specified when method is 'json_schema'. "
1001
1012
"Received None."
1002
1013
)
1014
+ raise ValueError (msg )
1003
1015
formatted_schema = convert_to_json_schema (schema )
1004
1016
llm = self .bind (
1005
1017
response_format = {"type" : "json_object" , "schema" : formatted_schema },
@@ -1027,10 +1039,11 @@ class AnswerWithJustification(BaseModel):
1027
1039
else JsonOutputParser ()
1028
1040
)
1029
1041
else :
1030
- raise ValueError (
1042
+ msg = (
1031
1043
f"Unrecognized method argument. Expected one of 'function_calling' or "
1032
1044
f"'json_mode'. Received: '{ method } '"
1033
1045
)
1046
+ raise ValueError (msg )
1034
1047
1035
1048
if include_raw :
1036
1049
parser_assign = RunnablePassthrough .assign (
@@ -1041,8 +1054,7 @@ class AnswerWithJustification(BaseModel):
1041
1054
[parser_none ], exception_key = "parsing_error"
1042
1055
)
1043
1056
return RunnableMap (raw = llm ) | parser_with_fallback
1044
- else :
1045
- return llm | output_parser
1057
+ return llm | output_parser
1046
1058
1047
1059
1048
1060
def _is_pydantic_class (obj : Any ) -> bool :
0 commit comments