99from ndcube import NDCube
1010from sklearn .exceptions import ConvergenceWarning
1111from sklearn .linear_model import ElasticNet
12+ from tqdm import tqdm
1213
1314from overlappogram .response import prepare_response_function
1415
@@ -60,6 +61,8 @@ def __init__(
6061 response_dependency_list = response_dependency_list ,
6162 )
6263
64+ self ._progress_bar = None # initialized in invert call
65+
6366 @property
6467 def is_inverted (self ) -> bool :
6568 return not any (
@@ -107,7 +110,8 @@ def _progress_indicator(self, future):
107110 with self ._thread_count_lock :
108111 if not future .cancelled ():
109112 self ._completed_row_count += 1
110- print (f"{ self ._completed_row_count / self .total_row_count * 100 :3.0f} % complete" , end = "\r " )
113+ #print(f"{self._completed_row_count / self.total_row_count * 100:3.0f}% complete", end="\r")
114+ self ._progress_bar .update (1 )
111115
112116 def _switch_to_row_inversion (self , model_config , alpha , rho , num_row_threads = 50 ):
113117 self ._mode = InversionMode .ROW
@@ -161,6 +165,29 @@ def _collect_results(self, mode_switch_thread_count, model_config, alpha, rho):
161165 self ._switch_to_row_inversion (model_config , alpha , rho )
162166 break
163167
168+ def _start_row_inversion (self , model_config , alpha , rho , num_threads ):
169+ self .executors = [concurrent .futures .ThreadPoolExecutor (max_workers = num_threads )]
170+
171+ self .futures = {}
172+ self ._models = []
173+ for i , row_index in enumerate (range (self ._detector_row_range [0 ], self ._detector_row_range [1 ])):
174+ enet_model = ElasticNet (
175+ alpha = alpha ,
176+ l1_ratio = rho ,
177+ tol = model_config ["tol" ],
178+ max_iter = model_config ["max_iter" ],
179+ precompute = False , # setting this to true slows down performance dramatically
180+ positive = True ,
181+ copy_X = False ,
182+ fit_intercept = False ,
183+ selection = model_config ["selection" ],
184+ warm_start = False , # warm start doesn't make sense in the row version
185+ )
186+ self ._models .append (enet_model )
187+ future = self .executors [- 1 ].submit (self ._invert_image_row , row_index , i )
188+ future .add_done_callback (self ._progress_indicator )
189+ self .futures [future ] = (row_index , i )
190+
164191 def _start_chunk_inversion (self , model_config , alpha , rho , num_threads ):
165192 starts = np .arange (
166193 self ._detector_row_range [0 ],
@@ -255,6 +282,8 @@ def invert(
255282
256283 self ._mode = mode
257284
285+ self ._progress_bar = tqdm (total = self .total_row_count , unit = 'row' , delay = 1 , leave = False )
286+
258287 self ._models = []
259288 self ._completed_row_count = 0
260289
@@ -269,16 +298,16 @@ def invert(
269298 # mode never switches since count=0
270299 self ._collect_results (0 , model_config , alpha , rho )
271300 elif self ._mode == InversionMode .ROW :
272- self ._start_chunk_inversion (model_config , alpha , rho , num_threads )
273- # TODO: it would be better to have a mode to start in row but right now we fake it with a fast mode switch
274- self ._collect_results (np .inf , model_config , alpha , rho ) # immediately switch mode
301+ self ._start_row_inversion (model_config , alpha , rho , num_threads )
275302 self ._collect_results (np .inf , model_config , alpha , rho )
276303 else :
277304 raise ValueError ("Invalid InversionMode." )
278305
279306 for executor in self .executors :
280307 executor .shutdown ()
281308
309+ self ._progress_bar .close ()
310+
282311 return (
283312 np .transpose (self ._em_data , (2 , 0 , 1 )),
284313 self ._inversion_prediction ,
0 commit comments