Skip to content

Commit 539e5b6

Browse files
authored
core: Add mypy strict-equality rule (#31286)
1 parent 2c4e0ab commit 539e5b6

File tree

13 files changed

+43
-55
lines changed

13 files changed

+43
-55
lines changed

libs/core/langchain_core/language_models/chat_models.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
AIMessage,
5252
AnyMessage,
5353
BaseMessage,
54-
BaseMessageChunk,
5554
HumanMessage,
5655
convert_to_messages,
5756
convert_to_openai_image_block,
@@ -446,13 +445,10 @@ def stream(
446445
*,
447446
stop: Optional[list[str]] = None,
448447
**kwargs: Any,
449-
) -> Iterator[BaseMessageChunk]:
448+
) -> Iterator[BaseMessage]:
450449
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
451450
# model doesn't implement streaming, so use default implementation
452-
yield cast(
453-
"BaseMessageChunk",
454-
self.invoke(input, config=config, stop=stop, **kwargs),
455-
)
451+
yield self.invoke(input, config=config, stop=stop, **kwargs)
456452
else:
457453
config = ensure_config(config)
458454
messages = self._convert_input(input).to_messages()
@@ -537,13 +533,10 @@ async def astream(
537533
*,
538534
stop: Optional[list[str]] = None,
539535
**kwargs: Any,
540-
) -> AsyncIterator[BaseMessageChunk]:
536+
) -> AsyncIterator[BaseMessage]:
541537
if not self._should_stream(async_api=True, **{**kwargs, "stream": True}):
542538
# No async or sync stream is implemented, so fall back to ainvoke
543-
yield cast(
544-
"BaseMessageChunk",
545-
await self.ainvoke(input, config=config, stop=stop, **kwargs),
546-
)
539+
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
547540
return
548541

549542
config = ensure_config(config)
@@ -1454,7 +1447,7 @@ class AnswerWithJustification(BaseModel):
14541447
PydanticToolsParser,
14551448
)
14561449

1457-
if self.bind_tools is BaseChatModel.bind_tools:
1450+
if type(self).bind_tools is BaseChatModel.bind_tools:
14581451
msg = "with_structured_output is not implemented for this model."
14591452
raise NotImplementedError(msg)
14601453

libs/core/langchain_core/runnables/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4331,8 +4331,9 @@ def __init__(
43314331
self,
43324332
func: Union[
43334333
Union[
4334-
Callable[[Input], Output],
43354334
Callable[[Input], Iterator[Output]],
4335+
Callable[[Input], Runnable[Input, Output]],
4336+
Callable[[Input], Output],
43364337
Callable[[Input, RunnableConfig], Output],
43374338
Callable[[Input, CallbackManagerForChainRun], Output],
43384339
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],

libs/core/langchain_core/utils/pydantic.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def get_pydantic_major_version() -> int:
7777
TypeBaseModel = type[BaseModel]
7878
elif IS_PYDANTIC_V2:
7979
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
80+
from pydantic.v1.fields import ModelField
8081

8182
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
8283
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
@@ -373,20 +374,20 @@ def get_fields(model: type[BaseModelV2]) -> dict[str, FieldInfoV2]: ...
373374
def get_fields(model: BaseModelV2) -> dict[str, FieldInfoV2]: ...
374375

375376
@overload
376-
def get_fields(model: type[BaseModelV1]) -> dict[str, FieldInfoV1]: ...
377+
def get_fields(model: type[BaseModelV1]) -> dict[str, ModelField]: ...
377378

378379
@overload
379-
def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ...
380+
def get_fields(model: BaseModelV1) -> dict[str, ModelField]: ...
380381

381382
def get_fields(
382383
model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1],
383-
) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]:
384+
) -> Union[dict[str, FieldInfoV2], dict[str, ModelField]]:
384385
"""Get the field names of a Pydantic model."""
385386
if hasattr(model, "model_fields"):
386387
return model.model_fields
387388

