Skip to content

Commit c6f1247

Browse files
authored
Merge pull request #752 from ChristosT/integrate-quad-dexelas-distortion
Integrate quad dexela distortion
2 parents 047ce00 + 88b3371 commit c6f1247

File tree

4 files changed

+155
-23
lines changed

4 files changed

+155
-23
lines changed

hexrd/distortion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from . import nyi
1010
from . import ge_41rt
1111
from . import dexela_2923
12+
from . import dexela_2923_quad
1213

1314
__all__ = ['maptypes', 'get_mapping']
1415

hexrd/distortion/dexela_2923.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -106,26 +106,3 @@ def _dexela_2923_inverse_distortion(out_, in_, params):
106106
# 1st quadrant
107107
out_[el, :] = in_[el, :] - params[0:2]
108108

109-
def test_disortion():
110-
pts = np.random.randn(16, 2)
111-
qi = _find_quadrant(pts)
112-
113-
# test trivial
114-
params = np.zeros(8)
115-
dc = Dexela_2923(params)
116-
if not np.all(dc.apply(pts) - pts == 0.):
117-
raise RuntimeError("distortion apply failed!")
118-
if not np.all(dc.apply_inverse(pts) - pts == 0.):
119-
raise RuntimeError("distortion apply_inverse failed!")
120-
121-
# test non-trivial
122-
params = np.random.randn(8)
123-
dc = Dexela_2923(params)
124-
ptile = np.vstack([params.reshape(4, 2)[j - 1, :] for j in qi])
125-
result = dc.apply(pts) - pts
126-
result_inv = dc.apply_inverse(pts) - pts
127-
if not np.all(abs(result - ptile) <= constants.epsf):
128-
raise RuntimeError("distortion apply failed!")
129-
if not np.all(abs(result_inv + ptile) <= constants.epsf):
130-
raise RuntimeError("distortion apply_inverse failed!")
131-
return True
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import numpy as np
2+
import numba
3+
from hexrd import constants
4+
5+
from .distortionabc import DistortionABC
6+
from .registry import _RegisterDistortionClass
7+
8+
9+
class Dexela_2923_quad(DistortionABC, metaclass=_RegisterDistortionClass):
10+
11+
maptype = "Dexela_2923_quad"
12+
13+
def __init__(self, params, **kwargs):
14+
assert len(params) == 6, "parameter list must have len of 6"
15+
self._params = np.asarray(params, dtype=float).flatten()
16+
17+
@property
18+
def params(self):
19+
return self._params
20+
21+
@params.setter
22+
def params(self, x):
23+
assert len(x) == 6, "parameter list must have len of 6"
24+
self._params = np.asarray(x, dtype=float).flatten()
25+
26+
@property
27+
def is_trivial(self):
28+
return np.all(self.params == 0)
29+
30+
def apply(self, xy_in):
31+
if self.is_trivial:
32+
return xy_in
33+
else:
34+
xy_in = np.asarray(xy_in, dtype=float)
35+
xy_out = np.empty_like(xy_in)
36+
_dexela_2923_quad_distortion(
37+
xy_out, xy_in, np.asarray(self.params)
38+
)
39+
return xy_out
40+
41+
def apply_inverse(self, xy_in):
42+
if self.is_trivial:
43+
return xy_in
44+
else:
45+
xy_in = np.asarray(xy_in, dtype=float)
46+
xy_out = np.empty_like(xy_in)
47+
_dexela_2923_quad_inverse_distortion(
48+
xy_out, xy_in, np.asarray(self.params)
49+
)
50+
return xy_out
51+
52+
53+
def _find_quadrant(xy_in):
54+
quad_label = np.zeros(len(xy_in), dtype=int)
55+
in_2_or_3 = xy_in[:, 0] < 0.0
56+
in_1_or_4 = ~in_2_or_3
57+
in_3_or_4 = xy_in[:, 1] < 0.0
58+
in_1_or_2 = ~in_3_or_4
59+
quad_label[np.logical_and(in_1_or_4, in_1_or_2)] = 1
60+
quad_label[np.logical_and(in_2_or_3, in_1_or_2)] = 2
61+
quad_label[np.logical_and(in_2_or_3, in_3_or_4)] = 3
62+
quad_label[np.logical_and(in_1_or_4, in_3_or_4)] = 4
63+
return quad_label
64+
65+
66+
@numba.njit(nogil=True, cache=True)
67+
def _dexela_2923_quad_distortion(out, in_, params):
68+
# 1 + x + y, inverse. Someone should definitely check my math here...
69+
p0, p1, p2, p3, p4, p5 = params[0:6]
70+
p1 = p1 + 1e-12
71+
p5 = p5 + 1e-12
72+
out[:, 0] = (
73+
in_[:, 0] / p1 - p0 / p1 - (p2 / (p1 * p5) * (in_[:, 1] - p3))
74+
) / (1 - (p2 * p4) / (p1 * p5))
75+
out[:, 1] = (in_[:, 1] - p3 - p4 * out[:, 0]) / p5
76+
77+
return out
78+
79+
80+
@numba.njit(nogil=True, cache=True)
81+
def _dexela_2923_quad_inverse_distortion(out, in_, params):
82+
# 1 + x + y
83+
p0, p1, p2, p3, p4, p5 = params[0:6]
84+
p1 = p1 + 1e-12
85+
p5 = p5 + 1e-12
86+
out[:, 0] = p0 + p1 * in_[:, 0] + p2 * in_[:, 1]
87+
out[:, 1] = p3 + p4 * in_[:, 0] + p5 * in_[:, 1]
88+
89+
return out

