Skip to content

Commit 9a4d4ba

Browse files
authored
Merge pull request #20 from mgiammar/main
Resolve a circular import error
2 parents 958275b + bc2a701 commit 9a4d4ba

File tree

5 files changed

+86
-115
lines changed

5 files changed

+86
-115
lines changed

src/leopard_em/backend/core_refine_template.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515
from leopard_em.backend.utils import normalize_template_projection
1616
from leopard_em.utils.cross_correlation import handle_correlation_mode
17-
from leopard_em.utils.pre_processing import calculate_ctf_filter_stack
17+
from leopard_em.utils.filter_preprocessing import calculate_ctf_filter_stack
1818

1919
# This is assuming the Euler angles are in the ZYZ intrinsic format
2020
# AND that the angles are ordered in (phi, theta, psi)

src/leopard_em/pydantic_models/match_template_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from leopard_em.pydantic_models.orientation_search import OrientationSearchConfig
1919
from leopard_em.pydantic_models.types import BaseModel2DTM, ExcludedTensor
2020
from leopard_em.utils.data_io import load_mrc_image, load_mrc_volume
21-
from leopard_em.utils.pre_processing import calculate_ctf_filter_stack
21+
from leopard_em.utils.filter_preprocessing import calculate_ctf_filter_stack
2222

2323

2424
class MatchTemplateManager(BaseModel2DTM):

