11import math
22import time
3- from threading import Thread
4-
3+ import threading
54
6- from aperturedb import ProgressBar
5+ from threading import Thread
6+ from tqdm import tqdm as tqdm
77
88
99class Parallelizer :
@@ -28,10 +28,7 @@ class Parallelizer:
2828 ```
2929 """
3030
31- def __init__ (self , progress_to_file = "" ):
32-
33- self .pb_file = progress_to_file
34-
31+ def __init__ (self ):
3532 self ._reset ()
3633
3734 def _reset (self , batchsize : int = 1 , numthreads : int = 1 ):
@@ -46,17 +43,13 @@ def _reset(self, batchsize: int = 1, numthreads: int = 1):
4643 self .error_counter = 0
4744 self .actual_stats = []
4845
49- if self .pb_file :
50- self .pb = ProgressBar .ProgressBar (self .pb_file )
51- else :
52- self .pb = ProgressBar .ProgressBar ()
53-
5446 def get_times (self ):
5547
5648 return self .times_arr
5749
58- def run (self , generator , batchsize : int , numthreads : int , stats : bool ):
59-
50+ def batched_run (self , generator , batchsize : int , numthreads : int , stats : bool ):
51+ run_event = threading .Event ()
52+ run_event .set ()
6053 self ._reset (batchsize , numthreads )
6154 self .stats = stats
6255 self .generator = generator
@@ -65,7 +58,8 @@ def run(self, generator, batchsize: int, numthreads: int, stats: bool):
6558 self .total_actions = generator .sample_count
6659 else :
6760 self .total_actions = len (generator )
68-
61+ self .pb = tqdm (total = self .total_actions , desc = "Progress" ,
62+ unit = "batches" , unit_scale = True , dynamic_ncols = True )
6963 start_time = time .time ()
7064
7165 if self .total_actions < batchsize :
@@ -82,15 +76,22 @@ def run(self, generator, batchsize: int, numthreads: int, stats: bool):
8276 self .total_actions )
8377
8478 thread_add = Thread (target = self .worker ,
85- args = (i , generator , idx_start , idx_end ))
79+ args = (i , generator , idx_start , idx_end , run_event ))
8680 thread_arr .append (thread_add )
8781
8882 a = [th .start () for th in thread_arr ]
89- a = [th .join () for th in thread_arr ]
83+ try :
84+ while run_event .is_set () and any ([th .is_alive () for th in thread_arr ]):
85+ time .sleep (1 )
86+ except KeyboardInterrupt :
87+ print ("Interrupted ... Shutting down workers" )
88+ finally :
89+ run_event .clear ()
90+ a = [th .join () for th in thread_arr ]
9091
9192 # Update progress bar to completion
9293 if self .stats :
93- self .pb .update ( 1 )
94+ self .pb .close ( )
9495
9596 self .total_actions_time = time .time () - start_time
9697
0 commit comments