388389
if hasattr(model, "__fields__"):
389-
return model.__fields__ # type: ignore[return-value]
390+
return model.__fields__
390391
msg = f"Expected a Pydantic model. Got {type(model)}"
391392
raise TypeError(msg)
392393

libs/core/pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ report_deprecated_as_note = "True"
7474
# TODO: activate for 'strict' checking
7575
disallow_any_generics = "False"
7676
warn_return_any = "False"
77-
strict_equality = "False"
7877

7978

8079
[tool.ruff]

libs/core/tests/unit_tests/fake/test_fake_chat_model.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Tests for verifying that testing utility code works as expected."""
22

3+
import operator
4+
from functools import reduce
35
from itertools import cycle
46
from typing import Any, Optional, Union
57
from uuid import UUID
@@ -115,12 +117,7 @@ async def test_generic_fake_chat_model_stream() -> None:
115117
]
116118
assert len({chunk.id for chunk in chunks}) == 1
117119

118-
accumulate_chunks = None
119-
for chunk in chunks:
120-
if accumulate_chunks is None:
121-
accumulate_chunks = chunk
122-
else:
123-
accumulate_chunks += chunk
120+
accumulate_chunks = reduce(operator.add, chunks)
124121

125122
assert accumulate_chunks == AIMessageChunk(
126123
content="",

libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Test PydanticOutputParser."""
22

33
from enum import Enum
4-
from typing import Literal, Optional
4+
from typing import Literal, Optional, Union
55

66
import pydantic
77
import pytest
@@ -30,7 +30,7 @@ class ForecastV1(V1BaseModel):
3030

