Skip to content

Commit 1712647

Browse files
committed
distortion correction preview refactoring
1 parent 9cf81e4 commit 1712647

File tree

2 files changed

+21
-20
lines changed

2 files changed

+21
-20
lines changed

httomolibgpu/prep/alignment.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
else:
3434
map_coordinates = Mock()
3535

36-
from typing import Dict, List
36+
from typing import Dict, List, Tuple
3737

3838
__all__ = [
3939
"distortion_correction_proj_discorpy",
@@ -48,7 +48,8 @@
4848
def distortion_correction_proj_discorpy(
4949
data: cp.ndarray,
5050
metadata_path: str,
51-
preview: Dict[str, List[int]],
51+
shift: Tuple[int, int] = (0, 0),
52+
step: Tuple[int, int] = (1, 1),
5253
order: int = 1,
5354
mode: str = "reflect",
5455
):
@@ -63,11 +64,11 @@ def distortion_correction_proj_discorpy(
6364
The path to the file containing the distortion coefficients for the
6465
data.
6566
66-
preview : Dict[str, List[int]]
67-
A dict containing three key-value pairs:
68-
- a list containing the `start` value of each dimension
69-
- a list containing the `stop` value of each dimension
70-
- a list containing the `step` value of each dimension
67+
shift: tuple, optional
68+
Centers of distortion in x (from the left of the image) and y directions (from the top of the image).
69+
70+
step: tuple, optional
71+
Steps in x and y directions respectively. They need to be not larger than one.
7172
7273
order : int, optional.
7374
The order of the spline interpolation.
@@ -90,12 +91,10 @@ def distortion_correction_proj_discorpy(
9091

9192
# Use preview information to offset the x and y coords of the center of
9293
# distortion
93-
shift = preview["starts"]
94-
step = preview["steps"]
95-
x_dim = 1
96-
y_dim = 0
97-
step_check = max([step[i] for i in [x_dim, y_dim]]) > 1
98-
if step_check:
94+
det_x_step = step[0]
95+
det_y_step = step[1]
96+
97+
if det_y_step > 1 or det_x_step > 1:
9998
msg = (
10099
"\n***********************************************\n"
101100
"!!! ERROR !!! -> Method doesn't work with the step in"
@@ -104,12 +103,13 @@ def distortion_correction_proj_discorpy(
104103
)
105104
raise ValueError(msg)
106105

107-
x_offset = shift[x_dim]
108-
y_offset = shift[y_dim]
109-
xcenter = xcenter - x_offset
110-
ycenter = ycenter - y_offset
106+
det_x_shift = shift[0]
107+
det_y_shift = shift[1]
108+
109+
xcenter = xcenter - det_x_shift
110+
ycenter = ycenter - det_y_shift
111111

112-
height, width = data.shape[y_dim + 1], data.shape[x_dim + 1]
112+
height, width = data.shape[1], data.shape[2]
113113
xu_list = cp.arange(width) - xcenter
114114
yu_list = cp.arange(height) - ycenter
115115
xu_mat, yu_mat = cp.meshgrid(xu_list, yu_list)

tests/test_prep/test_alignment.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ def test_correct_distortion(
4040
im_host = imread(path)
4141
im = cp.asarray(im_host)
4242

43-
preview = {"starts": [0, 0], "stops": [im.shape[0], im.shape[1]], "steps": [1, 1]}
44-
corrected_data = implementation(im, distortion_coeffs_path, preview).get()
43+
shift = (0, 0)
44+
step = (1, 1)
45+
corrected_data = implementation(im, distortion_coeffs_path, shift, step).get()
4546

4647
assert_allclose(np.mean(corrected_data), mean_value)
4748
assert np.max(corrected_data) == max_value

0 commit comments

Comments
 (0)