diff --git a/nettacker/cli/db_worker.py b/nettacker/cli/db_worker.py new file mode 100644 index 000000000..4bd464170 --- /dev/null +++ b/nettacker/cli/db_worker.py @@ -0,0 +1,47 @@ +import argparse +import signal +import time + +from nettacker.database.writer import get_writer + + +def _handle_sig(signum, frame): + writer = get_writer() + writer.stop() + raise SystemExit(0) + + +def run(): + parser = argparse.ArgumentParser() + parser.add_argument("--once", action="store_true", help="Drain the queue once and exit") + parser.add_argument("--batch-size", type=int, default=None, help="Writer batch size") + parser.add_argument("--interval", type=float, default=None, help="Writer sleep interval") + parser.add_argument( + "--max-items", type=int, default=None, help="Max items to process in --once mode" + ) + parser.add_argument("--summary", action="store_true", help="Print a summary after --once") + args = parser.parse_args() + + signal.signal(signal.SIGINT, _handle_sig) + signal.signal(signal.SIGTERM, _handle_sig) + + # apply runtime config + from nettacker.database.writer import get_writer_configured, get_stats + + writer = get_writer_configured(batch_size=args.batch_size, interval=args.interval) + if args.once: + processed = writer.drain_once(max_iterations=args.max_items or 100000) + if args.summary: + stats = get_stats() + print( + f"processed={processed} total_processed={stats.get('processed')} queue_size={stats.get('queue_size')}" + ) + return + + # Main loop - will be terminated by signal handlers + while True: + time.sleep(1) + + +if __name__ == "__main__": + run() diff --git a/nettacker/database/db.py b/nettacker/database/db.py index 19eda9b2a..0022d9aa6 100644 --- a/nettacker/database/db.py +++ b/nettacker/database/db.py @@ -9,6 +9,7 @@ from nettacker.config import Config from nettacker.core.messages import messages from nettacker.database.models import HostsLog, Report, TempEvents +from nettacker.database.writer import get_writer config = Config() log = logger.get_logger() @@ -95,6 +96,22 @@ def submit_report_to_db(event): return True if submitted otherwise False """ log.verbose_info(messages("inserting_report_db")) + writer = get_writer() + job = { + "action": "insert_report", + "payload": { + "date": event["date"], + "scan_id": event["scan_id"], + "report_path_filename": event["options"]["report_path_filename"], + "options": event["options"], + }, + } + try: + if writer.enqueue(job): + return True + except Exception: + pass + # Fallback to direct write session = create_connection() session.add( Report( @@ -140,6 +157,14 @@ def submit_logs_to_db(log): True if success otherwise False """ if isinstance(log, dict): + writer = get_writer() + job = {"action": "insert_hostslog", "payload": log} + try: + if writer.enqueue(job): + return True + except Exception: + pass + # Fallback session = create_connection() session.add( HostsLog( @@ -169,6 +194,14 @@ def submit_temp_logs_to_db(log): True if success otherwise False """ if isinstance(log, dict): + writer = get_writer() + job = {"action": "insert_tempevent", "payload": log} + try: + if writer.enqueue(job): + return True + except Exception: + pass + # Fallback session = create_connection() session.add( TempEvents( diff --git a/nettacker/database/writer.py b/nettacker/database/writer.py new file mode 100644 index 000000000..6f96c9708 --- /dev/null +++ b/nettacker/database/writer.py @@ -0,0 +1,368 @@ +import json +import threading +import time +from multiprocessing import Queue +from pathlib import Path + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from nettacker import logger +from nettacker.config import Config +from nettacker.database.models import Report, HostsLog, TempEvents + +log = logger.get_logger() + + +class DBWriter: + def __init__(self, batch_size=100, interval=0.5): + self.batch_size = int(batch_size) + self.interval = float(interval) + self._stop = threading.Event() + self._thread = None + # total processed across lifetime + self._processed_count = 0 + + self._use_litequeue = False + self._lq = None + self._lq_put = None + self._lq_get = None + + try: + import litequeue as _litequeue + + queue_file = Path(Config.path.data_dir) / "nettacker_db_queue.lq" + queue_file.parent.mkdir(parents=True, exist_ok=True) + # try common constructors + if hasattr(_litequeue, "LiteQueue"): + self._lq = _litequeue.LiteQueue(str(queue_file)) + elif hasattr(_litequeue, "Queue"): + self._lq = _litequeue.Queue(str(queue_file)) + else: + # fallback to a module-level factory + try: + self._lq = _litequeue.open(str(queue_file)) + except Exception: + self._lq = None + + if self._lq is not None: + # prefer destructive pop/get ordering + if hasattr(self._lq, "put"): + self._lq_put = self._lq.put + elif hasattr(self._lq, "push"): + self._lq_put = self._lq.push + elif hasattr(self._lq, "add"): + self._lq_put = self._lq.add + + if hasattr(self._lq, "pop"): + self._lq_get = self._lq.pop + elif hasattr(self._lq, "get"): + # note: some implementations require message_id; prefer pop above + self._lq_get = self._lq.get + elif hasattr(self._lq, "take"): + self._lq_get = self._lq.take + + if self._lq_put and self._lq_get: + self._use_litequeue = True + except Exception: + self._use_litequeue = False + + if not self._use_litequeue: + self.queue = Queue() + + db_url = Config.db.as_dict() + engine_url = ( + "sqlite:///{name}".format(**db_url) + if Config.db.engine.startswith("sqlite") + else Config.db.engine + ) + connect_args = {} + if engine_url.startswith("sqlite"): + connect_args["check_same_thread"] = False + + self.engine = create_engine(engine_url, connect_args=connect_args, pool_pre_ping=True) + if engine_url.startswith("sqlite"): + try: + with self.engine.connect() as conn: + conn.execute("PRAGMA journal_mode=WAL") + except Exception: + pass + + self.Session = sessionmaker(bind=self.engine) + + def start(self): + if self._thread and self._thread.is_alive(): + return + self._stop.clear() + self._thread = threading.Thread(target=self._run, name="nettacker-db-writer", daemon=True) + self._thread.start() + + def stop(self): + self._stop.set() + if self._thread: + self._thread.join(timeout=5) + + def enqueue(self, job): + try: + if self._use_litequeue: + self._lq_put(json.dumps(job)) + return True + self.queue.put(job) + return True + except Exception: + log.warn("DBWriter: failed to enqueue job") + return False + + def _acknowledge_message(self, message_id): + """Acknowledge a successfully processed message.""" + if self._use_litequeue and message_id is not None: + try: + if hasattr(self._lq, "done"): + self._lq.done(message_id) + except Exception: + pass + + def _pop_one(self): + if self._use_litequeue: + try: + # litequeue: use pop() to get and lock message, then mark done() AFTER processing + msg = None + if hasattr(self._lq, "pop"): + msg = self._lq.pop() + elif hasattr(self._lq, "get"): + # fallback: try to get next via get + msg = self._lq.get() + + if msg is None: + return None + + if hasattr(msg, "data"): + payload = msg.data + elif hasattr(msg, "message"): + payload = msg.message + else: + payload = str(msg) + + if isinstance(payload, (bytes, bytearray)): + payload = payload.decode() + + # Return both the payload and message_id for deferred acknowledgment + job_data = json.loads(payload) + if hasattr(msg, "message_id"): + return {"data": job_data, "message_id": msg.message_id} + else: + return {"data": job_data, "message_id": None} + except Exception: + return None + else: + try: + job_data = self.queue.get_nowait() + return {"data": job_data, "message_id": None} + + except Exception: + return None + + def _run(self): + pending = [] + while not self._stop.is_set(): + try: + while len(pending) < self.batch_size: + job = self._pop_one() + if job is None: + break + pending.append(job) + + if pending: + # Process each job individually with immediate commit + for job in pending: + job_session = self.Session() # Fresh session per job + try: + # Handle both litequeue format {"data": job, "message_id": id} and direct job + job_data = ( + job["data"] if isinstance(job, dict) and "data" in job else job + ) + self._apply_job(job_session, job_data) + job_session.commit() # Immediate commit per job + self._processed_count += 1 + + # Only acknowledge after successful commit + if isinstance(job, dict) and "message_id" in job: + self._acknowledge_message(job["message_id"]) + + except Exception as e: + job_session.rollback() + log.error(f"Failed to process job: {e}") + # Job is not acknowledged, so it can be retried + finally: + job_session.close() + + pending = [] + else: + time.sleep(self.interval) + except Exception: + time.sleep(0.1) + + # Final cleanup: process any remaining jobs individually + try: + while True: + job = self._pop_one() + if job is None: + break + + # Process final job individually with immediate commit + cleanup_session = self.Session() + try: + job_data = job["data"] if isinstance(job, dict) and "data" in job else job + self._apply_job(cleanup_session, job_data) + cleanup_session.commit() + self._processed_count += 1 + + # Only acknowledge after successful commit + if isinstance(job, dict) and "message_id" in job: + self._acknowledge_message(job["message_id"]) + + except Exception as e: + cleanup_session.rollback() + log.error(f"Failed to process cleanup job: {e}") + finally: + cleanup_session.close() + except Exception: + pass + + def drain_once(self, max_iterations=100000): + """Consume all queued jobs and return when queue is empty. + + This method is intended for on-demand draining (not long-lived). + """ + iterations = 0 + processed = 0 + + try: + while iterations < max_iterations: + job = self._pop_one() + if job is None: + break + + # Process each job individually with immediate commit for durability + job_session = self.Session() # Fresh session per job + try: + # Handle both litequeue format {"data": job, "message_id": id} and direct job + job_data = job["data"] if isinstance(job, dict) and "data" in job else job + self._apply_job(job_session, job_data) + job_session.commit() # Immediate commit per job + processed += 1 + self._processed_count += 1 + + # Only acknowledge after successful commit + if isinstance(job, dict) and "message_id" in job: + self._acknowledge_message(job["message_id"]) + + except Exception as e: + job_session.rollback() + log.error(f"Failed to process job during drain: {e}") + # Job is not acknowledged, so it can be retried + finally: + job_session.close() + + iterations += 1 + except Exception as e: + log.error(f"Error during drain operation: {e}") + + return processed + + def _apply_job(self, session, job): + action = job.get("action") + payload = job.get("payload", {}) + if action == "insert_report": + session.add( + Report( + date=payload.get("date"), + scan_unique_id=payload.get("scan_id"), + report_path_filename=payload.get("report_path_filename"), + options=json.dumps(payload.get("options", {})), + ) + ) + return + if action == "insert_hostslog": + session.add( + HostsLog( + target=payload.get("target"), + date=payload.get("date"), + module_name=payload.get("module_name"), + scan_unique_id=payload.get("scan_id"), + port=json.dumps(payload.get("port")), + event=json.dumps(payload.get("event")), + json_event=json.dumps(payload.get("json_event")), + ) + ) + return + if action == "insert_tempevent": + session.add( + TempEvents( + target=payload.get("target"), + date=payload.get("date"), + module_name=payload.get("module_name"), + scan_unique_id=payload.get("scan_id"), + event_name=payload.get("event_name"), + port=json.dumps(payload.get("port")), + event=json.dumps(payload.get("event")), + data=json.dumps(payload.get("data")), + ) + ) + return + log.warn(f"DBWriter: unsupported job action {action}") + + +# singleton writer +_writer = None + + +def get_writer(): + global _writer + if _writer is None: + _writer = DBWriter() + try: + _writer.start() + except Exception: + pass + return _writer + + +def get_writer_configured(batch_size=None, interval=None): + """Return singleton writer, applying optional configuration. + + If the writer already exists, provided parameters will update its settings. + """ + w = get_writer() + if batch_size is not None: + try: + w.batch_size = int(batch_size) + except Exception: + pass + if interval is not None: + try: + w.interval = float(interval) + except Exception: + pass + return w + + +def get_stats(): + w = get_writer() + queue_size = None + if getattr(w, "_use_litequeue", False) and getattr(w, "_lq", None) is not None: + try: + if hasattr(w._lq, "qsize"): + queue_size = w._lq.qsize() + elif hasattr(w._lq, "__len__"): + queue_size = len(w._lq) + elif hasattr(w._lq, "size"): + queue_size = w._lq.size() + except Exception: + queue_size = None + else: + try: + queue_size = w.queue.qsize() + except Exception: + queue_size = None + return {"processed": getattr(w, "_processed_count", 0), "queue_size": queue_size} diff --git a/poetry.lock b/poetry.lock index 0dae0b991..ea5d810a9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -973,7 +973,7 @@ description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" groups = ["main"] -markers = "python_version < \"3.10\"" +markers = "python_version == \"3.9\"" files = [ {file = "importlib_metadata-8.4.0-py3-none-any.whl", hash = "sha256:66f342cc6ac9818fc6ff340576acd24d65ba0b3efabb2b4ac08b598965a4a2f1"}, {file = "importlib_metadata-8.4.0.tar.gz", hash = "sha256:9a547d3bc3608b025f93d403fdd1aae741c24fbb8314df4b155675742ce303c5"}, @@ -1129,6 +1129,18 @@ files = [ dnspython = "*" ldap3 = ">2.5.0,<2.5.2 || >2.5.2,<2.6 || >2.6" +[[package]] +name = "litequeue" +version = "0.9" +description = "Simple queue built on top of SQLite" +optional = false +python-versions = ">=3.6" +groups = ["main"] +files = [ + {file = "litequeue-0.9-py3-none-any.whl", hash = "sha256:344312748f9d118253ecaa2fb82d432ab6f63dd3bd5c40922a63933bc47cd2e3"}, + {file = "litequeue-0.9.tar.gz", hash = "sha256:368f56b9de5c76fc6f2adc66177e81a59f545f3b5b95a26cae8562b858914647"}, +] + [[package]] name = "markupsafe" version = "2.1.5" @@ -2023,7 +2035,7 @@ files = [ {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] -markers = {dev = "python_version < \"3.10\""} +markers = {dev = "python_version == \"3.9\""} [[package]] name = "urllib3" @@ -2254,4 +2266,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = "^3.9, <3.13" -content-hash = "0e1731401cd6acfc4d45ede5e18668530aae6a6b2e359d7dc8d8d635635a1257" +content-hash = "aa676fcd9a242a436052e31b320c0c3d99451dc4323d6ae99fbe8f4f49e0d747" diff --git a/pyproject.toml b/pyproject.toml index e6939b9b5..a81f498a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ release_name = "QUIN" [tool.poetry.scripts] nettacker = "nettacker.main:run" +nettacker-db-worker = "nettacker.cli.db_worker:run" [tool.poetry.dependencies] python = "^3.9, <3.13" @@ -65,6 +66,7 @@ zipp = "^3.19.1" uvloop = "^0.21.0" pymysql = "^1.1.1" impacket = "^0.11.0" +litequeue = "^0.9" [tool.poetry.group.dev.dependencies] ipython = "^8.16.1" diff --git a/tests/database/test_writer.py b/tests/database/test_writer.py new file mode 100644 index 000000000..5e4a91a65 --- /dev/null +++ b/tests/database/test_writer.py @@ -0,0 +1,115 @@ +import os +import subprocess +from pathlib import Path + + +def start_worker_process(tmp_path): + # run worker in separate process + env = os.environ.copy() + # ensure current project is first on PYTHONPATH + env["PYTHONPATH"] = str(Path.cwd()) + os.pathsep + env.get("PYTHONPATH", "") + # Pass config through environment + data_dir = tmp_path / ".data" + env["NETTACKER_DATA_DIR"] = str(data_dir) + env["NETTACKER_DB_NAME"] = str(data_dir / "nettacker.db") + + proc = subprocess.Popen( + [ + env.get("PYTHON_BIN", "python"), + "-m", + "nettacker.cli.db_worker", + "--once", # Process all items and exit + "--max-items", + "10", + "--summary", # Show processing stats + ], + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + print(f"Started worker process {proc.pid} with data_dir={data_dir}") + return proc + + +def test_worker_writes(tmp_path): + """Test that the database writer correctly processes queued jobs and writes to database.""" + # Create test database + data_dir = tmp_path / ".data" + data_dir.mkdir() + db_path = str(data_dir / "nettacker.db") + + # Create database tables + from sqlalchemy import create_engine, text + + from nettacker.database.models import Base + + engine = create_engine(f"sqlite:///{db_path}") + Base.metadata.create_all(engine) + engine.dispose() + + # Create a writer configured to use the test database + from sqlalchemy.orm import sessionmaker + + from nettacker.database.writer import DBWriter + + writer = DBWriter() + # Override the database connection to use our test database + writer.engine = create_engine( + f"sqlite:///{db_path}", connect_args={"check_same_thread": False}, pool_pre_ping=True + ) + # Enable WAL mode for better concurrency + with writer.engine.connect() as conn: + conn.execute(text("PRAGMA journal_mode=WAL")) + conn.commit() + writer.Session = sessionmaker(bind=writer.engine) + + # Create test jobs for both report and hosts log + jobs = [ + { + "action": "insert_report", + "payload": { + "date": None, + "scan_id": "test-scan", + "report_path_filename": str(data_dir / "r.html"), + "options": {"report_path_filename": str(data_dir / "r.html")}, + }, + }, + { + "action": "insert_hostslog", + "payload": { + "date": None, + "target": "127.0.0.1", + "module_name": "m", + "scan_id": "test-scan", + "port": [], + "event": {}, + "json_event": {}, + }, + }, + ] + + # Enqueue jobs to the writer + for job in jobs: + writer.enqueue(job) + + # Process all queued jobs + processed_count = writer.drain_once(max_iterations=10) + assert processed_count == 2 + + # Verify the jobs were written to the database + import sqlite3 + + conn = sqlite3.connect(db_path) + c = conn.cursor() + + c.execute("select count(*) from reports where scan_unique_id = ?", ("test-scan",)) + report_count = c.fetchone()[0] + + c.execute("select count(*) from scan_events where scan_unique_id = ?", ("test-scan",)) + hosts_count = c.fetchone()[0] + + conn.close() + + assert report_count == 1 + assert hosts_count == 1