Skip to content

Commit 9b4aaaa

Browse files
committed
fix(portal): improve shutdown safety and add call timeout
- Only close event loop after thread has terminated to avoid undefined behavior when closing a running loop - Add configurable timeout to call() method (default 300s) to prevent deadlocks during shutdown races - Add test for timeout functionality
1 parent 1d630c4 commit 9b4aaaa

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

sqlspec/utils/portal.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def stop(self) -> None:
113113
"""Stop the background thread and event loop.
114114
115115
Gracefully shuts down the event loop and waits for thread to finish.
116+
Only closes the loop after the thread has terminated to avoid
117+
undefined behavior from closing a running loop.
116118
"""
117119
if self._loop is None or self._loop_thread is None:
118120
logger.debug("Portal provider not running")
@@ -122,9 +124,10 @@ def stop(self) -> None:
122124
self._loop_thread.join(timeout=5)
123125

124126
if self._loop_thread.is_alive():
125-
logger.warning("Portal thread did not stop within 5 seconds")
127+
logger.warning("Portal thread did not stop within 5 seconds, skipping loop.close()")
128+
else:
129+
self._loop.close()
126130

127-
self._loop.close()
128131
self._loop = None
129132
self._loop_thread = None
130133
self._ready_event.clear()
@@ -146,22 +149,25 @@ async def _async_caller(
146149
result: _R = await func(*args, **kwargs)
147150
return result
148151

149-
def call(self, func: "Callable[..., Coroutine[Any, Any, _R]]", *args: Any, **kwargs: Any) -> _R:
152+
def call(
153+
self, func: "Callable[..., Coroutine[Any, Any, _R]]", *args: Any, timeout: float = 300.0, **kwargs: Any
154+
) -> _R:
150155
"""Call an async function from synchronous context.
151156
152157
Executes the async function in the background event loop and blocks
153-
until the result is available.
158+
until the result is available or timeout is reached.
154159
155160
Args:
156161
func: The async function to call.
157162
*args: Positional arguments to the function.
163+
timeout: Maximum seconds to wait for result (default 300).
158164
**kwargs: Keyword arguments to the function.
159165
160166
Returns:
161167
Result of the async function.
162168
163169
Raises:
164-
ImproperConfigurationError: If portal provider not started.
170+
ImproperConfigurationError: If portal provider not started or timeout reached.
165171
166172
"""
167173
if self._loop is None or not self.is_running:
@@ -174,7 +180,11 @@ def call(self, func: "Callable[..., Coroutine[Any, Any, _R]]", *args: Any, **kwa
174180

175181
self._loop.call_soon_threadsafe(self._process_request)
176182

177-
result, exception = local_result_queue.get()
183+
try:
184+
result, exception = local_result_queue.get(timeout=timeout)
185+
except queue.Empty:
186+
msg = f"Portal call timed out after {timeout} seconds"
187+
raise ImproperConfigurationError(msg) from None
178188

179189
if exception:
180190
raise exception
@@ -223,19 +233,22 @@ def __init__(self, provider: "PortalProvider") -> None:
223233
"""
224234
self._provider = provider
225235

226-
def call(self, func: "Callable[..., Coroutine[Any, Any, _R]]", *args: Any, **kwargs: Any) -> _R:
236+
def call(
237+
self, func: "Callable[..., Coroutine[Any, Any, _R]]", *args: Any, timeout: float = 300.0, **kwargs: Any
238+
) -> _R:
227239
"""Call an async function using the portal provider.
228240
229241
Args:
230242
func: The async function to call.
231243
*args: Positional arguments to the function.
244+
timeout: Maximum seconds to wait for result (default 300).
232245
**kwargs: Keyword arguments to the function.
233246
234247
Returns:
235248
Result of the async function.
236249
237250
"""
238-
return self._provider.call(func, *args, **kwargs)
251+
return self._provider.call(func, *args, timeout=timeout, **kwargs)
239252

240253

241254
class PortalManager(metaclass=SingletonMeta):

tests/unit/test_utils/test_portal.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,3 +381,19 @@ def test_portal_manager_atexit_cleanup_noop_when_stopped() -> None:
381381
manager._atexit_cleanup() # pyright: ignore[reportPrivateUsage]
382382

383383
assert not manager.is_running
384+
385+
386+
def test_portal_call_timeout() -> None:
387+
"""PortalProvider.call raises error on timeout."""
388+
389+
async def slow_function() -> int:
390+
await asyncio.sleep(10)
391+
return 42
392+
393+
provider = PortalProvider()
394+
provider.start()
395+
396+
with pytest.raises(ImproperConfigurationError, match="timed out after"):
397+
provider.call(slow_function, timeout=0.1)
398+
399+
provider.stop()

0 commit comments

Comments
 (0)