Skip to content

Commit aab2e42

Browse files
authored
core[patch]: Use Blockbuster to detect blocking calls in asyncio during tests (#29043)
This PR uses the [blockbuster](https://github.com/cbornet/blockbuster) library in langchain-core to detect blocking calls made in the asyncio event loop during unit tests. Avoiding blocking calls is hard as these can be deeply buried in the code or made in 3rd party libraries. Blockbuster makes it easier to detect them by raising an exception when a call is made to a known blocking function (eg: `time.sleep`). Adding blockbuster allowed to find a blocking call in `aconfig_with_context` (it ends up calling `get_function_nonlocals` which loads function code). **Dependencies:** - blockbuster (test) **Twitter handle:** cbornet_
1 parent ceda8bc commit aab2e42

File tree

16 files changed

+550
-308
lines changed

16 files changed

+550
-308
lines changed

libs/core/poetry.lock

Lines changed: 27 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

libs/core/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ grandalf = "^0.8"
118118
responses = "^0.25.0"
119119
pytest-socket = "^0.7.0"
120120
pytest-xdist = "^3.6.1"
121+
blockbuster = "~1.5.11"
121122
[[tool.poetry.group.test.dependencies.numpy]]
122123
version = "^1.24.0"
123124
python = "<3.12"

libs/core/tests/unit_tests/conftest.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,41 @@
11
"""Configuration for unit tests."""
22

3-
from collections.abc import Sequence
3+
from collections.abc import Iterator, Sequence
44
from importlib import util
55
from uuid import UUID
66

77
import pytest
8+
from blockbuster import BlockBuster, blockbuster_ctx
89
from pytest import Config, Function, Parser
910
from pytest_mock import MockerFixture
1011

1112

13+
@pytest.fixture(autouse=True)
14+
def blockbuster() -> Iterator[BlockBuster]:
15+
with blockbuster_ctx("langchain_core") as bb:
16+
for func in ["os.stat", "os.path.abspath"]:
17+
(
18+
bb.functions[func]
19+
.can_block_in("langchain_core/_api/internal.py", "is_caller_internal")
20+
.can_block_in("langchain_core/runnables/base.py", "__repr__")
21+
.can_block_in(
22+
"langchain_core/beta/runnables/context.py", "aconfig_with_context"
23+
)
24+
)
25+
26+
for func in ["os.stat", "io.TextIOWrapper.read"]:
27+
bb.functions[func].can_block_in(
28+
"langsmith/client.py", "_default_retry_config"
29+
)
30+
31+
for bb_function in bb.functions.values():
32+
bb_function.can_block_in(
33+
"freezegun/api.py", "_get_cached_module_attributes"
34+
)
35+
36+
yield bb
37+
38+
1239
def pytest_addoption(parser: Parser) -> None:
1340
"""Add custom command line options to pytest."""
1441
parser.addoption(

libs/core/tests/unit_tests/fake/test_fake_chat_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,12 @@ async def on_llm_new_token(
191191
model = GenericFakeChatModel(messages=infinite_cycle)
192192
tokens: list[str] = []
193193
# New model
194-
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
194+
results = [
195+
chunk
196+
async for chunk in model.astream(
197+
"meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}
198+
)
199+
]
195200
assert results == [
196201
_any_id_ai_message_chunk(content="hello"),
197202
_any_id_ai_message_chunk(content=" "),

libs/core/tests/unit_tests/language_models/chat_models/test_base.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
2020
from langchain_core.outputs.llm_result import LLMResult
21+
from langchain_core.tracers import LogStreamCallbackHandler
2122
from langchain_core.tracers.base import BaseTracer
2223
from langchain_core.tracers.context import collect_runs
2324
from langchain_core.tracers.event_stream import _AstreamEventsCallbackHandler
@@ -303,39 +304,48 @@ def _stream(
303304

304305

305306
@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
306-
async def test_disable_streaming(
307+
def test_disable_streaming(
307308
disable_streaming: Union[bool, Literal["tool_calling"]],
308309
) -> None:
309310
model = StreamingModel(disable_streaming=disable_streaming)
310311
assert model.invoke([]).content == "invoke"
311-
assert (await model.ainvoke([])).content == "invoke"
312312

313313
expected = "invoke" if disable_streaming is True else "stream"
314314
assert next(model.stream([])).content == expected
315-
async for c in model.astream([]):
316-
assert c.content == expected
317-
break
315+
assert (
316+
model.invoke([], config={"callbacks": [LogStreamCallbackHandler()]}).content
317+
== expected
318+
)
319+
320+
expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream"
321+
assert next(model.stream([], tools=[{"type": "function"}])).content == expected
318322
assert (
319323
model.invoke(
320-
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
324+
[], config={"callbacks": [LogStreamCallbackHandler()]}, tools=[{}]
321325
).content
322326
== expected
323327
)
328+
329+
330+
@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
331+
async def test_disable_streaming_async(
332+
disable_streaming: Union[bool, Literal["tool_calling"]],
333+
) -> None:
334+
model = StreamingModel(disable_streaming=disable_streaming)
335+
assert (await model.ainvoke([])).content == "invoke"
336+
337+
expected = "invoke" if disable_streaming is True else "stream"
338+
async for c in model.astream([]):
339+
assert c.content == expected
340+
break
324341
assert (
325342
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
326343
).content == expected
327344

328345
expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream"
329-
assert next(model.stream([], tools=[{"type": "function"}])).content == expected
330346
async for c in model.astream([], tools=[{}]):
331347
assert c.content == expected
332348
break
333-
assert (
334-
model.invoke(
335-
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
336-
).content
337-
== expected
338-
)
339349
assert (
340350
await model.ainvoke(
341351
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
@@ -344,26 +354,31 @@ async def test_disable_streaming(
344354

345355

346356
@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
347-
async def test_disable_streaming_no_streaming_model(
357+
def test_disable_streaming_no_streaming_model(
348358
disable_streaming: Union[bool, Literal["tool_calling"]],
349359
) -> None:
350360
model = NoStreamingModel(disable_streaming=disable_streaming)
351361
assert model.invoke([]).content == "invoke"
352-
assert (await model.ainvoke([])).content == "invoke"
353362
assert next(model.stream([])).content == "invoke"
354-
async for c in model.astream([]):
355-
assert c.content == "invoke"
356-
break
357363
assert (
358-
model.invoke(
359-
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
360-
).content
364+
model.invoke([], config={"callbacks": [LogStreamCallbackHandler()]}).content
361365
== "invoke"
362366
)
367+
assert next(model.stream([], tools=[{}])).content == "invoke"
368+
369+
370+
@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
371+
async def test_disable_streaming_no_streaming_model_async(
372+
disable_streaming: Union[bool, Literal["tool_calling"]],
373+
) -> None:
374+
model = NoStreamingModel(disable_streaming=disable_streaming)
375+
assert (await model.ainvoke([])).content == "invoke"
376+
async for c in model.astream([]):
377+
assert c.content == "invoke"
378+
break
363379
assert (
364380
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
365381
).content == "invoke"
366-
assert next(model.stream([], tools=[{}])).content == "invoke"
367382
async for c in model.astream([], tools=[{}]):
368383
assert c.content == "invoke"
369384
break

libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
import time
22
from typing import Optional as Optional
33

4+
import pytest
5+
from blockbuster import BlockBuster
6+
47
from langchain_core.caches import InMemoryCache
58
from langchain_core.language_models import GenericFakeChatModel
69
from langchain_core.rate_limiters import InMemoryRateLimiter
710

811

12+
@pytest.fixture(autouse=True)
13+
def deactivate_blockbuster(blockbuster: BlockBuster) -> None:
14+
# Deactivate BlockBuster to not disturb the rate limiter timings
15+
blockbuster.deactivate()
16+
17+
918
def test_rate_limit_invoke() -> None:
1019
"""Add rate limiter."""
1120
model = GenericFakeChatModel(

libs/core/tests/unit_tests/prompts/test_chat.py

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import base64
2-
import tempfile
31
import warnings
42
from pathlib import Path
53
from typing import Any, Union, cast
@@ -727,44 +725,39 @@ async def test_chat_tmpl_from_messages_multipart_image() -> None:
727725
async def test_chat_tmpl_from_messages_multipart_formatting_with_path() -> None:
728726
"""Verify that we cannot pass `path` for an image as a variable."""
729727
in_mem = "base64mem"
730-
in_file_data = "base64file01"
731728

732-
with tempfile.NamedTemporaryFile(delete=True, suffix=".jpg") as temp_file:
733-
temp_file.write(base64.b64decode(in_file_data))
734-
temp_file.flush()
735-
736-
template = ChatPromptTemplate.from_messages(
737-
[
738-
("system", "You are an AI assistant named {name}."),
739-
(
740-
"human",
741-
[
742-
{"type": "text", "text": "What's in this image?"},
743-
{
744-
"type": "image_url",
745-
"image_url": "data:image/jpeg;base64,{in_mem}",
746-
},
747-
{
748-
"type": "image_url",
749-
"image_url": {"path": "{file_path}"},
750-
},
751-
],
752-
),
753-
]
729+
template = ChatPromptTemplate.from_messages(
730+
[
731+
("system", "You are an AI assistant named {name}."),
732+
(
733+
"human",
734+
[
735+
{"type": "text", "text": "What's in this image?"},
736+
{
737+
"type": "image_url",
738+
"image_url": "data:image/jpeg;base64,{in_mem}",
739+
},
740+
{
741+
"type": "image_url",
742+
"image_url": {"path": "{file_path}"},
743+
},
744+
],
745+
),
746+
]
747+
)
748+
with pytest.raises(ValueError):
749+
template.format_messages(
750+
name="R2D2",
751+
in_mem=in_mem,
752+
file_path="some/path",
754753
)
755-
with pytest.raises(ValueError):
756-
template.format_messages(
757-
name="R2D2",
758-
in_mem=in_mem,
759-
file_path=temp_file.name,
760-
)
761754

762-
with pytest.raises(ValueError):
763-
await template.aformat_messages(
764-
name="R2D2",
765-
in_mem=in_mem,
766-
file_path=temp_file.name,
767-
)
755+
with pytest.raises(ValueError):
756+
await template.aformat_messages(
757+
name="R2D2",
758+
in_mem=in_mem,
759+
file_path="some/path",
760+
)
768761

769762

770763
def test_messages_placeholder() -> None:

libs/core/tests/unit_tests/runnables/test_context.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from typing import Any, Callable, NamedTuple, Union
23

34
import pytest
@@ -330,19 +331,26 @@ def seq_naive_rag_scoped() -> Runnable:
330331

331332

332333
@pytest.mark.parametrize("runnable, cases", test_cases)
333-
async def test_context_runnables(
334+
def test_context_runnables(
334335
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
335336
) -> None:
336337
runnable = runnable if isinstance(runnable, Runnable) else runnable()
337338
assert runnable.invoke(cases[0].input) == cases[0].output
338-
assert await runnable.ainvoke(cases[1].input) == cases[1].output
339339
assert runnable.batch([case.input for case in cases]) == [
340340
case.output for case in cases
341341
]
342+
assert add(runnable.stream(cases[0].input)) == cases[0].output
343+
344+
345+
@pytest.mark.parametrize("runnable, cases", test_cases)
346+
async def test_context_runnables_async(
347+
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
348+
) -> None:
349+
runnable = runnable if isinstance(runnable, Runnable) else runnable()
350+
assert await runnable.ainvoke(cases[1].input) == cases[1].output
342351
assert await runnable.abatch([case.input for case in cases]) == [
343352
case.output for case in cases
344353
]
345-
assert add(runnable.stream(cases[0].input)) == cases[0].output
346354
assert await aadd(runnable.astream(cases[1].input)) == cases[1].output
347355

348356

@@ -390,8 +398,7 @@ async def test_runnable_seq_streaming_chunks() -> None:
390398
"prompt": Context.getter("prompt"),
391399
}
392400
)
393-
394-
chunks = list(chain.stream({"foo": "foo", "bar": "bar"}))
401+
chunks = await asyncio.to_thread(list, chain.stream({"foo": "foo", "bar": "bar"}))
395402
achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})]
396403
for c in chunks:
397404
assert c in achunks

0 commit comments

Comments
 (0)