Skip to content

Commit 0ebd7a5

Browse files
committed
Restore comments and type annotations
1 parent f33ea80 commit 0ebd7a5

File tree

2 files changed

+119
-4
lines changed

2 files changed

+119
-4
lines changed

zstash/extract.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,24 +304,39 @@ def multiprocess_extract(
304304
) -> List[FilesRow]:
305305
"""
306306
Extract the files from the matches in parallel.
307+
308+
A single unit of work is a tar and all of
309+
the files in it to extract.
307310
"""
311+
# A dict of tar -> size of files in it.
312+
# This is because we're trying to balance the load between
313+
# the processes.
308314
tar_to_size_unsorted: DefaultDict[str, float] = collections.defaultdict(float)
315+
db_row: FilesRow
309316
for db_row in matches:
310317
tar_to_size_unsorted[db_row.tar] += db_row.size
311318

312319
tar_to_size: collections.OrderedDict[str, float] = collections.OrderedDict(
313320
sorted(tar_to_size_unsorted.items(), key=lambda x: x[1])
314321
)
315322

323+
# We don't want to instantiate more processes than we need to.
324+
# So, if the number of tars is less than the number of workers,
325+
# set the number of workers to the number of tars.
316326
num_workers = min(num_workers, len(tar_to_size))
317327

328+
# For worker i, workers_to_tars[i] is a set of tars
329+
# that worker i will work on.
318330
# Round-robin assignment for predictable ordering
319331
workers_to_tars: List[set] = [set() for _ in range(num_workers)]
332+
tar: str
320333
for idx, tar in enumerate(sorted(tar_to_size.keys())):
321334
workers_to_tars[idx % num_workers].add(tar)
322335

323336
workers_to_matches: List[List[FilesRow]] = [[] for _ in range(num_workers)]
337+
workers_idx: int
324338
for db_row in matches:
339+
tar = db_row.tar
325340
for workers_idx in range(len(workers_to_tars)):
326341
if db_row.tar in workers_to_tars[workers_idx]:
327342
workers_to_matches[workers_idx].append(db_row)
@@ -336,6 +351,7 @@ def multiprocess_extract(
336351
tar_ordering, manager=manager
337352
)
338353

354+
# The return value for extractFiles will be added here.
339355
failure_queue: multiprocessing.Queue[FilesRow] = multiprocessing.Queue()
340356
processes: List[multiprocessing.Process] = []
341357

