Skip to content

Commit cff6b3e

Browse files
committed
Parallelize rows in snip1d
Snip1d was taking a very long time to compute for high resolution polar images. Since it is computed on a row-by-row basis, however, I found that parallelizing over the rows using multiprocessing provided a big performance improvement (3x faster for my example). Some other things I tried that did not result in a speed improvement: * Using the latest `convolve` from astropy * Using the latest `convolve_fft` from astropy * Performing multithreading instead of multiprocessing * Parallelizing at a finer grain level in the snip1d loops This also includes a test that verifies that the snip1d output does not change. Signed-off-by: Patrick Avery <patrick.avery@kitware.com>
1 parent 2e193cf commit cff6b3e

File tree

5 files changed

+127
-24
lines changed

5 files changed

+127
-24
lines changed

hexrd/fitting/fitpeak.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def estimate_pk_parms_1d(x, f, pktype='pvoigt'):
119119

120120
# handle background
121121
# ??? make kernel width a kwarg?
122-
bkg = snip1d(np.atleast_2d(f), w=int(2*npts/3.)).flatten()
122+
bkg = snip1d(np.atleast_2d(f), w=int(2*npts/3.), max_workers=1).flatten()
123123

124124
# fit linear bg and grab params
125125
bp, _ = optimize.curve_fit(lin_fit_obj, x, bkg, jac=lin_fit_jac)
@@ -372,8 +372,11 @@ def estimate_mpk_parms_1d(
372372
min_val = np.min(f)
373373

374374
# estimate background with SNIP1d
375-
bkg = snip1d(np.atleast_2d(f),
376-
w=int(np.floor(0.25*len(f)))).flatten()
375+
bkg = snip1d(
376+
np.atleast_2d(f),
377+
w=int(np.floor(0.25*len(f))),
378+
max_workers=1,
379+
).flatten()
377380

378381
# fit linear bg and grab params
379382
bp, _ = optimize.curve_fit(lin_fit_obj, x, bkg, jac=lin_fit_jac)

hexrd/fitting/spectrum.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,11 @@ def _initial_guess(peak_positions, x, f,
204204

205205
# estimate background with snip1d
206206
# !!! using a window size based on abcissa
207-
bkg = snip1d(np.atleast_2d(f),
208-
w=int(np.floor(len(f)/num_pks/2.))).flatten()
207+
bkg = snip1d(
208+
np.atleast_2d(f),
209+
w=int(np.floor(len(f)/num_pks/2.)),
210+
max_workers=1,
211+
).flatten()
209212

210213
bkg_mod = chebyshev.Chebyshev(
211214
[0., 0.], domain=(min(x), max(x))

hexrd/imageutil.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from concurrent.futures import ProcessPoolExecutor
2+
from functools import partial
3+
import os
4+
15
import numpy as np
26
from scipy import signal, ndimage
37

@@ -60,7 +64,13 @@ def fast_snip1d(y, w=4, numiter=2):
6064
return bkg
6165

6266

63-
def snip1d(y, w=4, numiter=2, threshold=None):
67+
def snip1d(
68+
y,
69+
w=4,
70+
numiter=2,
71+
threshold=None,
72+
max_workers=os.process_cpu_count(),
73+
):
6474
"""
6575
Return SNIP-estimated baseline-background for given spectrum y.
6676
@@ -79,29 +89,47 @@ def snip1d(y, w=4, numiter=2, threshold=None):
7989
mask = np.zeros_like(y, dtype=bool)
8090

8191
# step through rows
82-
for k, z in enumerate(zfull):
83-
if np.all(mask[k]):
84-
bkg[k, :] = np.nan
85-
else:
86-
b = z
87-
for i in range(numiter):
88-
for p in range(w, 0, -1):
89-
kernel = np.zeros(p*2 + 1)
90-
kernel[0] = kernel[-1] = 1./2.
91-
b = np.minimum(
92-
b,
93-
convolution.convolve(
94-
z, kernel, boundary='extend', mask=mask[k],
95-
nan_treatment='interpolate', preserve_nan=True
96-
)
97-
)
98-
z = b
99-
bkg[k, :] = _scale_image_snip(b, min_val, invert=True)
92+
tasks = enumerate(zip(zfull, mask))
93+
f = partial(_run_snip1d_row, numiter=numiter, w=w, min_val=min_val)
94+
95+
if max_workers > 1:
96+
# Parallelize over tasks
97+
with ProcessPoolExecutor(max_workers=max_workers) as executor:
98+
for k, result in executor.map(f, tasks):
99+
bkg[k, :] = result
100+
else:
101+
# Run the tasks in this process
102+
for task in tasks:
103+
k, result = f(task)
104+
bkg[k, :] = result
105+
100106
nan_idx = np.isnan(bkg)
101107
bkg[nan_idx] = threshold
102108
return bkg
103109

104110

111+
def _run_snip1d_row(task, numiter, w, min_val):
112+
k, (z, mask) = task
113+
114+
if np.all(mask):
115+
return k, np.nan
116+
117+
b = z
118+
for i in range(numiter):
119+
for p in range(w, 0, -1):
120+
kernel = np.zeros(p*2 + 1)
121+
kernel[0] = kernel[-1] = 1./2.
122+
b = np.minimum(
123+
b,
124+
convolution.convolve(
125+
z, kernel, boundary='extend', mask=mask,
126+
nan_treatment='interpolate', preserve_nan=True
127+
)
128+
)
129+
z = b
130+
return k, _scale_image_snip(b, min_val, invert=True)
131+
132+
105133
def snip1d_quad(y, w=4, numiter=2):
106134
"""Return SNIP-estimated baseline-background for given spectrum y.
107135
3.02 MB
Binary file not shown.

tests/test_snip.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from pathlib import Path
2+
3+
import numpy as np
4+
import yaml
5+
6+
import pytest
7+
8+
from hexrd import imageutil
9+
from hexrd.instrument.hedm_instrument import HEDMInstrument
10+
from hexrd.projections.polar import PolarView
11+
12+
13+
@pytest.fixture
14+
def simulated_tardis_path(example_repo_path: Path) -> Path:
15+
return example_repo_path / 'tardis' / 'simulated'
16+
17+
18+
@pytest.fixture
19+
def simulated_tardis_images(
20+
simulated_tardis_path: Path,
21+
) -> dict[str, np.ndarray]:
22+
path = simulated_tardis_path / 'tardis_images.npz'
23+
npz = np.load(path)
24+
return {k: v for k, v in npz.items()}
25+
26+
27+
@pytest.fixture
28+
def tardis_instrument(simulated_tardis_path: Path) -> HEDMInstrument:
29+
path = simulated_tardis_path / 'ideal_tardis.yml'
30+
with open(path, 'r') as rf:
31+
conf = yaml.safe_load(rf)
32+
33+
return HEDMInstrument(conf)
34+
35+
36+
@pytest.fixture
37+
def expected_snip1d_results(test_data_dir: Path) -> np.ndarray:
38+
path = test_data_dir / 'expected_snip1d_results.npy'
39+
return np.load(path)
40+
41+
42+
def test_snip1d(
43+
tardis_instrument: HEDMInstrument,
44+
simulated_tardis_images: dict[str, np.ndarray],
45+
expected_snip1d_results: np.ndarray,
46+
):
47+
instr = tardis_instrument
48+
img_dict = simulated_tardis_images
49+
ref = expected_snip1d_results
50+
51+
# Create the PolarView
52+
tth_range = [10, 120]
53+
eta_min = -180.0
54+
eta_max = 180.0
55+
pixel_size = (0.1, 1.0)
56+
57+
pv = PolarView(tth_range, instr, eta_min, eta_max, pixel_size)
58+
img = pv.warp_image(img_dict, pad_with_nans=True,
59+
do_interpolation=True)
60+
61+
snip_width = 100
62+
numiter = 2
63+
output = imageutil.snip1d(
64+
img,
65+
snip_width,
66+
numiter,
67+
)
68+
69+
assert np.allclose(output.filled(np.nan), ref, equal_nan=True)

0 commit comments

Comments
 (0)