diff --git a/.github/workflows/build_workflow.yml b/.github/workflows/build_workflow.yml index d0aa857a..a5e0913e 100644 --- a/.github/workflows/build_workflow.yml +++ b/.github/workflows/build_workflow.yml @@ -19,10 +19,10 @@ jobs: - name: Checkout Code Repository uses: actions/checkout@v3 - - name: Set up Python 3.13 + - name: Set up Python 3.14 uses: actions/setup-python@v4 with: - python-version: "3.13" + python-version: "3.14" # Run all pre-commit hooks on all the files. # Getting only staged files can be tricky in case a new PR is opened @@ -35,8 +35,9 @@ jobs: build: runs-on: ubuntu-latest strategy: + fail-fast: false # Continue all jobs even if one fails matrix: - python-version: ["3.11", "3.12", "3.13"] + python-version: ["3.11", "3.12", "3.13", "3.14"] defaults: run: shell: bash -l {0} @@ -72,11 +73,6 @@ jobs: conda list # Ensure we have the right Python version python --version - # Fix pip issues for Python 3.12+ - if [[ "${{ matrix.python-version }}" == "3.12" ]] || [[ "${{ matrix.python-version }}" == "3.13" ]]; then - python -m ensurepip --upgrade || true - python -m pip install --upgrade --force-reinstall pip setuptools wheel - fi - name: Install `zstash` Package run: | @@ -121,7 +117,7 @@ jobs: environment-file: conda/dev.yml channel-priority: flexible # Changed from strict to flexible auto-update-conda: true - python-version: "3.13" # Use stable Python version for docs + python-version: "3.14" # Use stable Python version for docs # sphinx-multiversion allows for version docs. - name: Build Sphinx Docs diff --git a/conda/dev.yml b/conda/dev.yml index 6ec8e1b9..85f64d32 100644 --- a/conda/dev.yml +++ b/conda/dev.yml @@ -5,7 +5,8 @@ dependencies: # Base # ================= - pip - - python >=3.11,<3.14 + - python >=3.11,<3.15 + - setuptools - sqlite - six >=1.16.0 - globus-sdk >=3.15.0,<4.0 diff --git a/setup.cfg b/setup.cfg index 4e6c6cb4..5a75b3ef 100644 --- a/setup.cfg +++ b/setup.cfg @@ -46,7 +46,7 @@ exclude = venv [mypy] -python_version = 3.13 +python_version = 3.14 check_untyped_defs = True ignore_missing_imports = True warn_unused_ignores = True diff --git a/setup.py b/setup.py index 292653e0..67df2e8e 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,6 @@ author_email="forsyth2@llnl.gov, golaz1@llnl.gov, shaheen2@llnl.gov", description="Long term HPSS archiving software for E3SM", packages=find_packages(include=["zstash", "zstash.*"]), - python_requires=">=3.11,<3.14", + python_requires=">=3.11,<3.15", entry_points={"console_scripts": ["zstash=zstash.main:main"]}, ) diff --git a/zstash/extract.py b/zstash/extract.py index 64977aef..c72432fa 100644 --- a/zstash/extract.py +++ b/zstash/extract.py @@ -7,10 +7,12 @@ import logging import multiprocessing import os.path +import queue import re import sqlite3 import sys import tarfile +import time import traceback from datetime import datetime from typing import DefaultDict, List, Optional, Set, Tuple @@ -282,10 +284,10 @@ def extract_database( if args.workers > 1: logger.debug("Running zstash {} with multiprocessing".format(cmd)) failures = multiprocess_extract( - args.workers, matches, keep_files, keep, cache, cur, args + args.workers, matches, keep_files, keep, cache, args ) else: - failures = extractFiles(matches, keep_files, keep, cache, cur, args) + failures = extractFiles(matches, keep_files, keep, cache, args, None, cur) # Close database logger.debug("Closing index database") @@ -300,7 +302,6 @@ def multiprocess_extract( keep_files: bool, keep_tars: Optional[bool], cache: str, - cur: sqlite3.Cursor, args: argparse.Namespace, ) -> List[FilesRow]: """ @@ -314,12 +315,9 @@ def multiprocess_extract( # the processes. tar_to_size_unsorted: DefaultDict[str, float] = collections.defaultdict(float) db_row: FilesRow - tar: str - size: int for db_row in matches: - tar, size = db_row.tar, db_row.size - tar_to_size_unsorted[tar] += size - # Sort by the size. + tar_to_size_unsorted[db_row.tar] += db_row.size + tar_to_size: collections.OrderedDict[str, float] = collections.OrderedDict( sorted(tar_to_size_unsorted.items(), key=lambda x: x[1]) ) @@ -335,38 +333,44 @@ def multiprocess_extract( # A min heap, of (work, worker_idx) tuples, work is the size of data # that worker_idx needs to work on. # We can efficiently get the worker with the least amount of work. - work_to_workers: List[Tuple[int, int]] = [(0, i) for i in range(num_workers)] - heapq.heapify(workers_to_tars) + work_to_workers: List[Tuple[float, int]] = [(0.0, i) for i in range(num_workers)] + heapq.heapify(work_to_workers) # Using a greedy approach, populate workers_to_tars. - for _, tar in enumerate(tar_to_size): + tar: str + for tar in tar_to_size: # The worker with the least work should get the current largest amount of work. - workers_work: int + workers_work: float worker_idx: int workers_work, worker_idx = heapq.heappop(work_to_workers) workers_to_tars[worker_idx].add(tar) # Add this worker back to the heap, with the new amount of work. worker_tuple: Tuple[float, int] = (workers_work + tar_to_size[tar], worker_idx) # FIXME: error: Cannot infer type argument 1 of "heappush" - heapq.heappush(work_to_workers, worker_tuple) # type: ignore + heapq.heappush(work_to_workers, worker_tuple) - # For worker i, workers_to_matches[i] is a list of - # matches from the database for it to process. workers_to_matches: List[List[FilesRow]] = [[] for _ in range(num_workers)] + workers_idx: int for db_row in matches: tar = db_row.tar - workers_idx: int for workers_idx in range(len(workers_to_tars)): - if tar in workers_to_tars[workers_idx]: - # This worker gets this db_row. + if db_row.tar in workers_to_tars[workers_idx]: workers_to_matches[workers_idx].append(db_row) + # Ensure each worker processes tars in order + for worker_matches in workers_to_matches: + worker_matches.sort(key=lambda t: t.tar) + tar_ordering: List[str] = sorted([tar for tar in tar_to_size]) - monitor: parallel.PrintMonitor = parallel.PrintMonitor(tar_ordering) + manager = multiprocessing.Manager() + monitor: parallel.PrintMonitor = parallel.PrintMonitor( + tar_ordering, manager=manager + ) # The return value for extractFiles will be added here. failure_queue: multiprocessing.Queue[FilesRow] = multiprocessing.Queue() processes: List[multiprocessing.Process] = [] + for matches in workers_to_matches: tars_for_this_worker: List[str] = list(set(match.tar for match in matches)) worker: parallel.ExtractWorker = parallel.ExtractWorker( @@ -374,7 +378,7 @@ def multiprocess_extract( ) process: multiprocessing.Process = multiprocessing.Process( target=extractFiles, - args=(matches, keep_files, keep_tars, cache, cur, args, worker), + args=(matches, keep_files, keep_tars, cache, args, worker, None), daemon=True, ) process.start() @@ -386,8 +390,21 @@ def multiprocess_extract( # because we'll be in this loop until completion. failures: List[FilesRow] = [] while any(p.is_alive() for p in processes): - while not failure_queue.empty(): - failures.append(failure_queue.get()) + try: + while True: + failures.append(failure_queue.get_nowait()) + except queue.Empty: + pass + time.sleep(0.01) + + # Drain any remaining failures after all processes have exited. + try: + while True: + failures.append(failure_queue.get_nowait()) + except queue.Empty: + pass + + manager.shutdown() # Sort the failures, since they can come in at any order. failures.sort(key=lambda t: (t.name, t.tar, t.offset)) @@ -479,9 +496,9 @@ def extractFiles( # noqa: C901 keep_files: bool, keep_tars: Optional[bool], cache: str, - cur: sqlite3.Cursor, args: argparse.Namespace, multiprocess_worker: Optional[parallel.ExtractWorker] = None, + cur: Optional[sqlite3.Cursor] = None, ) -> List[FilesRow]: """ Given a list of database rows, extract the files from the @@ -498,11 +515,26 @@ def extractFiles( # noqa: C901 that called this function. We need a reference to it so we can signal it to print the contents of what's in its print queue. + + If cur is None (when running in parallel), a new database connection + will be opened for this worker process. """ + # Open database connection if not provided (parallel case) + if cur is None: + con: sqlite3.Connection = sqlite3.connect( + get_db_filename(cache), detect_types=sqlite3.PARSE_DECLTYPES + ) + cur = con.cursor() + close_db: bool = True + else: + close_db = False + failures: List[FilesRow] = [] tfname: str newtar: bool = True nfiles: int = len(files) + + # Set up logging redirection for multiprocessing if multiprocess_worker: # All messages to the logger will now be sent to # this queue, instead of sys.stdout. @@ -524,17 +556,23 @@ def extractFiles( # noqa: C901 # Everytime we're extracting a new tar, if running in parallel, # let the process know. # This is to synchronize the print statements. + if multiprocess_worker: multiprocess_worker.set_curr_tar(files_row.tar) - if config.hpss is not None: - hpss: str = config.hpss + # Use args.hpss, falling back to config.hpss when not provided + if args.hpss is not None: + hpss: str = args.hpss + elif config.hpss is not None: + hpss = config.hpss else: raise TypeError("Invalid config.hpss={}".format(config.hpss)) + tries: int = args.retries + 1 # Set to True to test the `--retries` option with a forced failure. # Then run `python -m unittest tests.test_extract.TestExtract.testExtractRetries` test_retry: bool = False + while tries > 0: tries -= 1 do_retrieve: bool @@ -653,13 +691,15 @@ def extractFiles( # noqa: C901 logger.error("md5 mismatch for: {}".format(fname)) logger.error("md5 of extracted file: {}".format(md5)) logger.error("md5 of original file: {}".format(files_row_md5)) - failures.append(files_row) else: logger.debug("Valid md5: {} {}".format(md5, fname)) elif extract_this_file: - tar.extract(tarinfo) + if sys.version_info >= (3, 12): + tar.extract(tarinfo, filter="tar") + else: + tar.extract(tarinfo) # Note: tar.extract() will not restore time stamps of symbolic # links. Could not find a Python-way to restore it either, so # relying here on 'touch'. This is not the prettiest solution. @@ -676,9 +716,6 @@ def extractFiles( # noqa: C901 logger.error("Retrieving {}".format(files_row.name)) failures.append(files_row) - if multiprocess_worker: - multiprocess_worker.print_contents() - # Close current archive? if i == nfiles - 1 or files[i].tar != files[i + 1].tar: # We're either on the last file or the tar is distinct from the tar of the next file. @@ -689,6 +726,7 @@ def extractFiles( # noqa: C901 if multiprocess_worker: multiprocess_worker.done_enqueuing_output_for_tar(files_row.tar) + multiprocess_worker.print_all_contents() # Open new archive next time newtar = True @@ -700,15 +738,18 @@ def extractFiles( # noqa: C901 else: raise TypeError("Invalid tfname={}".format(tfname)) - if multiprocess_worker: - # If there are things left to print, print them. - multiprocess_worker.print_all_contents() + # Close database connection if we opened it + if close_db: + cur.close() + con.close() - # Add the failures to the queue. - # When running with multiprocessing, the function multiprocess_extract() - # that calls this extractFiles() function will return the failures as a list. + # Add the failures to the queue. + # When running with multiprocessing, the function multiprocess_extract() + # that calls this extractFiles() function will return the failures as a list. + if multiprocess_worker: for f in failures: multiprocess_worker.failure_queue.put(f) + return failures diff --git a/zstash/parallel.py b/zstash/parallel.py index 53beb53d..4bcda598 100644 --- a/zstash/parallel.py +++ b/zstash/parallel.py @@ -1,8 +1,8 @@ from __future__ import print_function import collections -import ctypes import multiprocessing +import time from typing import Dict, List, Optional from .settings import FilesRow @@ -24,7 +24,7 @@ class PrintMonitor(object): for that tar will print it's output. """ - def __init__(self, tars_to_print: List[str], *args, **kwargs): + def __init__(self, tars_to_print: List[str], manager=None, *args, **kwargs): # A list of tars to print. # Ex: ['000000.tar', '000008.tar', '00001a.tar'] if not tars_to_print: @@ -32,75 +32,60 @@ def __init__(self, tars_to_print: List[str], *args, **kwargs): msg += " the order of which to print the results." raise RuntimeError(msg) - # The variables below are modified/accessed by different processes, - # so they need to be in shared memory. - self._cv: multiprocessing.synchronize.Condition = multiprocessing.Condition() - - self._tars_to_print: multiprocessing.Queue[str] = multiprocessing.Queue() - tar: str - for tar in tars_to_print: - # Add the tar to the queue to be printed. - self._tars_to_print.put(tar) - - # We need a manager to instantiate the Value instead of multiprocessing.Value. - # If we didn't use a manager, it seems to get some junk value. - self._manager: multiprocessing.managers.SyncManager = multiprocessing.Manager() - self._current_tar: multiprocessing.managers.ValueProxy = self._manager.Value( - ctypes.c_char_p, self._tars_to_print.get() + if manager is None: + raise ValueError("manager must be provided to PrintMonitor") + + # Store the ordered list of tars + self._tars_list: List[str] = tars_to_print + + # Precomputed mapping from tar name to its position in the ordered list. + self._tar_to_index: Dict[str, int] = { + tar: i for i, tar in enumerate(tars_to_print) + } + + # Use a simple counter to track which tar we're on + self._current_tar_index: multiprocessing.managers.ValueProxy = manager.Value( + "i", 0 ) + # Lock for updating the counter + self._lock: multiprocessing.synchronize.Lock = manager.Lock() + def wait_turn( - # TODO: worker has type `ExtractWorker` - self, - worker, - workers_curr_tar: str, - indef_wait: bool = True, - *args, - **kwargs + self, worker, workers_curr_tar: str, indef_wait: bool = True, *args, **kwargs ): """ - While a worker's current tar isn't the one - needed to be printed, wait. + Wait until it's this worker's turn to process workers_curr_tar. + """ + if workers_curr_tar not in self._tar_to_index: + raise RuntimeError("Tar {} not in ordered list".format(workers_curr_tar)) + tar_index = self._tar_to_index[workers_curr_tar] - A timeout is passed into self._cv.wait(), and if the - turn isn't given within that, a NotYourTurnError is raised. + attempted = False + while True: + if self._current_tar_index.value == tar_index: + return - If indef_wait is True, indefinitely wait until it's - the worker's turn. - """ - with self._cv: - attempted: bool = False - while self._current_tar.value != workers_curr_tar: - if attempted and not indef_wait: - # It's not this worker's turn. - raise NotYourTurnError() + if attempted and not indef_wait: + raise NotYourTurnError() - attempted = True - # Wait 0.001 to see if it's the worker's turn. - self._cv.wait(0.001) + attempted = True + time.sleep(0.01) - def done_dequeuing_output_for_tar( - # TODO: worker has type `ExtractWorker` - self, - worker, - workers_curr_tar: str, - *args, - **kwargs + def done_enqueuing_output_for_tar( + self, worker, workers_curr_tar: str, *args, **kwargs ): """ - A worker has finished printing the output for workers_curr_tar - from the print queue. - If possible, update self._current_tar. - If there aren't anymore tars to print, set self._current_tar to None. + A worker has finished printing output for workers_curr_tar. + Advance to the next tar in the sequence. """ - # It must be the worker's turn before this can happen. - self.wait_turn(worker, workers_curr_tar, *args, **kwargs) + if workers_curr_tar not in self._tar_to_index: + raise RuntimeError("Tar {} not in ordered list".format(workers_curr_tar)) + tar_index = self._tar_to_index[workers_curr_tar] - with self._cv: - self._current_tar.value = ( - self._tars_to_print.get() if not self._tars_to_print.empty() else "" - ) - self._cv.notify_all() + with self._lock: + if self._current_tar_index.value == tar_index: + self._current_tar_index.value += 1 class ExtractWorker(object): @@ -164,7 +149,6 @@ def print_contents(self): Try to print the contents from self.print_queue. """ try: - # We only wait for 0.001 seconds. self.print_all_contents(indef_wait=False) except NotYourTurnError: # It's not our turn, so try again the next time this function is called. @@ -193,11 +177,9 @@ def print_all_contents(self, *args, **kwargs): msg: str = self.print_queue.popleft().msg print(msg, end="", flush=True) - # If True, then all of the output for extracting tar_to_print was in the queue. - # Since we just finished printing all of it, we can move onto the next one. + # If all output for this tar is done, advance the monitor if self.is_output_done_enqueuing[tar_to_print]: - # Let all of the other workers know that this worker is done. - self.print_monitor.done_dequeuing_output_for_tar(self, tar_to_print) + self.print_monitor.done_enqueuing_output_for_tar(self, tar_to_print) class PrintQueue(collections.deque):