Skip to content

Commit 2baefd9

Browse files
authored
Merge pull request #777 from HEXRD/parallelize-snip1d
Parallelize rows in snip1d
2 parents 0749c34 + 9a3c34c commit 2baefd9

File tree

5 files changed

+121
-24
lines changed

5 files changed

+121
-24
lines changed

hexrd/fitting/fitpeak.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def estimate_pk_parms_1d(x, f, pktype='pvoigt'):
119119

120120
# handle background
121121
# ??? make kernel width a kwarg?
122-
bkg = snip1d(np.atleast_2d(f), w=int(2*npts/3.)).flatten()
122+
bkg = snip1d(np.atleast_2d(f), w=int(2*npts/3.), max_workers=1).flatten()
123123

124124
# fit linear bg and grab params
125125
bp, _ = optimize.curve_fit(lin_fit_obj, x, bkg, jac=lin_fit_jac)
@@ -372,8 +372,11 @@ def estimate_mpk_parms_1d(
372372
min_val = np.min(f)
373373

374374
# estimate background with SNIP1d
375-
bkg = snip1d(np.atleast_2d(f),
376-
w=int(np.floor(0.25*len(f)))).flatten()
375+
bkg = snip1d(
376+
np.atleast_2d(f),
377+
w=int(np.floor(0.25*len(f))),
378+
max_workers=1,
379+
).flatten()
377380

378381
# fit linear bg and grab params
379382
bp, _ = optimize.curve_fit(lin_fit_obj, x, bkg, jac=lin_fit_jac)

hexrd/fitting/spectrum.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,11 @@ def _initial_guess(peak_positions, x, f,
204204

205205
# estimate background with snip1d
206206
# !!! using a window size based on abcissa
207-
bkg = snip1d(np.atleast_2d(f),
208-
w=int(np.floor(len(f)/num_pks/2.))).flatten()
207+
bkg = snip1d(
208+
np.atleast_2d(f),
209+
w=int(np.floor(len(f)/num_pks/2.)),
210+
max_workers=1,
211+
).flatten()
209212

210213
bkg_mod = chebyshev.Chebyshev(
211214
[0., 0.], domain=(min(x), max(x))

hexrd/imageutil.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from concurrent.futures import ProcessPoolExecutor
2+
from functools import partial
3+
import os
4+
15
import numpy as np
26
from scipy import signal, ndimage
37

@@ -60,7 +64,7 @@ def fast_snip1d(y, w=4, numiter=2):
6064
return bkg
6165

6266

63-
def snip1d(y, w=4, numiter=2, threshold=None):
67+
def snip1d(y, w=4, numiter=2, threshold=None, max_workers=os.cpu_count()):
6468
"""
6569
Return SNIP-estimated baseline-background for given spectrum y.
6670
@@ -79,29 +83,47 @@ def snip1d(y, w=4, numiter=2, threshold=None):
7983
mask = np.zeros_like(y, dtype=bool)
8084

8185
# step through rows
82-
for k, z in enumerate(zfull):
83-
if np.all(mask[k]):
84-
bkg[k, :] = np.nan
85-
else:
86-
b = z
87-
for i in range(numiter):
88-
for p in range(w, 0, -1):
89-
kernel = np.zeros(p*2 + 1)
90-
kernel[0] = kernel[-1] = 1./2.
91-
b = np.minimum(
92-
b,
93-
convolution.convolve(
94-
z, kernel, boundary='extend', mask=mask[k],
95-
nan_treatment='interpolate', preserve_nan=True
96-
)
97-
)
98-
z = b
99-
bkg[k, :] = _scale_image_snip(b, min_val, invert=True)
86+
tasks = enumerate(zip(zfull, mask))
87+
f = partial(_run_snip1d_row, numiter=numiter, w=w, min_val=min_val)
88+
89+
if max_workers > 1:
90+
# Parallelize over tasks
91+
with ProcessPoolExecutor(max_workers=max_workers) as executor:
92+
for k, result in executor.map(f, tasks):
93+
bkg[k, :] = result
94+
else:
95+
# Run the tasks in this process
96+
for task in tasks:
97+
k, result = f(task)
98+
bkg[k, :] = result
99+
100100
nan_idx = np.isnan(bkg)
101101
bkg[nan_idx] = threshold
102102
return bkg
103103

104104

105+
def _run_snip1d_row(task, numiter, w, min_val):
106+
k, (z, mask) = task
107+
108+
if np.all(mask):
109+
return k, np.nan
110+
111+
b = z
112+
for i in range(numiter):
113+
for p in range(w, 0, -1):
114+
kernel = np.zeros(p*2 + 1)
115+
kernel[0] = kernel[-1] = 1./2.
116+
b = np.minimum(
117+
b,
118+
convolution.convolve(
119+
z, kernel, boundary='extend', mask=mask,
120+
nan_treatment='interpolate', preserve_nan=True
121+
)
122+
)
123+
z = b
124+
return k, _scale_image_snip(b, min_val, invert=True)
125+
126+
105127
def snip1d_quad(y, w=4, numiter=2):
106128
"""Return SNIP-estimated baseline-background for given spectrum y.
107129
3.02 MB
Binary file not shown.

tests/test_snip.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from pathlib import Path
2+
3+
import numpy as np
4+
import yaml
5+
6+
import pytest
7+
8+
from hexrd import imageutil
9+
from hexrd.instrument.hedm_instrument import HEDMInstrument
10+
from hexrd.projections.polar import PolarView
11+
12+
13+
@pytest.fixture
14+
def simulated_tardis_path(example_repo_path: Path) -> Path:
15+
return example_repo_path / 'tardis' / 'simulated'
16+
17+
18+
@pytest.fixture
19+
def simulated_tardis_images(
20+
simulated_tardis_path: Path,
21+
) -> dict[str, np.ndarray]:
22+
path = simulated_tardis_path / 'tardis_images.npz'
23+
npz = np.load(path)
24+
return {k: v for k, v in npz.items()}
25+
26+
27+
@pytest.fixture
28+
def tardis_instrument(simulated_tardis_path: Path) -> HEDMInstrument:
29+
path = simulated_tardis_path / 'ideal_tardis.yml'
30+
with open(path, 'r') as rf:
31+
conf = yaml.safe_load(rf)
32+
33+
return HEDMInstrument(conf)
34+
35+
36+
@pytest.fixture
37+
def expected_snip1d_results(test_data_dir: Path) -> np.ndarray:
38+
path = test_data_dir / 'expected_snip1d_results.npy'
39+
return np.load(path)
40+
41+
42+
def test_snip1d(
43+
tardis_instrument: HEDMInstrument,
44+
simulated_tardis_images: dict[str, np.ndarray],
45+
expected_snip1d_results: np.ndarray,
46+
):
47+
instr = tardis_instrument
48+
img_dict = simulated_tardis_images
49+
ref = expected_snip1d_results
50+
51+
# Create the PolarView
52+
tth_range = [10, 120]
53+
eta_min = -180.0
54+
eta_max = 180.0
55+
pixel_size = (0.1, 1.0)
56+
57+
pv = PolarView(tth_range, instr, eta_min, eta_max, pixel_size)
58+
img = pv.warp_image(img_dict, pad_with_nans=True,
59+
do_interpolation=True)
60+
61+
snip_width = 100
62+
numiter = 2
63+
output = imageutil.snip1d(
64+
img,
65+
snip_width,
66+
numiter,
67+
)
68+
69+
assert np.allclose(output.filled(np.nan), ref, equal_nan=True)

0 commit comments

Comments
 (0)