@@ -53,6 +53,7 @@ def __init__(
5353 self ._row_scores : np .ndarray | None = None
5454 self ._overlappogram_width : int | None = None
5555 self ._overlappogram_height : int | None = None
56+ self ._n_iter : np .ndarray | None = None
5657
5758 self ._thread_count_lock = Lock ()
5859
@@ -104,13 +105,15 @@ def _invert_image_row(self, row_index, chunk_index):
104105 data_out = model .predict (masked_response_function )
105106 em = model .coef_
106107 score_data = model .score (masked_response_function , image_row )
108+ n_iter = model .n_iter_
107109 except ConvergenceWarning :
108110 self ._unconverged_rows .append (row_index )
109111 em = np .zeros ((self ._num_slits * self ._num_deps ), dtype = np .float32 )
110112 data_out = np .zeros (self ._overlappogram_width , dtype = np .float32 )
111113 score_data = - 999
114+ n_iter = - 1
112115
113- return row_index , em , data_out , score_data
116+ return row_index , em , data_out , score_data , n_iter
114117
115118 def _progress_indicator (self , future ):
116119 """used in multithreading to track progress of inversion"""
@@ -155,7 +158,7 @@ def _switch_to_row_inversion(self, model_config, alpha, rho, num_row_threads=50)
155158
156159 def _collect_results (self , mode_switch_thread_count , model_config , alpha , rho ):
157160 for future in concurrent .futures .as_completed (self .futures ):
158- row_index , em , data_out , score_data = future .result ()
161+ row_index , em , data_out , score_data , n_iter = future .result ()
159162 for slit_num in range (self ._num_slits ):
160163 if self ._smooth_over == "dependence" :
161164 slit_em = em [slit_num * self ._num_deps : (slit_num + 1 ) * self ._num_deps ]
@@ -164,6 +167,7 @@ def _collect_results(self, mode_switch_thread_count, model_config, alpha, rho):
164167 self ._em_data [row_index , slit_num , :] = slit_em
165168 self ._inversion_prediction [row_index , :] = data_out
166169 self ._row_scores [row_index ] = score_data
170+ self ._n_iter [row_index ] = n_iter
167171
168172 rows_remaining = self .total_row_count - self ._completed_row_count
169173
@@ -173,7 +177,6 @@ def _collect_results(self, mode_switch_thread_count, model_config, alpha, rho):
173177
174178 def _start_row_inversion (self , model_config , alpha , rho , num_threads ):
175179 self .executors = [concurrent .futures .ThreadPoolExecutor (max_workers = num_threads )]
176-
177180 self .futures = {}
178181 self ._models = []
179182 for i , row_index in enumerate (range (self ._detector_row_range [0 ], self ._detector_row_range [1 ])):
@@ -253,6 +256,7 @@ def _initialize_with_overlappogram(self, overlappogram):
253256 self ._em_data = np .zeros ((self ._overlappogram_height , self ._num_slits , self ._num_deps ), dtype = np .float32 )
254257 self ._inversion_prediction = np .zeros ((self ._overlappogram_height , self ._overlappogram_width ), dtype = np .float32 )
255258 self ._row_scores = np .zeros ((self ._overlappogram_height , 1 ), dtype = np .float32 )
259+ self ._n_iter = np .zeros ((self ._overlappogram_height , 1 ), dtype = np .int32 )
256260
257261 def invert (
258262 self ,
@@ -301,4 +305,5 @@ def invert(
301305 NDCube (data = self ._inversion_prediction , wcs = out_wcs , meta = self ._response_meta ),
302306 self ._row_scores ,
303307 self ._unconverged_rows ,
308+ self ._n_iter
304309 )
0 commit comments