Skip to content

Commit 7085aad

Browse files
committed
Return comments Claude mistakenly removed
1 parent d9950f9 commit 7085aad

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

zstash/extract.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,8 @@ def multiprocess_extract(
437437
) -> List[FilesRow]:
438438
"""
439439
Extract the files from the matches in parallel with checkpoint support.
440+
441+
A single unit of work is a tar and all of the files in it to extract.
440442
"""
441443
# Create checkpoint queue and saver process ONLY for check operations
442444
checkpoint_queue: Optional[multiprocessing.Queue[Optional[Tuple]]] = None
@@ -448,6 +450,8 @@ def multiprocess_extract(
448450
checkpoint_saver.start()
449451

450452
# A dict of tar -> size of files in it.
453+
# This is because we're trying to balance the load between
454+
# the processes.
451455
tar_to_size_unsorted: DefaultDict[str, float] = collections.defaultdict(float)
452456
db_row: FilesRow
453457
tar: str
@@ -456,35 +460,51 @@ def multiprocess_extract(
456460
tar, size = db_row.tar, db_row.size
457461
tar_to_size_unsorted[tar] += size
458462

463+
# Sort by the size.
459464
tar_to_size: collections.OrderedDict[str, float] = collections.OrderedDict(
460465
sorted(tar_to_size_unsorted.items(), key=lambda x: x[1])
461466
)
462467

468+
# We don't want to instantiate more processes than we need to.
469+
# So, if the number of tars is less than the number of workers,
470+
# set the number of workers to the number of tars.
463471
num_workers = min(num_workers, len(tar_to_size))
464472

473+
# For worker i, workers_to_tars[i] is a set of tars
474+
# that worker i will work on.
465475
workers_to_tars: List[set] = [set() for _ in range(num_workers)]
476+
# A min heap, of (work, worker_idx) tuples, work is the size of data
477+
# that worker_idx needs to work on.
478+
# We can efficiently get the worker with the least amount of work.
466479
work_to_workers: List[Tuple[float, int]] = [(0.0, i) for i in range(num_workers)]
467480
heapq.heapify(work_to_workers) # Fixed: was heapify(workers_to_tars)
468481

482+
# Using a greedy approach, populate workers_to_tars.
469483
for _, tar in enumerate(tar_to_size):
484+
# The worker with the least work should get the current largest amount of work.
470485
workers_work: float # Changed from int
471486
worker_idx: int
472487
workers_work, worker_idx = heapq.heappop(work_to_workers)
473488
workers_to_tars[worker_idx].add(tar)
489+
# Add this worker back to the heap, with the new amount of work.
474490
worker_tuple: Tuple[float, int] = (workers_work + tar_to_size[tar], worker_idx)
475491
heapq.heappush(work_to_workers, worker_tuple) # No type: ignore needed!
476492

493+
# For worker i, workers_to_matches[i] is a list of
494+
# matches from the database for it to process.
477495
workers_to_matches: List[List[FilesRow]] = [[] for _ in range(num_workers)]
478496
for db_row in matches:
479497
tar = db_row.tar
480498
workers_idx: int
481499
for workers_idx in range(len(workers_to_tars)):
482500
if tar in workers_to_tars[workers_idx]:
501+
# This worker gets this db_row.
483502
workers_to_matches[workers_idx].append(db_row)
484503

485504
tar_ordering: List[str] = sorted([tar for tar in tar_to_size])
486505
monitor: parallel.PrintMonitor = parallel.PrintMonitor(tar_ordering)
487506

507+
# The return value for extractFiles will be added here.
488508
failure_queue: multiprocessing.Queue[FilesRow] = multiprocessing.Queue()
489509
processes: List[multiprocessing.Process] = []
490510

@@ -513,6 +533,10 @@ def multiprocess_extract(
513533
process.start()
514534
processes.append(process)
515535

536+
# While the processes are running, we need to empty the queue.
537+
# Otherwise, it causes hanging.
538+
# No need to join() each of the processes when doing this,
539+
# because we'll be in this loop until completion.
516540
failures: List[FilesRow] = []
517541
while any(p.is_alive() for p in processes):
518542
while not failure_queue.empty():
@@ -628,6 +652,21 @@ def extractFiles( # noqa: C901
628652
"""
629653
Extract files with checkpoint support even in multiprocessing mode.
630654
655+
Given a list of database rows, extract the files from the
656+
tar archives to the current location on disk.
657+
658+
If keep_files is False, the files are not extracted.
659+
This is used for when checking if the files in an HPSS
660+
repository are valid.
661+
662+
If keep_tars is True, the tar archives that are downloaded are kept,
663+
even after the program has terminated. Otherwise, they are deleted.
664+
If running in parallel, then multiprocess_worker is the Worker
665+
that called this function.
666+
667+
We need a reference to it so we can signal it to print
668+
the contents of what's in its print queue.
669+
631670
If checkpoint_queue is provided, checkpoint data is sent to it
632671
instead of saving directly to the database.
633672
"""
@@ -638,11 +677,14 @@ def extractFiles( # noqa: C901
638677
files_processed: int = 0
639678

640679
if multiprocess_worker:
680+
# All messages to the logger will now be sent to
681+
# this queue, instead of sys.stdout.
641682
sh = logging.StreamHandler(multiprocess_worker.print_queue)
642683
sh.setLevel(logging.DEBUG)
643684
formatter: logging.Formatter = logging.Formatter("%(levelname)s: %(message)s")
644685
sh.setFormatter(formatter)
645686
logger.addHandler(sh)
687+
# Don't have the logger print to the console as the message come in.
646688
logger.propagate = False
647689

648690
for i in range(nfiles):
@@ -812,6 +854,9 @@ def extractFiles( # noqa: C901
812854

813855
# Close current archive?
814856
if i == nfiles - 1 or files[i].tar != files[i + 1].tar:
857+
# We're either on the last file or the tar is distinct from the tar of the next file.
858+
859+
# Close current archive file
815860
logger.debug("Closing tar archive {}".format(tfname))
816861
tar.close()
817862

@@ -841,16 +886,23 @@ def extractFiles( # noqa: C901
841886
if multiprocess_worker:
842887
multiprocess_worker.done_enqueuing_output_for_tar(files_row.tar)
843888

889+
# Open new archive next time
844890
newtar = True
845891

892+
# Delete this tar if the corresponding command-line arg was used.
846893
if not keep_tars:
847894
if tfname is not None:
848895
os.remove(tfname)
849896
else:
850897
raise TypeError("Invalid tfname={}".format(tfname))
851898

852899
if multiprocess_worker:
900+
# If there are things left to print, print them.
853901
multiprocess_worker.print_all_contents()
902+
903+
# Add the failures to the queue.
904+
# When running with multiprocessing, the function multiprocess_extract()
905+
# that calls this extractFiles() function will return the failures as a list.
854906
for f in failures:
855907
multiprocess_worker.failure_queue.put(f)
856908

0 commit comments

Comments
 (0)