Skip to content

Commit b4a5ebb

Browse files
committed
FIX: fixed issue with getting the masks from the probe positions. The simulated probe used for masks was being created incorrectly.
1 parent 03f670a commit b4a5ebb

File tree

4 files changed

+32
-68
lines changed

4 files changed

+32
-68
lines changed

src/pyxalign/api/options/projections.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class VolumeWidthOptions:
8888

8989

9090
@dataclasses.dataclass
91-
class SimulatedProbe:
91+
class SimulatedProbeOptions:
9292
"""
9393
Parameters for creating a gaussian probe
9494
"""
@@ -114,7 +114,7 @@ class ProbePositionMaskOptions:
114114
use the probe in the Projections object.
115115
"""
116116

117-
probe: SimulatedProbe = field(default_factory=SimulatedProbe)
117+
probe: SimulatedProbeOptions = field(default_factory=SimulatedProbeOptions)
118118

119119

120120
@dataclasses.dataclass

src/pyxalign/data_structures/projections.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import pyxalign.gpu_utils as gpu_utils
3636
from pyxalign.gpu_wrapper import device_handling_wrapper
3737
from pyxalign.data_structures.volume import Volume
38-
from pyxalign.mask import build_masks_from_threshold
38+
from pyxalign.mask import build_masks_from_threshold, get_simulated_probe_for_masks
3939
from pyxalign.io.utils import load_list_of_arrays
4040
from pyxalign.io.save import save_generic_data_structure_to_h5
4141

@@ -381,9 +381,15 @@ def center_projections(self):
381381

382382
@timer()
383383
def get_masks_from_probe_positions(self):
384+
if self.options.mask_from_positions.use_simulated_probe:
385+
probe = get_simulated_probe_for_masks(
386+
self.probe, self.options.mask_from_positions.probe
387+
)
388+
else:
389+
probe = self.probe
384390
self.masks = build_masks_from_threshold(
385391
self.data.shape,
386-
self.probe,
392+
probe,
387393
self.probe_positions.data,
388394
self.options.mask_from_positions.threshold,
389395
)

src/pyxalign/interactions/mask.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,12 @@
1515
QPushButton,
1616
)
1717

18-
from pyxalign.api.enums import RoundType
1918
import pyxalign.data_structures.projections as p
2019
from pyxalign.interactions.utils.loading_display_tools import loading_bar_wrapper
2120
from pyxalign.interactions.utils.misc import switch_to_matplotlib_qt_backend
22-
from pyxalign.mask import place_patches_fourier_batch
21+
from pyxalign.mask import get_simulated_probe_for_masks, place_patches_fourier_batch
2322
from pyxalign.interactions.viewers.base import IndexSelectorWidget
2423
from pyxalign.mask import clip_masks
25-
from pyxalign.model_functions import symmetric_gaussian_2d
26-
from pyxalign.transformations.helpers import round_to_divisor
27-
from pyxalign.api import constants
2824

2925
"""
3026
Interactive mask threshold selector based on pyqtgraph and the shared
@@ -153,23 +149,14 @@ def __init__(
153149
self.projections = projections
154150
self.options = self.projections.options.mask_from_positions
155151

156-
# use simulated probe if specified by options; this typically
157-
# gives better results
158152
if self.options.use_simulated_probe:
159-
shape = self.projections.probe.shape
160-
probe_width = round_to_divisor(
161-
shape[0] * self.options.probe.fractional_width,
162-
round_type=RoundType.NEAREST,
163-
divisor=constants.divisor
164-
)
165-
probe = symmetric_gaussian_2d(shape, amplitude=1, sigma=probe_width)
153+
probe = get_simulated_probe_for_masks(self.projections.probe, self.options.probe)
166154
else:
167155
probe = self.projections.probe
168-
169156
# Precompute masks (floating-point values)
170-
load_bar_func_wrapper = loading_bar_wrapper("Initializing masks...")(
171-
place_patches_fourier_batch
172-
)
157+
load_bar_func_wrapper = loading_bar_wrapper(
158+
"Initializing masks...", block_all_windows=True
159+
)(place_patches_fourier_batch)
173160
masks = load_bar_func_wrapper(
174161
self.projections.data.shape,
175162
probe,

src/pyxalign/mask.py

Lines changed: 17 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
from tqdm import tqdm
1212
from contextlib import nullcontext
1313
from pyxalign import gpu_utils
14+
from pyxalign.api.enums import RoundType
1415
from pyxalign.api.options.device import DeviceOptions
16+
from pyxalign.api.options.projections import SimulatedProbeOptions
1517
from pyxalign.gpu_wrapper import device_handling_wrapper
1618

1719
# from pyxalign.interactions.mask import ThresholdSelector, illum_map_threshold_plotter
18-
from pyxalign.transformations.helpers import is_array_real
20+
from pyxalign.model_functions import symmetric_gaussian_2d
21+
from pyxalign.transformations.helpers import is_array_real, round_to_divisor
1922
from IPython.display import display
20-
from PyQt5.QtWidgets import QApplication
2123
import plotly.graph_objects as go
2224
from plotly.subplots import make_subplots
2325

@@ -26,8 +28,6 @@
2628
from pyxalign.timing.timer_utils import timer, InlineTimer
2729
from pyxalign.api.types import ArrayType, r_type
2830

29-
from PyQt5.QtWidgets import QWidget
30-
3131

3232
@memory_releasing_error_handler
3333
@timer()
@@ -387,45 +387,16 @@ def build_masks_from_threshold(
387387
return clip_masks(masks, threshold)
388388

389389

390-
# class IlluminationMapMaskBuilder:
391-
# """
392-
# Class for building mask from the illumination map.
393-
# """
394-
395-
# def get_mask_base(
396-
# self,
397-
# probe: np.ndarray,
398-
# positions: list[np.ndarray],
399-
# projections: np.ndarray,
400-
# use_fourier: bool = True,
401-
# ):
402-
# # The base for building the mask is the illumination map
403-
# if use_fourier:
404-
# self.masks = place_patches_fourier_batch(projections.shape, probe, positions)
405-
# else:
406-
# for i in range(len(positions)):
407-
# self.masks = np.zeros_like(projections, dtype=r_type)
408-
# get_illumination_map(self.masks[i], probe, positions[i])
409-
410-
# def set_mask_threshold_interactively(self, projections: np.ndarray) -> float:
411-
# # temporary bugfix: all windows need to be closed or else app.exec_() will
412-
# # hang indefinitely. I am putting this temporary solution (which I don't like
413-
# # very much) in place, because any changes will be overwritten once merged with
414-
# # interactive_pma_gui anyway.
415-
# app = QApplication.instance() or QApplication([])
416-
# app.closeAllWindows()
417-
418-
# # Use interactivity to decide mask threshold"
419-
# self.threshold_selector = illum_map_threshold_plotter(
420-
# self.masks, projections, init_thresh=0.01
421-
# )
422-
# self.threshold_selector.show()
423-
424-
# app.exec_()
425-
# threshold = self.threshold_selector.threshold
426-
# return threshold
427-
428-
# def clip_masks(self, thresh: Optional[float] = None):
429-
# clip_idx = self.masks > thresh
430-
# self.masks[:] = 0
431-
# self.masks[clip_idx] = 1
390+
def get_simulated_probe_for_masks(
391+
probe: np.ndarray,
392+
simulated_probe_options: SimulatedProbeOptions,
393+
):
394+
# simulated probe typically gives better results
395+
shape = probe.shape
396+
probe_width = round_to_divisor(
397+
shape[0] * simulated_probe_options.fractional_width,
398+
round_type=RoundType.NEAREST,
399+
divisor=2,
400+
)
401+
probe = symmetric_gaussian_2d(shape, amplitude=1, sigma=probe_width)
402+
return probe

0 commit comments

Comments
 (0)