Skip to content

Commit 0f7d097

Browse files
authored
Merge pull request #766 from HEXRD/group-relative-constraints
Add group ("Detector") relative constraints
2 parents df38aa9 + a0032bc commit 0f7d097

File tree

5 files changed

+677
-62
lines changed

5 files changed

+677
-62
lines changed

hexrd/fitting/calibration/lmfit_param_handling.py

Lines changed: 138 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from hexrd.instrument import (
77
calc_angles_from_beam_vec,
88
calc_beam_vec,
9+
Detector,
910
HEDMInstrument,
1011
)
1112
from hexrd.rotations import (
@@ -26,6 +27,8 @@
2627
# First is the axes_order, second is extrinsic
2728
DEFAULT_EULER_CONVENTION = ('zxz', False)
2829

30+
EULER_CONVENTION_TYPES = dict | tuple | None
31+
2932

3033
def create_instr_params(instr, euler_convention=DEFAULT_EULER_CONVENTION,
3134
relative_constraints=None):
@@ -66,8 +69,12 @@ def create_instr_params(instr, euler_convention=DEFAULT_EULER_CONVENTION,
6669
parms_list,
6770
)
6871
elif relative_constraints.type == RelativeConstraintsType.group:
69-
# This should be implemented soon
70-
raise NotImplementedError(relative_constraints.type)
72+
add_group_constrained_detector_parameters(
73+
instr,
74+
euler_convention,
75+
parms_list,
76+
relative_constraints,
77+
)
7178
elif relative_constraints.type == RelativeConstraintsType.system:
7279
add_system_constrained_detector_parameters(
7380
instr,
@@ -120,12 +127,14 @@ def add_unconstrained_detector_parameters(instr, euler_convention, parms_list):
120127
-np.inf, np.inf))
121128

122129

123-
def add_system_constrained_detector_parameters(
124-
instr, euler_convention,
125-
parms_list, relative_constraints: RelativeConstraints):
126-
system_params = relative_constraints.params
127-
system_tvec = system_params['translation']
128-
system_tilt = system_params['tilt']
130+
def _add_constrained_detector_parameters(
131+
euler_convention: EULER_CONVENTION_TYPES,
132+
parms_list: list[tuple],
133+
prefix: str,
134+
constraint_params: dict,
135+
):
136+
tvec = constraint_params['translation']
137+
tilt = constraint_params['tilt']
129138

130139
if euler_convention is not None:
131140
# Convert the tilt to the specified Euler convention
@@ -136,30 +145,63 @@ def add_system_constrained_detector_parameters(
136145
extrinsic=normalized[1],
137146
)
138147

139-
rme.rmat = _tilt_to_rmat(system_tilt, None)
140-
system_tilt = np.degrees(rme.angles)
148+
rme.rmat = _tilt_to_rmat(tilt, None)
149+
tilt = np.degrees(rme.angles)
141150

142151
tvec_names = [
143-
'system_tvec_x',
144-
'system_tvec_y',
145-
'system_tvec_z',
152+
f'{prefix}_tvec_x',
153+
f'{prefix}_tvec_y',
154+
f'{prefix}_tvec_z',
146155
]
147156
tvec_deltas = [1, 1, 1]
148157

149-
tilt_names = param_names_euler_convention('system', euler_convention)
158+
tilt_names = param_names_euler_convention(prefix, euler_convention)
150159
tilt_deltas = [2, 2, 2]
151160

152161
for i, name in enumerate(tvec_names):
153-
value = system_tvec[i]
162+
value = tvec[i]
154163
delta = tvec_deltas[i]
155164
parms_list.append((name, value, True, value - delta, value + delta))
156165

157166
for i, name in enumerate(tilt_names):
158-
value = system_tilt[i]
167+
value = tilt[i]
159168
delta = tilt_deltas[i]
160169
parms_list.append((name, value, True, value - delta, value + delta))
161170

162171

172+
def add_system_constrained_detector_parameters(
173+
instr: HEDMInstrument,
174+
euler_convention: EULER_CONVENTION_TYPES,
175+
parms_list: list[tuple],
176+
relative_constraints: RelativeConstraints,
177+
):
178+
prefix = 'system'
179+
constraint_params = relative_constraints.params
180+
_add_constrained_detector_parameters(
181+
euler_convention,
182+
parms_list,
183+
prefix,
184+
constraint_params,
185+
)
186+
187+
188+
def add_group_constrained_detector_parameters(
189+
instr: HEDMInstrument,
190+
euler_convention: EULER_CONVENTION_TYPES,
191+
parms_list: list[tuple],
192+
relative_constraints: RelativeConstraints,
193+
):
194+
for group in instr.detector_groups:
195+
prefix = group.replace('-', '_')
196+
constraint_params = relative_constraints.params[group]
197+
_add_constrained_detector_parameters(
198+
euler_convention,
199+
parms_list,
200+
prefix,
201+
constraint_params,
202+
)
203+
204+
163205
def create_beam_param_names(instr: HEDMInstrument) -> dict[str, str]:
164206
param_names = {}
165207
for k, v in instr.beam_dict.items():
@@ -220,8 +262,12 @@ def update_instrument_from_params(
220262
euler_convention,
221263
)
222264
elif relative_constraints.type == RelativeConstraintsType.group:
223-
# This should be implemented soon
224-
raise NotImplementedError(relative_constraints.type)
265+
update_group_constrained_detector_parameters(
266+
instr,
267+
params,
268+
euler_convention,
269+
relative_constraints,
270+
)
225271
elif relative_constraints.type == RelativeConstraintsType.system:
226272
update_system_constrained_detector_parameters(
227273
instr,
@@ -263,60 +309,109 @@ def update_unconstrained_detector_parameters(instr, params, euler_convention):
263309
)
264310