@@ -352,6 +368,10 @@ def multiprocess_extract(
352368
process.start()
353369
processes.append(process)
354370

371+
# While the processes are running, we need to empty the queue.
372+
# Otherwise, it causes hanging.
373+
# No need to join() each of the processes when doing this,
374+
# because we'll be in this loop until completion.
355375
failures: List[FilesRow] = []
356376
while any(p.is_alive() for p in processes):
357377
while not failure_queue.empty():
@@ -361,6 +381,7 @@ def multiprocess_extract(
361381
while not failure_queue.empty():
362382
failures.append(failure_queue.get())
363383

384+
# Sort the failures, since they can come in at any order.
364385
failures.sort(key=lambda t: (t.name, t.tar, t.offset))
365386
return failures
366387

@@ -371,46 +392,61 @@ def check_sizes_match(cur, tfname, error_on_duplicate_tar):
371392
actual_size: int = os.path.getsize(tfname)
372393
name_only: str = os.path.split(tfname)[1]
373394

395+
# Get ALL entries for this tar name
374396
cur.execute(
375397
"SELECT size FROM tars WHERE name = ? ORDER by id DESC", (name_only,)
376398
)
377399
results = cur.fetchall()
378400

379401
if not results:
402+
# Cannot access size information; assume the sizes match.
380403
logger.error(f"No database entries found for {name_only}")
381404
return True
382405

406+
# Check for multiple entries
383407
if len(results) > 1:
408+
# Extract just the size values
384409
sizes: List[int] = [row[0] for row in results]
385410
error_str: str = (
386411
f"Database corruption detected! Found {len(results)} database entries for {name_only}, with sizes {sizes}"
387412
)
388413

389414
if error_on_duplicate_tar:
415+
# Tested by database_corruption.bash Case 5
390416
logger.error(error_str)
391417
raise RuntimeError(error_str)
392418
logger.warning(error_str)
393419

420+
# We ordered the results by id DESC,
421+
# so the first entry is the most recent.
394422
most_recent_size: int = sizes[0]
395423
if actual_size == most_recent_size:
424+
# Tested by database_corruption.bash Case 7
425+
# If the actual size matches the most recent size,
426+
# then we can assume that the tar is valid.
396427
logger.info(
397428
f"{name_only}: The most recent database entry has the same size as the actual file size: {actual_size}."
398429
)
399430
return True
400431
unique_sizes: Set[int] = set(sizes)
401432
if actual_size in unique_sizes:
433+
# Tested by database_corruption.bash Case 8
402434
logger.info(
403435
f"{name_only}: A database entry matches the actual file size, {actual_size}, but it is not the most recent entry."
404436
)
405437
else:
438+
# Tested by database_corruption.bash Case 6
406439
logger.info(
407440
f"{name_only}: No database entry matches the actual file size: {actual_size}."
408441
)
409442
return False
410443
else:
444+
# Tested by database_corruption.bash Cases 1,2,4
445+
# Single entry - normal case
411446
logger.info(f"{name_only}: Found a single database entry.")
412447
expected_size = results[0][0]
413448

449+
# Now check if actual size matches expected size
414450
if expected_size != actual_size:
415451
error_msg = (
416452
f"{name_only}: Size mismatch! "
@@ -420,13 +456,16 @@ def check_sizes_match(cur, tfname, error_on_duplicate_tar):
420456
logger.error(error_msg)
421457
return False
422458
else:
459+
# Sizes match
423460
logger.info(f"{name_only}: Size check passed ({actual_size} bytes)")
424461
return True
425462
else:
463+
# Cannot access size information; assume the sizes match.
426464
logger.debug("Cannot access tar size information; assuming sizes match")
427465
return True
428466

429467

468+
# FIXME: C901 'extractFiles' is too complex (33)
430469
def extractFiles( # noqa: C901
431470
files: List[FilesRow],
432471
keep_files: bool,
@@ -437,7 +476,20 @@ def extractFiles( # noqa: C901
437476
cur: Optional[sqlite3.Cursor] = None,
438477
) -> List[FilesRow]:
439478
"""
440-
Given a list of database rows, extract the files from the tar archives.
479+
Given a list of database rows, extract the files from the
480+
tar archives to the current location on disk.
481+
482+
If keep_files is False, the files are not extracted.
483+
This is used for when checking if the files in an HPSS
484+
repository are valid.
485+
486+
If keep_tars is True, the tar archives that are downloaded are kept,
487+
even after the program has terminated. Otherwise, they are deleted.
488+
489+
If running in parallel, then multiprocess_worker is the Worker
490+
that called this function.
491+
We need a reference to it so we can signal it to print
492+
the contents of what's in its print queue.
441493
442494
If cur is None (when running in parallel), a new database connection
443495
will be opened for this worker process.
@@ -459,19 +511,26 @@ def extractFiles( # noqa: C901
459511

460512
# Set up logging redirection for multiprocessing
461513
if multiprocess_worker:
514+
# All messages to the logger will now be sent to
515+
# this queue, instead of sys.stdout.
462516
sh = logging.StreamHandler(multiprocess_worker.print_queue)
463517
sh.setLevel(logging.DEBUG)
464518
formatter: logging.Formatter = logging.Formatter("%(levelname)s: %(message)s")
465519
sh.setFormatter(formatter)
466520
logger.addHandler(sh)
521+
# Don't have the logger print to the console as the message come in.
467522
logger.propagate = False
468523

469524
for i in range(nfiles):
470525
files_row: FilesRow = files[i]
471526

527+
# Open new tar archive
472528
if newtar:
473529
newtar = False
474530
tfname = os.path.join(cache, files_row.tar)
531+
# Everytime we're extracting a new tar, if running in parallel,
532+
# let the process know.
533+
# This is to synchronize the print statements.
475534

476535
# Wait for turn before processing this tar
477536
if multiprocess_worker:
@@ -487,6 +546,8 @@ def extractFiles( # noqa: C901
487546
raise TypeError("Invalid args.hpss={}".format(args.hpss))
488547

489548
tries: int = args.retries + 1
549+
# Set to True to test the `--retries` option with a forced failure.
550+
# Then run `python -m unittest tests.test_extract.TestExtract.testExtractRetries`
490551
test_retry: bool = False
491552

492553
while tries > 0:
@@ -512,10 +573,12 @@ def extractFiles( # noqa: C901
512573
raise RuntimeError(
513574
f"{tfname} size does not match expected size."
514575
)
576+
# `hpss_get` successful or not needed: no more tries needed
515577
break
516578
except RuntimeError as e:
517579
if tries > 0:
518580
logger.info(f"Retrying HPSS get: {tries} tries remaining.")
581+
# Run the try-except block again
519582
continue
520583
else:
521584
raise e
@@ -526,25 +589,34 @@ def extractFiles( # noqa: C901
526589
# Extract file
527590
cmd: str = "Extracting" if keep_files else "Checking"
528591
logger.info(cmd + " %s" % (files_row.name))
592+
# if multiprocess_worker:
593+
# print('{} is {} {} from {}'.format(multiprocess_worker, cmd, file[1], file[5]))
529594

530595
if keep_files and not should_extract_file(files_row):
596+
# If we were going to extract, but aren't
597+
# because a matching file is on disk
531598
msg: str = "Not extracting {}, because it"
532599
msg += " already exists on disk with the same"
533600
msg += " size and modification date."
534601
logger.info(msg.format(files_row.name))
535602

603+
# True if we should actually extract the file from the tar
536604
extract_this_file: bool = keep_files and should_extract_file(files_row)
537605

538606
try:
607+
# Seek file position
539608
if tar.fileobj is not None:
540609
fileobj = tar.fileobj
541610
else:
542611
raise TypeError("Invalid tar.fileobj={}".format(tar.fileobj))
543612
fileobj.seek(files_row.offset)
544613

614+
# Get next member
545615
tarinfo: tarfile.TarInfo = tar.tarinfo.fromtarfile(tar)
546616

547617
if tarinfo.isfile():
618+
# fileobj to extract
619+
# error: Name 'tarfile.ExFileObject' is not defined
548620
extracted_file: Optional[tarfile.ExFileObject] = tar.extractfile(tarinfo) # type: ignore
549621
if extracted_file:
550622
fin: tarfile.ExFileObject = extracted_file
@@ -557,8 +629,11 @@ def extractFiles( # noqa: C901
557629
path, name = os.path.split(fname)
558630
if path != "" and extract_this_file:
559631
if not os.path.isdir(path):
632+
# The path doesn't exist, so create it.
560633
os.makedirs(path)
561634
if extract_this_file:
635+
# If we're keeping the files,
636+
# then have an output file
562637
fout: _io.BufferedWriter = open(fname, "wb")
563638

564639
hash_md5: _hashlib.HASH = hashlib.md5()
@@ -577,12 +652,17 @@ def extractFiles( # noqa: C901
577652

578653
md5: str = hash_md5.hexdigest()
579654
if extract_this_file:
655+
# numeric_owner is a required arg in Python 3.
656+
# If True, "only the numbers for user/group names
657+
# are used and not the names".
580658
tar.chown(tarinfo, fname, numeric_owner=False)
581659
tar.chmod(tarinfo, fname)
582660
tar.utime(tarinfo, fname)
661+
# Verify size
583662
if os.path.getsize(fname) != files_row.size:
584663
logger.error("size mismatch for: {}".format(fname))
585664

665+
# Verify md5 checksum
586666
files_row_md5: Optional[str] = files_row.md5
587667
if md5 != files_row_md5:
588668
logger.error("md5 mismatch for: {}".format(fname))
@@ -597,28 +677,38 @@ def extractFiles( # noqa: C901
597677
tar.extract(tarinfo, filter="tar")
598678
else:
599679
tar.extract(tarinfo)
680+
# Note: tar.extract() will not restore time stamps of symbolic
681+
# links. Could not find a Python-way to restore it either, so
682+
# relying here on 'touch'. This is not the prettiest solution.
683+
# Maybe a better one can be implemented later.
600684
if tarinfo.issym():
601685
tmp1 = tarinfo.mtime
602686
tmp2: datetime = datetime.fromtimestamp(tmp1)
603687
tmp3: str = tmp2.strftime("%Y%m%d%H%M.%S")
604688
os.system("touch -h -t %s %s" % (tmp3, tarinfo.name))
605689

606690
except Exception:
691+
# Catch all exceptions here.
607692
traceback.print_exc()
608693
logger.error("Retrieving {}".format(files_row.name))
609694
failures.append(files_row)
610695

611696
# Close current archive?
612697
if i == nfiles - 1 or files[i].tar != files[i + 1].tar:
698+
# We're either on the last file or the tar is distinct from the tar of the next file.
699+
700+
# Close current archive file
613701
logger.debug("Closing tar archive {}".format(tfname))
614702
tar.close()
615703

616704
if multiprocess_worker:
617705
multiprocess_worker.done_enqueuing_output_for_tar(files_row.tar)
618706
multiprocess_worker.print_all_contents()
619707

708+
# Open new archive next time
620709
newtar = True
621710

711+
# Delete this tar if the corresponding command-line arg was used.
622712
if not keep_tars:
623713
if tfname is not None:
624714
os.remove(tfname)
@@ -630,6 +720,9 @@ def extractFiles( # noqa: C901
630720
cur.close()
631721
con.close()
632722

723+
# Add the failures to the queue.
724+
# When running with multiprocessing, the function multiprocess_extract()
725+
# that calls this extractFiles() function will return the failures as a list.
633726
if multiprocess_worker:
634727
for f in failures:
635728
multiprocess_worker.failure_queue.put(f)
@@ -648,11 +741,15 @@ def should_extract_file(db_row: FilesRow) -> bool:
648741
file_name, size_db, mod_time_db = db_row.name, db_row.size, db_row.mtime
649742

650743
if not os.path.exists(file_name):
744+
# The file doesn't exist locally.
745+
# We must get files that are not on disk.
651746
return True
652747

653748
size_disk: int = os.path.getsize(file_name)
654749
mod_time_disk: datetime = datetime.utcfromtimestamp(os.path.getmtime(file_name))
655750

751+
# Only extract when the times and sizes are not the same (within tolerance)
752+
# We have a TIME_TOL because mod_time_disk doesn't have the microseconds.
656753
return not (
657754
(size_disk == size_db)
658755
and (abs(mod_time_disk - mod_time_db).total_seconds() < TIME_TOL)

0 commit comments

Comments
 (0)