tests/test_distortion.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from hexrd.distortion.dexela_2923 import Dexela_2923, _find_quadrant
2+
from hexrd.distortion.dexela_2923_quad import Dexela_2923_quad
3+
from hexrd import constants
4+
import numpy as np
5+
6+
7+
def test_dexela_2923_distortion():
8+
pts = np.random.randn(16, 2)
9+
qi = _find_quadrant(pts)
10+
11+
# test trivial
12+
params = np.zeros(8)
13+
dc = Dexela_2923(params)
14+
if not np.all(dc.apply(pts) - pts == 0.0):
15+
raise RuntimeError("distortion apply failed!")
16+
if not np.all(dc.apply_inverse(pts) - pts == 0.0):
17+
raise RuntimeError("distortion apply_inverse failed!")
18+
19+
# test non-trivial
20+
params = np.random.randn(8)
21+
dc = Dexela_2923(params)
22+
ptile = np.vstack([params.reshape(4, 2)[j - 1, :] for j in qi])
23+
result = dc.apply(pts) - pts
24+
result_inv = dc.apply_inverse(pts) - pts
25+
if not np.all(abs(result - ptile) <= constants.ten_epsf):
26+
raise RuntimeError("distortion apply failed!")
27+
if not np.all(abs(result_inv + ptile) <= constants.ten_epsf):
28+
raise RuntimeError("distortion apply_inverse failed!")
29+
30+
31+
def test_dexela_2923_quad_distortion():
32+
pts = np.random.randn(16, 2)
33+
qi = _find_quadrant(pts)
34+
35+
# test trivial
36+
params = np.zeros(10)
37+
dc = Dexela_2923(params)
38+
if not np.all(dc.apply(pts) - pts == 0.0):
39+
raise RuntimeError("distortion apply failed!")
40+
if not np.all(dc.apply_inverse(pts) - pts == 0.0):
41+
raise RuntimeError("distortion apply_inverse failed!")
42+
43+
# test non-trivial
44+
45+
# this is the original test submited in
46+
# https://github.com/HEXRD/hexrd/issues/749
47+
# but it does not currently work. First, params needs to be of size 6, but
48+
# this break vstack command bellow. Not sure how to adapt it.
49+
# params = np.random.randn(10)
50+
# dc = Dexela_2923_quad(params)
51+
# ptile = np.vstack([params.reshape(4, 2)[j - 1, :] for j in qi])
52+
# result = dc.apply(pts) - pts
53+
# result_inv = dc.apply_inverse(pts) - pts
54+
# if not np.all(abs(result - ptile) <= constants.epsf):
55+
# raise RuntimeError("distortion apply failed!")
56+
# if not np.all(abs(result_inv + ptile) <= constants.epsf):
57+
# raise RuntimeError("distortion apply_inverse failed!")
58+
# return True
59+
60+
# we simply test that apply and reverse cancel each other
61+
params = np.random.randn(6)
62+
dc = Dexela_2923_quad(params)
63+
result = dc.apply_inverse(dc.apply(pts))
64+
if not np.all(abs(result - pts) <= 100 * constants.epsf):
65+
raise RuntimeError("distortion apply_inverse(apply) failed!")

0 commit comments

Comments
 (0)