265311

266-
def update_system_constrained_detector_parameters(
267-
instr, params, euler_convention,
268-
relative_constraints: RelativeConstraints):
269-
system_params = relative_constraints.params
270-
system_tvec = system_params['translation']
271-
system_tilt = system_params['tilt']
312+
def _update_constrained_detector_parameters(
313+
detectors: list[Detector],
314+
params: dict,
315+
rotation_center: np.ndarray,
316+
euler_convention: EULER_CONVENTION_TYPES,
317+
prefix: str,
318+
constraint_params: dict,
319+
320+
):
321+
tvec = constraint_params['translation']
322+
tilt = constraint_params['tilt']
272323

273324
tvec_names = [
274-
'system_tvec_x',
275-
'system_tvec_y',
276-
'system_tvec_z',
325+
f'{prefix}_tvec_x',
326+
f'{prefix}_tvec_y',
327+
f'{prefix}_tvec_z',
277328
]
278-
tilt_names = param_names_euler_convention('system', euler_convention)
329+
tilt_names = param_names_euler_convention(prefix, euler_convention)
279330

280331
# Just like the detectors, we will apply tilt first and then translation
281332
# Only apply these transforms if they were marked "Vary".
282333

283334
if any(params[x].vary for x in tilt_names):
284-
# Get the center of rotation (depending on the settings)
285-
rotation_center = relative_constraints.center_of_rotation(instr)
286-
287335
# Find the change in tilt, create an rmat, then apply to detector tilts
288336
# and translations.
289-
new_system_tilt = np.array([params[x].value for x in tilt_names])
337+
new_tilt = np.array([params[x].value for x in tilt_names])
290338

291-
# The old system tilt was in the None convention
292-
old_rmat = _tilt_to_rmat(system_tilt, None)
293-
new_rmat = _tilt_to_rmat(new_system_tilt, euler_convention)
339+
# The old tilt was in the None convention
340+
old_rmat = _tilt_to_rmat(tilt, None)
341+
new_rmat = _tilt_to_rmat(new_tilt, euler_convention)
294342

295343
# Compute the rmat used to convert from old to new
296344
rmat_diff = new_rmat @ old_rmat.T
297345

298346
# Rotate each detector using the rmat_diff
299-
for panel in instr.detectors.values():
347+
for panel in detectors:
300348
panel.tilt = _rmat_to_tilt(rmat_diff @ panel.rmat)
301349

302350
# Also rotate the detectors about the rotation center
303351
panel.tvec = (
304352
rmat_diff @ (panel.tvec - rotation_center) + rotation_center
305353
)
306354

307-
# Update the system tilt
308-
system_tilt[:] = _rmat_to_tilt(new_rmat)
355+
# Update the tilt
356+
tilt[:] = _rmat_to_tilt(new_rmat)
309357

310358
if any(params[x].vary for x in tvec_names):
311359
# Find the change in center and shift all tvecs
312-
new_system_tvec = np.array([params[x].value for x in tvec_names])
360+
new_tvec = np.array([params[x].value for x in tvec_names])
313361

314-
diff = new_system_tvec - system_tvec
315-
for panel in instr.detectors.values():
362+
diff = new_tvec - tvec
363+
for panel in detectors:
316364
panel.tvec += diff
317365

318-
# Update the system tvec
319-
system_tvec[:] = new_system_tvec
366+
# Update the tvec
367+
tvec[:] = new_tvec
368+
369+
370+
def update_system_constrained_detector_parameters(
371+
instr: HEDMInstrument,
372+
params: dict,
373+
euler_convention: EULER_CONVENTION_TYPES,
374+
relative_constraints: RelativeConstraints,
375+
):
376+
detectors = list(instr.detectors.values())
377+
378+
# Get the center of rotation (depending on the settings)
379+
rotation_center = relative_constraints.center_of_rotation(instr)
380+
prefix = 'system'
381+
constraint_params = relative_constraints.params
382+
383+
_update_constrained_detector_parameters(
384+
detectors,
385+
params,
386+
rotation_center,
387+
euler_convention,
388+
prefix,
389+
constraint_params,
390+
)
391+
392+
393+
def update_group_constrained_detector_parameters(
394+
instr: HEDMInstrument,
395+
params: dict,
396+
euler_convention: EULER_CONVENTION_TYPES,
397+
relative_constraints: RelativeConstraints,
398+
):
399+
for group in instr.detector_groups:
400+
detectors = list(instr.detectors_in_group(group).values())
401+
402+
# Get the center of rotation (depending on the settings)
403+
rotation_center = relative_constraints.center_of_rotation(instr, group)
404+
prefix = group.replace('-', '_')
405+
constraint_params = relative_constraints.params[group]
406+
407+
_update_constrained_detector_parameters(
408+
detectors,
409+
params,
410+
rotation_center,
411+
euler_convention,
412+
prefix,
413+
constraint_params,
414+
)
320415

