Skip to content

Commit 06ab297

Browse files
authored
fireworks[patch]: ruff fixes and rules (#31903)
* bump ruff deps * add more thorough ruff rules * fix said rules
1 parent 63e3f2d commit 06ab297

File tree

12 files changed

+165
-92
lines changed

12 files changed

+165
-92
lines changed

libs/partners/fireworks/langchain_fireworks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from langchain_fireworks.version import __version__
55

66
__all__ = [
7-
"__version__",
87
"ChatFireworks",
98
"Fireworks",
109
"FireworksEmbeddings",
10+
"__version__",
1111
]

libs/partners/fireworks/langchain_fireworks/chat_models.py

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import contextlib
56
import json
67
import logging
78
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
@@ -16,7 +17,7 @@
1617
cast,
1718
)
1819

19-
from fireworks.client import AsyncFireworks, Fireworks # type: ignore
20+
from fireworks.client import AsyncFireworks, Fireworks # type: ignore[import-untyped]
2021
from langchain_core._api import deprecated
2122
from langchain_core.callbacks import (
2223
AsyncCallbackManagerForLLMRun,
@@ -94,11 +95,12 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
9495
9596
Returns:
9697
The LangChain message.
98+
9799
"""
98100
role = _dict.get("role")
99101
if role == "user":
100102
return HumanMessage(content=_dict.get("content", ""))
101-
elif role == "assistant":
103+
if role == "assistant":
102104
# Fix for azure
103105
# Also Fireworks returns None for tool invocations
104106
content = _dict.get("content", "") or ""
@@ -122,13 +124,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
122124
tool_calls=tool_calls,
123125
invalid_tool_calls=invalid_tool_calls,
124126
)
125-
elif role == "system":
127+
if role == "system":
126128
return SystemMessage(content=_dict.get("content", ""))
127-
elif role == "function":
129+
if role == "function":
128130
return FunctionMessage(
129131
content=_dict.get("content", ""), name=_dict.get("name", "")
130132
)
131-
elif role == "tool":
133+
if role == "tool":
132134
additional_kwargs = {}
133135
if "name" in _dict:
134136
additional_kwargs["name"] = _dict["name"]
@@ -137,8 +139,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
137139
tool_call_id=_dict.get("tool_call_id", ""),
138140
additional_kwargs=additional_kwargs,
139141
)
140-
else:
141-
return ChatMessage(content=_dict.get("content", ""), role=role or "")
142+
return ChatMessage(content=_dict.get("content", ""), role=role or "")
142143

143144

144145
def _convert_message_to_dict(message: BaseMessage) -> dict:
@@ -149,6 +150,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
149150
150151
Returns:
151152
The dictionary.
153+
152154
"""
153155
message_dict: dict[str, Any]
154156
if isinstance(message, ChatMessage):
@@ -191,7 +193,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
191193
"tool_call_id": message.tool_call_id,
192194
}
193195
else:
194-
raise TypeError(f"Got unknown type {message}")
196+
msg = f"Got unknown type {message}"
197+
raise TypeError(msg)
195198
if "name" in message.additional_kwargs:
196199
message_dict["name"] = message.additional_kwargs["name"]
197200
return message_dict
@@ -214,7 +217,7 @@ def _convert_chunk_to_message_chunk(
214217
if raw_tool_calls := _dict.get("tool_calls"):
215218
additional_kwargs["tool_calls"] = raw_tool_calls
216219
for rtc in raw_tool_calls:
217-
try:
220+
with contextlib.suppress(KeyError):
218221
tool_call_chunks.append(
219222
create_tool_call_chunk(
220223
name=rtc["function"].get("name"),
@@ -223,11 +226,9 @@ def _convert_chunk_to_message_chunk(
223226
index=rtc.get("index"),
224227
)
225228
)
226-
except KeyError:
227-
pass
228229
if role == "user" or default_class == HumanMessageChunk:
229230
return HumanMessageChunk(content=content)
230-
elif role == "assistant" or default_class == AIMessageChunk:
231+
if role == "assistant" or default_class == AIMessageChunk:
231232
if usage := chunk.get("usage"):
232233
input_tokens = usage.get("prompt_tokens", 0)
233234
output_tokens = usage.get("completion_tokens", 0)
@@ -244,16 +245,15 @@ def _convert_chunk_to_message_chunk(
244245
tool_call_chunks=tool_call_chunks,
245246
usage_metadata=usage_metadata, # type: ignore[arg-type]
246247
)
247-
elif role == "system" or default_class == SystemMessageChunk:
248+
if role == "system" or default_class == SystemMessageChunk:
248249
return SystemMessageChunk(content=content)
249-
elif role == "function" or default_class == FunctionMessageChunk:
250+
if role == "function" or default_class == FunctionMessageChunk:
250251
return FunctionMessageChunk(content=content, name=_dict["name"])
251-
elif role == "tool" or default_class == ToolMessageChunk:
252+
if role == "tool" or default_class == ToolMessageChunk:
252253
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
253-
elif role or default_class == ChatMessageChunk:
254+
if role or default_class == ChatMessageChunk:
254255
return ChatMessageChunk(content=content, role=role)
255-
else:
256-
return default_class(content=content) # type: ignore
256+
return default_class(content=content) # type: ignore[call-arg]
257257

258258

259259
class _FunctionCall(TypedDict):
@@ -280,6 +280,7 @@ class ChatFireworks(BaseChatModel):
280280
from langchain_fireworks.chat_models import ChatFireworks
281281
fireworks = ChatFireworks(
282282
model_name="accounts/fireworks/models/llama-v3p1-8b-instruct")
283+
283284
"""
284285

285286
@property
@@ -326,14 +327,14 @@ def is_lc_serializable(cls) -> bool:
326327
),
327328
)
328329
"""Fireworks API key.
329-
330+
330331
Automatically read from env variable ``FIREWORKS_API_KEY`` if not provided.
331332
"""
332333

333334
fireworks_api_base: Optional[str] = Field(
334335
alias="base_url", default_factory=from_env("FIREWORKS_API_BASE", default=None)
335336
)
336-
"""Base URL path for API requests, leave blank if not using a proxy or service
337+
"""Base URL path for API requests, leave blank if not using a proxy or service
337338
emulator."""
338339
request_timeout: Union[float, tuple[float, float], Any, None] = Field(
339340
default=None, alias="timeout"
@@ -358,16 +359,17 @@ def is_lc_serializable(cls) -> bool:
358359
def build_extra(cls, values: dict[str, Any]) -> Any:
359360
"""Build extra kwargs from additional params that were passed in."""
360361
all_required_field_names = get_pydantic_field_names(cls)
361-
values = _build_model_kwargs(values, all_required_field_names)
362-
return values
362+
return _build_model_kwargs(values, all_required_field_names)
363363

364364
@model_validator(mode="after")
365365
def validate_environment(self) -> Self:
366366
"""Validate that api key and python package exists in environment."""
367367
if self.n < 1:
368-
raise ValueError("n must be at least 1.")
368+
msg = "n must be at least 1."
369+
raise ValueError(msg)
369370
if self.n > 1 and self.streaming:
370-
raise ValueError("n must be 1 when streaming.")
371+
msg = "n must be 1 when streaming."
372+
raise ValueError(msg)
371373

372374
client_params = {
373375
"api_key": (
@@ -522,7 +524,7 @@ def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
522524
"output_tokens": token_usage.get("completion_tokens", 0),
523525
"total_tokens": token_usage.get("total_tokens", 0),
524526
}
525-
generation_info = dict(finish_reason=res.get("finish_reason"))
527+
generation_info = {"finish_reason": res.get("finish_reason")}
526528
if "logprobs" in res:
527529
generation_info["logprobs"] = res["logprobs"]
528530
gen = ChatGeneration(
@@ -628,7 +630,7 @@ def bind_functions(
628630
self,
629631
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
630632
function_call: Optional[
631-
Union[_FunctionCall, str, Literal["auto", "none"]]
633+
Union[_FunctionCall, str, Literal["auto", "none"]] # noqa: PYI051
632634
] = None,
633635
**kwargs: Any,
634636
) -> Runnable[LanguageModelInput, BaseMessage]:
@@ -651,8 +653,8 @@ def bind_functions(
651653
(if any).
652654
**kwargs: Any additional parameters to pass to the
653655
:class:`~langchain.runnable.Runnable` constructor.
654-
"""
655656
657+
"""
656658
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
657659
if function_call is not None:
658660
function_call = (
@@ -662,18 +664,20 @@ def bind_functions(
662664
else function_call
663665
)
664666
if isinstance(function_call, dict) and len(formatted_functions) != 1:
665-
raise ValueError(
667+
msg = (
666668
"When specifying `function_call`, you must provide exactly one "
667669
"function."
668670
)
671+
raise ValueError(msg)
669672
if (
670673
isinstance(function_call, dict)
671674
and formatted_functions[0]["name"] != function_call["name"]
672675
):
673-
raise ValueError(
676+
msg = (
674677
f"Function call {function_call} was specified, but the only "
675678
f"provided function was {formatted_functions[0]['name']}."
676679
)
680+
raise ValueError(msg)
677681
kwargs = {**kwargs, "function_call": function_call}
678682
return super().bind(
679683
functions=formatted_functions,
@@ -685,7 +689,7 @@ def bind_tools(
685689
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
686690
*,
687691
tool_choice: Optional[
688-
Union[dict, str, Literal["auto", "any", "none"], bool]
692+
Union[dict, str, Literal["auto", "any", "none"], bool] # noqa: PYI051
689693
] = None,
690694
**kwargs: Any,
691695
) -> Runnable[LanguageModelInput, BaseMessage]:
@@ -705,8 +709,8 @@ def bind_tools(
705709
``{"type": "function", "function": {"name": <<tool_name>>}}``.
706710
**kwargs: Any additional parameters to pass to
707711
:meth:`~langchain_fireworks.chat_models.ChatFireworks.bind`
708-
"""
709712
713+
"""
710714
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
711715
if tool_choice is not None and tool_choice:
712716
if isinstance(tool_choice, str) and (
@@ -715,10 +719,11 @@ def bind_tools(
715719
tool_choice = {"type": "function", "function": {"name": tool_choice}}
716720
if isinstance(tool_choice, bool):
717721
if len(tools) > 1:
718-
raise ValueError(
722+
msg = (
719723
"tool_choice can only be True when there is one tool. Received "
720724
f"{len(tools)} tools."
721725
)
726+
raise ValueError(msg)
722727
tool_name = formatted_tools[0]["function"]["name"]
723728
tool_choice = {
724729
"type": "function",
@@ -779,6 +784,9 @@ def with_structured_output(
779784
will be caught and returned as well. The final output is always a dict
780785
with keys "raw", "parsed", and "parsing_error".
781786
787+
kwargs:
788+
Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor.
789+
782790
Returns:
783791
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
784792
@@ -964,17 +972,20 @@ class AnswerWithJustification(BaseModel):
964972
# },
965973
# 'parsing_error': None
966974
# }
975+
967976
""" # noqa: E501
968977
_ = kwargs.pop("strict", None)
969978
if kwargs:
970-
raise ValueError(f"Received unsupported arguments {kwargs}")
979+
msg = f"Received unsupported arguments {kwargs}"
980+
raise ValueError(msg)
971981
is_pydantic_schema = _is_pydantic_class(schema)
972982
if method == "function_calling":
973983
if schema is None:
974-
raise ValueError(
984+
msg = (
975985
"schema must be specified when method is 'function_calling'. "
976986
"Received None."
977987
)
988+
raise ValueError(msg)
978989
formatted_tool = convert_to_openai_tool(schema)
979990
tool_name = formatted_tool["function"]["name"]
980991
llm = self.bind_tools(
@@ -996,10 +1007,11 @@ class AnswerWithJustification(BaseModel):
9961007
)
9971008
elif method == "json_schema":
9981009
if schema is None:
999-
raise ValueError(
1010+
msg = (
10001011
"schema must be specified when method is 'json_schema'. "
10011012
"Received None."
10021013
)
1014+
raise ValueError(msg)
10031015
formatted_schema = convert_to_json_schema(schema)
10041016
llm = self.bind(
10051017
response_format={"type": "json_object", "schema": formatted_schema},
@@ -1027,10 +1039,11 @@ class AnswerWithJustification(BaseModel):
10271039
else JsonOutputParser()
10281040
)
10291041
else:
1030-
raise ValueError(
1042+
msg = (
10311043
f"Unrecognized method argument. Expected one of 'function_calling' or "
10321044
f"'json_mode'. Received: '{method}'"
10331045
)
1046+
raise ValueError(msg)
10341047

10351048
if include_raw:
10361049
parser_assign = RunnablePassthrough.assign(
@@ -1041,8 +1054,7 @@ class AnswerWithJustification(BaseModel):
10411054
[parser_none], exception_key="parsing_error"
10421055
)
10431056
return RunnableMap(raw=llm) | parser_with_fallback
1044-
else:
1045-
return llm | output_parser
1057+
return llm | output_parser
10461058

10471059

10481060
def _is_pydantic_class(obj: Any) -> bool:

libs/partners/fireworks/langchain_fireworks/embeddings.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
55
from typing_extensions import Self
66

7-
# type: ignore
8-
97

108
class FireworksEmbeddings(BaseModel, Embeddings):
119
"""Fireworks embedding model integration.
@@ -78,7 +76,7 @@ class FireworksEmbeddings(BaseModel, Embeddings):
7876
),
7977
)
8078
"""Fireworks API key.
81-
79+
8280
Automatically read from env variable ``FIREWORKS_API_KEY`` if not provided.
8381
"""
8482
model: str = "nomic-ai/nomic-embed-text-v1.5"

0 commit comments

Comments
 (0)