Skip to content

Commit d904e73

Browse files
authored
Merge pull request #68 from mgiammar/main
Sync minor test and linting updates upstream
2 parents c14df84 + 17de2e0 commit d904e73

File tree

3 files changed

+29
-15
lines changed

3 files changed

+29
-15
lines changed

src/leopard_em/utils/fourier_slice.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import roma
44
import torch
55
from torch_fourier_slice import extract_central_slices_rfft_3d
6-
from torch_fourier_slice._dft_utils import _fftshift_3d, _ifftshift_2d
76
from torch_grid_utils import fftfreq_grid
87

98

@@ -35,11 +34,12 @@ def _rfft_slices_to_real_projections(
3534
torch.Tensor
3635
The real-space projections.
3736
"""
38-
fourier_slices = _ifftshift_2d(fourier_slices, rfft=True)
37+
# pylint: disable=not-callable
38+
fourier_slices = torch.fft.fftshift(fourier_slices, dim=(-2,))
3939
# pylint: disable=not-callable
4040
projections = torch.fft.irfftn(fourier_slices, dim=(-2, -1))
41-
projections = _ifftshift_2d(projections, rfft=False)
42-
41+
# pylint: disable=not-callable
42+
projections = torch.fft.ifftshift(projections, dim=(-2, -1))
4343
return projections
4444

4545

@@ -72,10 +72,9 @@ def get_rfft_slices_from_volume(
7272
7373
"""
7474
shape = volume.shape
75-
volume_rfft = _fftshift_3d(volume, rfft=False)
76-
# pylint: disable=not-callable
77-
volume_rfft = torch.fft.fftn(volume_rfft, dim=(-3, -2, -1))
78-
volume_rfft = _fftshift_3d(volume_rfft, rfft=True)
75+
volume_rfft = torch.fft.fftshift(volume, dim=(-3, -2, -1)) # pylint: disable=not-callable
76+
volume_rfft = torch.fft.fftn(volume_rfft, dim=(-3, -2, -1)) # pylint: disable=not-callable
77+
volume_rfft = torch.fft.fftshift(volume_rfft, dim=(-3, -2)) # pylint: disable=not-callable
7978

8079
# Use roma to keep angles on same device
8180
rot_matrix = roma.euler_to_rotmat("ZYZ", (phi, theta, psi), degrees=degrees)

tests/self_consistency/test_backend_cross_correlate.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch
2323
from scipy.ndimage import gaussian_filter
2424
from torch_fourier_filter.ctf import calculate_ctf_2d
25+
from torch_fourier_filter.envelopes import b_envelope
2526

2627
from leopard_em.backend.cross_correlation import (
2728
do_batched_orientation_cross_correlate,
@@ -70,7 +71,7 @@ def sample_input_data() -> dict[str, torch.Tensor]:
7071
template_fft = torch.fft.fftshift(template_fft, dim=(0, 1))
7172

7273
# Generate a set of projective filters (CTFs) for the template
73-
defocus_values = torch.linspace(500, 1500, NUM_DEFOCUS_VALUES)
74+
defocus_values = torch.linspace(2000, 4000, NUM_DEFOCUS_VALUES)
7475
pixel_sizes = torch.linspace(0.8, 1.2, NUM_PIXEL_SIZES)
7576
cs_values = get_cs_range(1.0, pixel_sizes, 2.7)
7677

@@ -83,7 +84,6 @@ def sample_input_data() -> dict[str, torch.Tensor]:
8384
voltage=300, # 300 kV
8485
spherical_aberration=cs_val,
8586
amplitude_contrast=0.07,
86-
b_factor=100.0,
8787
phase_shift=0.0,
8888
pixel_size=1.0,
8989
image_shape=template.shape[-2:],
@@ -92,7 +92,17 @@ def sample_input_data() -> dict[str, torch.Tensor]:
9292
)
9393
ctf_list.append(tmp)
9494

95+
# Apply a b-factor envelope to the CTFs
96+
b_factor = 100.0 # arbitrary value
97+
b_envelope_values = b_envelope(
98+
B=b_factor,
99+
image_shape=template.shape[-2:],
100+
pixel_size=1.0,
101+
device="cuda",
102+
)
103+
95104
projective_filters = torch.stack(ctf_list, dim=0).to(device="cuda")
105+
projective_filters *= b_envelope_values[None, None]
96106

97107
return {
98108
"image_dft": image_fft,

tests/utils/test_crop_extraction.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,23 +61,28 @@ def test_get_cropped_image_regions_numpy_fixed_positions():
6161
def test_get_cropped_image_regions_numpy_random_nonoverlapping():
6262
"""Random non-overlapping pos for _get_cropped_image_regions_numpy function."""
6363
box_size = (5, 5)
64-
num_patches = 10
64+
num_patches = 32
6565
image_size = (256, 256)
6666
test_patch = get_test_patch(box_size)
6767
image = np.zeros(image_size, dtype=np.float32)
6868

6969
# Generate non-overlapping positions
7070
positions = []
7171
for _ in range(num_patches):
72-
while True:
72+
total_failures = 0
73+
while total_failures < 100:
7374
y = np.random.randint(0, image_size[0] - box_size[0] + 1)
7475
x = np.random.randint(0, image_size[1] - box_size[1] + 1)
75-
if all(
76-
not (y <= py < y + box_size[0] and x <= px < x + box_size[1])
76+
# Check if new position overlaps with any existing position
77+
overlap = any(
78+
abs(y - py) < box_size[0] and abs(x - px) < box_size[1]
7779
for py, px in positions
78-
):
80+
)
81+
82+
if not overlap:
7983
positions.append((y, x))
8084
break
85+
total_failures += 1
8186

8287
pos_y, pos_x = zip(*positions)
8388

0 commit comments

Comments
 (0)