321416

322417
def _tilt_to_rmat(tilt: np.ndarray,

hexrd/fitting/calibration/relative_constraints.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ class RotationCenter(Enum):
2121
# Rotate about the mean center of all the detectors
2222
instrument_mean_center = 'InstrumentMeanCenter'
2323

24+
# Rotate about the mean center of the group
25+
group_mean_center = 'GroupMeanCenter'
26+
2427
# Rotate about lab origin, which is (0, 0, 0)
2528
lab_origin = 'Origin'
2629

@@ -75,11 +78,7 @@ class RelativeConstraintsGroup(RelativeConstraints):
7578
type = RelativeConstraintsType.group
7679

7780
def __init__(self, instr: HEDMInstrument):
78-
self._groups = []
79-
for panel in instr.detectors.values():
80-
if panel.group is not None and panel.group not in self._groups:
81-
self._groups.append(panel.group)
82-
81+
self._groups = instr.detector_groups
8382
self.reset()
8483

8584
def reset(self):
@@ -96,7 +95,7 @@ def reset_params(self):
9695
}
9796

9897
def reset_rotation_center(self):
99-
self._rotation_center = RotationCenter.instrument_mean_center
98+
self._rotation_center = RotationCenter.group_mean_center
10099

101100
@property
102101
def params(self) -> dict:
@@ -110,6 +109,20 @@ def rotation_center(self):
110109
def rotation_center(self, v: RotationCenter):
111110
self._rotation_center = v
112111

112+
def center_of_rotation(
113+
self,
114+
instr: HEDMInstrument,
115+
group: str,
116+
) -> np.ndarray:
117+
if self.rotation_center == RotationCenter.instrument_mean_center:
118+
return instr.mean_detector_center
119+
elif self.rotation_center == RotationCenter.lab_origin:
120+
return np.array([0.0, 0.0, 0.0])
121+
elif self.rotation_center == RotationCenter.group_mean_center:
122+
return instr.mean_group_center(group)
123+
124+
raise NotImplementedError(self.rotation_center)
125+
113126

114127
class RelativeConstraintsSystem(RelativeConstraints):
115128
type = RelativeConstraintsType.system
@@ -160,7 +173,7 @@ def create_relative_constraints(type: RelativeConstraintsType,
160173
}
161174

162175
kwargs = {}
163-
if type == 'System':
176+
if type.value == 'Group':
164177
kwargs['instr'] = instr
165178

166179
return types[type.value](**kwargs)

hexrd/instrument/hedm_instrument.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
from .cylindrical_detector import CylindricalDetector
8181
from .detector import (
8282
beam_energy_DFLT,
83+
Detector,
8384
max_workers_DFLT,
8485
)
8586
from .planar_detector import PlanarDetector
@@ -797,21 +798,25 @@ def mean_detector_center(self) -> np.ndarray:
797798
centers = np.array([panel.tvec for panel in self.detectors.values()])
798799
return centers.sum(axis=0) / len(centers)
799800

801+
def mean_group_center(self, group: str) -> np.ndarray:
802+
"""Return the mean center for detectors belonging to a group"""
803+
centers = np.array([
804+
x.tvec for x in self.detectors_in_group(group).values()
805+
])
806+
return centers.sum(axis=0) / len(centers)
807+
800808
@property
801-
def mean_group_centers(self) -> dict[str, np.ndarray]:
802-
"""Return the mean center for every group of detectors"""
803-
centers = {}
809+
def detector_groups(self) -> list[str]:
810+
groups = []
804811
for panel in self.detectors.values():
805-
if panel.group is None:
806-
# Skip over panels without groups
807-
continue
808-
809-
if panel.group not in centers:
810-
centers[panel.group] = []
812+
group = panel.group
813+
if group is not None and group not in groups:
814+
groups.append(group)
811815

812-
centers[panel.group].append(panel.tvec)
816+
return groups
813817

814-
return {k: v.sum(axis=0) / len(v) for k, v in centers.items()}
818+
def detectors_in_group(self, group: str) -> dict[str, Detector]:
819+
return {k: v for k, v in self.detectors.items() if v.group == group}
815820

816821
# properties for physical size of rectangular detector
817822
@property

0 commit comments

Comments
 (0)