Skip to content

Commit 64d8e31

Browse files
authored
Merge pull request #161 from DiamondLightSource/remdev
distortion correction preview refactoring
2 parents 74f47da + 0671389 commit 64d8e31

File tree

3 files changed

+32
-69
lines changed

3 files changed

+32
-69
lines changed

.github/workflows/httomolibgpu_nightly_build.yml

Lines changed: 0 additions & 40 deletions
This file was deleted.

httomolibgpu/prep/alignment.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
else:
3434
map_coordinates = Mock()
3535

36-
from typing import Dict, List
36+
from typing import Dict, List, Tuple
3737

3838
__all__ = [
3939
"distortion_correction_proj_discorpy",
@@ -48,9 +48,10 @@
4848
def distortion_correction_proj_discorpy(
4949
data: cp.ndarray,
5050
metadata_path: str,
51-
preview: Dict[str, List[int]],
52-
order: int = 1,
53-
mode: str = "reflect",
51+
shift_xy: List[int] = [0, 0],
52+
step_xy: List[int] = [1, 1],
53+
order: int = 3,
54+
mode: str = "constant",
5455
):
5556
"""Unwarp a stack of images using a backward model. See :cite:`vo2015radial`.
5657
@@ -63,18 +64,18 @@ def distortion_correction_proj_discorpy(
6364
The path to the file containing the distortion coefficients for the
6465
data.
6566
66-
preview : Dict[str, List[int]]
67-
A dict containing three key-value pairs:
68-
- a list containing the `start` value of each dimension
69-
- a list containing the `stop` value of each dimension
70-
- a list containing the `step` value of each dimension
67+
shift_xy: List[int]
68+
Centers of distortion in x (from the left of the image) and y directions (from the top of the image).
7169
72-
order : int, optional.
73-
The order of the spline interpolation.
70+
step_xy: List[int]
71+
Steps in x and y directions respectively. They need to be not larger than one.
7472
75-
mode : {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest',
76-
'mirror', 'grid-wrap', 'wrap'}, optional
77-
To determine how to handle image boundaries.
73+
order : int, optional
74+
The order of the spline interpolation, default is 3. Must be in the range 0-5.
75+
76+
mode : str, optional
77+
Points outside the boundaries of the input are filled according to the given mode
78+
('constant', 'nearest', 'mirror', 'reflect', 'wrap', 'grid-mirror', 'grid-wrap', 'grid-constant' or 'opencv').
7879
7980
Returns
8081
-------
@@ -90,26 +91,25 @@ def distortion_correction_proj_discorpy(
9091

9192
# Use preview information to offset the x and y coords of the center of
9293
# distortion
93-
shift = preview["starts"]
94-
step = preview["steps"]
95-
x_dim = 1
96-
y_dim = 0
97-
step_check = max([step[i] for i in [x_dim, y_dim]]) > 1
98-
if step_check:
94+
det_x_step = step_xy[0]
95+
det_y_step = step_xy[1]
96+
97+
if det_y_step > 1 or det_x_step > 1:
9998
msg = (
10099
"\n***********************************************\n"
101-
"!!! ERROR !!! -> Method doesn't work with the step in"
102-
" the preview larger than 1 \n"
100+
"!!! ERROR !!! -> Method doesn't work with the step parameter"
101+
" larger than 1 \n"
103102
"***********************************************\n"
104103
)
105104
raise ValueError(msg)
106105

107-
x_offset = shift[x_dim]
108-
y_offset = shift[y_dim]
109-
xcenter = xcenter - x_offset
110-
ycenter = ycenter - y_offset
106+
det_x_shift = shift_xy[0]
107+
det_y_shift = shift_xy[1]
108+
109+
xcenter = xcenter - det_x_shift
110+
ycenter = ycenter - det_y_shift
111111

112-
height, width = data.shape[y_dim + 1], data.shape[x_dim + 1]
112+
height, width = data.shape[1], data.shape[2]
113113
xu_list = cp.arange(width) - xcenter
114114
yu_list = cp.arange(height) - ycenter
115115
xu_mat, yu_mat = cp.meshgrid(xu_list, yu_list)

tests/test_prep/test_alignment.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,11 @@ def test_correct_distortion(
4040
im_host = imread(path)
4141
im = cp.asarray(im_host)
4242

43-
preview = {"starts": [0, 0], "stops": [im.shape[0], im.shape[1]], "steps": [1, 1]}
44-
corrected_data = implementation(im, distortion_coeffs_path, preview).get()
43+
shift_xy = [0, 0]
44+
step_xy = [1, 1]
45+
corrected_data = implementation(
46+
im, distortion_coeffs_path, shift_xy, step_xy, order=1, mode="reflect"
47+
).get()
4548

4649
assert_allclose(np.mean(corrected_data), mean_value)
4750
assert np.max(corrected_data) == max_value

0 commit comments

Comments
 (0)