|
1 | | -"""Helper functions for pre-processing data for 2DTM.""" |
| 1 | +"""Helper functions for the CTF filter preprocessing.""" |
2 | 2 |
|
3 | 3 | import einops |
4 | 4 | import torch |
5 | 5 | from torch_fourier_filter.ctf import calculate_ctf_2d |
6 | 6 |
|
7 | | -from leopard_em.pydantic_models import WhiteningFilterConfig |
| 7 | + |
| 8 | +def get_Cs_range( |
| 9 | + pixel_size: float, |
| 10 | + pixel_size_offsets: torch.Tensor, |
| 11 | + Cs: float = 2.7, |
| 12 | +) -> torch.Tensor: |
| 13 | + """Get the Cs values for a range of pixel sizes. |
| 14 | +
|
| 15 | + Parameters |
| 16 | + ---------- |
| 17 | + pixel_size : float |
| 18 | + The nominal pixel size. |
| 19 | + pixel_size_offsets : torch.Tensor |
| 20 | + The pixel size offsets. |
| 21 | + Cs : float, optional |
| 22 | + The Cs value, by default 2.7. |
| 23 | +
|
| 24 | + Returns |
| 25 | + ------- |
| 26 | + torch.Tensor |
| 27 | + The Cs values for the range of pixel sizes. |
| 28 | + """ |
| 29 | + pixel_sizes = pixel_size + pixel_size_offsets |
| 30 | + Cs_values = Cs / torch.pow(pixel_sizes / pixel_size, 4) |
| 31 | + return Cs_values |
| 32 | + |
| 33 | + |
| 34 | +def Cs_to_pixel_size( |
| 35 | + Cs_vals: torch.Tensor, |
| 36 | + nominal_pixel_size: float, |
| 37 | + nominal_Cs: float = 2.7, |
| 38 | +) -> torch.Tensor: |
| 39 | + """Convert Cs values to pixel sizes. |
| 40 | +
|
| 41 | + Parameters |
| 42 | + ---------- |
| 43 | + Cs_vals : torch.Tensor |
| 44 | + The Cs values. |
| 45 | + nominal_pixel_size : float |
| 46 | + The nominal pixel size. |
| 47 | + nominal_Cs : float, optional |
| 48 | + The nominal Cs value, by default 2.7. |
| 49 | +
|
| 50 | + Returns |
| 51 | + ------- |
| 52 | + torch.Tensor |
| 53 | + The pixel sizes. |
| 54 | + """ |
| 55 | + pixel_size = torch.pow(nominal_Cs / Cs_vals, 0.25) * nominal_pixel_size |
| 56 | + return pixel_size |
8 | 57 |
|
9 | 58 |
|
10 | 59 | def calculate_ctf_filter_stack( |
@@ -98,112 +147,3 @@ def calculate_ctf_filter_stack( |
98 | 147 | # The CTF will have a shape of (n_Cs n_defoc, nx, ny) |
99 | 148 | # These will catch any potential errors |
100 | 149 | return ctf |
101 | | - |
102 | | - |
103 | | -def do_image_preprocessing( |
104 | | - image_rfft: torch.Tensor, |
105 | | - wf_config: WhiteningFilterConfig, |
106 | | -) -> torch.Tensor: |
107 | | - """Pre-processes the input image before running the algorithm. |
108 | | -
|
109 | | - 1. Zero central pixel (0, 0) |
110 | | - 2. Calculate a whitening filter |
111 | | - 3. Do element-wise multiplication with the whitening filter |
112 | | - 4. Zero central pixel again (superfluous, but following cisTEM) |
113 | | - 5. Normalize (x /= sqrt(sum(abs(x)**2)); pixelwise) |
114 | | -
|
115 | | - Parameters |
116 | | - ---------- |
117 | | - image_rfft : torch.Tensor |
118 | | - The input image, RFFT'd and unshifted. |
119 | | - wf_config : WhiteningFilterConfig |
120 | | - The configuration for the whitening filter. |
121 | | -
|
122 | | - Returns |
123 | | - ------- |
124 | | - torch.Tensor |
125 | | - The pre-processed image. |
126 | | -
|
127 | | - """ |
128 | | - H, W = image_rfft.shape |
129 | | - W = (W - 1) * 2 # Account for RFFT |
130 | | - npix_real = H * W |
131 | | - |
132 | | - # Zero out the constant term |
133 | | - image_rfft[0, 0] = 0 + 0j |
134 | | - |
135 | | - wf_image = wf_config.calculate_whitening_filter( |
136 | | - ref_img_rfft=image_rfft, |
137 | | - output_shape=image_rfft.shape, |
138 | | - ) |
139 | | - image_rfft *= wf_image |
140 | | - image_rfft[0, 0] = 0 + 0j # superfluous, but following cisTEM |
141 | | - |
142 | | - # NOTE: Extra indexing happening with squared_sum so that Hermitian pairs are |
143 | | - # counted, but we skip the first column of the RFFT which should not be duplicated. |
144 | | - squared_image_rfft = torch.abs(image_rfft) ** 2 |
145 | | - squared_sum = squared_image_rfft.sum() + squared_image_rfft[:, 1:].sum() |
146 | | - image_rfft /= torch.sqrt(squared_sum) |
147 | | - |
148 | | - # # real-space image will now have mean=0 and variance=1 |
149 | | - # image_rfft *= npix_real # NOTE: This would set the variance to 1 exactly, but... |
150 | | - |
151 | | - # NOTE: We add on extra division by sqrt(num_pixels) so the cross-correlograms |
152 | | - # are roughly normalized to have mean 0 and variance 1. |
153 | | - # We do this here since Fourier transform is linear, and we don't have to multiply |
154 | | - # the cross correlation at each iteration. This *will not* make the image |
155 | | - # have variance 1. |
156 | | - image_rfft *= npix_real**0.5 |
157 | | - |
158 | | - return image_rfft |
159 | | - |
160 | | - |
161 | | -def get_Cs_range( |
162 | | - pixel_size: float, |
163 | | - pixel_size_offsets: torch.Tensor, |
164 | | - Cs: float = 2.7, |
165 | | -) -> torch.Tensor: |
166 | | - """Get the Cs values for a range of pixel sizes. |
167 | | -
|
168 | | - Parameters |
169 | | - ---------- |
170 | | - pixel_size : float |
171 | | - The nominal pixel size. |
172 | | - pixel_size_offsets : torch.Tensor |
173 | | - The pixel size offsets. |
174 | | - Cs : float, optional |
175 | | - The Cs value, by default 2.7. |
176 | | -
|
177 | | - Returns |
178 | | - ------- |
179 | | - torch.Tensor |
180 | | - The Cs values for the range of pixel sizes. |
181 | | - """ |
182 | | - pixel_sizes = pixel_size + pixel_size_offsets |
183 | | - Cs_values = Cs / torch.pow(pixel_sizes / pixel_size, 4) |
184 | | - return Cs_values |
185 | | - |
186 | | - |
187 | | -def Cs_to_pixel_size( |
188 | | - Cs_vals: torch.Tensor, |
189 | | - nominal_pixel_size: float, |
190 | | - nominal_Cs: float = 2.7, |
191 | | -) -> torch.Tensor: |
192 | | - """Convert Cs values to pixel sizes. |
193 | | -
|
194 | | - Parameters |
195 | | - ---------- |
196 | | - Cs_vals : torch.Tensor |
197 | | - The Cs values. |
198 | | - nominal_pixel_size : float |
199 | | - The nominal pixel size. |
200 | | - nominal_Cs : float, optional |
201 | | - The nominal Cs value, by default 2.7. |
202 | | -
|
203 | | - Returns |
204 | | - ------- |
205 | | - torch.Tensor |
206 | | - The pixel sizes. |
207 | | - """ |
208 | | - pixel_size = torch.pow(nominal_Cs / Cs_vals, 0.25) * nominal_pixel_size |
209 | | - return pixel_size |
0 commit comments