Skip to content

Commit 4876720

Browse files
committed
⚡ Optimized imports in xtl.diffraction & xtl.math
xtl.cli.diffraction & xtl.cli.math - Optimized imports throughout to improve startup performance
1 parent 3eb8242 commit 4876720

File tree

8 files changed

+79
-39
lines changed

8 files changed

+79
-39
lines changed

src/xtl/cli/diffraction/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22

33
from .geometry import app as geometry_app
44
from .plot import app as plot_app
5+
# from .sum import app as sum_app
56
from .integrate import app as integrate_app
67
from .correlate import app as correlate_app
8+
# from .array import app as array_app
79

810

911
app = typer.Typer(name='xtl.diffraction', help='Utilities for diffraction data')
1012
app.add_typer(geometry_app)
1113
app.add_typer(plot_app)
14+
# app.add_typer(sum_app)
1215
app.add_typer(integrate_app)
1316
app.add_typer(correlate_app)
17+
# app.add_typer(array_app)

src/xtl/cli/diffraction/cli_utils.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from enum import Enum
22
import re
3-
from typing import Optional
3+
from typing import Optional, TYPE_CHECKING
44

5-
from pyFAI.geometry import Geometry
6-
from pyFAI.detectors import Detector
5+
if TYPE_CHECKING:
6+
from pyFAI.geometry import Geometry
77

8-
from xtl.diffraction.images.images import Image
9-
from xtl.units.crystallography.radial import RadialUnit, RadialUnitType
8+
from xtl.diffraction.images.images import Image
9+
from xtl.units.crystallography.radial import RadialUnit
1010

1111

1212
class ZScale(Enum):
@@ -26,7 +26,7 @@ class IntegrationRadialUnits(Enum):
2626
Q_NM = 'q_nm'
2727

2828

29-
def get_image_frames(images: list[str]) -> list[Image]:
29+
def get_image_frames(images: list[str]) -> list['Image']:
3030
opened_images = []
3131
for i, img in enumerate(images):
3232
parts = img.split(':')
@@ -42,6 +42,8 @@ def get_image_frames(images: list[str]) -> list[Image]:
4242
else:
4343
raise ValueError(f'Invalid image format for image [{i}]: {img!r}')
4444

45+
from xtl.diffraction.images.images import Image
46+
4547
image = Image()
4648
try:
4749
image.open(file=file, frame=frame, is_eager=False)
@@ -52,11 +54,14 @@ def get_image_frames(images: list[str]) -> list[Image]:
5254
return opened_images
5355

5456

55-
def get_geometry_from_header(header: str) -> Geometry:
57+
def get_geometry_from_header(header: str) -> 'Geometry':
5658
"""
5759
Return a pyFAI Geometry object from the header of an NPX file written by
5860
AzimuthalCrossCorrelatorQQ_1 or Integrator.
5961
"""
62+
from pyFAI.geometry import Geometry
63+
from pyFAI.detectors import Detector
64+
6065
lines = []
6166
for line in header.splitlines():
6267
if line.startswith('pyFAI.Geometry'):
@@ -89,7 +94,9 @@ def get_geometry_from_header(header: str) -> Geometry:
8994
return Geometry(**kwargs)
9095

9196

92-
def get_radial_units_from_header(header: str) -> Optional[RadialUnit]:
97+
def get_radial_units_from_header(header: str) -> Optional['RadialUnit']:
98+
from xtl.units.crystallography.radial import RadialUnit, RadialUnitType
99+
93100
for line in header.splitlines():
94101
if line.startswith('pyFAI.AzimuthalIntegrator.unit'):
95102
units = line.split(':')[-1].strip()

src/xtl/cli/diffraction/correlate/fft.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,9 @@
1-
from functools import partial
21
from pathlib import Path
32

4-
import matplotlib.pyplot as plt
5-
from matplotlib.colors import Normalize, LogNorm, SymLogNorm
6-
import numpy as np
73
import typer
84

95
from xtl.cli.cliio import Console, epilog
10-
from xtl.cli.utils import Timer
11-
from xtl.cli.diffraction.cli_utils import get_geometry_from_header, get_radial_units_from_header, ZScale
12-
from xtl.exceptions.utils import Catcher
13-
from xtl.files.npx import NpxFile
14-
from xtl.units.crystallography.radial import RadialUnitType, RadialValue
6+
from xtl.cli.diffraction.cli_utils import ZScale
157

