Skip to content

Commit 3f33cd1

Browse files
Merge pull request #3586 from Textualize/wokers-inside-workers
Workers inside workers
2 parents 001881c + 841f726 commit 3f33cd1

File tree

3 files changed

+108
-4
lines changed

3 files changed

+108
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
4545
- DataTable now has a max-height of 100vh rather than 100%, which doesn't work with auto
4646
- Breaking change: empty rules now result in an error https://github.com/Textualize/textual/pull/3566
4747
- Improved startup time by caching CSS parsing https://github.com/Textualize/textual/pull/3575
48+
- Workers are now created/run in a thread-safe way https://github.com/Textualize/textual/pull/3586
4849

4950
### Added
5051

src/textual/dom.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
"""
2-
32
A DOMNode is a base class for any object within the Textual Document Object Model,
43
which includes all Widgets, Screens, and Apps.
54
"""
@@ -8,7 +7,8 @@
87
from __future__ import annotations
98

109
import re
11-
from functools import lru_cache
10+
import threading
11+
from functools import lru_cache, partial
1212
from inspect import getfile
1313
from typing import (
1414
TYPE_CHECKING,
@@ -267,7 +267,14 @@ def run_worker(
267267
Returns:
268268
New Worker instance.
269269
"""
270-
worker: Worker[ResultType] = self.workers._new_worker(
270+
271+
# If we're running a worker from inside a secondary thread,
272+
# do so in a thread-safe way.
273+
if self.app._thread_id != threading.get_ident():
274+
creator = partial(self.app.call_from_thread, self.workers._new_worker)
275+
else:
276+
creator = self.workers._new_worker
277+
worker: Worker[ResultType] = creator(
271278
work,
272279
self,
273280
name=name,

tests/workers/test_work_decorator.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
from time import sleep
3-
from typing import Callable
3+
from typing import Callable, List, Tuple
44

55
import pytest
66

@@ -88,3 +88,99 @@ class _(App[None]):
8888
@work(thread=False)
8989
def foo(self) -> None:
9090
pass
91+
92+
93+
class NestedWorkersApp(App[None]):
94+
def __init__(self, call_stack: List[str]):
95+
self.call_stack = call_stack
96+
super().__init__()
97+
98+
def call_from_stack(self):
99+
if self.call_stack:
100+
call_now = self.call_stack.pop()
101+
getattr(self, call_now)()
102+
103+
@work(thread=False)
104+
async def async_no_thread(self):
105+
self.call_from_stack()
106+
107+
@work(thread=True)
108+
async def async_thread(self):
109+
self.call_from_stack()
110+
111+
@work(thread=True)
112+
def thread(self):
113+
self.call_from_stack()
114+
115+
116+
@pytest.mark.parametrize(
117+
"call_stack",
118+
[ # from itertools import product; list(product("async_no_thread async_thread thread".split(), repeat=3))
119+
("async_no_thread", "async_no_thread", "async_no_thread"),
120+
("async_no_thread", "async_no_thread", "async_thread"),
121+
("async_no_thread", "async_no_thread", "thread"),
122+
("async_no_thread", "async_thread", "async_no_thread"),
123+
("async_no_thread", "async_thread", "async_thread"),
124+
("async_no_thread", "async_thread", "thread"),
125+
("async_no_thread", "thread", "async_no_thread"),
126+
("async_no_thread", "thread", "async_thread"),
127+
("async_no_thread", "thread", "thread"),
128+
("async_thread", "async_no_thread", "async_no_thread"),
129+
("async_thread", "async_no_thread", "async_thread"),
130+
("async_thread", "async_no_thread", "thread"),
131+
("async_thread", "async_thread", "async_no_thread"),
132+
("async_thread", "async_thread", "async_thread"),
133+
("async_thread", "async_thread", "thread"),
134+
("async_thread", "thread", "async_no_thread"),
135+
("async_thread", "thread", "async_thread"),
136+
("async_thread", "thread", "thread"),
137+
("thread", "async_no_thread", "async_no_thread"),
138+
("thread", "async_no_thread", "async_thread"),
139+
("thread", "async_no_thread", "thread"),
140+
("thread", "async_thread", "async_no_thread"),
141+
("thread", "async_thread", "async_thread"),
142+
("thread", "async_thread", "thread"),
143+
("thread", "thread", "async_no_thread"),
144+
("thread", "thread", "async_thread"),
145+
("thread", "thread", "thread"),
146+
( # Plus a longer chain to stress test this mechanism.
147+
"async_no_thread",
148+
"async_no_thread",
149+
"thread",
150+
"thread",
151+
"async_thread",
152+
"async_thread",
153+
"async_no_thread",
154+
"async_thread",
155+
"async_no_thread",
156+
"async_thread",
157+
"thread",
158+
"async_thread",
159+
"async_thread",
160+
"async_no_thread",
161+
"async_no_thread",
162+
"thread",
163+
"thread",
164+
"async_no_thread",
165+
"async_no_thread",
166+
"thread",
167+
"async_no_thread",
168+
"thread",
169+
"thread",
170+
),
171+
],
172+
)
173+
async def test_calling_workers_from_within_workers(call_stack: Tuple[str]):
174+
"""Regression test for https://github.com/Textualize/textual/issues/3472.
175+
176+
This makes sure we can nest worker calls without a problem.
177+
"""
178+
app = NestedWorkersApp(list(call_stack))
179+
async with app.run_test():
180+
app.call_from_stack()
181+
# We need multiple awaits because we're creating a chain of workers that may
182+
# have multiple async workers, each of which may need the await to have enough
183+
# time to call the next one in the chain.
184+
for _ in range(len(call_stack)):
185+
await app.workers.wait_for_complete()
186+
assert app.call_stack == []

0 commit comments

Comments
 (0)