Skip to content

Commit da4572a

Browse files
authored
Merge pull request #83 from ECCCO-mission/speed-up
Adds tracking of num iterations
2 parents 96fe742 + 5dae4cc commit da4572a

File tree

5 files changed

+29
-16
lines changed

5 files changed

+29
-16
lines changed

overlappogram/cli.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,17 @@ def unfold(config):
4949

5050
for alpha in config["model"]["alphas"]:
5151
for rho in config["model"]["rhos"]:
52-
print(80*"-")
52+
print(80 * "-")
5353
print(f"Beginning inversion for alpha={alpha}, rho={rho}.")
5454
start = time.time()
55-
em_cube, prediction, scores, unconverged_rows = inversion.invert(
56-
overlappogram,
57-
config["model"],
58-
alpha,
59-
rho,
60-
num_threads=config["execution"]["num_threads"],
61-
mode_switch_thread_count=config["execution"]["mode_switch_thread_count"],
62-
mode=MODE_MAPPING.get(config['execution']['mode'], 'invalid')
55+
em_cube, prediction, scores, unconverged_rows, n_iter = inversion.invert(
56+
overlappogram,
57+
config["model"],
58+
alpha,
59+
rho,
60+
num_threads=config["execution"]["num_threads"],
61+
mode_switch_thread_count=config["execution"]["mode_switch_thread_count"],
62+
mode=MODE_MAPPING.get(config['execution']['mode'], 'invalid')
6363
)
6464
end = time.time()
6565
print(
@@ -68,8 +68,11 @@ def unfold(config):
6868
f"seconds; {len(unconverged_rows)} unconverged rows",
6969
)
7070

71+
print(f"Unconverged rows: {unconverged_rows}")
72+
7173
postfix = (
72-
"x" + str(config["inversion"]["solution_fov_width"]) + "_" + str(rho * 10) + "_" + str(alpha) + "_wpsf"
74+
"x" + str(config["inversion"]["solution_fov_width"]) + "_" + str(rho * 10) + "_" + str(
75+
alpha) + "_wpsf"
7376
)
7477
save_em_cube(
7578
em_cube,
@@ -90,6 +93,11 @@ def unfold(config):
9093
with open(scores_path, 'w') as f:
9194
f.write("\n".join(scores.flatten().astype(str).tolist()))
9295

96+
niter_path = os.path.join(config["output"]["directory"],
97+
f"{config['output']['prefix']}_niter_{postfix}.txt")
98+
with open(niter_path, 'w') as f:
99+
f.write("\n".join(n_iter.flatten().astype(str).tolist()))
100+
93101
if config["output"]["make_spectral"]:
94102
spectral_images = create_spectrally_pure_images(
95103
[em_cube], config["paths"]["gnt"], config["inversion"]["response_dependency_list"]

overlappogram/inversion.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
self._row_scores: np.ndarray | None = None
5454
self._overlappogram_width: int | None = None
5555
self._overlappogram_height: int | None = None
56+
self._n_iter: np.ndarray | None = None
5657

5758
self._thread_count_lock = Lock()
5859

@@ -104,13 +105,15 @@ def _invert_image_row(self, row_index, chunk_index):
104105
data_out = model.predict(masked_response_function)
105106
em = model.coef_
106107
score_data = model.score(masked_response_function, image_row)
108+
n_iter = model.n_iter_
107109
except ConvergenceWarning:
108110
self._unconverged_rows.append(row_index)
109111
em = np.zeros((self._num_slits * self._num_deps), dtype=np.float32)
110112
data_out = np.zeros(self._overlappogram_width, dtype=np.float32)
111113
score_data = -999
114+
n_iter = -1
112115

113-
return row_index, em, data_out, score_data
116+
return row_index, em, data_out, score_data, n_iter
114117

115118
def _progress_indicator(self, future):
116119
"""used in multithreading to track progress of inversion"""
@@ -155,7 +158,7 @@ def _switch_to_row_inversion(self, model_config, alpha, rho, num_row_threads=50)
155158

156159
def _collect_results(self, mode_switch_thread_count, model_config, alpha, rho):
157160
for future in concurrent.futures.as_completed(self.futures):
158-
row_index, em, data_out, score_data = future.result()
161+
row_index, em, data_out, score_data, n_iter = future.result()
159162
for slit_num in range(self._num_slits):
160163
if self._smooth_over == "dependence":
161164
slit_em = em[slit_num * self._num_deps : (slit_num + 1) * self._num_deps]
@@ -164,6 +167,7 @@ def _collect_results(self, mode_switch_thread_count, model_config, alpha, rho):
164167
self._em_data[row_index, slit_num, :] = slit_em
165168
self._inversion_prediction[row_index, :] = data_out
166169
self._row_scores[row_index] = score_data
170+
self._n_iter[row_index] = n_iter
167171

168172
rows_remaining = self.total_row_count - self._completed_row_count
169173

@@ -173,7 +177,6 @@ def _collect_results(self, mode_switch_thread_count, model_config, alpha, rho):
173177

174178
def _start_row_inversion(self, model_config, alpha, rho, num_threads):
175179
self.executors = [concurrent.futures.ThreadPoolExecutor(max_workers=num_threads)]
176-
177180
self.futures = {}
178181
self._models = []
179182
for i, row_index in enumerate(range(self._detector_row_range[0], self._detector_row_range[1])):
@@ -253,6 +256,7 @@ def _initialize_with_overlappogram(self, overlappogram):
253256
self._em_data = np.zeros((self._overlappogram_height, self._num_slits, self._num_deps), dtype=np.float32)
254257
self._inversion_prediction = np.zeros((self._overlappogram_height, self._overlappogram_width), dtype=np.float32)
255258
self._row_scores = np.zeros((self._overlappogram_height, 1), dtype=np.float32)
259+
self._n_iter = np.zeros((self._overlappogram_height, 1), dtype=np.int32)
256260

257261
def invert(
258262
self,
@@ -301,4 +305,5 @@ def invert(
301305
NDCube(data=self._inversion_prediction, wcs=out_wcs, meta=self._response_meta),
302306
self._row_scores,
303307
self._unconverged_rows,
308+
self._n_iter
304309
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ requires = ["setuptools",
55

66
[project]
77
name = "overlappogram"
8-
version = "0.0.9"
8+
version = "0.0.10"
99
dependencies = ["numpy<2.0.0",
1010
"astropy",
1111
"scikit-learn",

tests/test_inversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_inversion_runs(tmp_path, inversion_mode, is_weighted):
4545
detector_row_range=config["inversion"]["detector_row_range"],
4646
)
4747

48-
em_cube, prediction, scores, unconverged_rows = inversion.invert(
48+
em_cube, prediction, scores, unconverged_rows, niter = inversion.invert(
4949
overlappogram,
5050
config["model"],
5151
3E-5,

tests/test_spectral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_create_spectrally_pure_images(tmp_path):
3434
detector_row_range=config["inversion"]["detector_row_range"],
3535
)
3636

37-
em_cube, prediction, scores, unconverged_rows = inversion.invert(
37+
em_cube, prediction, scores, unconverged_rows, _ = inversion.invert(
3838
overlappogram,
3939
config["model"],
4040
3E-5,

0 commit comments

Comments
 (0)