Skip to content

Commit 816e7f6

Browse files
committed
allows weightless mode
1 parent ebf5080 commit 816e7f6

File tree

4 files changed

+27
-10
lines changed

4 files changed

+27
-10
lines changed

overlappogram/cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def unfold(config):
3131

3232
os.makedirs(config["output"]["directory"], exist_ok=True) # make sure output directory exists
3333

34-
overlappogram = load_overlappogram(config["paths"]["overlappogram"], config["paths"]["weights"])
34+
overlappogram = load_overlappogram(config["paths"]["overlappogram"],
35+
config["paths"]["weights"] if 'weights' in config['paths'] else None)
3536
response_cube = load_response_cube(config["paths"]["response"])
3637

3738
inversion = Inverter(

overlappogram/error.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
class OverlappogramWarning(Warning):
2+
pass
3+
4+
5+
class NoWeightsWarnings(OverlappogramWarning):
6+
"""There are no weights passed to unfold."""

overlappogram/inversion.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sklearn.linear_model import ElasticNet
1212
from tqdm import tqdm
1313

14+
from overlappogram.error import NoWeightsWarnings
1415
from overlappogram.response import prepare_response_function
1516

1617
__all__ = ["Inverter"]
@@ -226,6 +227,9 @@ def _start_chunk_inversion(self, model_config, alpha, rho, num_threads):
226227
def _initialize_with_overlappogram(self, overlappogram):
227228
self._overlappogram = overlappogram
228229

230+
if self._overlappogram.uncertainty is None:
231+
warnings.warn("Running in weightless mode since no weights array was provided.", NoWeightsWarnings)
232+
229233
if self._detector_row_range is None:
230234
self._detector_row_range = (0, overlappogram.data.shape[0])
231235
self.total_row_count = self._detector_row_range[1] - self._detector_row_range[0]
@@ -239,7 +243,7 @@ def _initialize_with_overlappogram(self, overlappogram):
239243

240244
self._overlappogram.data[np.where(self._overlappogram.data < 0.0)] = 0.0
241245

242-
# initialize all result cubes
246+
# initialize all results cubes
243247
self._overlappogram_height, self._overlappogram_width = self._overlappogram.data.shape
244248
self._em_data = np.zeros((self._overlappogram_height, self._num_slits, self._num_deps), dtype=np.float32)
245249
self._inversion_prediction = np.zeros((self._overlappogram_height, self._overlappogram_width), dtype=np.float32)

overlappogram/io.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,23 @@
66
__all__ = ["load_overlappogram", "load_response_cube", "save_em_cube", "save_spectral_cube", "save_prediction"]
77

88

9-
def load_overlappogram(image_path, weights_path) -> NDCube:
9+
def load_overlappogram(image_path: str, weights_path: str | None) -> NDCube:
1010
with fits.open(image_path) as image_hdul:
1111
image = image_hdul[0].data
1212
header = image_hdul[0].header
1313
wcs = WCS(image_hdul[0].header)
14-
with fits.open(weights_path) as weights_hdul:
15-
weights = weights_hdul[0].data
16-
return NDCube(image, wcs=wcs, uncertainty=StdDevUncertainty(1 / weights), meta=dict(header))
1714

15+
if weights_path is None:
16+
uncertainty = None
17+
else:
18+
with fits.open(weights_path) as weights_hdul:
19+
weights = weights_hdul[0].data
20+
uncertainty = StdDevUncertainty(1 / weights)
1821

19-
def load_response_cube(path) -> NDCube:
22+
return NDCube(image, wcs=wcs, uncertainty=uncertainty, meta=dict(header))
23+
24+
25+
def load_response_cube(path: str) -> NDCube:
2026
with fits.open(path) as hdul:
2127
response = hdul[0].data
2228
header = hdul[0].header
@@ -28,13 +34,13 @@ def load_response_cube(path) -> NDCube:
2834
return NDCube(response, wcs=wcs, meta=meta)
2935

3036

31-
def save_em_cube(cube, path, overwrite=True) -> None:
37+
def save_em_cube(cube, path: str, overwrite: bool = True) -> None:
3238
fits.writeto(path, cube, overwrite=overwrite)
3339

3440

35-
def save_prediction(prediction, path, overwrite=True) -> None:
41+
def save_prediction(prediction, path: str, overwrite: bool = True) -> None:
3642
fits.writeto(path, prediction, overwrite=overwrite)
3743

3844

39-
def save_spectral_cube(spectral_cube, path, overwrite=True) -> None:
45+
def save_spectral_cube(spectral_cube, path: str, overwrite: bool = True) -> None:
4046
fits.writeto(path, spectral_cube, overwrite=overwrite)

0 commit comments

Comments
 (0)