Skip to content

Commit 466df90

Browse files
committed
Distortion: introduce quadradic distortion from dexelas
source: #749
1 parent 37359c8 commit 466df90

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

hexrd/distortion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from . import nyi
1010
from . import ge_41rt
1111
from . import dexela_2923
12+
from . import dexela_2923_quad
1213

1314
__all__ = ['maptypes', 'get_mapping']
1415

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import numpy as np
2+
import numba
3+
from hexrd import constants
4+
5+
from .distortionabc import DistortionABC
6+
from .registry import _RegisterDistortionClass
7+
8+
9+
class Dexela_2923_quad(DistortionABC, metaclass=_RegisterDistortionClass):
10+
11+
maptype = "Dexela_2923_quad"
12+
13+
def __init__(self, params, **kwargs):
14+
assert len(params) == 6, "parameter list must have len of 6"
15+
self._params = np.asarray(params, dtype=float).flatten()
16+
17+
@property
18+
def params(self):
19+
return self._params
20+
21+
@params.setter
22+
def params(self, x):
23+
assert len(x) == 6, "parameter list must have len of 6"
24+
self._params = np.asarray(x, dtype=float).flatten()
25+
26+
@property
27+
def is_trivial(self):
28+
return np.all(self.params == 0)
29+
30+
def apply(self, xy_in):
31+
if self.is_trivial:
32+
return xy_in
33+
else:
34+
xy_in = np.asarray(xy_in, dtype=float)
35+
xy_out = np.empty_like(xy_in)
36+
_dexela_2923_quad_distortion(
37+
xy_out, xy_in, np.asarray(self.params)
38+
)
39+
return xy_out
40+
41+
def apply_inverse(self, xy_in):
42+
if self.is_trivial:
43+
return xy_in
44+
else:
45+
xy_in = np.asarray(xy_in, dtype=float)
46+
xy_out = np.empty_like(xy_in)
47+
_dexela_2923_quad_inverse_distortion(
48+
xy_out, xy_in, np.asarray(self.params)
49+
)
50+
return xy_out
51+
52+
53+
def _find_quadrant(xy_in):
54+
quad_label = np.zeros(len(xy_in), dtype=int)
55+
in_2_or_3 = xy_in[:, 0] < 0.0
56+
in_1_or_4 = ~in_2_or_3
57+
in_3_or_4 = xy_in[:, 1] < 0.0
58+
in_1_or_2 = ~in_3_or_4
59+
quad_label[np.logical_and(in_1_or_4, in_1_or_2)] = 1
60+
quad_label[np.logical_and(in_2_or_3, in_1_or_2)] = 2
61+
quad_label[np.logical_and(in_2_or_3, in_3_or_4)] = 3
62+
quad_label[np.logical_and(in_1_or_4, in_3_or_4)] = 4
63+
return quad_label
64+
65+
66+
@numba.njit(nogil=True, cache=True)
67+
def _dexela_2923_quad_distortion(out, in_, params):
68+
# 1 + x + y, inverse. Someone should definitely check my math here...
69+
p0, p1, p2, p3, p4, p5 = params[0:6]
70+
p1 = p1 + 1e-12
71+
p5 = p5 + 1e-12
72+
out[:, 0] = (
73+
in_[:, 0] / p1 - p0 / p1 - (p2 / (p1 * p5) * (in_[:, 1] - p3))
74+
) / (1 - (p2 * p4) / (p1 * p5))
75+
out[:, 1] = (in_[:, 1] - p3 - p4 * out[:, 0]) / p5
76+
77+
return out
78+
79+
80+
@numba.njit(nogil=True, cache=True)
81+
def _dexela_2923_quad_inverse_distortion(out, in_, params):
82+
# 1 + x + y
83+
p0, p1, p2, p3, p4, p5 = params[0:6]
84+
p1 = p1 + 1e-12
85+
p5 = p5 + 1e-12
86+
out[:, 0] = p0 + p1 * in_[:, 0] + p2 * in_[:, 1]
87+
out[:, 1] = p3 + p4 * in_[:, 0] + p5 * in_[:, 1]
88+
89+
return out

0 commit comments

Comments
 (0)