From dd552cc465ed945c0e43082d9561c794381cf7e2 Mon Sep 17 00:00:00 2001 From: DavdaJames Date: Mon, 6 Oct 2025 04:27:46 +0530 Subject: [PATCH 1/3] fix multi-threading issue --- nettacker/core/app.py | 14 +- nettacker/core/lib/base.py | 73 +++++++- nettacker/core/module.py | 105 ++++++++--- nettacker/core/queue_manager.py | 318 ++++++++++++++++++++++++++++++++ 4 files changed, 474 insertions(+), 36 deletions(-) create mode 100644 nettacker/core/queue_manager.py diff --git a/nettacker/core/app.py b/nettacker/core/app.py index 5cd47a98b..c9794631e 100644 --- a/nettacker/core/app.py +++ b/nettacker/core/app.py @@ -25,6 +25,7 @@ ) from nettacker.core.messages import messages as _ from nettacker.core.module import Module +from nettacker.core.queue_manager import initialize_thread_pool, shutdown_thread_pool from nettacker.core.socks_proxy import set_socks_proxy from nettacker.core.utils import common as common_utils from nettacker.core.utils.common import wait_for_threads_to_finish @@ -245,6 +246,12 @@ def start_scan(self, scan_id): target_groups.remove([]) log.info(_("start_multi_process").format(len(self.arguments.targets), len(target_groups))) + + # Initialize the enhanced thread pool for cross-process sharing + num_processes = len(target_groups) + max_workers_per_process = getattr(self.arguments, "parallel_module_scan", None) + initialize_thread_pool(num_processes, max_workers_per_process) + active_processes = [] for t_id, target_groups in enumerate(target_groups): process = multiprocess.Process( @@ -253,7 +260,12 @@ def start_scan(self, scan_id): process.start() active_processes.append(process) - return wait_for_threads_to_finish(active_processes, sub_process=True) + result = wait_for_threads_to_finish(active_processes, sub_process=True) + + # Shutdown the thread pool after scanning is complete + shutdown_thread_pool() + + return result def scan_target( self, diff --git a/nettacker/core/lib/base.py b/nettacker/core/lib/base.py index 110e51dc8..a611de657 100644 --- a/nettacker/core/lib/base.py +++ b/nettacker/core/lib/base.py @@ -9,6 +9,7 @@ from nettacker.config import Config from nettacker.core.messages import messages as _ +from nettacker.core.queue_manager import dependency_resolver from nettacker.core.utils.common import merge_logs_to_list, remove_sensitive_header_keys from nettacker.database.db import find_temp_events, submit_temp_logs_to_db, submit_logs_to_db from nettacker.logger import get_logger, TerminalCodes @@ -47,14 +48,40 @@ def filter_large_content(self, content, filter_rate=150): return content def get_dependent_results_from_database(self, target, module_name, scan_id, event_names): + """ + Efficiently get dependency results without busy-waiting. + Uses event-driven approach to avoid CPU consumption. + """ + # Try to get results efficiently using the new dependency resolver + results = dependency_resolver.get_dependency_results_efficiently( + target, module_name, scan_id, event_names, {}, self, () + ) + + if results is not None: + return results + + # Fallback to original implementation for backward compatibility + # but with increased sleep time to reduce CPU usage events = [] for event_name in event_names.split(","): - while True: + retry_count = 0 + max_retries = 300 # 30 seconds with 0.1s sleep + + while retry_count < max_retries: event = find_temp_events(target, module_name, scan_id, event_name) if event: events.append(json.loads(event.event)["response"]["conditions_results"]) break - time.sleep(0.1) + + retry_count += 1 + # Exponential backoff to reduce CPU usage + sleep_time = min(0.1 * (1.5 ** (retry_count // 10)), 1.0) + time.sleep(sleep_time) + else: + # Timeout reached + log.warn(f"Timeout waiting for dependency: {event_name} for {target}") + events.append(None) + return events def find_and_replace_dependent_values(self, sub_step, dependent_on_temp_event): @@ -123,18 +150,26 @@ def process_conditions( # Remove sensitive keys from headers before submitting to DB event = remove_sensitive_header_keys(event) if "save_to_temp_events_only" in event.get("response", ""): + event_name = event["response"]["save_to_temp_events_only"] + + # Submit to database submit_temp_logs_to_db( { "date": datetime.now(), "target": target, "module_name": module_name, "scan_id": scan_id, - "event_name": event["response"]["save_to_temp_events_only"], + "event_name": event_name, "port": event.get("ports", ""), "event": event, "data": response, } ) + + # Notify dependency resolver that a dependency is now available + dependency_resolver.notify_dependency_available( + target, module_name, scan_id, event_name, response + ) if event["response"]["conditions_results"] and "save_to_temp_events_only" not in event.get( "response", "" ): @@ -279,9 +314,37 @@ def run( sub_step[attr_name.rstrip("s")] = int(value) if attr_name == "ports" else value if "dependent_on_temp_event" in backup_response: - temp_event = self.get_dependent_results_from_database( - target, module_name, scan_id, backup_response["dependent_on_temp_event"] + # Try to get dependency results efficiently + temp_event = dependency_resolver.get_dependency_results_efficiently( + target, + module_name, + scan_id, + backup_response["dependent_on_temp_event"], + sub_step, + self, + ( + sub_step, + module_name, + target, + scan_id, + options, + process_number, + module_thread_number, + total_module_thread_number, + request_number_counter, + total_number_of_requests, + ), ) + + # If dependencies are not available yet, the task is queued + # Return early to avoid blocking the thread + if temp_event is None: + log.verbose_event_info( + f"Task queued waiting for dependencies: {target} -> {module_name}" + ) + return False + + # Dependencies are available, continue with execution sub_step = self.replace_dependent_values(sub_step, temp_event) action = getattr(self.library(), backup_method) diff --git a/nettacker/core/module.py b/nettacker/core/module.py index 17ab60601..762c0edbd 100644 --- a/nettacker/core/module.py +++ b/nettacker/core/module.py @@ -7,6 +7,7 @@ from nettacker import logger from nettacker.config import Config +from nettacker.core import queue_manager from nettacker.core.messages import messages as _ from nettacker.core.template import TemplateLoader from nettacker.core.utils.common import expand_module_steps, wait_for_threads_to_finish @@ -118,26 +119,44 @@ def generate_loops(self): self.module_content["payloads"] = expand_module_steps(self.module_content["payloads"]) def sort_loops(self): - steps = [] + """ + Sort loops to optimize dependency resolution: + 1. Independent steps first + 2. Steps that generate dependencies (save_to_temp_events_only) + 3. Steps that consume dependencies (dependent_on_temp_event) + """ for index in range(len(self.module_content["payloads"])): - for step in copy.deepcopy(self.module_content["payloads"][index]["steps"]): - if "dependent_on_temp_event" not in step[0]["response"]: - steps.append(step) + independent_steps = [] + dependency_generators = [] + dependency_consumers = [] for step in copy.deepcopy(self.module_content["payloads"][index]["steps"]): - if ( - "dependent_on_temp_event" in step[0]["response"] - and "save_to_temp_events_only" in step[0]["response"] - ): - steps.append(step) + step_response = step[0]["response"] if step and len(step) > 0 else {} + + has_dependency = "dependent_on_temp_event" in step_response + generates_dependency = "save_to_temp_events_only" in step_response + + if not has_dependency and not generates_dependency: + independent_steps.append(step) + elif generates_dependency and not has_dependency: + dependency_generators.append(step) + elif generates_dependency and has_dependency: + dependency_generators.append(step) # Generator first + elif has_dependency and not generates_dependency: + dependency_consumers.append(step) + else: + independent_steps.append(step) # Fallback - for step in copy.deepcopy(self.module_content["payloads"][index]["steps"]): - if ( - "dependent_on_temp_event" in step[0]["response"] - and "save_to_temp_events_only" not in step[0]["response"] - ): - steps.append(step) - self.module_content["payloads"][index]["steps"] = steps + # Combine in optimal order + sorted_steps = independent_steps + dependency_generators + dependency_consumers + self.module_content["payloads"][index]["steps"] = sorted_steps + + log.verbose_info( + f"Sorted {len(sorted_steps)} steps: " + f"{len(independent_steps)} independent, " + f"{len(dependency_generators)} generators, " + f"{len(dependency_consumers)} consumers" + ) def start(self): active_threads = [] @@ -158,11 +177,16 @@ def start(self): importlib.import_module(f"nettacker.core.lib.{library.lower()}"), f"{library.capitalize()}Engine", )() + for step in payload["steps"]: for sub_step in step: - thread = Thread( - target=engine.run, - args=( + # Try to use shared thread pool if available, otherwise use local threads + if queue_manager.thread_pool and hasattr( + queue_manager.thread_pool, "submit_task" + ): + # Submit to shared thread pool + queue_manager.thread_pool.submit_task( + engine.run, sub_step, self.module_name, self.target, @@ -173,9 +197,35 @@ def start(self): self.total_module_thread_number, request_number_counter, total_number_of_requests, - ), - ) - thread.name = f"{self.target} -> {self.module_name} -> {sub_step}" + ) + else: + # Use local thread (fallback to original behavior) + thread = Thread( + target=engine.run, + args=( + sub_step, + self.module_name, + self.target, + self.scan_id, + self.module_inputs, + self.process_number, + self.module_thread_number, + self.total_module_thread_number, + request_number_counter, + total_number_of_requests, + ), + ) + thread.name = f"{self.target} -> {self.module_name} -> {sub_step}" + thread.start() + active_threads.append(thread) + + # Manage local thread pool size + wait_for_threads_to_finish( + active_threads, + maximum=self.module_inputs["thread_per_host"], + terminable=True, + ) + request_number_counter += 1 log.verbose_event_info( _("sending_module_request").format( @@ -188,13 +238,8 @@ def start(self): total_number_of_requests, ) ) - thread.start() time.sleep(self.module_inputs["time_sleep_between_requests"]) - active_threads.append(thread) - wait_for_threads_to_finish( - active_threads, - maximum=self.module_inputs["thread_per_host"], - terminable=True, - ) - wait_for_threads_to_finish(active_threads, maximum=None, terminable=True) + # Wait for any remaining local threads to finish + if active_threads: + wait_for_threads_to_finish(active_threads, maximum=None, terminable=True) diff --git a/nettacker/core/queue_manager.py b/nettacker/core/queue_manager.py new file mode 100644 index 000000000..e008ff678 --- /dev/null +++ b/nettacker/core/queue_manager.py @@ -0,0 +1,318 @@ +""" +Enhanced queue and dependency management system for Nettacker. +This module provides solutions for: +1. CPU-efficient dependency resolution using event-driven approach +2. Cross-subprocess thread sharing for better resource utilization +""" + +import json +import multiprocessing +import queue +import threading +import time +from dataclasses import dataclass +from datetime import datetime +from typing import Dict, List, Optional, Any + +from nettacker.database.db import find_temp_events +from nettacker.logger import get_logger + +log = get_logger() + + +@dataclass +class DependentTask: + """Represents a task waiting for dependencies.""" + + target: str + module_name: str + scan_id: str + event_names: List[str] + sub_step: Dict[str, Any] + engine: Any + run_args: tuple + created_at: datetime + max_wait_time: float = 30.0 # Maximum wait time in seconds + + +class DependencyResolver: + """ + Event-driven dependency resolver that avoids busy-waiting. + Instead of polling, tasks are queued and executed when dependencies become available. + """ + + def __init__(self): + self._pending_tasks: Dict[str, List[DependentTask]] = {} + self._dependency_cache: Dict[str, Any] = {} + self._lock = threading.RLock() + + def _get_dependency_key( + self, target: str, module_name: str, scan_id: str, event_name: str + ) -> str: + """Generate a unique key for dependency tracking.""" + return f"{target}:{module_name}:{scan_id}:{event_name}" + + def notify_dependency_available( + self, target: str, module_name: str, scan_id: str, event_name: str, result: Any + ): + """ + Notify that a dependency is now available. + This should be called when temp events are saved to the database. + """ + dependency_key = self._get_dependency_key(target, module_name, scan_id, event_name) + + with self._lock: + # Cache the result + self._cache_dependency_result(dependency_key, result) + + # Check for pending tasks that can now be executed + self._process_pending_tasks(dependency_key) + + def _cache_dependency_result(self, dependency_key: str, result: Any): + """Cache dependency result for future use.""" + self._dependency_cache[dependency_key] = {"result": result, "timestamp": datetime.now()} + + def _process_pending_tasks(self, dependency_key: str): + """Process tasks that were waiting for the given dependency.""" + if dependency_key not in self._pending_tasks: + return + + ready_tasks = [] + remaining_tasks = [] + + for task in self._pending_tasks[dependency_key]: + if self._all_dependencies_available(task): + ready_tasks.append(task) + else: + # Check if task has expired + elapsed = (datetime.now() - task.created_at).total_seconds() + if elapsed < task.max_wait_time: + remaining_tasks.append(task) + else: + log.warn( + f"Task expired waiting for dependencies: {task.target} -> {task.module_name}" + ) + + # Update pending tasks list + if remaining_tasks: + self._pending_tasks[dependency_key] = remaining_tasks + else: + del self._pending_tasks[dependency_key] + + # Execute ready tasks + for task in ready_tasks: + self._execute_task(task) + + def _all_dependencies_available(self, task: DependentTask) -> bool: + """Check if all dependencies for a task are available.""" + for event_name in task.event_names: + dependency_key = self._get_dependency_key( + task.target, task.module_name, task.scan_id, event_name + ) + if dependency_key not in self._dependency_cache: + return False + return True + + def _execute_task(self, task: DependentTask): + """Execute a task that has all its dependencies available.""" + try: + # Get dependency results + dependency_results = [] + for event_name in task.event_names: + dependency_key = self._get_dependency_key( + task.target, task.module_name, task.scan_id, event_name + ) + dependency_results.append(self._dependency_cache[dependency_key]["result"]) + + # Replace dependent values in sub_step + updated_sub_step = task.engine.replace_dependent_values( + task.sub_step, dependency_results + ) + + # Execute the task + task.engine.run(updated_sub_step, *task.run_args[1:]) + + except Exception as e: + log.error(f"Error executing dependent task: {e}") + + def get_dependency_results_efficiently( + self, + target: str, + module_name: str, + scan_id: str, + event_names: str, + sub_step: Dict, + engine: Any, + run_args: tuple, + ) -> Optional[List[Any]]: + """ + Efficiently get dependency results without busy-waiting. + Returns results immediately if available, otherwise queues the task. + """ + event_name_list = event_names.split(",") + + # Check if all dependencies are already available + all_available = True + results = [] + + with self._lock: + for event_name in event_name_list: + dependency_key = self._get_dependency_key(target, module_name, scan_id, event_name) + + if dependency_key in self._dependency_cache: + results.append(self._dependency_cache[dependency_key]["result"]) + else: + # Try to get from database once + event = find_temp_events(target, module_name, scan_id, event_name) + if event: + result = json.loads(event.event)["response"]["conditions_results"] + self._cache_dependency_result(dependency_key, result) + results.append(result) + else: + all_available = False + break + + if all_available: + return results + + # Dependencies not available - queue the task + task = DependentTask( + target=target, + module_name=module_name, + scan_id=scan_id, + event_names=event_name_list, + sub_step=sub_step, + engine=engine, + run_args=run_args, + created_at=datetime.now(), + ) + + # Add to pending tasks for each missing dependency + for event_name in event_name_list: + dependency_key = self._get_dependency_key(target, module_name, scan_id, event_name) + if dependency_key not in self._dependency_cache: + if dependency_key not in self._pending_tasks: + self._pending_tasks[dependency_key] = [] + self._pending_tasks[dependency_key].append(task) + + return None # Task queued, will be executed later + + +class CrossProcessThreadPool: + """ + Manages a shared thread pool across multiple processes. + Allows processes to share work when they have idle threads. + """ + + def __init__(self, max_workers_per_process: Optional[int] = None): + self.max_workers_per_process = max_workers_per_process or multiprocessing.cpu_count() + self.task_queue = multiprocessing.Queue() + self.workers = [] + self.is_running = multiprocessing.Value("i", 1) + + def start_workers(self, num_processes: int): + """Start worker processes.""" + for i in range(num_processes): + worker = multiprocessing.Process( + target=self._worker_process, + args=(i, self.task_queue, self.is_running), + ) + worker.start() + self.workers.append(worker) + + def submit_task(self, task_func, *args, **kwargs): + """Submit a task to the shared pool.""" + task = {"func": task_func, "args": args, "kwargs": kwargs, "timestamp": time.time()} + self.task_queue.put(task) + + def _worker_process( + self, + worker_id: int, + task_queue: multiprocessing.Queue, + is_running: multiprocessing.Value, + ): + """Worker process that executes tasks from the shared queue.""" + local_threads = [] + max_local_threads = self.max_workers_per_process + + log.info(f"Worker process {worker_id} started with {max_local_threads} threads") + + while is_running.value: + try: + # Clean up finished threads + local_threads = [t for t in local_threads if t.is_alive()] + + # If we have capacity, get a new task + if len(local_threads) < max_local_threads: + try: + task = task_queue.get(timeout=1.0) + + # Create thread to execute task + thread = threading.Thread( + target=self._execute_task, args=(task, worker_id) + ) + thread.start() + local_threads.append(thread) + + except queue.Empty: + continue + else: + # Wait a bit if at capacity + time.sleep(0.1) + + except Exception as e: + log.exception(f"Worker process {worker_id} error: {e}") + + # Wait for remaining threads to finish + for thread in local_threads: + thread.join(timeout=5.0) + + log.info(f"Worker process {worker_id} finished") + + def _execute_task(self, task: Dict, worker_id: int): + """Execute a single task.""" + try: + func = task["func"] + args = task["args"] + kwargs = task["kwargs"] + + # Execute the task - engine.run() handles its own results/logging + func(*args, **kwargs) + + log.debug(f"Worker {worker_id} completed task successfully") + + except Exception as e: + log.exception(f"Worker {worker_id} task execution failed: {e}") + + def shutdown(self): + """Shutdown the thread pool.""" + self.is_running.value = 0 + + # Wait for workers to finish + for worker in self.workers: + worker.join(timeout=10.0) + if worker.is_alive(): + worker.terminate() + + log.info("Thread pool shutdown complete") + + +# Global instances +dependency_resolver = DependencyResolver() +thread_pool = None + + +def initialize_thread_pool(num_processes: int, max_workers_per_process: int = None): + """Initialize the global thread pool.""" + global thread_pool + thread_pool = CrossProcessThreadPool(max_workers_per_process) + thread_pool.start_workers(num_processes) + return thread_pool + + +def shutdown_thread_pool(): + """Shutdown the global thread pool.""" + global thread_pool + if thread_pool: + thread_pool.shutdown() + thread_pool = None From 29d84b1597ffad94a12afc8bc61ce88fd77d74e8 Mon Sep 17 00:00:00 2001 From: DavdaJames Date: Mon, 6 Oct 2025 13:29:36 +0530 Subject: [PATCH 2/3] thread pool now wait for submitted task completion --- nettacker/core/module.py | 13 +++++++++ nettacker/core/queue_manager.py | 48 +++++++++++++++++++++++++++++---- 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/nettacker/core/module.py b/nettacker/core/module.py index 762c0edbd..75600943b 100644 --- a/nettacker/core/module.py +++ b/nettacker/core/module.py @@ -160,6 +160,7 @@ def sort_loops(self): def start(self): active_threads = [] + used_shared_pool = False # counting total number of requests total_number_of_requests = 0 @@ -198,6 +199,7 @@ def start(self): request_number_counter, total_number_of_requests, ) + used_shared_pool = True else: # Use local thread (fallback to original behavior) thread = Thread( @@ -240,6 +242,17 @@ def start(self): ) time.sleep(self.module_inputs["time_sleep_between_requests"]) + # Wait for completion based on execution path + if used_shared_pool: + # Wait for shared thread pool tasks to complete + if queue_manager.thread_pool and hasattr( + queue_manager.thread_pool, "wait_for_completion" + ): + # Wait with a reasonable timeout to prevent hanging + completed = queue_manager.thread_pool.wait_for_completion(timeout=300) # 5 minutes + if not completed: + log.warn(f"Module {self.module_name} tasks did not complete within timeout") + # Wait for any remaining local threads to finish if active_threads: wait_for_threads_to_finish(active_threads, maximum=None, terminable=True) diff --git a/nettacker/core/queue_manager.py b/nettacker/core/queue_manager.py index e008ff678..36e64c9ca 100644 --- a/nettacker/core/queue_manager.py +++ b/nettacker/core/queue_manager.py @@ -209,13 +209,16 @@ def __init__(self, max_workers_per_process: Optional[int] = None): self.task_queue = multiprocessing.Queue() self.workers = [] self.is_running = multiprocessing.Value("i", 1) + # Task completion tracking + self.submitted_tasks = multiprocessing.Value("i", 0) + self.completed_tasks = multiprocessing.Value("i", 0) def start_workers(self, num_processes: int): """Start worker processes.""" for i in range(num_processes): worker = multiprocessing.Process( target=self._worker_process, - args=(i, self.task_queue, self.is_running), + args=(i, self.task_queue, self.is_running, self.completed_tasks), ) worker.start() self.workers.append(worker) @@ -224,12 +227,42 @@ def submit_task(self, task_func, *args, **kwargs): """Submit a task to the shared pool.""" task = {"func": task_func, "args": args, "kwargs": kwargs, "timestamp": time.time()} self.task_queue.put(task) + with self.submitted_tasks.get_lock(): + self.submitted_tasks.value += 1 + + def wait_for_completion(self, timeout: Optional[float] = None) -> bool: + """Wait for all submitted tasks to complete. + + Args: + timeout: Maximum time to wait in seconds. None means wait indefinitely. + + Returns: + True if all tasks completed, False if timeout occurred. + """ + start_time = time.time() + + while True: + with self.submitted_tasks.get_lock(): + submitted = self.submitted_tasks.value + with self.completed_tasks.get_lock(): + completed = self.completed_tasks.value + + if completed >= submitted: + return True + + if timeout is not None: + elapsed = time.time() - start_time + if elapsed >= timeout: + return False + + time.sleep(0.01) # Small sleep to avoid busy waiting def _worker_process( self, worker_id: int, task_queue: multiprocessing.Queue, is_running: multiprocessing.Value, + completed_tasks: multiprocessing.Value, ): """Worker process that executes tasks from the shared queue.""" local_threads = [] @@ -249,7 +282,7 @@ def _worker_process( # Create thread to execute task thread = threading.Thread( - target=self._execute_task, args=(task, worker_id) + target=self._execute_task, args=(task, worker_id, completed_tasks) ) thread.start() local_threads.append(thread) @@ -269,7 +302,7 @@ def _worker_process( log.info(f"Worker process {worker_id} finished") - def _execute_task(self, task: Dict, worker_id: int): + def _execute_task(self, task: Dict, worker_id: int, completed_tasks: multiprocessing.Value): """Execute a single task.""" try: func = task["func"] @@ -279,10 +312,15 @@ def _execute_task(self, task: Dict, worker_id: int): # Execute the task - engine.run() handles its own results/logging func(*args, **kwargs) - log.debug(f"Worker {worker_id} completed task successfully") + # Don't log successful task completion in normal operation + # log.info(f"Worker {worker_id} completed task successfully") except Exception as e: - log.exception(f"Worker {worker_id} task execution failed: {e}") + log.error(f"Worker {worker_id} task execution failed: {e}") + finally: + # Always increment completed count, even on failure + with completed_tasks.get_lock(): + completed_tasks.value += 1 def shutdown(self): """Shutdown the thread pool.""" From d1bbb1caab8246652148a83206dc9dd66d6893d0 Mon Sep 17 00:00:00 2001 From: DavdaJames Date: Mon, 6 Oct 2025 13:33:12 +0530 Subject: [PATCH 3/3] switched to joinable queue --- nettacker/core/queue_manager.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/nettacker/core/queue_manager.py b/nettacker/core/queue_manager.py index 36e64c9ca..cd604e2a0 100644 --- a/nettacker/core/queue_manager.py +++ b/nettacker/core/queue_manager.py @@ -206,7 +206,8 @@ class CrossProcessThreadPool: def __init__(self, max_workers_per_process: Optional[int] = None): self.max_workers_per_process = max_workers_per_process or multiprocessing.cpu_count() - self.task_queue = multiprocessing.Queue() + # Inter-process communication + self.task_queue = multiprocessing.JoinableQueue() self.workers = [] self.is_running = multiprocessing.Value("i", 1) # Task completion tracking @@ -282,7 +283,8 @@ def _worker_process( # Create thread to execute task thread = threading.Thread( - target=self._execute_task, args=(task, worker_id, completed_tasks) + target=self._execute_task, + args=(task, worker_id, completed_tasks, task_queue), ) thread.start() local_threads.append(thread) @@ -302,7 +304,9 @@ def _worker_process( log.info(f"Worker process {worker_id} finished") - def _execute_task(self, task: Dict, worker_id: int, completed_tasks: multiprocessing.Value): + def _execute_task( + self, task: Dict, worker_id: int, completed_tasks: multiprocessing.Value, task_queue + ): """Execute a single task.""" try: func = task["func"] @@ -321,15 +325,30 @@ def _execute_task(self, task: Dict, worker_id: int, completed_tasks: multiproces # Always increment completed count, even on failure with completed_tasks.get_lock(): completed_tasks.value += 1 + # Mark task as done for JoinableQueue + task_queue.task_done() def shutdown(self): - """Shutdown the thread pool.""" + """Shutdown the thread pool gracefully, ensuring all queued tasks complete.""" + log.info("Starting thread pool shutdown...") + + # First, wait for all queued tasks to complete + try: + log.info("Waiting for queued tasks to complete...") + self.task_queue.join() # Wait for all tasks to be marked as done + log.info("All queued tasks completed") + except Exception as e: + log.error(f"Error while waiting for tasks to complete: {e}") + + # Now signal workers to stop self.is_running.value = 0 + log.info("Signaled workers to stop") # Wait for workers to finish - for worker in self.workers: + for i, worker in enumerate(self.workers): worker.join(timeout=10.0) if worker.is_alive(): + log.warn(f"Worker {i} did not terminate gracefully, forcing termination") worker.terminate() log.info("Thread pool shutdown complete")