3131
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1])
3232
def test_pydantic_parser_chaining(
33-
pydantic_object: TBaseModel,
33+
pydantic_object: Union[type[ForecastV2], type[ForecastV1]],
3434
) -> None:
3535
prompt = PromptTemplate(
3636
template="""{{
@@ -43,11 +43,11 @@ def test_pydantic_parser_chaining(
4343

4444
model = ParrotFakeChatModel()
4545

46-
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type,var-annotated]
46+
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[type-var]
4747
chain = prompt | model | parser
4848

4949
res = chain.invoke({})
50-
assert type(res) is pydantic_object
50+
assert isinstance(res, pydantic_object)
5151
assert res.f_or_c == "C"
5252
assert res.temperature == 20
5353
assert res.forecast == "Sunny"

libs/core/tests/unit_tests/prompts/test_prompt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def test_basic_sandboxing_with_jinja2() -> None:
441441
template = " {{''.__class__.__bases__[0] }} " # malicious code
442442
prompt = PromptTemplate.from_template(template, template_format="jinja2")
443443
with pytest.raises(jinja2.exceptions.SecurityError):
444-
assert prompt.format() == []
444+
prompt.format()
445445

446446

447447
@pytest.mark.requires("jinja2")

libs/core/tests/unit_tests/prompts/test_structured.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class OutputSchema(BaseModel):
5151

5252
chain = prompt | model
5353

54-
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=42)
54+
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=42) # type: ignore[comparison-overlap]
5555

5656

5757
def test_structured_prompt_dict() -> None:
@@ -73,13 +73,13 @@ def test_structured_prompt_dict() -> None:
7373

7474
chain = prompt | model
7575

76-
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42}
76+
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42} # type: ignore[comparison-overlap]
7777

7878
assert loads(dumps(prompt)).model_dump() == prompt.model_dump()
7979

8080
chain = loads(dumps(prompt)) | model
8181

82-
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42}
82+
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42} # type: ignore[comparison-overlap]
8383

8484

8585
def test_structured_prompt_kwargs() -> None:
@@ -99,10 +99,10 @@ def test_structured_prompt_kwargs() -> None:
9999
)
100100
model = FakeStructuredChatModel(responses=[])
101101
chain = prompt | model
102-
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7}
102+
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7} # type: ignore[comparison-overlap]
103103
assert loads(dumps(prompt)).model_dump() == prompt.model_dump()
104104
chain = loads(dumps(prompt)) | model
105-
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7}
105+
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7} # type: ignore[comparison-overlap]
106106

107107
class OutputSchema(BaseModel):
108108
name: str
@@ -116,7 +116,7 @@ class OutputSchema(BaseModel):
116116

117117
chain = prompt | model
118118

119-
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7)
119+
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7) # type: ignore[comparison-overlap]
120120

121121

122122
def test_structured_prompt_template_format() -> None:

libs/core/tests/unit_tests/runnables/test_runnable.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ def foo(x: int) -> None:
671671

672672
def test_schema_with_itemgetter() -> None:
673673
"""Test runnable with itemgetter."""
674-
foo = RunnableLambda(itemgetter("hello"))
674+
foo: Runnable = RunnableLambda(itemgetter("hello"))
675675
assert _schema(foo.input_schema) == {
676676
"properties": {"hello": {"title": "Hello"}},
677677
"required": ["hello"],
@@ -4001,7 +4001,7 @@ def test_runnable_lambda_stream() -> None:
40014001
# sleep to better simulate a real stream
40024002
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
40034003

4004-
output = list(RunnableLambda(lambda _: llm).stream(""))
4004+
output = list(RunnableLambda[str, str](lambda _: llm).stream(""))
40054005
assert output == list(llm_res)
40064006

40074007

@@ -4014,9 +4014,9 @@ def test_runnable_lambda_stream_with_callbacks() -> None:
40144014
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
40154015
config: RunnableConfig = {"callbacks": [tracer]}
40164016

4017-
assert list(RunnableLambda(lambda _: llm).stream("", config=config)) == list(
4018-
llm_res
4019-
)
4017+
assert list(
4018+
RunnableLambda[str, str](lambda _: llm).stream("", config=config)
4019+
) == list(llm_res)
40204020

40214021
assert len(tracer.runs) == 1
40224022
assert tracer.runs[0].error is None
@@ -4075,10 +4075,7 @@ async def afunc(*args: Any, **kwargs: Any) -> Any:
40754075
assert output == list(llm_res)
40764076

40774077
output = [
4078-
chunk
4079-
async for chunk in cast(
4080-
"AsyncIterator[str]", RunnableLambda(lambda _: llm).astream("")
4081-
)
4078+
chunk async for chunk in RunnableLambda[str, str](lambda _: llm).astream("")
40824079
]
40834080
assert output == list(llm_res)
40844081

@@ -4093,7 +4090,10 @@ async def test_runnable_lambda_astream_with_callbacks() -> None:
40934090
config: RunnableConfig = {"callbacks": [tracer]}
40944091

40954092
assert [
4096-
_ async for _ in RunnableLambda(lambda _: llm).astream("", config=config)
4093+
_
4094+
async for _ in RunnableLambda[str, str](lambda _: llm).astream(
4095+
"", config=config
4096+
)
40974097
] == list(llm_res)
40984098

40994099
assert len(tracer.runs) == 1
@@ -5300,7 +5300,7 @@ async def idchain_async(_input: dict, /) -> bool:
53005300
def func(_input: dict, /) -> Runnable:
53015301
return idchain
53025302

5303-
assert await RunnableLambda(func).ainvoke({})
5303+
assert await RunnableLambda[dict, bool](func).ainvoke({})
53045304

53055305

53065306
def test_invoke_stream_passthrough_assign_trace() -> None:

libs/core/tests/unit_tests/test_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def tool_func_v1(*, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
210210
return f"{arg1} {arg2} {arg3}"
211211

212212
assert isinstance(tool_func_v1, BaseTool)
213-
assert tool_func_v1.args_schema == _MockSchemaV1
213+
assert tool_func_v1.args_schema == cast("ArgsSchema", _MockSchemaV1)
214214

215215

216216
def test_decorated_function_schema_equivalent() -> None:

0 commit comments

Comments
 (0)