Skip to content

Commit be255a7

Browse files
tboserKludex
andauthored
Copy context to new thread for sync tool calls (#1576)
Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent 433e1bc commit be255a7

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

pydantic_ai_slim/pydantic_ai/_utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from types import GenericAlias
1212
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
1313

14+
from anyio.to_thread import run_sync
1415
from pydantic import BaseModel
1516
from pydantic.json_schema import JsonSchemaValue
1617
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
@@ -31,11 +32,8 @@
3132

3233

3334
async def run_in_executor(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs) -> _R:
34-
if kwargs:
35-
# noinspection PyTypeChecker
36-
return await asyncio.get_running_loop().run_in_executor(None, partial(func, *args, **kwargs))
37-
else:
38-
return await asyncio.get_running_loop().run_in_executor(None, func, *args) # type: ignore
35+
wrapped_func = partial(func, *args, **kwargs)
36+
return await run_sync(wrapped_func)
3937

4038

4139
def is_model_like(type_: Any) -> bool:

tests/test_utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations as _annotations
22

33
import asyncio
4+
import contextvars
45
import os
56
from collections.abc import AsyncIterator
67
from importlib.metadata import distributions
@@ -9,7 +10,7 @@
910
from inline_snapshot import snapshot
1011

1112
from pydantic_ai import UserError
12-
from pydantic_ai._utils import UNSET, PeekableAsyncStream, check_object_json_schema, group_by_temporal
13+
from pydantic_ai._utils import UNSET, PeekableAsyncStream, check_object_json_schema, group_by_temporal, run_in_executor
1314

1415
from .models.mock_async_stream import MockAsyncStream
1516

@@ -136,3 +137,19 @@ def test_package_versions(capsys: pytest.CaptureFixture[str]):
136137
packages = sorted((package.metadata['Name'], package.version) for package in distributions())
137138
for name, version in packages:
138139
print(f'{name:30} {version}')
140+
141+
142+
async def test_run_in_executor_with_contextvars() -> None:
143+
ctx_var = contextvars.ContextVar('test_var', default='default')
144+
ctx_var.set('original_value')
145+
146+
result = await run_in_executor(ctx_var.get)
147+
assert result == ctx_var.get()
148+
149+
ctx_var.set('new_value')
150+
result = await run_in_executor(ctx_var.get)
151+
assert result == ctx_var.get()
152+
153+
# show that the old version did not work
154+
old_result = asyncio.get_running_loop().run_in_executor(None, ctx_var.get)
155+
assert old_result != ctx_var.get()

0 commit comments

Comments
 (0)