From e315362a8512680c27db1606a6957ffe9a15d0c5 Mon Sep 17 00:00:00 2001 From: Ryan Forsyth Date: Tue, 18 Nov 2025 17:57:06 -0600 Subject: [PATCH 1/9] Add Python 3.14 support --- .github/workflows/build_workflow.yml | 10 +++++----- conda/dev.yml | 2 +- setup.cfg | 2 +- setup.py | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/build_workflow.yml b/.github/workflows/build_workflow.yml index d0aa857a..0fa59694 100644 --- a/.github/workflows/build_workflow.yml +++ b/.github/workflows/build_workflow.yml @@ -19,10 +19,10 @@ jobs: - name: Checkout Code Repository uses: actions/checkout@v3 - - name: Set up Python 3.13 + - name: Set up Python 3.14 uses: actions/setup-python@v4 with: - python-version: "3.13" + python-version: "3.14" # Run all pre-commit hooks on all the files. # Getting only staged files can be tricky in case a new PR is opened @@ -36,7 +36,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.11", "3.12", "3.13"] + python-version: ["3.11", "3.12", "3.13", "3.14"] defaults: run: shell: bash -l {0} @@ -73,7 +73,7 @@ jobs: # Ensure we have the right Python version python --version # Fix pip issues for Python 3.12+ - if [[ "${{ matrix.python-version }}" == "3.12" ]] || [[ "${{ matrix.python-version }}" == "3.13" ]]; then + if [[ "${{ matrix.python-version }}" == "3.12" ]] || [[ "${{ matrix.python-version }}" == "3.13" ]] || [[ "${{ matrix.python-version }}" == "3.14" ]]; then python -m ensurepip --upgrade || true python -m pip install --upgrade --force-reinstall pip setuptools wheel fi @@ -121,7 +121,7 @@ jobs: environment-file: conda/dev.yml channel-priority: flexible # Changed from strict to flexible auto-update-conda: true - python-version: "3.13" # Use stable Python version for docs + python-version: "3.14" # Use stable Python version for docs # sphinx-multiversion allows for version docs. - name: Build Sphinx Docs diff --git a/conda/dev.yml b/conda/dev.yml index 6ec8e1b9..7aa2deac 100644 --- a/conda/dev.yml +++ b/conda/dev.yml @@ -5,7 +5,7 @@ dependencies: # Base # ================= - pip - - python >=3.11,<3.14 + - python >=3.11,<3.15 - sqlite - six >=1.16.0 - globus-sdk >=3.15.0,<4.0 diff --git a/setup.cfg b/setup.cfg index 4e6c6cb4..5a75b3ef 100644 --- a/setup.cfg +++ b/setup.cfg @@ -46,7 +46,7 @@ exclude = venv [mypy] -python_version = 3.13 +python_version = 3.14 check_untyped_defs = True ignore_missing_imports = True warn_unused_ignores = True diff --git a/setup.py b/setup.py index 292653e0..67df2e8e 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,6 @@ author_email="forsyth2@llnl.gov, golaz1@llnl.gov, shaheen2@llnl.gov", description="Long term HPSS archiving software for E3SM", packages=find_packages(include=["zstash", "zstash.*"]), - python_requires=">=3.11,<3.14", + python_requires=">=3.11,<3.15", entry_points={"console_scripts": ["zstash=zstash.main:main"]}, ) From 562c67c8c6aadeb72e1637e8e1adee97ed0e5ace Mon Sep 17 00:00:00 2001 From: Ryan Forsyth Date: Tue, 24 Feb 2026 13:05:07 -0800 Subject: [PATCH 2/9] Address review comments --- .github/workflows/build_workflow.yml | 5 ----- conda/dev.yml | 1 + 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/.github/workflows/build_workflow.yml b/.github/workflows/build_workflow.yml index 0fa59694..bd158784 100644 --- a/.github/workflows/build_workflow.yml +++ b/.github/workflows/build_workflow.yml @@ -72,11 +72,6 @@ jobs: conda list # Ensure we have the right Python version python --version - # Fix pip issues for Python 3.12+ - if [[ "${{ matrix.python-version }}" == "3.12" ]] || [[ "${{ matrix.python-version }}" == "3.13" ]] || [[ "${{ matrix.python-version }}" == "3.14" ]]; then - python -m ensurepip --upgrade || true - python -m pip install --upgrade --force-reinstall pip setuptools wheel - fi - name: Install `zstash` Package run: | diff --git a/conda/dev.yml b/conda/dev.yml index 7aa2deac..85f64d32 100644 --- a/conda/dev.yml +++ b/conda/dev.yml @@ -6,6 +6,7 @@ dependencies: # ================= - pip - python >=3.11,<3.15 + - setuptools - sqlite - six >=1.16.0 - globus-sdk >=3.15.0,<4.0 From 8f479a646bfb5627bd666b1bad95018ebe2b3cc9 Mon Sep 17 00:00:00 2001 From: Ryan Forsyth Date: Wed, 19 Nov 2025 16:07:22 -0600 Subject: [PATCH 3/9] Implementation changes to support Python 3.14 --- .github/workflows/build_workflow.yml | 1 + zstash/extract.py | 192 +++++++++++++++++++-------- zstash/parallel.py | 151 +++++++++++++-------- 3 files changed, 234 insertions(+), 110 deletions(-) diff --git a/.github/workflows/build_workflow.yml b/.github/workflows/build_workflow.yml index bd158784..a5e0913e 100644 --- a/.github/workflows/build_workflow.yml +++ b/.github/workflows/build_workflow.yml @@ -35,6 +35,7 @@ jobs: build: runs-on: ubuntu-latest strategy: + fail-fast: false # Continue all jobs even if one fails matrix: python-version: ["3.11", "3.12", "3.13", "3.14"] defaults: diff --git a/zstash/extract.py b/zstash/extract.py index 64977aef..b5e65501 100644 --- a/zstash/extract.py +++ b/zstash/extract.py @@ -3,7 +3,6 @@ import argparse import collections import hashlib -import heapq import logging import multiprocessing import os.path @@ -11,6 +10,7 @@ import sqlite3 import sys import tarfile +import time import traceback from datetime import datetime from typing import DefaultDict, List, Optional, Set, Tuple @@ -282,10 +282,10 @@ def extract_database( if args.workers > 1: logger.debug("Running zstash {} with multiprocessing".format(cmd)) failures = multiprocess_extract( - args.workers, matches, keep_files, keep, cache, cur, args + args.workers, matches, keep_files, keep, cache, args ) else: - failures = extractFiles(matches, keep_files, keep, cache, cur, args) + failures = extractFiles(matches, keep_files, keep, cache, args, None, cur) # Close database logger.debug("Closing index database") @@ -300,7 +300,6 @@ def multiprocess_extract( keep_files: bool, keep_tars: Optional[bool], cache: str, - cur: sqlite3.Cursor, args: argparse.Namespace, ) -> List[FilesRow]: """ @@ -329,26 +328,12 @@ def multiprocess_extract( # set the number of workers to the number of tars. num_workers = min(num_workers, len(tar_to_size)) - # For worker i, workers_to_tars[i] is a set of tars - # that worker i will work on. + # For worker i, workers_to_tars[i] is a set of tars that worker i will work on. + # Assign tars in round-robin fashion to maintain proper ordering workers_to_tars: List[set] = [set() for _ in range(num_workers)] - # A min heap, of (work, worker_idx) tuples, work is the size of data - # that worker_idx needs to work on. - # We can efficiently get the worker with the least amount of work. - work_to_workers: List[Tuple[int, int]] = [(0, i) for i in range(num_workers)] - heapq.heapify(workers_to_tars) - - # Using a greedy approach, populate workers_to_tars. - for _, tar in enumerate(tar_to_size): - # The worker with the least work should get the current largest amount of work. - workers_work: int - worker_idx: int - workers_work, worker_idx = heapq.heappop(work_to_workers) + for idx, tar in enumerate(sorted(tar_to_size.keys())): + worker_idx = idx % num_workers workers_to_tars[worker_idx].add(tar) - # Add this worker back to the heap, with the new amount of work. - worker_tuple: Tuple[float, int] = (workers_work + tar_to_size[tar], worker_idx) - # FIXME: error: Cannot infer type argument 1 of "heappush" - heapq.heappush(work_to_workers, worker_tuple) # type: ignore # For worker i, workers_to_matches[i] is a list of # matches from the database for it to process. @@ -361,8 +346,15 @@ def multiprocess_extract( # This worker gets this db_row. workers_to_matches[workers_idx].append(db_row) + # Sort each worker's matches by tar to ensure they process in order + for worker_matches in workers_to_matches: + worker_matches.sort(key=lambda t: t.tar) + tar_ordering: List[str] = sorted([tar for tar in tar_to_size]) - monitor: parallel.PrintMonitor = parallel.PrintMonitor(tar_ordering) + manager = multiprocessing.Manager() + monitor: parallel.PrintMonitor = parallel.PrintMonitor( + tar_ordering, manager=manager + ) # The return value for extractFiles will be added here. failure_queue: multiprocessing.Queue[FilesRow] = multiprocessing.Queue() @@ -374,7 +366,7 @@ def multiprocess_extract( ) process: multiprocessing.Process = multiprocessing.Process( target=extractFiles, - args=(matches, keep_files, keep_tars, cache, cur, args, worker), + args=(matches, keep_files, keep_tars, cache, args, worker), daemon=True, ) process.start() @@ -385,10 +377,39 @@ def multiprocess_extract( # No need to join() each of the processes when doing this, # because we'll be in this loop until completion. failures: List[FilesRow] = [] + max_wait_time = 180 # 3 minute timeout for tests + start_time = time.time() + last_log_time = start_time + while any(p.is_alive() for p in processes): + elapsed = time.time() - start_time + if elapsed > max_wait_time: + logger.error( + f"Timeout after {elapsed:.1f}s waiting for worker processes. Terminating..." + ) + for p in processes: + if p.is_alive(): + logger.error(f"Terminating process {p.pid}") + p.terminate() + break + + # Log progress every 30 seconds + if time.time() - last_log_time > 30: + alive_count = sum(1 for p in processes if p.is_alive()) + logger.debug( + f"Still waiting for {alive_count} worker processes after {elapsed:.1f}s" + ) + last_log_time = time.time() + while not failure_queue.empty(): failures.append(failure_queue.get()) + time.sleep(0.1) # Larger sleep to reduce CPU usage + + # Collect any remaining failures + while not failure_queue.empty(): + failures.append(failure_queue.get()) + # Sort the failures, since they can come in at any order. failures.sort(key=lambda t: (t.name, t.tar, t.offset)) return failures @@ -479,9 +500,9 @@ def extractFiles( # noqa: C901 keep_files: bool, keep_tars: Optional[bool], cache: str, - cur: sqlite3.Cursor, args: argparse.Namespace, multiprocess_worker: Optional[parallel.ExtractWorker] = None, + cur: Optional[sqlite3.Cursor] = None, ) -> List[FilesRow]: """ Given a list of database rows, extract the files from the @@ -498,21 +519,56 @@ def extractFiles( # noqa: C901 that called this function. We need a reference to it so we can signal it to print the contents of what's in its print queue. + + If cur is None (when running in parallel), a new database connection + will be opened for this worker process. + """ + try: + result = _extractFiles_impl( + files, keep_files, keep_tars, cache, args, multiprocess_worker, cur + ) + return result + except Exception as e: + if multiprocess_worker: + # Make sure we report failures even if there's an exception + sys.stderr.write(f"ERROR: Worker encountered fatal error: {e}\n") + sys.stderr.flush() + traceback.print_exc(file=sys.stderr) + for f in files: + multiprocess_worker.failure_queue.put(f) + raise + + +# FIXME: C901 '_extractFiles_impl' is too complex (42) +def _extractFiles_impl( # noqa: C901 + files: List[FilesRow], + keep_files: bool, + keep_tars: Optional[bool], + cache: str, + args: argparse.Namespace, + multiprocess_worker: Optional[parallel.ExtractWorker] = None, + cur: Optional[sqlite3.Cursor] = None, +) -> List[FilesRow]: + """ + Implementation of extractFiles - actual extraction logic. """ + # Open database connection if not provided (parallel case) + if cur is None: + con: sqlite3.Connection = sqlite3.connect( + get_db_filename(cache), detect_types=sqlite3.PARSE_DECLTYPES + ) + cur = con.cursor() + close_db: bool = True + else: + close_db = False + failures: List[FilesRow] = [] tfname: str newtar: bool = True nfiles: int = len(files) - if multiprocess_worker: - # All messages to the logger will now be sent to - # this queue, instead of sys.stdout. - sh = logging.StreamHandler(multiprocess_worker.print_queue) - sh.setLevel(logging.DEBUG) - formatter: logging.Formatter = logging.Formatter("%(levelname)s: %(message)s") - sh.setFormatter(formatter) - logger.addHandler(sh) - # Don't have the logger print to the console as the message come in. - logger.propagate = False + + # Track if we've set up logging yet + logging_setup: bool = False for i in range(nfiles): files_row: FilesRow = files[i] @@ -521,16 +577,46 @@ def extractFiles( # noqa: C901 if newtar: newtar = False tfname = os.path.join(cache, files_row.tar) - # Everytime we're extracting a new tar, if running in parallel, - # let the process know. - # This is to synchronize the print statements. + + # CRITICAL: Wait for our turn BEFORE doing anything with this tar if multiprocess_worker: + try: + multiprocess_worker.print_monitor.wait_turn( + multiprocess_worker, files_row.tar, indef_wait=True + ) + except TimeoutError as e: + logger.error( + f"Timeout waiting for turn to process {files_row.tar}: {e}" + ) + # Mark all remaining files from this tar as failed + for j in range(i, nfiles): + if files[j].tar == files_row.tar: + failures.append(files[j]) + # Skip to next tar + newtar = True + continue + + # NOW set up logging (only once) + if not logging_setup: + sh = logging.StreamHandler(multiprocess_worker.print_queue) + sh.setLevel(logging.DEBUG) + formatter: logging.Formatter = logging.Formatter( + "%(levelname)s: %(message)s" + ) + sh.setFormatter(formatter) + logger.addHandler(sh) + logger.propagate = False + logging_setup = True + + # Set current tar for this worker multiprocess_worker.set_curr_tar(files_row.tar) - if config.hpss is not None: - hpss: str = config.hpss + # Use args.hpss directly - it's always set correctly + if args.hpss is not None: + hpss: str = args.hpss else: - raise TypeError("Invalid config.hpss={}".format(config.hpss)) + raise TypeError("Invalid args.hpss={}".format(args.hpss)) + tries: int = args.retries + 1 # Set to True to test the `--retries` option with a forced failure. # Then run `python -m unittest tests.test_extract.TestExtract.testExtractRetries` @@ -574,8 +660,6 @@ def extractFiles( # noqa: C901 # Extract file cmd: str = "Extracting" if keep_files else "Checking" logger.info(cmd + " %s" % (files_row.name)) - # if multiprocess_worker: - # print('{} is {} {} from {}'.format(multiprocess_worker, cmd, file[1], file[5])) if keep_files and not should_extract_file(files_row): # If we were going to extract, but aren't @@ -676,9 +760,6 @@ def extractFiles( # noqa: C901 logger.error("Retrieving {}".format(files_row.name)) failures.append(files_row) - if multiprocess_worker: - multiprocess_worker.print_contents() - # Close current archive? if i == nfiles - 1 or files[i].tar != files[i + 1].tar: # We're either on the last file or the tar is distinct from the tar of the next file. @@ -688,8 +769,15 @@ def extractFiles( # noqa: C901 tar.close() if multiprocess_worker: + # Mark that all output for this tar is queued multiprocess_worker.done_enqueuing_output_for_tar(files_row.tar) + # Now print everything and advance to next tar + try: + multiprocess_worker.print_all_contents() + except (TimeoutError, Exception) as e: + logger.debug(f"Error printing contents for {files_row.tar}: {e}") + # Open new archive next time newtar = True @@ -700,13 +788,13 @@ def extractFiles( # noqa: C901 else: raise TypeError("Invalid tfname={}".format(tfname)) - if multiprocess_worker: - # If there are things left to print, print them. - multiprocess_worker.print_all_contents() + # Close database connection if we opened it + if close_db: + cur.close() + con.close() - # Add the failures to the queue. - # When running with multiprocessing, the function multiprocess_extract() - # that calls this extractFiles() function will return the failures as a list. + if multiprocess_worker: + # Add failures to the queue for f in failures: multiprocess_worker.failure_queue.put(f) return failures diff --git a/zstash/parallel.py b/zstash/parallel.py index 53beb53d..636fdb73 100644 --- a/zstash/parallel.py +++ b/zstash/parallel.py @@ -1,8 +1,8 @@ from __future__ import print_function import collections -import ctypes import multiprocessing +import time from typing import Dict, List, Optional from .settings import FilesRow @@ -24,7 +24,7 @@ class PrintMonitor(object): for that tar will print it's output. """ - def __init__(self, tars_to_print: List[str], *args, **kwargs): + def __init__(self, tars_to_print: List[str], manager=None, *args, **kwargs): # A list of tars to print. # Ex: ['000000.tar', '000008.tar', '00001a.tar'] if not tars_to_print: @@ -32,75 +32,92 @@ def __init__(self, tars_to_print: List[str], *args, **kwargs): msg += " the order of which to print the results." raise RuntimeError(msg) - # The variables below are modified/accessed by different processes, - # so they need to be in shared memory. - self._cv: multiprocessing.synchronize.Condition = multiprocessing.Condition() - - self._tars_to_print: multiprocessing.Queue[str] = multiprocessing.Queue() - tar: str - for tar in tars_to_print: - # Add the tar to the queue to be printed. - self._tars_to_print.put(tar) - - # We need a manager to instantiate the Value instead of multiprocessing.Value. - # If we didn't use a manager, it seems to get some junk value. - self._manager: multiprocessing.managers.SyncManager = multiprocessing.Manager() - self._current_tar: multiprocessing.managers.ValueProxy = self._manager.Value( - ctypes.c_char_p, self._tars_to_print.get() + # Accept manager from outside to avoid pickling issues + # The manager must be created in the main process before forking + if manager is None: + raise ValueError("manager must be provided to PrintMonitor") + + # Store the ordered list of tars + self._tars_list: List[str] = tars_to_print + + # Use a simple counter instead of condition variables + # Tracks which tar index we're currently on + self._current_tar_index: multiprocessing.managers.ValueProxy = manager.Value( + "i", 0 ) + # Lock for updating the counter + self._lock: multiprocessing.synchronize.Lock = manager.Lock() + def wait_turn( - # TODO: worker has type `ExtractWorker` - self, - worker, - workers_curr_tar: str, - indef_wait: bool = True, - *args, - **kwargs + self, worker, workers_curr_tar: str, indef_wait: bool = True, *args, **kwargs ): - """ - While a worker's current tar isn't the one - needed to be printed, wait. - - A timeout is passed into self._cv.wait(), and if the - turn isn't given within that, a NotYourTurnError is raised. + import sys - If indef_wait is True, indefinitely wait until it's - the worker's turn. - """ - with self._cv: - attempted: bool = False - while self._current_tar.value != workers_curr_tar: - if attempted and not indef_wait: - # It's not this worker's turn. - raise NotYourTurnError() - - attempted = True - # Wait 0.001 to see if it's the worker's turn. - self._cv.wait(0.001) - - def done_dequeuing_output_for_tar( + # Find the index of the worker's tar in the ordered list + try: + tar_index = self._tars_list.index(workers_curr_tar) + except ValueError: + sys.stderr.write(f"DEBUG: Tar {workers_curr_tar} not in list!\n") + sys.stderr.flush() + return + + sys.stderr.write( + f"DEBUG: Worker waiting for tar {workers_curr_tar} (index {tar_index}), current index is {self._current_tar_index.value}\n" + ) + sys.stderr.flush() + + max_wait_time = 180.0 + start_time = time.time() + attempted = False + + while True: + if self._current_tar_index.value == tar_index: + sys.stderr.write( + f"DEBUG: Worker got turn for tar {workers_curr_tar}!\n" + ) + sys.stderr.flush() + return + + if attempted and not indef_wait: + # It's not this worker's turn. + raise NotYourTurnError() + + # Check if we've been waiting too long + if indef_wait and (time.time() - start_time) > max_wait_time: + raise TimeoutError( + f"Worker timed out waiting for turn to print {workers_curr_tar}. " + f"Current tar index is {self._current_tar_index.value} (expecting {tar_index})" + ) + + attempted = True + # Sleep briefly and check again + time.sleep(0.1) + + def done_enqueuing_output_for_tar( # TODO: worker has type `ExtractWorker` self, worker, workers_curr_tar: str, *args, - **kwargs + **kwargs, ): """ A worker has finished printing the output for workers_curr_tar from the print queue. - If possible, update self._current_tar. - If there aren't anymore tars to print, set self._current_tar to None. + Advance to the next tar in the sequence. """ - # It must be the worker's turn before this can happen. - self.wait_turn(worker, workers_curr_tar, *args, **kwargs) + # Find our tar's index + try: + tar_index = self._tars_list.index(workers_curr_tar) + except ValueError: + return - with self._cv: - self._current_tar.value = ( - self._tars_to_print.get() if not self._tars_to_print.empty() else "" - ) - self._cv.notify_all() + # Advance to the next tar ONLY if we're the current tar + # This allows workers to signal completion without blocking + with self._lock: + if self._current_tar_index.value == tar_index: + self._current_tar_index.value += 1 class ExtractWorker(object): @@ -120,7 +137,7 @@ def __init__( # TODO: failure_queue has type `multiprocessing.Queue[FilesRow]` failure_queue, *args, - **kwargs + **kwargs, ): """ print_monitor is used to determine if it's this worker's turn to print. @@ -186,7 +203,21 @@ def print_all_contents(self, *args, **kwargs): while self.has_to_print(): # Try to print the first element in the queue. tar_to_print: str = self.print_queue[0].tar - self.print_monitor.wait_turn(self, tar_to_print, *args, **kwargs) + + try: + self.print_monitor.wait_turn(self, tar_to_print, *args, **kwargs) + except TimeoutError: + # If we timeout waiting to print, dump to stdout without ordering + # This prevents deadlocks + while self.print_queue and (self.print_queue[0].tar == tar_to_print): + err_msg: str = self.print_queue.popleft().msg + print(err_msg, end="", flush=True) + + # Mark as done even though we couldn't print in order + if self.is_output_done_enqueuing.get(tar_to_print, False): + # Skip the monitor sync since it timed out + pass + continue # Print all applicable values in the print_queue. while self.print_queue and (self.print_queue[0].tar == tar_to_print): @@ -197,7 +228,11 @@ def print_all_contents(self, *args, **kwargs): # Since we just finished printing all of it, we can move onto the next one. if self.is_output_done_enqueuing[tar_to_print]: # Let all of the other workers know that this worker is done. - self.print_monitor.done_dequeuing_output_for_tar(self, tar_to_print) + try: + self.print_monitor.done_enqueuing_output_for_tar(self, tar_to_print) + except TimeoutError: + # If we can't update the monitor, just continue + pass class PrintQueue(collections.deque): From d4a092937771b58a1583ab681554dcd893b0240a Mon Sep 17 00:00:00 2001 From: Ryan Forsyth Date: Mon, 2 Feb 2026 17:02:17 -0800 Subject: [PATCH 4/9] Fixes to make Perlmutter tests pass --- zstash/extract.py | 6 +++++- zstash/parallel.py | 21 ++++++++++----------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/zstash/extract.py b/zstash/extract.py index b5e65501..b88bf9e0 100644 --- a/zstash/extract.py +++ b/zstash/extract.py @@ -743,7 +743,11 @@ def _extractFiles_impl( # noqa: C901 logger.debug("Valid md5: {} {}".format(md5, fname)) elif extract_this_file: - tar.extract(tarinfo) + # Python 3.11 and earlier don't support the filter parameter at all + if sys.version_info >= (3, 12): + tar.extract(tarinfo, filter="tar") + else: + tar.extract(tarinfo) # Note: tar.extract() will not restore time stamps of symbolic # links. Could not find a Python-way to restore it either, so # relying here on 'touch'. This is not the prettiest solution. diff --git a/zstash/parallel.py b/zstash/parallel.py index 636fdb73..e4295029 100644 --- a/zstash/parallel.py +++ b/zstash/parallel.py @@ -52,20 +52,19 @@ def __init__(self, tars_to_print: List[str], manager=None, *args, **kwargs): def wait_turn( self, worker, workers_curr_tar: str, indef_wait: bool = True, *args, **kwargs ): - import sys # Find the index of the worker's tar in the ordered list try: tar_index = self._tars_list.index(workers_curr_tar) except ValueError: - sys.stderr.write(f"DEBUG: Tar {workers_curr_tar} not in list!\n") - sys.stderr.flush() + # sys.stderr.write(f"DEBUG: Tar {workers_curr_tar} not in list!\n") + # sys.stderr.flush() return - sys.stderr.write( - f"DEBUG: Worker waiting for tar {workers_curr_tar} (index {tar_index}), current index is {self._current_tar_index.value}\n" - ) - sys.stderr.flush() + # sys.stderr.write( + # f"DEBUG: Worker waiting for tar {workers_curr_tar} (index {tar_index}), current index is {self._current_tar_index.value}\n" + # ) + # sys.stderr.flush() max_wait_time = 180.0 start_time = time.time() @@ -73,10 +72,10 @@ def wait_turn( while True: if self._current_tar_index.value == tar_index: - sys.stderr.write( - f"DEBUG: Worker got turn for tar {workers_curr_tar}!\n" - ) - sys.stderr.flush() + # sys.stderr.write( + # f"DEBUG: Worker got turn for tar {workers_curr_tar}!\n" + # ) + # sys.stderr.flush() return if attempted and not indef_wait: From 62df9c71368f81cee53cd2a4464255577ab3c64b Mon Sep 17 00:00:00 2001 From: Ryan Forsyth Date: Tue, 3 Feb 2026 11:21:24 -0800 Subject: [PATCH 5/9] Claude's changes to reduce over-engineering --- zstash/extract.py | 223 ++++++--------------------------------------- zstash/parallel.py | 97 +++----------------- 2 files changed, 38 insertions(+), 282 deletions(-) diff --git a/zstash/extract.py b/zstash/extract.py index b88bf9e0..2ecfe792 100644 --- a/zstash/extract.py +++ b/zstash/extract.py @@ -304,49 +304,29 @@ def multiprocess_extract( ) -> List[FilesRow]: """ Extract the files from the matches in parallel. - - A single unit of work is a tar and all of - the files in it to extract. """ - # A dict of tar -> size of files in it. - # This is because we're trying to balance the load between - # the processes. tar_to_size_unsorted: DefaultDict[str, float] = collections.defaultdict(float) - db_row: FilesRow - tar: str - size: int for db_row in matches: - tar, size = db_row.tar, db_row.size - tar_to_size_unsorted[tar] += size - # Sort by the size. + tar_to_size_unsorted[db_row.tar] += db_row.size + tar_to_size: collections.OrderedDict[str, float] = collections.OrderedDict( sorted(tar_to_size_unsorted.items(), key=lambda x: x[1]) ) - # We don't want to instantiate more processes than we need to. - # So, if the number of tars is less than the number of workers, - # set the number of workers to the number of tars. num_workers = min(num_workers, len(tar_to_size)) - # For worker i, workers_to_tars[i] is a set of tars that worker i will work on. - # Assign tars in round-robin fashion to maintain proper ordering + # Round-robin assignment for predictable ordering workers_to_tars: List[set] = [set() for _ in range(num_workers)] for idx, tar in enumerate(sorted(tar_to_size.keys())): - worker_idx = idx % num_workers - workers_to_tars[worker_idx].add(tar) + workers_to_tars[idx % num_workers].add(tar) - # For worker i, workers_to_matches[i] is a list of - # matches from the database for it to process. workers_to_matches: List[List[FilesRow]] = [[] for _ in range(num_workers)] for db_row in matches: - tar = db_row.tar - workers_idx: int for workers_idx in range(len(workers_to_tars)): - if tar in workers_to_tars[workers_idx]: - # This worker gets this db_row. + if db_row.tar in workers_to_tars[workers_idx]: workers_to_matches[workers_idx].append(db_row) - # Sort each worker's matches by tar to ensure they process in order + # Ensure each worker processes tars in order for worker_matches in workers_to_matches: worker_matches.sort(key=lambda t: t.tar) @@ -356,9 +336,9 @@ def multiprocess_extract( tar_ordering, manager=manager ) - # The return value for extractFiles will be added here. failure_queue: multiprocessing.Queue[FilesRow] = multiprocessing.Queue() processes: List[multiprocessing.Process] = [] + for matches in workers_to_matches: tars_for_this_worker: List[str] = list(set(match.tar for match in matches)) worker: parallel.ExtractWorker = parallel.ExtractWorker( @@ -366,51 +346,21 @@ def multiprocess_extract( ) process: multiprocessing.Process = multiprocessing.Process( target=extractFiles, - args=(matches, keep_files, keep_tars, cache, args, worker), + args=(matches, keep_files, keep_tars, cache, args, worker, None), daemon=True, ) process.start() processes.append(process) - # While the processes are running, we need to empty the queue. - # Otherwise, it causes hanging. - # No need to join() each of the processes when doing this, - # because we'll be in this loop until completion. failures: List[FilesRow] = [] - max_wait_time = 180 # 3 minute timeout for tests - start_time = time.time() - last_log_time = start_time - while any(p.is_alive() for p in processes): - elapsed = time.time() - start_time - if elapsed > max_wait_time: - logger.error( - f"Timeout after {elapsed:.1f}s waiting for worker processes. Terminating..." - ) - for p in processes: - if p.is_alive(): - logger.error(f"Terminating process {p.pid}") - p.terminate() - break - - # Log progress every 30 seconds - if time.time() - last_log_time > 30: - alive_count = sum(1 for p in processes if p.is_alive()) - logger.debug( - f"Still waiting for {alive_count} worker processes after {elapsed:.1f}s" - ) - last_log_time = time.time() - while not failure_queue.empty(): failures.append(failure_queue.get()) + time.sleep(0.01) - time.sleep(0.1) # Larger sleep to reduce CPU usage - - # Collect any remaining failures while not failure_queue.empty(): failures.append(failure_queue.get()) - # Sort the failures, since they can come in at any order. failures.sort(key=lambda t: (t.name, t.tar, t.offset)) return failures @@ -421,61 +371,46 @@ def check_sizes_match(cur, tfname, error_on_duplicate_tar): actual_size: int = os.path.getsize(tfname) name_only: str = os.path.split(tfname)[1] - # Get ALL entries for this tar name cur.execute( "SELECT size FROM tars WHERE name = ? ORDER by id DESC", (name_only,) ) results = cur.fetchall() if not results: - # Cannot access size information; assume the sizes match. logger.error(f"No database entries found for {name_only}") return True - # Check for multiple entries if len(results) > 1: - # Extract just the size values sizes: List[int] = [row[0] for row in results] error_str: str = ( f"Database corruption detected! Found {len(results)} database entries for {name_only}, with sizes {sizes}" ) if error_on_duplicate_tar: - # Tested by database_corruption.bash Case 5 logger.error(error_str) raise RuntimeError(error_str) logger.warning(error_str) - # We ordered the results by id DESC, - # so the first entry is the most recent. most_recent_size: int = sizes[0] if actual_size == most_recent_size: - # Tested by database_corruption.bash Case 7 - # If the actual size matches the most recent size, - # then we can assume that the tar is valid. logger.info( f"{name_only}: The most recent database entry has the same size as the actual file size: {actual_size}." ) return True unique_sizes: Set[int] = set(sizes) if actual_size in unique_sizes: - # Tested by database_corruption.bash Case 8 logger.info( f"{name_only}: A database entry matches the actual file size, {actual_size}, but it is not the most recent entry." ) else: - # Tested by database_corruption.bash Case 6 logger.info( f"{name_only}: No database entry matches the actual file size: {actual_size}." ) return False else: - # Tested by database_corruption.bash Cases 1,2,4 - # Single entry - normal case logger.info(f"{name_only}: Found a single database entry.") expected_size = results[0][0] - # Now check if actual size matches expected size if expected_size != actual_size: error_msg = ( f"{name_only}: Size mismatch! " @@ -485,16 +420,13 @@ def check_sizes_match(cur, tfname, error_on_duplicate_tar): logger.error(error_msg) return False else: - # Sizes match logger.info(f"{name_only}: Size check passed ({actual_size} bytes)") return True else: - # Cannot access size information; assume the sizes match. logger.debug("Cannot access tar size information; assuming sizes match") return True -# FIXME: C901 'extractFiles' is too complex (33) def extractFiles( # noqa: C901 files: List[FilesRow], keep_files: bool, @@ -505,53 +437,11 @@ def extractFiles( # noqa: C901 cur: Optional[sqlite3.Cursor] = None, ) -> List[FilesRow]: """ - Given a list of database rows, extract the files from the - tar archives to the current location on disk. - - If keep_files is False, the files are not extracted. - This is used for when checking if the files in an HPSS - repository are valid. - - If keep_tars is True, the tar archives that are downloaded are kept, - even after the program has terminated. Otherwise, they are deleted. - - If running in parallel, then multiprocess_worker is the Worker - that called this function. - We need a reference to it so we can signal it to print - the contents of what's in its print queue. + Given a list of database rows, extract the files from the tar archives. If cur is None (when running in parallel), a new database connection will be opened for this worker process. """ - try: - result = _extractFiles_impl( - files, keep_files, keep_tars, cache, args, multiprocess_worker, cur - ) - return result - except Exception as e: - if multiprocess_worker: - # Make sure we report failures even if there's an exception - sys.stderr.write(f"ERROR: Worker encountered fatal error: {e}\n") - sys.stderr.flush() - traceback.print_exc(file=sys.stderr) - for f in files: - multiprocess_worker.failure_queue.put(f) - raise - - -# FIXME: C901 '_extractFiles_impl' is too complex (42) -def _extractFiles_impl( # noqa: C901 - files: List[FilesRow], - keep_files: bool, - keep_tars: Optional[bool], - cache: str, - args: argparse.Namespace, - multiprocess_worker: Optional[parallel.ExtractWorker] = None, - cur: Optional[sqlite3.Cursor] = None, -) -> List[FilesRow]: - """ - Implementation of extractFiles - actual extraction logic. - """ # Open database connection if not provided (parallel case) if cur is None: con: sqlite3.Connection = sqlite3.connect( @@ -567,60 +457,38 @@ def _extractFiles_impl( # noqa: C901 newtar: bool = True nfiles: int = len(files) - # Track if we've set up logging yet - logging_setup: bool = False + # Set up logging redirection for multiprocessing + if multiprocess_worker: + sh = logging.StreamHandler(multiprocess_worker.print_queue) + sh.setLevel(logging.DEBUG) + formatter: logging.Formatter = logging.Formatter("%(levelname)s: %(message)s") + sh.setFormatter(formatter) + logger.addHandler(sh) + logger.propagate = False for i in range(nfiles): files_row: FilesRow = files[i] - # Open new tar archive if newtar: newtar = False tfname = os.path.join(cache, files_row.tar) - # CRITICAL: Wait for our turn BEFORE doing anything with this tar + # Wait for turn before processing this tar if multiprocess_worker: - try: - multiprocess_worker.print_monitor.wait_turn( - multiprocess_worker, files_row.tar, indef_wait=True - ) - except TimeoutError as e: - logger.error( - f"Timeout waiting for turn to process {files_row.tar}: {e}" - ) - # Mark all remaining files from this tar as failed - for j in range(i, nfiles): - if files[j].tar == files_row.tar: - failures.append(files[j]) - # Skip to next tar - newtar = True - continue - - # NOW set up logging (only once) - if not logging_setup: - sh = logging.StreamHandler(multiprocess_worker.print_queue) - sh.setLevel(logging.DEBUG) - formatter: logging.Formatter = logging.Formatter( - "%(levelname)s: %(message)s" - ) - sh.setFormatter(formatter) - logger.addHandler(sh) - logger.propagate = False - logging_setup = True - - # Set current tar for this worker + multiprocess_worker.print_monitor.wait_turn( + multiprocess_worker, files_row.tar + ) multiprocess_worker.set_curr_tar(files_row.tar) - # Use args.hpss directly - it's always set correctly + # Use args.hpss directly if args.hpss is not None: hpss: str = args.hpss else: raise TypeError("Invalid args.hpss={}".format(args.hpss)) tries: int = args.retries + 1 - # Set to True to test the `--retries` option with a forced failure. - # Then run `python -m unittest tests.test_extract.TestExtract.testExtractRetries` test_retry: bool = False + while tries > 0: tries -= 1 do_retrieve: bool @@ -644,12 +512,10 @@ def _extractFiles_impl( # noqa: C901 raise RuntimeError( f"{tfname} size does not match expected size." ) - # `hpss_get` successful or not needed: no more tries needed break except RuntimeError as e: if tries > 0: logger.info(f"Retrying HPSS get: {tries} tries remaining.") - # Run the try-except block again continue else: raise e @@ -662,30 +528,23 @@ def _extractFiles_impl( # noqa: C901 logger.info(cmd + " %s" % (files_row.name)) if keep_files and not should_extract_file(files_row): - # If we were going to extract, but aren't - # because a matching file is on disk msg: str = "Not extracting {}, because it" msg += " already exists on disk with the same" msg += " size and modification date." logger.info(msg.format(files_row.name)) - # True if we should actually extract the file from the tar extract_this_file: bool = keep_files and should_extract_file(files_row) try: - # Seek file position if tar.fileobj is not None: fileobj = tar.fileobj else: raise TypeError("Invalid tar.fileobj={}".format(tar.fileobj)) fileobj.seek(files_row.offset) - # Get next member tarinfo: tarfile.TarInfo = tar.tarinfo.fromtarfile(tar) if tarinfo.isfile(): - # fileobj to extract - # error: Name 'tarfile.ExFileObject' is not defined extracted_file: Optional[tarfile.ExFileObject] = tar.extractfile(tarinfo) # type: ignore if extracted_file: fin: tarfile.ExFileObject = extracted_file @@ -698,11 +557,8 @@ def _extractFiles_impl( # noqa: C901 path, name = os.path.split(fname) if path != "" and extract_this_file: if not os.path.isdir(path): - # The path doesn't exist, so create it. os.makedirs(path) if extract_this_file: - # If we're keeping the files, - # then have an output file fout: _io.BufferedWriter = open(fname, "wb") hash_md5: _hashlib.HASH = hashlib.md5() @@ -721,37 +577,26 @@ def _extractFiles_impl( # noqa: C901 md5: str = hash_md5.hexdigest() if extract_this_file: - # numeric_owner is a required arg in Python 3. - # If True, "only the numbers for user/group names - # are used and not the names". tar.chown(tarinfo, fname, numeric_owner=False) tar.chmod(tarinfo, fname) tar.utime(tarinfo, fname) - # Verify size if os.path.getsize(fname) != files_row.size: logger.error("size mismatch for: {}".format(fname)) - # Verify md5 checksum files_row_md5: Optional[str] = files_row.md5 if md5 != files_row_md5: logger.error("md5 mismatch for: {}".format(fname)) logger.error("md5 of extracted file: {}".format(md5)) logger.error("md5 of original file: {}".format(files_row_md5)) - failures.append(files_row) else: logger.debug("Valid md5: {} {}".format(md5, fname)) elif extract_this_file: - # Python 3.11 and earlier don't support the filter parameter at all if sys.version_info >= (3, 12): tar.extract(tarinfo, filter="tar") else: tar.extract(tarinfo) - # Note: tar.extract() will not restore time stamps of symbolic - # links. Could not find a Python-way to restore it either, so - # relying here on 'touch'. This is not the prettiest solution. - # Maybe a better one can be implemented later. if tarinfo.issym(): tmp1 = tarinfo.mtime tmp2: datetime = datetime.fromtimestamp(tmp1) @@ -759,33 +604,21 @@ def _extractFiles_impl( # noqa: C901 os.system("touch -h -t %s %s" % (tmp3, tarinfo.name)) except Exception: - # Catch all exceptions here. traceback.print_exc() logger.error("Retrieving {}".format(files_row.name)) failures.append(files_row) # Close current archive? if i == nfiles - 1 or files[i].tar != files[i + 1].tar: - # We're either on the last file or the tar is distinct from the tar of the next file. - - # Close current archive file logger.debug("Closing tar archive {}".format(tfname)) tar.close() if multiprocess_worker: - # Mark that all output for this tar is queued multiprocess_worker.done_enqueuing_output_for_tar(files_row.tar) + multiprocess_worker.print_all_contents() - # Now print everything and advance to next tar - try: - multiprocess_worker.print_all_contents() - except (TimeoutError, Exception) as e: - logger.debug(f"Error printing contents for {files_row.tar}: {e}") - - # Open new archive next time newtar = True - # Delete this tar if the corresponding command-line arg was used. if not keep_tars: if tfname is not None: os.remove(tfname) @@ -798,9 +631,9 @@ def _extractFiles_impl( # noqa: C901 con.close() if multiprocess_worker: - # Add failures to the queue for f in failures: multiprocess_worker.failure_queue.put(f) + return failures @@ -815,15 +648,11 @@ def should_extract_file(db_row: FilesRow) -> bool: file_name, size_db, mod_time_db = db_row.name, db_row.size, db_row.mtime if not os.path.exists(file_name): - # The file doesn't exist locally. - # We must get files that are not on disk. return True size_disk: int = os.path.getsize(file_name) mod_time_disk: datetime = datetime.utcfromtimestamp(os.path.getmtime(file_name)) - # Only extract when the times and sizes are not the same (within tolerance) - # We have a TIME_TOL because mod_time_disk doesn't have the microseconds. return not ( (size_disk == size_db) and (abs(mod_time_disk - mod_time_db).total_seconds() < TIME_TOL) diff --git a/zstash/parallel.py b/zstash/parallel.py index e4295029..78d3a3de 100644 --- a/zstash/parallel.py +++ b/zstash/parallel.py @@ -20,28 +20,21 @@ class NotYourTurnError(Exception): class PrintMonitor(object): """ Used to synchronize the printing of the output between workers. - Depending on the current_tar, the worker processing the work - for that tar will print it's output. """ def __init__(self, tars_to_print: List[str], manager=None, *args, **kwargs): - # A list of tars to print. - # Ex: ['000000.tar', '000008.tar', '00001a.tar'] if not tars_to_print: msg: str = "You must pass in a list of tars, which dictates" msg += " the order of which to print the results." raise RuntimeError(msg) - # Accept manager from outside to avoid pickling issues - # The manager must be created in the main process before forking if manager is None: raise ValueError("manager must be provided to PrintMonitor") # Store the ordered list of tars self._tars_list: List[str] = tars_to_print - # Use a simple counter instead of condition variables - # Tracks which tar index we're currently on + # Use a simple counter to track which tar we're on self._current_tar_index: multiprocessing.managers.ValueProxy = manager.Value( "i", 0 ) @@ -52,68 +45,37 @@ def __init__(self, tars_to_print: List[str], manager=None, *args, **kwargs): def wait_turn( self, worker, workers_curr_tar: str, indef_wait: bool = True, *args, **kwargs ): - - # Find the index of the worker's tar in the ordered list + """ + Wait until it's this worker's turn to process workers_curr_tar. + """ try: tar_index = self._tars_list.index(workers_curr_tar) except ValueError: - # sys.stderr.write(f"DEBUG: Tar {workers_curr_tar} not in list!\n") - # sys.stderr.flush() return - # sys.stderr.write( - # f"DEBUG: Worker waiting for tar {workers_curr_tar} (index {tar_index}), current index is {self._current_tar_index.value}\n" - # ) - # sys.stderr.flush() - - max_wait_time = 180.0 - start_time = time.time() attempted = False - while True: if self._current_tar_index.value == tar_index: - # sys.stderr.write( - # f"DEBUG: Worker got turn for tar {workers_curr_tar}!\n" - # ) - # sys.stderr.flush() return if attempted and not indef_wait: - # It's not this worker's turn. raise NotYourTurnError() - # Check if we've been waiting too long - if indef_wait and (time.time() - start_time) > max_wait_time: - raise TimeoutError( - f"Worker timed out waiting for turn to print {workers_curr_tar}. " - f"Current tar index is {self._current_tar_index.value} (expecting {tar_index})" - ) - attempted = True - # Sleep briefly and check again - time.sleep(0.1) + time.sleep(0.01) def done_enqueuing_output_for_tar( - # TODO: worker has type `ExtractWorker` - self, - worker, - workers_curr_tar: str, - *args, - **kwargs, + self, worker, workers_curr_tar: str, *args, **kwargs ): """ - A worker has finished printing the output for workers_curr_tar - from the print queue. + A worker has finished printing output for workers_curr_tar. Advance to the next tar in the sequence. """ - # Find our tar's index try: tar_index = self._tars_list.index(workers_curr_tar) except ValueError: return - # Advance to the next tar ONLY if we're the current tar - # This allows workers to signal completion without blocking with self._lock: if self._current_tar_index.value == tar_index: self._current_tar_index.value += 1 @@ -125,15 +87,12 @@ class ExtractWorker(object): It redirects all of the output of the logging module to a queue. Then with a PrintMonitor, it prints to the terminal in the order defined by the PrintMonitor. - - This worker is called during `zstash extract`. """ def __init__( self, print_monitor: PrintMonitor, tars_to_work_on: List[str], - # TODO: failure_queue has type `multiprocessing.Queue[FilesRow]` failure_queue, *args, **kwargs, @@ -144,14 +103,10 @@ def __init__( Any failures are added to the failure_queue, to return any failed values. """ self.print_monitor: PrintMonitor = print_monitor - # Every call to print() in the original function will - # be piped to this queue instead of the screen. self.print_queue: PrintQueue = PrintQueue() - # A tar is mapped to True when all of its output is in the queue. self.is_output_done_enqueuing: Dict[str, bool] = { tar: False for tar in tars_to_work_on } - # After extractFiles is done, all of the failures will be added to this queue. self.failure_queue: multiprocessing.Queue[FilesRow] = failure_queue def set_curr_tar(self, tar: str): @@ -164,7 +119,6 @@ def done_enqueuing_output_for_tar(self, tar: str): """ All of the output for extracting this tar is in the print queue. """ - msg: str if tar not in self.is_output_done_enqueuing: msg = "This tar {} isn't assigned to this worker." raise RuntimeError(msg.format(tar)) @@ -180,10 +134,8 @@ def print_contents(self): Try to print the contents from self.print_queue. """ try: - # We only wait for 0.001 seconds. self.print_all_contents(indef_wait=False) except NotYourTurnError: - # It's not our turn, so try again the next time this function is called. pass def has_to_print(self) -> bool: @@ -195,43 +147,20 @@ def has_to_print(self) -> bool: def print_all_contents(self, *args, **kwargs): """ Block until all of the contents of self.print_queue are printed. - - If it's not our turn and the passed in timeout to print_monitor.wait_turn - is over, a NotYourTurnError exception is raised. """ while self.has_to_print(): - # Try to print the first element in the queue. tar_to_print: str = self.print_queue[0].tar - try: - self.print_monitor.wait_turn(self, tar_to_print, *args, **kwargs) - except TimeoutError: - # If we timeout waiting to print, dump to stdout without ordering - # This prevents deadlocks - while self.print_queue and (self.print_queue[0].tar == tar_to_print): - err_msg: str = self.print_queue.popleft().msg - print(err_msg, end="", flush=True) - - # Mark as done even though we couldn't print in order - if self.is_output_done_enqueuing.get(tar_to_print, False): - # Skip the monitor sync since it timed out - pass - continue - - # Print all applicable values in the print_queue. + self.print_monitor.wait_turn(self, tar_to_print, *args, **kwargs) + + # Print all applicable values in the print_queue while self.print_queue and (self.print_queue[0].tar == tar_to_print): msg: str = self.print_queue.popleft().msg print(msg, end="", flush=True) - # If True, then all of the output for extracting tar_to_print was in the queue. - # Since we just finished printing all of it, we can move onto the next one. + # If all output for this tar is done, advance the monitor if self.is_output_done_enqueuing[tar_to_print]: - # Let all of the other workers know that this worker is done. - try: - self.print_monitor.done_enqueuing_output_for_tar(self, tar_to_print) - except TimeoutError: - # If we can't update the monitor, just continue - pass + self.print_monitor.done_enqueuing_output_for_tar(self, tar_to_print) class PrintQueue(collections.deque): @@ -249,8 +178,6 @@ def write(self, msg: str): self.append(TarAndMsg(self.curr_tar, msg)) def flush(self): - # Not needed, but it's called by some internal Python code. - # So we need to provide a function like this. pass From 6c9cf36ab047593f07811874a3e93f620a837bb8 Mon Sep 17 00:00:00 2001 From: Ryan Forsyth Date: Tue, 3 Feb 2026 11:43:57 -0800 Subject: [PATCH 6/9] Restore comments and type annotations --- zstash/extract.py | 99 +++++++++++++++++++++++++++++++++++++++++++++- zstash/parallel.py | 24 +++++++++-- 2 files changed, 119 insertions(+), 4 deletions(-) diff --git a/zstash/extract.py b/zstash/extract.py index 2ecfe792..1ea4596a 100644 --- a/zstash/extract.py +++ b/zstash/extract.py @@ -304,8 +304,15 @@ def multiprocess_extract( ) -> List[FilesRow]: """ Extract the files from the matches in parallel. + + A single unit of work is a tar and all of + the files in it to extract. """ + # A dict of tar -> size of files in it. + # This is because we're trying to balance the load between + # the processes. tar_to_size_unsorted: DefaultDict[str, float] = collections.defaultdict(float) + db_row: FilesRow for db_row in matches: tar_to_size_unsorted[db_row.tar] += db_row.size @@ -313,15 +320,23 @@ def multiprocess_extract( sorted(tar_to_size_unsorted.items(), key=lambda x: x[1]) ) + # We don't want to instantiate more processes than we need to. + # So, if the number of tars is less than the number of workers, + # set the number of workers to the number of tars. num_workers = min(num_workers, len(tar_to_size)) + # For worker i, workers_to_tars[i] is a set of tars + # that worker i will work on. # Round-robin assignment for predictable ordering workers_to_tars: List[set] = [set() for _ in range(num_workers)] + tar: str for idx, tar in enumerate(sorted(tar_to_size.keys())): workers_to_tars[idx % num_workers].add(tar) workers_to_matches: List[List[FilesRow]] = [[] for _ in range(num_workers)] + workers_idx: int for db_row in matches: + tar = db_row.tar for workers_idx in range(len(workers_to_tars)): if db_row.tar in workers_to_tars[workers_idx]: workers_to_matches[workers_idx].append(db_row) @@ -336,6 +351,7 @@ def multiprocess_extract( tar_ordering, manager=manager ) + # The return value for extractFiles will be added here. failure_queue: multiprocessing.Queue[FilesRow] = multiprocessing.Queue() processes: List[multiprocessing.Process] = [] @@ -352,6 +368,10 @@ def multiprocess_extract( process.start() processes.append(process) + # While the processes are running, we need to empty the queue. + # Otherwise, it causes hanging. + # No need to join() each of the processes when doing this, + # because we'll be in this loop until completion. failures: List[FilesRow] = [] while any(p.is_alive() for p in processes): while not failure_queue.empty(): @@ -361,6 +381,7 @@ def multiprocess_extract( while not failure_queue.empty(): failures.append(failure_queue.get()) + # Sort the failures, since they can come in at any order. failures.sort(key=lambda t: (t.name, t.tar, t.offset)) return failures @@ -371,46 +392,61 @@ def check_sizes_match(cur, tfname, error_on_duplicate_tar): actual_size: int = os.path.getsize(tfname) name_only: str = os.path.split(tfname)[1] + # Get ALL entries for this tar name cur.execute( "SELECT size FROM tars WHERE name = ? ORDER by id DESC", (name_only,) ) results = cur.fetchall() if not results: + # Cannot access size information; assume the sizes match. logger.error(f"No database entries found for {name_only}") return True + # Check for multiple entries if len(results) > 1: + # Extract just the size values sizes: List[int] = [row[0] for row in results] error_str: str = ( f"Database corruption detected! Found {len(results)} database entries for {name_only}, with sizes {sizes}" ) if error_on_duplicate_tar: + # Tested by database_corruption.bash Case 5 logger.error(error_str) raise RuntimeError(error_str) logger.warning(error_str) + # We ordered the results by id DESC, + # so the first entry is the most recent. most_recent_size: int = sizes[0] if actual_size == most_recent_size: + # Tested by database_corruption.bash Case 7 + # If the actual size matches the most recent size, + # then we can assume that the tar is valid. logger.info( f"{name_only}: The most recent database entry has the same size as the actual file size: {actual_size}." ) return True unique_sizes: Set[int] = set(sizes) if actual_size in unique_sizes: + # Tested by database_corruption.bash Case 8 logger.info( f"{name_only}: A database entry matches the actual file size, {actual_size}, but it is not the most recent entry." ) else: + # Tested by database_corruption.bash Case 6 logger.info( f"{name_only}: No database entry matches the actual file size: {actual_size}." ) return False else: + # Tested by database_corruption.bash Cases 1,2,4 + # Single entry - normal case logger.info(f"{name_only}: Found a single database entry.") expected_size = results[0][0] + # Now check if actual size matches expected size if expected_size != actual_size: error_msg = ( f"{name_only}: Size mismatch! " @@ -420,13 +456,16 @@ def check_sizes_match(cur, tfname, error_on_duplicate_tar): logger.error(error_msg) return False else: + # Sizes match logger.info(f"{name_only}: Size check passed ({actual_size} bytes)") return True else: + # Cannot access size information; assume the sizes match. logger.debug("Cannot access tar size information; assuming sizes match") return True +# FIXME: C901 'extractFiles' is too complex (33) def extractFiles( # noqa: C901 files: List[FilesRow], keep_files: bool, @@ -437,7 +476,20 @@ def extractFiles( # noqa: C901 cur: Optional[sqlite3.Cursor] = None, ) -> List[FilesRow]: """ - Given a list of database rows, extract the files from the tar archives. + Given a list of database rows, extract the files from the + tar archives to the current location on disk. + + If keep_files is False, the files are not extracted. + This is used for when checking if the files in an HPSS + repository are valid. + + If keep_tars is True, the tar archives that are downloaded are kept, + even after the program has terminated. Otherwise, they are deleted. + + If running in parallel, then multiprocess_worker is the Worker + that called this function. + We need a reference to it so we can signal it to print + the contents of what's in its print queue. If cur is None (when running in parallel), a new database connection will be opened for this worker process. @@ -459,19 +511,26 @@ def extractFiles( # noqa: C901 # Set up logging redirection for multiprocessing if multiprocess_worker: + # All messages to the logger will now be sent to + # this queue, instead of sys.stdout. sh = logging.StreamHandler(multiprocess_worker.print_queue) sh.setLevel(logging.DEBUG) formatter: logging.Formatter = logging.Formatter("%(levelname)s: %(message)s") sh.setFormatter(formatter) logger.addHandler(sh) + # Don't have the logger print to the console as the message come in. logger.propagate = False for i in range(nfiles): files_row: FilesRow = files[i] + # Open new tar archive if newtar: newtar = False tfname = os.path.join(cache, files_row.tar) + # Everytime we're extracting a new tar, if running in parallel, + # let the process know. + # This is to synchronize the print statements. # Wait for turn before processing this tar if multiprocess_worker: @@ -487,6 +546,8 @@ def extractFiles( # noqa: C901 raise TypeError("Invalid args.hpss={}".format(args.hpss)) tries: int = args.retries + 1 + # Set to True to test the `--retries` option with a forced failure. + # Then run `python -m unittest tests.test_extract.TestExtract.testExtractRetries` test_retry: bool = False while tries > 0: @@ -512,10 +573,12 @@ def extractFiles( # noqa: C901 raise RuntimeError( f"{tfname} size does not match expected size." ) + # `hpss_get` successful or not needed: no more tries needed break except RuntimeError as e: if tries > 0: logger.info(f"Retrying HPSS get: {tries} tries remaining.") + # Run the try-except block again continue else: raise e @@ -526,25 +589,34 @@ def extractFiles( # noqa: C901 # Extract file cmd: str = "Extracting" if keep_files else "Checking" logger.info(cmd + " %s" % (files_row.name)) + # if multiprocess_worker: + # print('{} is {} {} from {}'.format(multiprocess_worker, cmd, file[1], file[5])) if keep_files and not should_extract_file(files_row): + # If we were going to extract, but aren't + # because a matching file is on disk msg: str = "Not extracting {}, because it" msg += " already exists on disk with the same" msg += " size and modification date." logger.info(msg.format(files_row.name)) + # True if we should actually extract the file from the tar extract_this_file: bool = keep_files and should_extract_file(files_row) try: + # Seek file position if tar.fileobj is not None: fileobj = tar.fileobj else: raise TypeError("Invalid tar.fileobj={}".format(tar.fileobj)) fileobj.seek(files_row.offset) + # Get next member tarinfo: tarfile.TarInfo = tar.tarinfo.fromtarfile(tar) if tarinfo.isfile(): + # fileobj to extract + # error: Name 'tarfile.ExFileObject' is not defined extracted_file: Optional[tarfile.ExFileObject] = tar.extractfile(tarinfo) # type: ignore if extracted_file: fin: tarfile.ExFileObject = extracted_file @@ -557,8 +629,11 @@ def extractFiles( # noqa: C901 path, name = os.path.split(fname) if path != "" and extract_this_file: if not os.path.isdir(path): + # The path doesn't exist, so create it. os.makedirs(path) if extract_this_file: + # If we're keeping the files, + # then have an output file fout: _io.BufferedWriter = open(fname, "wb") hash_md5: _hashlib.HASH = hashlib.md5() @@ -577,12 +652,17 @@ def extractFiles( # noqa: C901 md5: str = hash_md5.hexdigest() if extract_this_file: + # numeric_owner is a required arg in Python 3. + # If True, "only the numbers for user/group names + # are used and not the names". tar.chown(tarinfo, fname, numeric_owner=False) tar.chmod(tarinfo, fname) tar.utime(tarinfo, fname) + # Verify size if os.path.getsize(fname) != files_row.size: logger.error("size mismatch for: {}".format(fname)) + # Verify md5 checksum files_row_md5: Optional[str] = files_row.md5 if md5 != files_row_md5: logger.error("md5 mismatch for: {}".format(fname)) @@ -597,6 +677,10 @@ def extractFiles( # noqa: C901 tar.extract(tarinfo, filter="tar") else: tar.extract(tarinfo) + # Note: tar.extract() will not restore time stamps of symbolic + # links. Could not find a Python-way to restore it either, so + # relying here on 'touch'. This is not the prettiest solution. + # Maybe a better one can be implemented later. if tarinfo.issym(): tmp1 = tarinfo.mtime tmp2: datetime = datetime.fromtimestamp(tmp1) @@ -604,12 +688,16 @@ def extractFiles( # noqa: C901 os.system("touch -h -t %s %s" % (tmp3, tarinfo.name)) except Exception: + # Catch all exceptions here. traceback.print_exc() logger.error("Retrieving {}".format(files_row.name)) failures.append(files_row) # Close current archive? if i == nfiles - 1 or files[i].tar != files[i + 1].tar: + # We're either on the last file or the tar is distinct from the tar of the next file. + + # Close current archive file logger.debug("Closing tar archive {}".format(tfname)) tar.close() @@ -617,8 +705,10 @@ def extractFiles( # noqa: C901 multiprocess_worker.done_enqueuing_output_for_tar(files_row.tar) multiprocess_worker.print_all_contents() + # Open new archive next time newtar = True + # Delete this tar if the corresponding command-line arg was used. if not keep_tars: if tfname is not None: os.remove(tfname) @@ -630,6 +720,9 @@ def extractFiles( # noqa: C901 cur.close() con.close() + # Add the failures to the queue. + # When running with multiprocessing, the function multiprocess_extract() + # that calls this extractFiles() function will return the failures as a list. if multiprocess_worker: for f in failures: multiprocess_worker.failure_queue.put(f) @@ -648,11 +741,15 @@ def should_extract_file(db_row: FilesRow) -> bool: file_name, size_db, mod_time_db = db_row.name, db_row.size, db_row.mtime if not os.path.exists(file_name): + # The file doesn't exist locally. + # We must get files that are not on disk. return True size_disk: int = os.path.getsize(file_name) mod_time_disk: datetime = datetime.utcfromtimestamp(os.path.getmtime(file_name)) + # Only extract when the times and sizes are not the same (within tolerance) + # We have a TIME_TOL because mod_time_disk doesn't have the microseconds. return not ( (size_disk == size_db) and (abs(mod_time_disk - mod_time_db).total_seconds() < TIME_TOL) diff --git a/zstash/parallel.py b/zstash/parallel.py index 78d3a3de..4470a399 100644 --- a/zstash/parallel.py +++ b/zstash/parallel.py @@ -20,9 +20,13 @@ class NotYourTurnError(Exception): class PrintMonitor(object): """ Used to synchronize the printing of the output between workers. + Depending on the current_tar, the worker processing the work + for that tar will print it's output. """ def __init__(self, tars_to_print: List[str], manager=None, *args, **kwargs): + # A list of tars to print. + # Ex: ['000000.tar', '000008.tar', '00001a.tar'] if not tars_to_print: msg: str = "You must pass in a list of tars, which dictates" msg += " the order of which to print the results." @@ -87,15 +91,18 @@ class ExtractWorker(object): It redirects all of the output of the logging module to a queue. Then with a PrintMonitor, it prints to the terminal in the order defined by the PrintMonitor. + + This worker is called during `zstash extract`. """ def __init__( self, print_monitor: PrintMonitor, tars_to_work_on: List[str], + # TODO: failure_queue has type `multiprocessing.Queue[FilesRow]` failure_queue, *args, - **kwargs, + **kwargs ): """ print_monitor is used to determine if it's this worker's turn to print. @@ -103,10 +110,14 @@ def __init__( Any failures are added to the failure_queue, to return any failed values. """ self.print_monitor: PrintMonitor = print_monitor + # Every call to print() in the original function will + # be piped to this queue instead of the screen. self.print_queue: PrintQueue = PrintQueue() + # A tar is mapped to True when all of its output is in the queue. self.is_output_done_enqueuing: Dict[str, bool] = { tar: False for tar in tars_to_work_on } + # After extractFiles is done, all of the failures will be added to this queue. self.failure_queue: multiprocessing.Queue[FilesRow] = failure_queue def set_curr_tar(self, tar: str): @@ -119,6 +130,7 @@ def done_enqueuing_output_for_tar(self, tar: str): """ All of the output for extracting this tar is in the print queue. """ + msg: str if tar not in self.is_output_done_enqueuing: msg = "This tar {} isn't assigned to this worker." raise RuntimeError(msg.format(tar)) @@ -136,6 +148,7 @@ def print_contents(self): try: self.print_all_contents(indef_wait=False) except NotYourTurnError: + # It's not our turn, so try again the next time this function is called. pass def has_to_print(self) -> bool: @@ -147,13 +160,16 @@ def has_to_print(self) -> bool: def print_all_contents(self, *args, **kwargs): """ Block until all of the contents of self.print_queue are printed. + + If it's not our turn and the passed in timeout to print_monitor.wait_turn + is over, a NotYourTurnError exception is raised. """ while self.has_to_print(): + # Try to print the first element in the queue. tar_to_print: str = self.print_queue[0].tar - self.print_monitor.wait_turn(self, tar_to_print, *args, **kwargs) - # Print all applicable values in the print_queue + # Print all applicable values in the print_queue. while self.print_queue and (self.print_queue[0].tar == tar_to_print): msg: str = self.print_queue.popleft().msg print(msg, end="", flush=True) @@ -178,6 +194,8 @@ def write(self, msg: str): self.append(TarAndMsg(self.curr_tar, msg)) def flush(self): + # Not needed, but it's called by some internal Python code. + # So we need to provide a function like this. pass From b4e9957309ac3d58cb750e8362bc28542282f8b2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 23 Feb 2026 19:23:00 +0000 Subject: [PATCH 7/9] Restore heap-based worker allocation in multiprocess_extract Co-authored-by: forsyth2 <30700190+forsyth2@users.noreply.github.com> --- zstash/extract.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/zstash/extract.py b/zstash/extract.py index 1ea4596a..b38f28df 100644 --- a/zstash/extract.py +++ b/zstash/extract.py @@ -3,6 +3,7 @@ import argparse import collections import hashlib +import heapq import logging import multiprocessing import os.path @@ -327,11 +328,25 @@ def multiprocess_extract( # For worker i, workers_to_tars[i] is a set of tars # that worker i will work on. - # Round-robin assignment for predictable ordering workers_to_tars: List[set] = [set() for _ in range(num_workers)] + # A min heap, of (work, worker_idx) tuples, work is the size of data + # that worker_idx needs to work on. + # We can efficiently get the worker with the least amount of work. + work_to_workers: List[Tuple[float, int]] = [(0.0, i) for i in range(num_workers)] + heapq.heapify(work_to_workers) + + # Using a greedy approach, populate workers_to_tars. tar: str - for idx, tar in enumerate(sorted(tar_to_size.keys())): - workers_to_tars[idx % num_workers].add(tar) + for tar in tar_to_size: + # The worker with the least work should get the current largest amount of work. + workers_work: float + worker_idx: int + workers_work, worker_idx = heapq.heappop(work_to_workers) + workers_to_tars[worker_idx].add(tar) + # Add this worker back to the heap, with the new amount of work. + worker_tuple: Tuple[float, int] = (workers_work + tar_to_size[tar], worker_idx) + # FIXME: error: Cannot infer type argument 1 of "heappush" + heapq.heappush(work_to_workers, worker_tuple) # type: ignore workers_to_matches: List[List[FilesRow]] = [[] for _ in range(num_workers)] workers_idx: int From 72ef3ac068086b7a13e9120be7fd199a4a7ec09a Mon Sep 17 00:00:00 2001 From: Ryan Forsyth Date: Mon, 23 Feb 2026 11:44:07 -0800 Subject: [PATCH 8/9] Fix pre-commit check --- zstash/extract.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zstash/extract.py b/zstash/extract.py index b38f28df..4ee91ee3 100644 --- a/zstash/extract.py +++ b/zstash/extract.py @@ -346,7 +346,7 @@ def multiprocess_extract( # Add this worker back to the heap, with the new amount of work. worker_tuple: Tuple[float, int] = (workers_work + tar_to_size[tar], worker_idx) # FIXME: error: Cannot infer type argument 1 of "heappush" - heapq.heappush(work_to_workers, worker_tuple) # type: ignore + heapq.heappush(work_to_workers, worker_tuple) workers_to_matches: List[List[FilesRow]] = [[] for _ in range(num_workers)] workers_idx: int From a87b57c6bdd8a9c1311cecf9f112068508373f59 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 25 Feb 2026 18:52:44 +0000 Subject: [PATCH 9/9] Address code review: remove wait_turn from extraction, fix HPSS auth, manager cleanup, queue drain, precomputed tar index map Co-authored-by: forsyth2 <30700190+forsyth2@users.noreply.github.com> --- zstash/extract.py | 34 +++++++++++++++++++++------------- zstash/parallel.py | 19 +++++++++++-------- 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/zstash/extract.py b/zstash/extract.py index 4ee91ee3..c72432fa 100644 --- a/zstash/extract.py +++ b/zstash/extract.py @@ -7,6 +7,7 @@ import logging import multiprocessing import os.path +import queue import re import sqlite3 import sys @@ -389,12 +390,21 @@ def multiprocess_extract( # because we'll be in this loop until completion. failures: List[FilesRow] = [] while any(p.is_alive() for p in processes): - while not failure_queue.empty(): - failures.append(failure_queue.get()) + try: + while True: + failures.append(failure_queue.get_nowait()) + except queue.Empty: + pass time.sleep(0.01) - while not failure_queue.empty(): - failures.append(failure_queue.get()) + # Drain any remaining failures after all processes have exited. + try: + while True: + failures.append(failure_queue.get_nowait()) + except queue.Empty: + pass + + manager.shutdown() # Sort the failures, since they can come in at any order. failures.sort(key=lambda t: (t.name, t.tar, t.offset)) @@ -547,18 +557,16 @@ def extractFiles( # noqa: C901 # let the process know. # This is to synchronize the print statements. - # Wait for turn before processing this tar if multiprocess_worker: - multiprocess_worker.print_monitor.wait_turn( - multiprocess_worker, files_row.tar - ) multiprocess_worker.set_curr_tar(files_row.tar) - # Use args.hpss directly + # Use args.hpss, falling back to config.hpss when not provided if args.hpss is not None: hpss: str = args.hpss + elif config.hpss is not None: + hpss = config.hpss else: - raise TypeError("Invalid args.hpss={}".format(args.hpss)) + raise TypeError("Invalid config.hpss={}".format(config.hpss)) tries: int = args.retries + 1 # Set to True to test the `--retries` option with a forced failure. @@ -735,9 +743,9 @@ def extractFiles( # noqa: C901 cur.close() con.close() - # Add the failures to the queue. - # When running with multiprocessing, the function multiprocess_extract() - # that calls this extractFiles() function will return the failures as a list. + # Add the failures to the queue. + # When running with multiprocessing, the function multiprocess_extract() + # that calls this extractFiles() function will return the failures as a list. if multiprocess_worker: for f in failures: multiprocess_worker.failure_queue.put(f) diff --git a/zstash/parallel.py b/zstash/parallel.py index 4470a399..4bcda598 100644 --- a/zstash/parallel.py +++ b/zstash/parallel.py @@ -38,6 +38,11 @@ def __init__(self, tars_to_print: List[str], manager=None, *args, **kwargs): # Store the ordered list of tars self._tars_list: List[str] = tars_to_print + # Precomputed mapping from tar name to its position in the ordered list. + self._tar_to_index: Dict[str, int] = { + tar: i for i, tar in enumerate(tars_to_print) + } + # Use a simple counter to track which tar we're on self._current_tar_index: multiprocessing.managers.ValueProxy = manager.Value( "i", 0 @@ -52,10 +57,9 @@ def wait_turn( """ Wait until it's this worker's turn to process workers_curr_tar. """ - try: - tar_index = self._tars_list.index(workers_curr_tar) - except ValueError: - return + if workers_curr_tar not in self._tar_to_index: + raise RuntimeError("Tar {} not in ordered list".format(workers_curr_tar)) + tar_index = self._tar_to_index[workers_curr_tar] attempted = False while True: @@ -75,10 +79,9 @@ def done_enqueuing_output_for_tar( A worker has finished printing output for workers_curr_tar. Advance to the next tar in the sequence. """ - try: - tar_index = self._tars_list.index(workers_curr_tar) - except ValueError: - return + if workers_curr_tar not in self._tar_to_index: + raise RuntimeError("Tar {} not in ordered list".format(workers_curr_tar)) + tar_index = self._tar_to_index[workers_curr_tar] with self._lock: if self._current_tar_index.value == tar_index: