Skip to content

Commit dd76209

Browse files
authored
groq[patch]: ruff fixes and rules (#31904)
* bump ruff deps * add more thorough ruff rules * fix said rules
1 parent 750721b commit dd76209

File tree

9 files changed

+119
-52
lines changed

9 files changed

+119
-52
lines changed

libs/partners/groq/langchain_groq/chat_models.py

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383

8484

8585
class ChatGroq(BaseChatModel):
86-
"""Groq Chat large language models API.
86+
r"""Groq Chat large language models API.
8787
8888
To use, you should have the
8989
environment variable ``GROQ_API_KEY`` set with your API key.
@@ -412,7 +412,8 @@ def build_extra(cls, values: dict[str, Any]) -> Any:
412412
extra = values.get("model_kwargs", {})
413413
for field_name in list(values):
414414
if field_name in extra:
415-
raise ValueError(f"Found {field_name} supplied twice.")
415+
msg = f"Found {field_name} supplied twice."
416+
raise ValueError(msg)
416417
if field_name not in all_required_field_names:
417418
warnings.warn(
418419
f"""WARNING! {field_name} is not default parameter.
@@ -423,10 +424,11 @@ def build_extra(cls, values: dict[str, Any]) -> Any:
423424

424425
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
425426
if invalid_model_kwargs:
426-
raise ValueError(
427+
msg = (
427428
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
428429
f"Instead they were passed in as part of `model_kwargs` parameter."
429430
)
431+
raise ValueError(msg)
430432

431433
values["model_kwargs"] = extra
432434
return values
@@ -435,9 +437,11 @@ def build_extra(cls, values: dict[str, Any]) -> Any:
435437
def validate_environment(self) -> Self:
436438
"""Validate that api key and python package exists in environment."""
437439
if self.n < 1:
438-
raise ValueError("n must be at least 1.")
440+
msg = "n must be at least 1."
441+
raise ValueError(msg)
439442
if self.n > 1 and self.streaming:
440-
raise ValueError("n must be 1 when streaming.")
443+
msg = "n must be 1 when streaming."
444+
raise ValueError(msg)
441445
if self.temperature == 0:
442446
self.temperature = 1e-8
443447

@@ -470,10 +474,11 @@ def validate_environment(self) -> Self:
470474
**client_params, **async_specific
471475
).chat.completions
472476
except ImportError as exc:
473-
raise ImportError(
477+
msg = (
474478
"Could not import groq python package. "
475479
"Please install it with `pip install groq`."
476-
) from exc
480+
)
481+
raise ImportError(msg) from exc
477482
return self
478483

479484
#
@@ -680,7 +685,7 @@ def _default_params(self) -> dict[str, Any]:
680685
return params
681686

682687
def _create_chat_result(
683-
self, response: Union[dict, BaseModel], params: dict
688+
self, response: dict | BaseModel, params: dict
684689
) -> ChatResult:
685690
generations = []
686691
if not isinstance(response, dict):
@@ -698,7 +703,7 @@ def _create_chat_result(
698703
"total_tokens", input_tokens + output_tokens
699704
),
700705
}
701-
generation_info = dict(finish_reason=res.get("finish_reason"))
706+
generation_info = {"finish_reason": res.get("finish_reason")}
702707
if "logprobs" in res:
703708
generation_info["logprobs"] = res["logprobs"]
704709
gen = ChatGeneration(
@@ -755,7 +760,7 @@ def bind_functions(
755760
self,
756761
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
757762
function_call: Optional[
758-
Union[_FunctionCall, str, Literal["auto", "none"]]
763+
Union[_FunctionCall, str, Literal["auto", "none"]] # noqa: PYI051
759764
] = None,
760765
**kwargs: Any,
761766
) -> Runnable[LanguageModelInput, BaseMessage]:
@@ -777,8 +782,8 @@ def bind_functions(
777782
(if any).
778783
**kwargs: Any additional parameters to pass to
779784
:meth:`~langchain_groq.chat_models.ChatGroq.bind`.
780-
"""
781785
786+
"""
782787
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
783788
if function_call is not None:
784789
function_call = (
@@ -788,18 +793,20 @@ def bind_functions(
788793
else function_call
789794
)
790795
if isinstance(function_call, dict) and len(formatted_functions) != 1:
791-
raise ValueError(
796+
msg = (
792797
"When specifying `function_call`, you must provide exactly one "
793798
"function."
794799
)
800+
raise ValueError(msg)
795801
if (
796802
isinstance(function_call, dict)
797803
and formatted_functions[0]["name"] != function_call["name"]
798804
):
799-
raise ValueError(
805+
msg = (
800806
f"Function call {function_call} was specified, but the only "
801807
f"provided function was {formatted_functions[0]['name']}."
802808
)
809+
raise ValueError(msg)
803810
kwargs = {**kwargs, "function_call": function_call}
804811
return super().bind(
805812
functions=formatted_functions,
@@ -811,7 +818,7 @@ def bind_tools(
811818
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
812819
*,
813820
tool_choice: Optional[
814-
Union[dict, str, Literal["auto", "any", "none"], bool]
821+
Union[dict, str, Literal["auto", "any", "none"], bool] # noqa: PYI051
815822
] = None,
816823
**kwargs: Any,
817824
) -> Runnable[LanguageModelInput, BaseMessage]:
@@ -829,8 +836,8 @@ def bind_tools(
829836
{"type": "function", "function": {"name": <<tool_name>>}}.
830837
**kwargs: Any additional parameters to pass to the
831838
:class:`~langchain.runnable.Runnable` constructor.
832-
"""
833839
840+
"""
834841
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
835842
if tool_choice is not None and tool_choice:
836843
if tool_choice == "any":
@@ -841,10 +848,11 @@ def bind_tools(
841848
tool_choice = {"type": "function", "function": {"name": tool_choice}}
842849
if isinstance(tool_choice, bool):
843850
if len(tools) > 1:
844-
raise ValueError(
851+
msg = (
845852
"tool_choice can only be True when there is one tool. Received "
846853
f"{len(tools)} tools."
847854
)
855+
raise ValueError(msg)
848856
tool_name = formatted_tools[0]["function"]["name"]
849857
tool_choice = {
850858
"type": "function",
@@ -861,8 +869,8 @@ def with_structured_output(
861869
method: Literal["function_calling", "json_mode"] = "function_calling",
862870
include_raw: bool = False,
863871
**kwargs: Any,
864-
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
865-
"""Model wrapper that returns outputs formatted to match the given schema.
872+
) -> Runnable[LanguageModelInput, dict | BaseModel]:
873+
r"""Model wrapper that returns outputs formatted to match the given schema.
866874
867875
Args:
868876
schema:
@@ -895,6 +903,9 @@ def with_structured_output(
895903
response will be returned. If an error occurs during output parsing it
896904
will be caught and returned as well. The final output is always a dict
897905
with keys "raw", "parsed", and "parsing_error".
906+
kwargs:
907+
Any additional parameters to pass to the
908+
:class:`~langchain.runnable.Runnable` constructor.
898909
899910
Returns:
900911
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
@@ -1075,21 +1086,24 @@ class AnswerWithJustification(BaseModel):
10751086
# },
10761087
# 'parsing_error': None
10771088
# }
1089+
10781090
""" # noqa: E501
10791091
_ = kwargs.pop("strict", None)
10801092
if kwargs:
1081-
raise ValueError(f"Received unsupported arguments {kwargs}")
1093+
msg = f"Received unsupported arguments {kwargs}"
1094+
raise ValueError(msg)
10821095
is_pydantic_schema = _is_pydantic_class(schema)
10831096
if method == "json_schema":
10841097
# Some applications require that incompatible parameters (e.g., unsupported
10851098
# methods) be handled.
10861099
method = "function_calling"
10871100
if method == "function_calling":
10881101
if schema is None:
1089-
raise ValueError(
1102+
msg = (
10901103
"schema must be specified when method is 'function_calling'. "
10911104
"Received None."
10921105
)
1106+
raise ValueError(msg)
10931107
formatted_tool = convert_to_openai_tool(schema)
10941108
tool_name = formatted_tool["function"]["name"]
10951109
llm = self.bind_tools(
@@ -1123,10 +1137,11 @@ class AnswerWithJustification(BaseModel):
11231137
else JsonOutputParser()
11241138
)
11251139
else:
1126-
raise ValueError(
1140+
msg = (
11271141
f"Unrecognized method argument. Expected one of 'function_calling' or "
11281142
f"'json_mode'. Received: '{method}'"
11291143
)
1144+
raise ValueError(msg)
11301145

11311146
if include_raw:
11321147
parser_assign = RunnablePassthrough.assign(
@@ -1137,8 +1152,7 @@ class AnswerWithJustification(BaseModel):
11371152
[parser_none], exception_key="parsing_error"
11381153
)
11391154
return RunnableMap(raw=llm) | parser_with_fallback
1140-
else:
1141-
return llm | output_parser
1155+
return llm | output_parser
11421156

11431157

11441158
def _is_pydantic_class(obj: Any) -> bool:
@@ -1160,6 +1174,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
11601174
11611175
Returns:
11621176
The dictionary.
1177+
11631178
"""
11641179
message_dict: dict[str, Any]
11651180
if isinstance(message, ChatMessage):
@@ -1200,7 +1215,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
12001215
"tool_call_id": message.tool_call_id,
12011216
}
12021217
else:
1203-
raise TypeError(f"Got unknown type {message}")
1218+
msg = f"Got unknown type {message}"
1219+
raise TypeError(msg)
12041220
if "name" in message.additional_kwargs:
12051221
message_dict["name"] = message.additional_kwargs["name"]
12061222
return message_dict
@@ -1224,7 +1240,7 @@ def _convert_chunk_to_message_chunk(
12241240

12251241
if role == "user" or default_class == HumanMessageChunk:
12261242
return HumanMessageChunk(content=content)
1227-
elif role == "assistant" or default_class == AIMessageChunk:
1243+
if role == "assistant" or default_class == AIMessageChunk:
12281244
if reasoning := _dict.get("reasoning"):
12291245
additional_kwargs["reasoning_content"] = reasoning
12301246
if usage := (chunk.get("x_groq") or {}).get("usage"):
@@ -1242,16 +1258,15 @@ def _convert_chunk_to_message_chunk(
12421258
additional_kwargs=additional_kwargs,
12431259
usage_metadata=usage_metadata, # type: ignore[arg-type]
12441260
)
1245-
elif role == "system" or default_class == SystemMessageChunk:
1261+
if role == "system" or default_class == SystemMessageChunk:
12461262
return SystemMessageChunk(content=content)
1247-
elif role == "function" or default_class == FunctionMessageChunk:
1263+
if role == "function" or default_class == FunctionMessageChunk:
12481264
return FunctionMessageChunk(content=content, name=_dict["name"])
1249-
elif role == "tool" or default_class == ToolMessageChunk:
1265+
if role == "tool" or default_class == ToolMessageChunk:
12501266
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
1251-
elif role or default_class == ChatMessageChunk:
1267+
if role or default_class == ChatMessageChunk:
12521268
return ChatMessageChunk(content=content, role=role)
1253-
else:
1254-
return default_class(content=content) # type: ignore
1269+
return default_class(content=content) # type: ignore[call-arg]
12551270

12561271

12571272
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
@@ -1262,12 +1277,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
12621277
12631278
Returns:
12641279
The LangChain message.
1280+
12651281
"""
12661282
id_ = _dict.get("id")
12671283
role = _dict.get("role")
12681284
if role == "user":
12691285
return HumanMessage(content=_dict.get("content", ""))
1270-
elif role == "assistant":
1286+
if role == "assistant":
12711287
content = _dict.get("content", "") or ""
12721288
additional_kwargs: dict = {}
12731289
if reasoning := _dict.get("reasoning"):
@@ -1292,11 +1308,11 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
12921308
tool_calls=tool_calls,
12931309
invalid_tool_calls=invalid_tool_calls,
12941310
)
1295-
elif role == "system":
1311+
if role == "system":
12961312
return SystemMessage(content=_dict.get("content", ""))
1297-
elif role == "function":
1313+
if role == "function":
12981314
return FunctionMessage(content=_dict.get("content", ""), name=_dict.get("name")) # type: ignore[arg-type]
1299-
elif role == "tool":
1315+
if role == "tool":
13001316
additional_kwargs = {}
13011317
if "name" in _dict:
13021318
additional_kwargs["name"] = _dict["name"]
@@ -1305,8 +1321,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
13051321
tool_call_id=_dict.get("tool_call_id"),
13061322
additional_kwargs=additional_kwargs,
13071323
)
1308-
else:
1309-
return ChatMessage(content=_dict.get("content", ""), role=role) # type: ignore[arg-type]
1324+
return ChatMessage(content=_dict.get("content", ""), role=role) # type: ignore[arg-type]
13101325

13111326

13121327
def _lc_tool_call_to_groq_tool_call(tool_call: ToolCall) -> dict:

libs/partners/groq/pyproject.toml

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,58 @@ disallow_untyped_defs = "True"
4444
target-version = "py39"
4545

4646
[tool.ruff.lint]
47-
select = ["E", "F", "I", "W", "UP", "S"]
48-
ignore = [ "UP007", ]
47+
select = [
48+
"A", # flake8-builtins
49+
"ASYNC", # flake8-async
50+
"C4", # flake8-comprehensions
51+
"COM", # flake8-commas
52+
"D", # pydocstyle
53+
"DOC", # pydoclint
54+
"E", # pycodestyle error
55+
"EM", # flake8-errmsg
56+
"F", # pyflakes
57+
"FA", # flake8-future-annotations
58+
"FBT", # flake8-boolean-trap
59+
"FLY", # flake8-flynt
60+
"I", # isort
61+
"ICN", # flake8-import-conventions
62+
"INT", # flake8-gettext
63+
"ISC", # isort-comprehensions
64+
"PGH", # pygrep-hooks
65+
"PIE", # flake8-pie
66+
"PERF", # flake8-perf
67+
"PYI", # flake8-pyi
68+
"Q", # flake8-quotes
69+
"RET", # flake8-return
70+
"RSE", # flake8-rst-docstrings
71+
"RUF", # ruff
72+
"S", # flake8-bandit
73+
"SLF", # flake8-self
74+
"SLOT", # flake8-slots
75+
"SIM", # flake8-simplify
76+
"T10", # flake8-debugger
77+
"T20", # flake8-print
78+
"TID", # flake8-tidy-imports
79+
"UP", # pyupgrade
80+
"W", # pycodestyle warning
81+
"YTT", # flake8-2020
82+
]
83+
ignore = [
84+
"D100", # Missing docstring in public module
85+
"D101", # Missing docstring in public class
86+
"D102", # Missing docstring in public method
87+
"D103", # Missing docstring in public function
88+
"D104", # Missing docstring in public package
89+
"D105", # Missing docstring in magic method
90+
"D107", # Missing docstring in __init__
91+
"COM812", # Messes with the formatter
92+
"ISC001", # Messes with the formatter
93+
"PERF203", # Rarely useful
94+
"S112", # Rarely useful
95+
"RUF012", # Doesn't play well with Pydantic
96+
"SLF001", # Private member access
97+
"UP007", # pyupgrade: non-pep604-annotation-union
98+
]
4999

50100
[tool.coverage.run]
51101
omit = ["tests/*"]

libs/partners/groq/scripts/check_imports.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
SourceFileLoader("x", file).load_module()
1111
except Exception:
1212
has_failure = True
13-
print(file)
1413
traceback.print_exc()
15-
print()
1614

1715
sys.exit(1 if has_failure else 0)

0 commit comments

Comments
 (0)