Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/3229.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Avoid holding refs to result/exception from ``trio.to_thread.run_sync``.
16 changes: 1 addition & 15 deletions src/trio/_core/_tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
create_asyncio_future_in_new_loop,
gc_collect_harder,
ignore_coroutine_never_awaited_warnings,
no_other_refs,
restore_unraisablehook,
slow,
)
Expand Down Expand Up @@ -2802,25 +2803,10 @@ async def spawn_tasks_in_old_nursery(task_status: _core.TaskStatus[None]) -> Non
assert RaisesGroup(ValueError, ValueError).matches(excinfo.value.__cause__)


if sys.version_info >= (3, 11):

def no_other_refs() -> list[object]:
return []

else:

def no_other_refs() -> list[object]:
return [sys._getframe(1)]


@pytest.mark.skipif(
sys.implementation.name != "cpython",
reason="Only makes sense with refcounting GC",
)
@pytest.mark.xfail(
sys.version_info >= (3, 14),
reason="https://github.com/python/cpython/issues/125603",
)
async def test_ki_protection_doesnt_leave_cyclic_garbage() -> None:
class MyException(Exception):
pass
Expand Down
17 changes: 17 additions & 0 deletions src/trio/_core/_tests/tutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,20 @@ def check_sequence_matches(seq: Sequence[T], template: Iterable[T | set[T]]) ->
def create_asyncio_future_in_new_loop() -> asyncio.Future[object]:
with closing(asyncio.new_event_loop()) as loop:
return loop.create_future()


if sys.version_info >= (3, 14):

def no_other_refs() -> list[object]:
gen = sys._getframe().f_generator
return [] if gen is None else [gen]

elif sys.version_info >= (3, 11):

def no_other_refs() -> list[object]:
return []

else:

def no_other_refs() -> list[object]:
return [sys._getframe(1)]
58 changes: 57 additions & 1 deletion src/trio/_tests/test_threads.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextvars
import gc
import queue as stdlib_queue
import re
import sys
Expand Down Expand Up @@ -29,7 +30,7 @@
sleep_forever,
)
from .._core._tests.test_ki import ki_self
from .._core._tests.tutil import slow
from .._core._tests.tutil import gc_collect_harder, no_other_refs, slow
from .._threads import (
active_thread_count,
current_default_thread_limiter,
Expand Down Expand Up @@ -1141,3 +1142,58 @@ async def wait_no_threads_left() -> None:
async def test_wait_all_threads_completed_no_threads() -> None:
await wait_all_threads_completed()
assert active_thread_count() == 0


@pytest.mark.skipif(
sys.implementation.name == "pypy",
reason=(
"gc.get_referrers is broken on PyPy (see "
"https://github.com/pypy/pypy/issues/5075)"
),
)
async def test_run_sync_worker_references() -> None:
class Foo:
pass

def foo(_: Foo) -> Foo:
return Foo()

cvar = contextvars.ContextVar[Foo]("cvar")
contextval = Foo()
arg = Foo()
cvar.set(contextval)
v = await to_thread_run_sync(foo, arg)

cvar.set(Foo())
gc_collect_harder()

assert gc.get_referrers(contextval) == no_other_refs()
assert gc.get_referrers(foo) == no_other_refs()
assert gc.get_referrers(arg) == no_other_refs()
assert gc.get_referrers(v) == no_other_refs()


@pytest.mark.skipif(
sys.implementation.name == "pypy",
reason=(
"gc.get_referrers is broken on PyPy (see "
"https://github.com/pypy/pypy/issues/5075)"
),
)
async def test_run_sync_workerreferences_exc() -> None:

class MyException(Exception):
pass

def throw() -> None:
raise MyException

e = None
try:
await to_thread_run_sync(throw)
except MyException as err:
e = err

gc_collect_harder()

assert gc.get_referrers(e) == no_other_refs()
64 changes: 50 additions & 14 deletions src/trio/_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import queue as stdlib_queue
import threading
from itertools import count
from typing import TYPE_CHECKING, Generic, TypeVar
from typing import TYPE_CHECKING, Final, Generic, NoReturn, Protocol, TypeVar

import attrs
import outcome
Expand Down Expand Up @@ -36,6 +36,7 @@
Ts = TypeVarTuple("Ts")

RetT = TypeVar("RetT")
T_co = TypeVar("T_co", covariant=True)


class _ParentTaskData(threading.local):
Expand Down Expand Up @@ -253,6 +254,32 @@ def run_in_system_nursery(self, token: TrioToken) -> None:
token.run_sync_soon(self.run_sync)


class _SupportsUnwrap(Protocol, Generic[T_co]):
def unwrap(self) -> T_co: ...


class _Value(_SupportsUnwrap[T_co]):
def __init__(self, v: T_co) -> None:
self._v: Final = v

def unwrap(self) -> T_co:
try:
return self._v
finally:
del self._v


class _Error(_SupportsUnwrap[NoReturn]):
def __init__(self, e: BaseException) -> None:
self._e: Final = e

def unwrap(self) -> NoReturn:
try:
raise self._e
finally:
del self._e


@enable_ki_protection
async def to_thread_run_sync(
sync_fn: Callable[[Unpack[Ts]], RetT],
Expand Down Expand Up @@ -375,8 +402,15 @@ def do_release_then_return_result() -> RetT:
limiter.release_on_behalf_of(placeholder)

result = outcome.capture(do_release_then_return_result)
if isinstance(result, outcome.Error):
result2: _SupportsUnwrap[RetT] = _Error(result.error)
elif isinstance(result, outcome.Value):
result2 = _Value(result.value)
else:
raise RuntimeError("invalid outcome")
del result
if task_register[0] is not None:
trio.lowlevel.reschedule(task_register[0], outcome.Value(result))
trio.lowlevel.reschedule(task_register[0], outcome.Value(result2))

current_trio_token = trio.lowlevel.current_trio_token()

Expand Down Expand Up @@ -440,20 +474,22 @@ def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort:

while True:
# wait_task_rescheduled return value cannot be typed
msg_from_thread: outcome.Outcome[RetT] | Run[object] | RunSync[object] = (
msg_from_thread: _Value[RetT] | _Error | Run[object] | RunSync[object] = (
await trio.lowlevel.wait_task_rescheduled(abort)
)
if isinstance(msg_from_thread, outcome.Outcome):
return msg_from_thread.unwrap()
elif isinstance(msg_from_thread, Run):
await msg_from_thread.run()
elif isinstance(msg_from_thread, RunSync):
msg_from_thread.run_sync()
else: # pragma: no cover, internal debugging guard TODO: use assert_never
raise TypeError(
f"trio.to_thread.run_sync received unrecognized thread message {msg_from_thread!r}.",
)
del msg_from_thread
try:
if isinstance(msg_from_thread, (_Value, _Error)):
return msg_from_thread.unwrap()
elif isinstance(msg_from_thread, Run):
await msg_from_thread.run()
elif isinstance(msg_from_thread, RunSync):
msg_from_thread.run_sync()
else: # pragma: no cover, internal debugging guard TODO: use assert_never
raise TypeError(
f"trio.to_thread.run_sync received unrecognized thread message {msg_from_thread!r}.",
)
finally:
del msg_from_thread


def from_thread_check_cancelled() -> None:
Expand Down
Loading