diff --git a/hatchet_sdk/worker/runner/runner.py b/hatchet_sdk/worker/runner/runner.py index d37da955..422e4e05 100644 --- a/hatchet_sdk/worker/runner/runner.py +++ b/hatchet_sdk/worker/runner/runner.py @@ -5,12 +5,11 @@ import json import logging import traceback -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from enum import Enum from io import StringIO from logging import StreamHandler -from multiprocessing import Queue -from threading import Thread, current_thread +from multiprocessing import Process, Queue, current_process from typing import Any, Callable, Coroutine, Dict from hatchet_sdk.client import new_client_raw @@ -140,8 +139,8 @@ def __init__( self.event_queue = event_queue # The thread pool is used for synchronous functions which need to run concurrently - self.thread_pool = ThreadPoolExecutor(max_workers=max_runs) - self.threads: Dict[str, Thread] = {} # Store run ids and threads + self.thread_pool = ProcessPoolExecutor(max_workers=max_runs) + self.threads: Dict[str, Process] = {} # Store run ids and threads self.killing = False self.handle_kill = handle_kill @@ -262,12 +261,12 @@ def inner_callback(task: asyncio.Task): def thread_action_func(self, context, action_func, action: Action): if action.step_run_id is not None and action.step_run_id != "": - self.threads[action.step_run_id] = current_thread() + self.threads[action.step_run_id] = current_process() elif ( action.get_group_key_run_id is not None and action.get_group_key_run_id != "" ): - self.threads[action.get_group_key_run_id] = current_thread() + self.threads[action.get_group_key_run_id] = current_process() return action_func(context) @@ -284,6 +283,7 @@ async def async_wrapped_action_func( ) or asyncio.iscoroutinefunction(action_func): return await action_func(context) else: + pfunc = functools.partial( # we must copy the context vars to the new thread, as only asyncio natively supports # contextvars