|
36 | 36 | import os |
37 | 37 | from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor |
38 | 38 | from functools import partial |
39 | | -from typing import Optional |
| 39 | +from typing import Optional, Union |
40 | 40 |
|
41 | 41 | from tqdm import tqdm |
42 | 42 |
|
@@ -604,6 +604,7 @@ def __init__(self, instrument_config=None, |
604 | 604 | beam['vector']['polar_angle'], |
605 | 605 | ), |
606 | 606 | 'distance': beam.get('source_distance', np.inf), |
| 607 | + 'energy_correction': beam.get('energy_correction', None), |
607 | 608 | } |
608 | 609 |
|
609 | 610 | # Set the active beam name if not set already |
@@ -811,6 +812,7 @@ def _create_default_beam(self): |
811 | 812 | 'energy': beam_energy_DFLT, |
812 | 813 | 'vector': beam_vec_DFLT.copy(), |
813 | 814 | 'distance': np.inf, |
| 815 | + 'energy_correction': None, |
814 | 816 | } |
815 | 817 |
|
816 | 818 | if self._active_beam_name is None: |
@@ -889,6 +891,51 @@ def source_distance(self, x): |
889 | 891 | self.active_beam['distance'] = x |
890 | 892 | self.beam_dict_modified() |
891 | 893 |
|
| 894 | + @property |
| 895 | + def energy_correction(self) -> Union[dict, None]: |
| 896 | + """Energy correction dict appears as follows: |
| 897 | +
|
| 898 | + { |
| 899 | + # The beam energy gradient center, along the specified |
| 900 | + # axis, in millimeters. |
| 901 | + 'intercept': 0.0, |
| 902 | +
|
| 903 | + # The slope of the beam energy gradient along the |
| 904 | + # specified axis, in eV/mm. |
| 905 | + 'slope': 0.0, |
| 906 | +
|
| 907 | + # The specified axis for the beam energy gradient, |
| 908 | + # either 'x' or 'y'. |
| 909 | + 'axis': 'y', |
| 910 | + } |
| 911 | + """ |
| 912 | + return self.active_beam['energy_correction'] |
| 913 | + |
| 914 | + @energy_correction.setter |
| 915 | + def energy_correction(self, v: Union[dict, None]): |
| 916 | + if v is not None: |
| 917 | + # First validate |
| 918 | + keys = sorted(list(v)) |
| 919 | + default_keys = sorted(list( |
| 920 | + self.create_default_energy_correction() |
| 921 | + )) |
| 922 | + if keys != default_keys: |
| 923 | + msg = ( |
| 924 | + f'Keys in energy correction dict, "{keys}", do not match ' |
| 925 | + f'the required keys: "{default_keys}"' |
| 926 | + ) |
| 927 | + raise RuntimeError(msg) |
| 928 | + |
| 929 | + self.active_beam['energy_correction'] = v |
| 930 | + |
| 931 | + @staticmethod |
| 932 | + def create_default_energy_correction() -> dict[str, float]: |
| 933 | + return { |
| 934 | + 'intercept': 0.0, # in mm |
| 935 | + 'slope': 0.0, # eV/mm |
| 936 | + 'axis': 'y', |
| 937 | + } |
| 938 | + |
892 | 939 | @property |
893 | 940 | def eta_vector(self): |
894 | 941 | return self._eta_vector |
@@ -932,6 +979,11 @@ def write_config(self, file=None, style='yaml', calibration_dict={}): |
932 | 979 | if beam['distance'] != np.inf: |
933 | 980 | beam_dict[beam_name]['source_distance'] = beam['distance'] |
934 | 981 |
|
| 982 | + if beam.get('energy_correction') is not None: |
| 983 | + beam_dict[beam_name]['energy_correction'] = beam[ |
| 984 | + 'energy_correction' |
| 985 | + ] |
| 986 | + |
935 | 987 | if len(beam_dict) == 1: |
936 | 988 | # Just write it out a single beam (classical way) |
937 | 989 | beam_dict = next(iter(beam_dict.values())) |
@@ -1508,7 +1560,8 @@ def simulate_rotation_series(self, plane_data, grain_param_list, |
1508 | 1560 | ome_ranges=ome_ranges, |
1509 | 1561 | ome_period=ome_period, |
1510 | 1562 | chi=self.chi, tVec_s=self.tvec, |
1511 | | - wavelength=wavelength) |
| 1563 | + wavelength=wavelength, |
| 1564 | + energy_correction=self.energy_correction) |
1512 | 1565 | return results |
1513 | 1566 |
|
1514 | 1567 | def pull_spots(self, plane_data, grain_params, |
|
0 commit comments