|
6 | 6 |
|
7 | 7 | import asyncio |
8 | 8 | import inspect |
| 9 | +from concurrent.futures import ProcessPoolExecutor |
9 | 10 | from typing import TYPE_CHECKING, Any, TypeVar, cast |
10 | 11 |
|
11 | 12 | from dag_simple.context import ExecutionContext |
@@ -232,3 +233,77 @@ async def _execute_node_without_cache( |
232 | 233 | validate_output_type(node, result, node.type_hints) |
233 | 234 |
|
234 | 235 | return cast(R, result) |
| 236 | + |
| 237 | + |
| 238 | +def run_sync_in_process( |
| 239 | + node: Node[R], |
| 240 | + *, |
| 241 | + enable_cache: bool = True, |
| 242 | + executor: ProcessPoolExecutor | None = None, |
| 243 | + **inputs: Any, |
| 244 | +) -> R: |
| 245 | + """Execute ``run_sync`` inside a worker process. |
| 246 | +
|
| 247 | + Args: |
| 248 | + node: The root node to execute. |
| 249 | + enable_cache: Whether to enable caching for this execution. |
| 250 | + executor: Optional ``ProcessPoolExecutor`` to submit the work to. When |
| 251 | + omitted, a temporary single-worker executor is created for the call. |
| 252 | + **inputs: Additional keyword arguments passed as DAG inputs. |
| 253 | +
|
| 254 | + Returns: |
| 255 | + The result returned by ``run_sync``. |
| 256 | + """ |
| 257 | + |
| 258 | + if executor is not None: |
| 259 | + future = executor.submit(_run_sync_entry_point, node, enable_cache, inputs) |
| 260 | + return future.result() |
| 261 | + |
| 262 | + with ProcessPoolExecutor(max_workers=1) as process_pool: |
| 263 | + future = process_pool.submit(_run_sync_entry_point, node, enable_cache, inputs) |
| 264 | + return future.result() |
| 265 | + |
| 266 | + |
| 267 | +def run_async_in_process( |
| 268 | + node: Node[R], |
| 269 | + *, |
| 270 | + enable_cache: bool = True, |
| 271 | + executor: ProcessPoolExecutor | None = None, |
| 272 | + **inputs: Any, |
| 273 | +) -> R: |
| 274 | + """Execute ``run_async`` inside a worker process. |
| 275 | +
|
| 276 | + Args: |
| 277 | + node: The root node to execute. |
| 278 | + enable_cache: Whether to enable caching for this execution. |
| 279 | + executor: Optional ``ProcessPoolExecutor`` to submit the work to. When |
| 280 | + omitted, a temporary single-worker executor is created for the call. |
| 281 | + **inputs: Additional keyword arguments passed as DAG inputs. |
| 282 | +
|
| 283 | + Returns: |
| 284 | + The result returned by ``run_async``. |
| 285 | + """ |
| 286 | + |
| 287 | + if executor is not None: |
| 288 | + future = executor.submit(_run_async_entry_point, node, enable_cache, inputs) |
| 289 | + return future.result() |
| 290 | + |
| 291 | + with ProcessPoolExecutor(max_workers=1) as process_pool: |
| 292 | + future = process_pool.submit(_run_async_entry_point, node, enable_cache, inputs) |
| 293 | + return future.result() |
| 294 | + |
| 295 | + |
| 296 | +def _run_sync_entry_point( |
| 297 | + node: Node[R], enable_cache: bool, inputs: dict[str, Any] |
| 298 | +) -> R: |
| 299 | + """Process entry point for ``run_sync_in_process``.""" |
| 300 | + |
| 301 | + return run_sync(node, enable_cache=enable_cache, **inputs) # pragma: no cover |
| 302 | + |
| 303 | + |
| 304 | +def _run_async_entry_point( |
| 305 | + node: Node[R], enable_cache: bool, inputs: dict[str, Any] |
| 306 | +) -> R: |
| 307 | + """Process entry point for ``run_async_in_process``.""" |
| 308 | + |
| 309 | + return asyncio.run(run_async(node, enable_cache=enable_cache, **inputs)) # pragma: no cover |
0 commit comments