168

179
app = typer.Typer()
@@ -60,12 +52,16 @@ def cli_diffraction_correlate_fft(
6052
cli.print('Select only one parameter to calculate FFT (--2theta, --q)', style='red')
6153
raise typer.Abort()
6254

55+
from xtl.units.crystallography.radial import RadialUnitType, RadialValue
6356
if selection_2theta is not None:
6457
selection = RadialValue(value=selection_2theta, type=RadialUnitType.TWOTHETA_DEG)
6558
else:
6659
selection = RadialValue(value=selection_q, type=RadialUnitType.Q_NM)
6760

6861
# Load CCF data
62+
from xtl.exceptions.utils import Catcher
63+
from xtl.files.npx import NpxFile
64+
6965
with Catcher(echo_func=cli.print, traceback_func=cli.print_traceback) as catcher:
7066
acc = NpxFile.load(ccf_file)
7167
for key in ['radial', 'delta', 'ccf']:
@@ -91,6 +87,8 @@ def cli_diffraction_correlate_fft(
9187
raise typer.Abort()
9288

9389
# Get geometry from CCF file
90+
from xtl.cli.diffraction.cli_utils import get_geometry_from_header, get_radial_units_from_header
91+
9492
with Catcher(echo_func=cli.print, traceback_func=cli.print_traceback) as catcher:
9593
geometry = get_geometry_from_header(acc.header)
9694
if verbose > 2:
@@ -127,6 +125,9 @@ def cli_diffraction_correlate_fft(
127125
raise typer.Abort()
128126

129127
# Get the radial index of the selection
128+
import numpy as np
129+
from xtl.cli.utils import Timer
130+
130131
ccf_i = np.argmin(np.abs(acc.data['radial'] - selection.value))
131132

132133
# Calculate the FFT of the CCF
@@ -154,6 +155,9 @@ def cli_diffraction_correlate_fft(
154155
f'{d.name.latex}={d.value:.2f} {d.units.latex}')
155156

156157
# Prepare plots
158+
import matplotlib.pyplot as plt
159+
from matplotlib.colors import Normalize, LogNorm, SymLogNorm
160+
157161
fig = plt.figure('XCCA overview', figsize=(16 / 1.2, 9 / 1.2))
158162
fig.suptitle(f'{ccf_file.name}\n{subtitle}')
159163
gs = fig.add_gridspec(2, 3, wspace=0.2,)
@@ -196,6 +200,8 @@ def cli_diffraction_correlate_fft(
196200
if zmax is not None:
197201
vmax = zmax
198202

203+
from functools import partial
204+
199205
if zscale == ZScale.LINEAR:
200206
norm = partial(Normalize, clip=False)
201207
elif zscale == ZScale.LOG:

src/xtl/cli/diffraction/correlate/qq.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,9 @@
11
from pathlib import Path
22

3-
import matplotlib.pyplot as plt
4-
import numpy as np
5-
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn, MofNCompleteColumn
63
import typer
74

85
from xtl.cli.cliio import Console, epilog
9-
from xtl.cli.utils import Timer
10-
from xtl.cli.diffraction.cli_utils import get_image_frames, IntegrationErrorModel, IntegrationRadialUnits
11-
from xtl.diffraction.images.correlators import AzimuthalCrossCorrelatorQQ_1
12-
from xtl.exceptions.utils import Catcher
6+
from xtl.cli.diffraction.cli_utils import IntegrationErrorModel, IntegrationRadialUnits
137

148
import warnings
159

@@ -56,7 +50,10 @@ def cli_diffraction_correlate_qq(
5650
cli = Console(verbose=verbose, debug=debug)
5751
input_images = images
5852

53+
from xtl.exceptions.utils import Catcher
54+
5955
with Catcher(echo_func=cli.print, traceback_func=cli.print_traceback) as catcher:
56+
from xtl.cli.diffraction.cli_utils import get_image_frames
6057
images = get_image_frames(input_images)
6158
if catcher.raised:
6259
cli.print(f'Error: Failed to read all images', style='red')
@@ -84,6 +81,11 @@ def cli_diffraction_correlate_qq(
8481
'error_model': error_model.value if error_model != IntegrationErrorModel.NONE else None,
8582
}
8683

84+
from xtl.cli.utils import Timer
85+
from xtl.diffraction.images.correlators import AzimuthalCrossCorrelatorQQ_1
86+
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn, \
87+
MofNCompleteColumn
88+
8789
with (Catcher(echo_func=cli.print, traceback_func=cli.print_traceback),
8890
Progress(SpinnerColumn(), *Progress.get_default_columns(),
8991
TimeElapsedColumn(), MofNCompleteColumn(),
@@ -115,7 +117,7 @@ def cli_diffraction_correlate_qq(
115117
with warnings.catch_warnings():
116118
warnings.filterwarnings('ignore')
117119
accf.correlate(points_radial=points_radial, points_azimuthal=points_azimuthal,
118-
units_radial=units_radial.value, method=0)
120+
units_radial=units_radial.value, method=0)
119121

120122
ccf_file = output_dir / f'{dataset_name}_ccf.npx'
121123
accf.save(ccf_file, overwrite=overwrite)
@@ -129,6 +131,9 @@ def cli_diffraction_correlate_qq(
129131
progress.console.print(f'Saved 2D integration results to {ai2_file}')
130132

131133
# Prepare plots
134+
import matplotlib.pyplot as plt
135+
import numpy as np
136+
132137
fig = plt.figure('XCCA overview', figsize=(16 / 1.2, 9 / 1.2))
133138
gs0 = fig.add_gridspec(1, 2, wspace=0.2,
134139
width_ratios=[1.2, 2]) # outer grid (1x2)

src/xtl/cli/diffraction/geometry.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
from pathlib import Path
2+
from typing import TYPE_CHECKING
23

3-
from pyFAI.detectors import ALL_DETECTORS, Detector
4-
from pyFAI.geometry import Geometry
54
import typer
5+
if TYPE_CHECKING:
6+
from pyFAI.detectors import ALL_DETECTORS, Detector
7+
from pyFAI.geometry import Geometry
68

79
from xtl.cli.cliio import Console
810

911

1012
app = typer.Typer()
1113

1214

13-
def get_detector_info(detector: Detector) -> dict:
15+
def get_detector_info(detector: 'Detector') -> dict:
1416
detector_info = {}
1517
detector_info['Detector name'] = detector.name
1618
if isinstance(detector.MANUFACTURER, list):
@@ -31,6 +33,8 @@ def get_detector_info(detector: Detector) -> dict:
3133

3234

3335
def get_detectors_list() -> dict:
36+
from pyFAI.detectors import ALL_DETECTORS
37+
3438
detectors = {}
3539
for alias, detector in ALL_DETECTORS.items():
3640
detector = detector()
@@ -50,7 +54,7 @@ def get_detectors_list() -> dict:
5054
return {k: v['info'] for k, v in detectors.items()}
5155

5256

53-
def get_geometry_info(geometry: Geometry) -> dict:
57+
def get_geometry_info(geometry: 'Geometry') -> dict:
5458
return {f'{k}': f'{v}' for k, v in geometry.get_config().items() if k not in ['poni_version', 'detector_config']}
5559

5660

@@ -63,15 +67,20 @@ def cli_diffraction_geometry():
6367
while detector is None:
6468
answer = typer.prompt('Enter detector name (? for list)').lower()
6569
if answer in ['detector', 'custom']:
70+
from pyFAI.detectors import Detector
71+
6672
cli.print('Initializing custom detector...')
6773
pixel1 = typer.prompt('Enter pixel size (horizontal) [pixel1 in \u03bcm]', default=50., type=float)
6874
pixel2 = typer.prompt('Enter pixel size (vertical) [pixel2 in \u03bcm]', default=50., type=float)
6975
shape1 = typer.prompt('Enter number of pixels (horizontal) [shape[1] in px]', default=1024, type=int)
7076
shape2 = typer.prompt('Enter number of pixels (vertical) [shape[0] in px', default=1024, type=int)
7177
detector = Detector(pixel1=pixel1/1e6, pixel2=pixel2/1e6, max_shape=(shape2, shape1))
72-
elif answer in ALL_DETECTORS:
73-
detector: Detector = ALL_DETECTORS[answer]()
7478
else:
79+
from pyFAI.detectors import ALL_DETECTORS
80+
if answer in ALL_DETECTORS:
81+
detector: 'Detector' = ALL_DETECTORS[answer]()
82+
continue
83+
7584
detector_list = get_detectors_list()
7685
headers = list(list(detector_list.values())[0].keys())
7786
cli.print_table([d.values() for d in detector_list.values()], headers=headers,
@@ -89,6 +98,8 @@ def cli_diffraction_geometry():
8998
detector = None
9099

91100
# Build Geometry object
101+
from pyFAI.geometry import Geometry
102+
92103
geometry = None
93104
while geometry is None:
94105
poni2 = typer.prompt('Enter point of normal incidence along horizontal (x) axis [poni2 in px]', default=0., type=float)

src/xtl/cli/diffraction/integrate/integrate_1d.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from pathlib import Path
22

3-
import matplotlib.pyplot as plt
4-
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn, MofNCompleteColumn
3+
from rich.progress import Progress
54
import typer
65

76
from xtl.cli.cliio import Console, epilog
8-
from xtl.cli.diffraction.cli_utils import get_image_frames, IntegrationErrorModel, IntegrationRadialUnits
7+
from xtl.cli.diffraction.cli_utils import IntegrationErrorModel, IntegrationRadialUnits
98

109

1110
app = typer.Typer()
@@ -48,6 +47,8 @@ def cli_diffraction_integrate_1d(
4847
input_images = images
4948

5049
try:
50+
from xtl.cli.diffraction.cli_utils import get_image_frames
51+
5152
images = get_image_frames(input_images)
5253
except ValueError as e:
5354
cli.print_traceback(e)
@@ -78,6 +79,8 @@ def cli_diffraction_integrate_1d(
7879
'error_model': error_model.value if error_model != IntegrationErrorModel.NONE else None,
7980
}
8081

82+
import matplotlib.pyplot as plt
83+
8184
output_dir = output_dir.expanduser().resolve()
8285
fig, ax = plt.subplots()
8386
xscale = 'log' if xlog else 'linear'

src/xtl/cli/diffraction/plot/plot_2d.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
import math
22
from pathlib import Path
33

4-
import matplotlib.pyplot as plt
5-
from mpl_toolkits.axes_grid1 import AxesGrid
6-
import numpy as np
74
import typer
85

96
from xtl.cli.cliio import Console, epilog
10-
from xtl.cli.diffraction.cli_utils import get_image_frames, ZScale
7+
from xtl.cli.diffraction.cli_utils import ZScale
118

129

1310
app = typer.Typer()
@@ -48,12 +45,18 @@ def cli_diffraction_plot_2d(
4845
input_images = images
4946

5047
try:
48+
from xtl.cli.diffraction.cli_utils import get_image_frames
5149
images = get_image_frames(input_images)
5250
except ValueError as e:
5351
cli.print_traceback(e)
5452
cli.print(f'Error: Failed to read all images', style='red')
5553
raise typer.Abort()
5654

55+
# Imports
56+
import matplotlib.pyplot as plt
57+
from mpl_toolkits.axes_grid1 import AxesGrid
58+
import numpy as np
59+
5760
# Initialize figure
5861
fig = plt.figure()
5962
nimages = len(images)

src/xtl/cli/math/spacing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import typer
33

44
from xtl.cli.cliio import Console
5-
from xtl.exceptions.utils import Catcher
65
from xtl.units.crystallography.radial import RadialUnitType, RadialValue, RadialUnit
76

87
app = typer.Typer()
@@ -98,6 +97,8 @@ def cli_math_spacing(
9897
cli.print('Wavelength required to calculate all units', style='red')
9998
raise typer.Abort()
10099

100+
from xtl.exceptions.utils import Catcher
101+
101102
r = RadialValue(value=value, type=from_type)
102103
with Catcher(echo_func=cli.print, traceback_func=cli.print_traceback, silent=True) as catcher:
103104
if to_type:

0 commit comments

Comments
 (0)