33import argparse
44import collections
55import hashlib
6- import heapq
76import logging
87import multiprocessing
98import os .path
109import re
1110import sqlite3
1211import sys
1312import tarfile
13+ import time
1414import traceback
1515from datetime import datetime
1616from typing import DefaultDict , List , Optional , Set , Tuple
@@ -282,10 +282,10 @@ def extract_database(
282282 if args .workers > 1 :
283283 logger .debug ("Running zstash {} with multiprocessing" .format (cmd ))
284284 failures = multiprocess_extract (
285- args .workers , matches , keep_files , keep , cache , cur , args
285+ args .workers , matches , keep_files , keep , cache , args
286286 )
287287 else :
288- failures = extractFiles (matches , keep_files , keep , cache , cur , args )
288+ failures = extractFiles (matches , keep_files , keep , cache , args , None , cur )
289289
290290 # Close database
291291 logger .debug ("Closing index database" )
@@ -300,7 +300,6 @@ def multiprocess_extract(
300300 keep_files : bool ,
301301 keep_tars : Optional [bool ],
302302 cache : str ,
303- cur : sqlite3 .Cursor ,
304303 args : argparse .Namespace ,
305304) -> List [FilesRow ]:
306305 """
@@ -329,26 +328,12 @@ def multiprocess_extract(
329328 # set the number of workers to the number of tars.
330329 num_workers = min (num_workers , len (tar_to_size ))
331330
332- # For worker i, workers_to_tars[i] is a set of tars
333- # that worker i will work on.
331+ # For worker i, workers_to_tars[i] is a set of tars that worker i will work on.
332+ # Assign tars in round-robin fashion to maintain proper ordering
334333 workers_to_tars : List [set ] = [set () for _ in range (num_workers )]
335- # A min heap, of (work, worker_idx) tuples, work is the size of data
336- # that worker_idx needs to work on.
337- # We can efficiently get the worker with the least amount of work.
338- work_to_workers : List [Tuple [int , int ]] = [(0 , i ) for i in range (num_workers )]
339- heapq .heapify (workers_to_tars )
340-
341- # Using a greedy approach, populate workers_to_tars.
342- for _ , tar in enumerate (tar_to_size ):
343- # The worker with the least work should get the current largest amount of work.
344- workers_work : int
345- worker_idx : int
346- workers_work , worker_idx = heapq .heappop (work_to_workers )
334+ for idx , tar in enumerate (sorted (tar_to_size .keys ())):
335+ worker_idx = idx % num_workers
347336 workers_to_tars [worker_idx ].add (tar )
348- # Add this worker back to the heap, with the new amount of work.
349- worker_tuple : Tuple [float , int ] = (workers_work + tar_to_size [tar ], worker_idx )
350- # FIXME: error: Cannot infer type argument 1 of "heappush"
351- heapq .heappush (work_to_workers , worker_tuple ) # type: ignore
352337
353338 # For worker i, workers_to_matches[i] is a list of
354339 # matches from the database for it to process.
@@ -361,8 +346,15 @@ def multiprocess_extract(
361346 # This worker gets this db_row.
362347 workers_to_matches [workers_idx ].append (db_row )
363348
349+ # Sort each worker's matches by tar to ensure they process in order
350+ for worker_matches in workers_to_matches :
351+ worker_matches .sort (key = lambda t : t .tar )
352+
364353 tar_ordering : List [str ] = sorted ([tar for tar in tar_to_size ])
365- monitor : parallel .PrintMonitor = parallel .PrintMonitor (tar_ordering )
354+ manager = multiprocessing .Manager ()
355+ monitor : parallel .PrintMonitor = parallel .PrintMonitor (
356+ tar_ordering , manager = manager
357+ )
366358
367359 # The return value for extractFiles will be added here.
368360 failure_queue : multiprocessing .Queue [FilesRow ] = multiprocessing .Queue ()
@@ -374,7 +366,7 @@ def multiprocess_extract(
374366 )
375367 process : multiprocessing .Process = multiprocessing .Process (
376368 target = extractFiles ,
377- args = (matches , keep_files , keep_tars , cache , cur , args , worker ),
369+ args = (matches , keep_files , keep_tars , cache , args , worker ),
378370 daemon = True ,
379371 )
380372 process .start ()
@@ -385,10 +377,39 @@ def multiprocess_extract(
385377 # No need to join() each of the processes when doing this,
386378 # because we'll be in this loop until completion.
387379 failures : List [FilesRow ] = []
380+ max_wait_time = 180 # 3 minute timeout for tests
381+ start_time = time .time ()
382+ last_log_time = start_time
383+
388384 while any (p .is_alive () for p in processes ):
385+ elapsed = time .time () - start_time
386+ if elapsed > max_wait_time :
387+ logger .error (
388+ f"Timeout after { elapsed :.1f} s waiting for worker processes. Terminating..."
389+ )
390+ for p in processes :
391+ if p .is_alive ():
392+ logger .error (f"Terminating process { p .pid } " )
393+ p .terminate ()
394+ break
395+
396+ # Log progress every 30 seconds
397+ if time .time () - last_log_time > 30 :
398+ alive_count = sum (1 for p in processes if p .is_alive ())
399+ logger .debug (
400+ f"Still waiting for { alive_count } worker processes after { elapsed :.1f} s"
401+ )
402+ last_log_time = time .time ()
403+
389404 while not failure_queue .empty ():
390405 failures .append (failure_queue .get ())
391406
407+ time .sleep (0.1 ) # Larger sleep to reduce CPU usage
408+
409+ # Collect any remaining failures
410+ while not failure_queue .empty ():
411+ failures .append (failure_queue .get ())
412+
392413 # Sort the failures, since they can come in at any order.
393414 failures .sort (key = lambda t : (t .name , t .tar , t .offset ))
394415 return failures
@@ -479,9 +500,9 @@ def extractFiles( # noqa: C901
479500 keep_files : bool ,
480501 keep_tars : Optional [bool ],
481502 cache : str ,
482- cur : sqlite3 .Cursor ,
483503 args : argparse .Namespace ,
484504 multiprocess_worker : Optional [parallel .ExtractWorker ] = None ,
505+ cur : Optional [sqlite3 .Cursor ] = None ,
485506) -> List [FilesRow ]:
486507 """
487508 Given a list of database rows, extract the files from the
@@ -498,21 +519,56 @@ def extractFiles( # noqa: C901
498519 that called this function.
499520 We need a reference to it so we can signal it to print
500521 the contents of what's in its print queue.
522+
523+ If cur is None (when running in parallel), a new database connection
524+ will be opened for this worker process.
525+ """
526+ try :
527+ result = _extractFiles_impl (
528+ files , keep_files , keep_tars , cache , args , multiprocess_worker , cur
529+ )
530+ return result
531+ except Exception as e :
532+ if multiprocess_worker :
533+ # Make sure we report failures even if there's an exception
534+ sys .stderr .write (f"ERROR: Worker encountered fatal error: { e } \n " )
535+ sys .stderr .flush ()
536+ traceback .print_exc (file = sys .stderr )
537+ for f in files :
538+ multiprocess_worker .failure_queue .put (f )
539+ raise
540+
541+
542+ # FIXME: C901 '_extractFiles_impl' is too complex (42)
543+ def _extractFiles_impl ( # noqa: C901
544+ files : List [FilesRow ],
545+ keep_files : bool ,
546+ keep_tars : Optional [bool ],
547+ cache : str ,
548+ args : argparse .Namespace ,
549+ multiprocess_worker : Optional [parallel .ExtractWorker ] = None ,
550+ cur : Optional [sqlite3 .Cursor ] = None ,
551+ ) -> List [FilesRow ]:
552+ """
553+ Implementation of extractFiles - actual extraction logic.
501554 """
555+ # Open database connection if not provided (parallel case)
556+ if cur is None :
557+ con : sqlite3 .Connection = sqlite3 .connect (
558+ get_db_filename (cache ), detect_types = sqlite3 .PARSE_DECLTYPES
559+ )
560+ cur = con .cursor ()
561+ close_db : bool = True
562+ else :
563+ close_db = False
564+
502565 failures : List [FilesRow ] = []
503566 tfname : str
504567 newtar : bool = True
505568 nfiles : int = len (files )
506- if multiprocess_worker :
507- # All messages to the logger will now be sent to
508- # this queue, instead of sys.stdout.
509- sh = logging .StreamHandler (multiprocess_worker .print_queue )
510- sh .setLevel (logging .DEBUG )
511- formatter : logging .Formatter = logging .Formatter ("%(levelname)s: %(message)s" )
512- sh .setFormatter (formatter )
513- logger .addHandler (sh )
514- # Don't have the logger print to the console as the message come in.
515- logger .propagate = False
569+
570+ # Track if we've set up logging yet
571+ logging_setup : bool = False
516572
517573 for i in range (nfiles ):
518574 files_row : FilesRow = files [i ]
@@ -521,16 +577,46 @@ def extractFiles( # noqa: C901
521577 if newtar :
522578 newtar = False
523579 tfname = os .path .join (cache , files_row .tar )
524- # Everytime we're extracting a new tar, if running in parallel,
525- # let the process know.
526- # This is to synchronize the print statements.
580+
581+ # CRITICAL: Wait for our turn BEFORE doing anything with this tar
527582 if multiprocess_worker :
583+ try :
584+ multiprocess_worker .print_monitor .wait_turn (
585+ multiprocess_worker , files_row .tar , indef_wait = True
586+ )
587+ except TimeoutError as e :
588+ logger .error (
589+ f"Timeout waiting for turn to process { files_row .tar } : { e } "
590+ )
591+ # Mark all remaining files from this tar as failed
592+ for j in range (i , nfiles ):
593+ if files [j ].tar == files_row .tar :
594+ failures .append (files [j ])
595+ # Skip to next tar
596+ newtar = True
597+ continue
598+
599+ # NOW set up logging (only once)
600+ if not logging_setup :
601+ sh = logging .StreamHandler (multiprocess_worker .print_queue )
602+ sh .setLevel (logging .DEBUG )
603+ formatter : logging .Formatter = logging .Formatter (
604+ "%(levelname)s: %(message)s"
605+ )
606+ sh .setFormatter (formatter )
607+ logger .addHandler (sh )
608+ logger .propagate = False
609+ logging_setup = True
610+
611+ # Set current tar for this worker
528612 multiprocess_worker .set_curr_tar (files_row .tar )
529613
530- if config .hpss is not None :
531- hpss : str = config .hpss
614+ # Use args.hpss directly - it's always set correctly
615+ if args .hpss is not None :
616+ hpss : str = args .hpss
532617 else :
533- raise TypeError ("Invalid config.hpss={}" .format (config .hpss ))
618+ raise TypeError ("Invalid args.hpss={}" .format (args .hpss ))
619+
534620 tries : int = args .retries + 1
535621 # Set to True to test the `--retries` option with a forced failure.
536622 # Then run `python -m unittest tests.test_extract.TestExtract.testExtractRetries`
@@ -574,8 +660,6 @@ def extractFiles( # noqa: C901
574660 # Extract file
575661 cmd : str = "Extracting" if keep_files else "Checking"
576662 logger .info (cmd + " %s" % (files_row .name ))
577- # if multiprocess_worker:
578- # print('{} is {} {} from {}'.format(multiprocess_worker, cmd, file[1], file[5]))
579663
580664 if keep_files and not should_extract_file (files_row ):
581665 # If we were going to extract, but aren't
@@ -676,9 +760,6 @@ def extractFiles( # noqa: C901
676760 logger .error ("Retrieving {}" .format (files_row .name ))
677761 failures .append (files_row )
678762
679- if multiprocess_worker :
680- multiprocess_worker .print_contents ()
681-
682763 # Close current archive?
683764 if i == nfiles - 1 or files [i ].tar != files [i + 1 ].tar :
684765 # We're either on the last file or the tar is distinct from the tar of the next file.
@@ -688,8 +769,15 @@ def extractFiles( # noqa: C901
688769 tar .close ()
689770
690771 if multiprocess_worker :
772+ # Mark that all output for this tar is queued
691773 multiprocess_worker .done_enqueuing_output_for_tar (files_row .tar )
692774
775+ # Now print everything and advance to next tar
776+ try :
777+ multiprocess_worker .print_all_contents ()
778+ except (TimeoutError , Exception ) as e :
779+ logger .debug (f"Error printing contents for { files_row .tar } : { e } " )
780+
693781 # Open new archive next time
694782 newtar = True
695783
@@ -700,13 +788,13 @@ def extractFiles( # noqa: C901
700788 else :
701789 raise TypeError ("Invalid tfname={}" .format (tfname ))
702790
703- if multiprocess_worker :
704- # If there are things left to print, print them.
705- multiprocess_worker .print_all_contents ()
791+ # Close database connection if we opened it
792+ if close_db :
793+ cur .close ()
794+ con .close ()
706795
707- # Add the failures to the queue.
708- # When running with multiprocessing, the function multiprocess_extract()
709- # that calls this extractFiles() function will return the failures as a list.
796+ if multiprocess_worker :
797+ # Add failures to the queue
710798 for f in failures :
711799 multiprocess_worker .failure_queue .put (f )
712800 return failures
0 commit comments