Skip to content

Commit 35b1508

Browse files
Use separate thread for running sync tasks (#565)
Sync tasks broken in the following scenario: * `_run_and_stop` in runtime runs the task, and blocks on `self.func()` in `task.execute` * This [raise](https://github.com/flyteorg/flyte-sdk/blob/4a5f4e1bffa13ffcee00a09057018367b0d66187/src/flyte/_internal/controllers/remote/_core.py#L131) never happens because this is in the main thread and the main thread is blocked on `self.func()`. The way around this is to create a new thread to run `self.func` --------- Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
1 parent 1e2d231 commit 35b1508

File tree

4 files changed

+234
-4
lines changed

4 files changed

+234
-4
lines changed

src/flyte/_internal/controllers/remote/_core.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,10 @@ def _bg_thread_target(self):
213213
"""Target function for the controller thread that creates and manages its own event loop"""
214214
try:
215215
# Create a new event loop for this thread
216-
self._loop = asyncio.new_event_loop()
217-
asyncio.set_event_loop(self._loop)
218-
self._loop.set_exception_handler(flyte.errors.silence_grpc_polling_error)
216+
with self._thread_com_lock:
217+
self._loop = asyncio.new_event_loop()
218+
asyncio.set_event_loop(self._loop)
219+
self._loop.set_exception_handler(flyte.errors.silence_grpc_polling_error)
219220
logger.debug(f"Controller thread started with new event loop: {threading.current_thread().name}")
220221

221222
# Create an event to signal the errors were observed in the thread's loop

src/flyte/_task.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525

2626
from flyte._pod import PodTemplate
27+
from flyte._utils.asyncify import run_sync_with_loop
2728
from flyte.errors import RuntimeSystemError, RuntimeUserError
2829

2930
from ._cache import Cache, CacheRequest
@@ -493,7 +494,8 @@ async def execute(self, *args: P.args, **kwargs: P.kwargs) -> R:
493494
if iscoroutinefunction(self.func):
494495
v = await self.func(*args, **kwargs)
495496
else:
496-
v = self.func(*args, **kwargs)
497+
v = await run_sync_with_loop(self.func, *args, **kwargs)
498+
497499
await self.post(v)
498500
return v
499501

src/flyte/_utils/asyncify.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import contextvars
5+
import inspect
6+
import random
7+
import threading
8+
from typing import Callable, TypeVar
9+
10+
from typing_extensions import ParamSpec
11+
12+
from flyte._logging import logger
13+
14+
T = TypeVar("T")
15+
P = ParamSpec("P")
16+
17+
18+
async def run_sync_with_loop(
19+
func: Callable[P, T],
20+
*args: P.args,
21+
**kwargs: P.kwargs,
22+
) -> T:
23+
"""
24+
Run a synchronous function from an async context with its own event loop.
25+
26+
This function:
27+
- Copies the current context variables and preserves them in the sync function
28+
- Creates a new event loop in a separate thread for the sync function
29+
- Allows the sync function to potentially use asyncio operations
30+
- Returns the result without blocking the calling async event loop
31+
32+
Args:
33+
func: The synchronous function to run (must not be an async function)
34+
*args: Positional arguments to pass to the function
35+
**kwargs: Keyword arguments to pass to the function
36+
37+
Returns:
38+
The result of calling func(*args, **kwargs)
39+
40+
Raises:
41+
TypeError: If func is an async function (coroutine function)
42+
43+
Example:
44+
async def my_async_function():
45+
result = await run_sync_with_loop(some_sync_function, arg1, arg2)
46+
return result
47+
"""
48+
# Check if func is an async function
49+
if inspect.iscoroutinefunction(func):
50+
raise TypeError(
51+
f"Cannot call run_sync_with_loop with async function '{func.__name__}'. "
52+
"This utility is for running sync functions from async contexts."
53+
)
54+
55+
copied_ctx = contextvars.copy_context()
56+
execute_loop = None
57+
execute_loop_created = threading.Event()
58+
59+
# Build thread name with random suffix for uniqueness
60+
func_name = getattr(func, "__name__", "unknown")
61+
current_thread = threading.current_thread().name
62+
random_suffix = f"{random.getrandbits(32):08x}"
63+
full_thread_name = f"sync-executor-{random_suffix}_from_{current_thread}"
64+
65+
def _sync_thread_loop_runner() -> None:
66+
"""This method runs the event loop and should be invoked in a separate thread."""
67+
nonlocal execute_loop
68+
try:
69+
execute_loop = asyncio.new_event_loop()
70+
asyncio.set_event_loop(execute_loop)
71+
logger.debug(f"Created event loop for function '{func_name}' in thread '{full_thread_name}'")
72+
execute_loop_created.set()
73+
execute_loop.run_forever()
74+
except Exception as e:
75+
logger.error(f"Exception in thread '{full_thread_name}' running '{func_name}': {e}", exc_info=True)
76+
raise
77+
finally:
78+
if execute_loop:
79+
logger.debug(f"Stopping event loop for function '{func_name}' in thread '{full_thread_name}'")
80+
execute_loop.stop()
81+
execute_loop.close()
82+
logger.debug(f"Cleaned up event loop for function '{func_name}' in thread '{full_thread_name}'")
83+
84+
executor_thread = threading.Thread(
85+
name=full_thread_name,
86+
daemon=True,
87+
target=_sync_thread_loop_runner,
88+
)
89+
logger.debug(f"Starting executor thread '{full_thread_name}' for function '{func_name}'")
90+
executor_thread.start()
91+
92+
async def async_wrapper():
93+
res = copied_ctx.run(func, *args, **kwargs)
94+
return res
95+
96+
# Wait for the loop to be created in a thread to avoid blocking the current thread
97+
await asyncio.get_event_loop().run_in_executor(None, execute_loop_created.wait)
98+
assert execute_loop is not None
99+
fut = asyncio.run_coroutine_threadsafe(async_wrapper(), loop=execute_loop)
100+
async_fut = asyncio.wrap_future(fut)
101+
result = await async_fut
102+
103+
return result

tests/flyte/utils/test_asyncify.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import asyncio
2+
import contextvars
3+
import threading
4+
5+
import pytest
6+
7+
from flyte._utils.asyncify import run_sync_with_loop
8+
9+
# Context variable for testing context preservation
10+
test_context_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_context_var")
11+
12+
13+
@pytest.mark.asyncio
14+
async def test_basic_sync_function():
15+
"""Test that a basic sync function can be called and returns the correct result."""
16+
17+
def sync_add(a: int, b: int) -> int:
18+
return a + b
19+
20+
result = await run_sync_with_loop(sync_add, 5, 7)
21+
assert result == 12
22+
23+
24+
@pytest.mark.asyncio
25+
async def test_sync_function_with_kwargs():
26+
"""Test that kwargs are properly passed to the sync function."""
27+
28+
def sync_multiply(x: int, y: int, multiplier: int = 1) -> int:
29+
return x * y * multiplier
30+
31+
result = await run_sync_with_loop(sync_multiply, 3, 4, multiplier=2)
32+
assert result == 24
33+
34+
35+
@pytest.mark.asyncio
36+
async def test_context_variable_preservation():
37+
"""Test that context variables are preserved when calling the sync function."""
38+
test_context_var.set("test_value")
39+
40+
def get_context_value() -> str:
41+
return test_context_var.get()
42+
43+
result = await run_sync_with_loop(get_context_value)
44+
assert result == "test_value"
45+
46+
47+
@pytest.mark.asyncio
48+
async def test_raises_error_on_async_function():
49+
"""Test that TypeError is raised when trying to run an async function."""
50+
51+
async def async_function():
52+
return 42
53+
54+
with pytest.raises(TypeError) as exc_info:
55+
await run_sync_with_loop(async_function)
56+
57+
assert "Cannot call run_sync_with_loop with async function" in str(exc_info.value)
58+
assert "async_function" in str(exc_info.value)
59+
60+
61+
@pytest.mark.asyncio
62+
async def test_sync_function_has_own_event_loop():
63+
"""Test that the sync function runs with its own event loop."""
64+
main_loop_id = id(asyncio.get_event_loop())
65+
66+
def get_loop_info() -> tuple:
67+
# Get the loop that the sync function is running in
68+
loop = asyncio.get_event_loop()
69+
loop_id = id(loop)
70+
thread_name = threading.current_thread().name
71+
return loop_id, thread_name
72+
73+
loop_id, thread_name = await run_sync_with_loop(get_loop_info)
74+
75+
# The sync function should have a different event loop than the main async function
76+
assert loop_id != main_loop_id
77+
# And it should be running in a different thread
78+
assert "sync-executor" in thread_name
79+
80+
81+
@pytest.mark.asyncio
82+
async def test_thread_name_uniqueness():
83+
"""Test that different invocations create threads with unique names."""
84+
thread_names = []
85+
86+
def capture_thread_name() -> str:
87+
name = threading.current_thread().name
88+
thread_names.append(name)
89+
return name
90+
91+
# Run multiple times
92+
name1 = await run_sync_with_loop(capture_thread_name)
93+
name2 = await run_sync_with_loop(capture_thread_name)
94+
95+
# Thread names should be different due to random suffix
96+
assert name1 != name2
97+
assert "sync-executor" in name1
98+
assert "sync-executor" in name2
99+
assert "_from_" in name1
100+
assert "_from_" in name2
101+
102+
103+
@pytest.mark.asyncio
104+
async def test_exception_propagation():
105+
"""Test that exceptions raised in sync functions are properly propagated."""
106+
107+
def sync_function_that_raises():
108+
raise ValueError("Test error message")
109+
110+
with pytest.raises(ValueError) as exc_info:
111+
await run_sync_with_loop(sync_function_that_raises)
112+
113+
assert "Test error message" in str(exc_info.value)
114+
115+
116+
@pytest.mark.asyncio
117+
async def test_return_complex_types():
118+
"""Test that complex return types are properly returned."""
119+
120+
def sync_function_returning_dict() -> dict:
121+
return {"key1": "value1", "key2": [1, 2, 3], "key3": {"nested": True}}
122+
123+
result = await run_sync_with_loop(sync_function_returning_dict)
124+
assert result == {"key1": "value1", "key2": [1, 2, 3], "key3": {"nested": True}}

0 commit comments

Comments
 (0)