src/leopard_em/pydantic_models/particle_stack.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,8 @@ def construct_cropped_statistic_stack(
219219

220220
# with reference to the exact pixel of the statistic (top-left)
221221
# need to account for relative extracted box size
222-
pos_y = self._df.loc[indexes, "pos_y"]
223-
pos_x = self._df.loc[indexes, "pos_x"]
222+
pos_y = self._df.loc[indexes, "pos_y"].to_numpy()
223+
pos_x = self._df.loc[indexes, "pos_x"].to_numpy()
224224
pos_y = torch.tensor(pos_y)
225225
pos_x = torch.tensor(pos_x)
226226
pos_y -= (H - h) // 2

src/leopard_em/utils/__init__.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""Utilities submodule for various data and pre- and post-processing tasks."""
2+
3+
from .cross_correlation import handle_correlation_mode
4+
from .data_io import (
5+
load_mrc_image,
6+
load_mrc_volume,
7+
read_mrc_to_numpy,
8+
read_mrc_to_tensor,
9+
write_mrc_from_numpy,
10+
write_mrc_from_tensor,
11+
)
12+
from .filter_preprocessing import (
13+
Cs_to_pixel_size,
14+
calculate_ctf_filter_stack,
15+
get_Cs_range,
16+
)
17+
from .particle_stack import get_cropped_image_regions
18+
19+
__all__ = [
20+
"handle_correlation_mode",
21+
"read_mrc_to_numpy",
22+
"read_mrc_to_tensor",
23+
"write_mrc_from_numpy",
24+
"write_mrc_from_tensor",
25+
"load_mrc_image",
26+
"load_mrc_volume",
27+
"get_cropped_image_regions",
28+
"calculate_ctf_filter_stack",
29+
"get_Cs_range",
30+
"Cs_to_pixel_size",
31+
]
Lines changed: 51 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,59 @@
1-
"""Helper functions for pre-processing data for 2DTM."""
1+
"""Helper functions for the CTF filter preprocessing."""
22

33
import einops
44
import torch
55
from torch_fourier_filter.ctf import calculate_ctf_2d
66

7-
from leopard_em.pydantic_models import WhiteningFilterConfig
7+
8+
def get_Cs_range(
9+
pixel_size: float,
10+
pixel_size_offsets: torch.Tensor,
11+
Cs: float = 2.7,
12+
) -> torch.Tensor:
13+
"""Get the Cs values for a range of pixel sizes.
14+
15+
Parameters
16+
----------
17+
pixel_size : float
18+
The nominal pixel size.
19+
pixel_size_offsets : torch.Tensor
20+
The pixel size offsets.
21+
Cs : float, optional
22+
The Cs value, by default 2.7.
23+
24+
Returns
25+
-------
26+
torch.Tensor
27+
The Cs values for the range of pixel sizes.
28+
"""
29+
pixel_sizes = pixel_size + pixel_size_offsets
30+
Cs_values = Cs / torch.pow(pixel_sizes / pixel_size, 4)
31+
return Cs_values
32+
33+
34+
def Cs_to_pixel_size(
35+
Cs_vals: torch.Tensor,
36+
nominal_pixel_size: float,
37+
nominal_Cs: float = 2.7,
38+
) -> torch.Tensor:
39+
"""Convert Cs values to pixel sizes.
40+
41+
Parameters
42+
----------
43+
Cs_vals : torch.Tensor
44+
The Cs values.
45+
nominal_pixel_size : float
46+
The nominal pixel size.
47+
nominal_Cs : float, optional
48+
The nominal Cs value, by default 2.7.
49+
50+
Returns
51+
-------
52+
torch.Tensor
53+
The pixel sizes.
54+
"""
55+
pixel_size = torch.pow(nominal_Cs / Cs_vals, 0.25) * nominal_pixel_size
56+
return pixel_size
857

958

1059
def calculate_ctf_filter_stack(
@@ -98,112 +147,3 @@ def calculate_ctf_filter_stack(
98147
# The CTF will have a shape of (n_Cs n_defoc, nx, ny)
99148
# These will catch any potential errors
100149
return ctf
101-
102-
103-
def do_image_preprocessing(
104-
image_rfft: torch.Tensor,
105-
wf_config: WhiteningFilterConfig,
106-
) -> torch.Tensor:
107-
"""Pre-processes the input image before running the algorithm.
108-
109-
1. Zero central pixel (0, 0)
110-
2. Calculate a whitening filter
111-
3. Do element-wise multiplication with the whitening filter
112-
4. Zero central pixel again (superfluous, but following cisTEM)
113-
5. Normalize (x /= sqrt(sum(abs(x)**2)); pixelwise)
114-
115-
Parameters
116-
----------
117-
image_rfft : torch.Tensor
118-
The input image, RFFT'd and unshifted.
119-
wf_config : WhiteningFilterConfig
120-
The configuration for the whitening filter.
121-
122-
Returns
123-
-------
124-
torch.Tensor
125-
The pre-processed image.
126-
127-
"""
128-
H, W = image_rfft.shape
129-
W = (W - 1) * 2 # Account for RFFT
130-
npix_real = H * W
131-
132-
# Zero out the constant term
133-
image_rfft[0, 0] = 0 + 0j
134-
135-
wf_image = wf_config.calculate_whitening_filter(
136-
ref_img_rfft=image_rfft,
137-
output_shape=image_rfft.shape,
138-
)
139-
image_rfft *= wf_image
140-
image_rfft[0, 0] = 0 + 0j # superfluous, but following cisTEM
141-
142-
# NOTE: Extra indexing happening with squared_sum so that Hermitian pairs are
143-
# counted, but we skip the first column of the RFFT which should not be duplicated.
144-
squared_image_rfft = torch.abs(image_rfft) ** 2
145-
squared_sum = squared_image_rfft.sum() + squared_image_rfft[:, 1:].sum()
146-
image_rfft /= torch.sqrt(squared_sum)
147-
148-
# # real-space image will now have mean=0 and variance=1
149-
# image_rfft *= npix_real # NOTE: This would set the variance to 1 exactly, but...
150-
151-
# NOTE: We add on extra division by sqrt(num_pixels) so the cross-correlograms
152-
# are roughly normalized to have mean 0 and variance 1.
153-
# We do this here since Fourier transform is linear, and we don't have to multiply
154-
# the cross correlation at each iteration. This *will not* make the image
155-
# have variance 1.
156-
image_rfft *= npix_real**0.5
157-
158-
return image_rfft
159-
160-
161-
def get_Cs_range(
162-
pixel_size: float,
163-
pixel_size_offsets: torch.Tensor,
164-
Cs: float = 2.7,
165-
) -> torch.Tensor:
166-
"""Get the Cs values for a range of pixel sizes.
167-
168-
Parameters
169-
----------
170-
pixel_size : float
171-
The nominal pixel size.
172-
pixel_size_offsets : torch.Tensor
173-
The pixel size offsets.
174-
Cs : float, optional
175-
The Cs value, by default 2.7.
176-
177-
Returns
178-
-------
179-
torch.Tensor
180-
The Cs values for the range of pixel sizes.
181-
"""
182-
pixel_sizes = pixel_size + pixel_size_offsets
183-
Cs_values = Cs / torch.pow(pixel_sizes / pixel_size, 4)
184-
return Cs_values
185-
186-
187-
def Cs_to_pixel_size(
188-
Cs_vals: torch.Tensor,
189-
nominal_pixel_size: float,
190-
nominal_Cs: float = 2.7,
191-
) -> torch.Tensor:
192-
"""Convert Cs values to pixel sizes.
193-
194-
Parameters
195-
----------
196-
Cs_vals : torch.Tensor
197-
The Cs values.
198-
nominal_pixel_size : float
199-
The nominal pixel size.
200-
nominal_Cs : float, optional
201-
The nominal Cs value, by default 2.7.
202-
203-
Returns
204-
-------
205-
torch.Tensor
206-
The pixel sizes.
207-
"""
208-
pixel_size = torch.pow(nominal_Cs / Cs_vals, 0.25) * nominal_pixel_size
209-
return pixel_size

0 commit comments

Comments
 (0)