Skip to content

Commit 261b7e6

Browse files
authored
Merge pull request #905 from HEXRD/return-pull-spots-data
Add option to return pull_spots data to fit_grains
2 parents 4e17bfc + a42f5fa commit 261b7e6

File tree

3 files changed

+99
-2
lines changed

3 files changed

+99
-2
lines changed

hexrd/core/instrument/hedm_instrument.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2082,6 +2082,7 @@ def pull_spots(
20822082
ang_centers[patch_id],
20832083
meas_angs,
20842084
meas_xy,
2085+
xy_centers[patch_id],
20852086
]
20862087
)
20872088
if write_text:

hexrd/hedm/fitgrains.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def fit_grain_FF_reduced(grain_id):
104104
prefix = paramMP['spots_filename']
105105
spots_filename = None if prefix is None else prefix % grain_id
106106

107+
return_pull_spots_data: bool = paramMP.get('return_pull_spots_data', False)
108+
107109
grain = grains_table[grain_id]
108110
grain_params = grain[3:15]
109111

@@ -137,7 +139,10 @@ def fit_grain_FF_reduced(grain_id):
137139
f'Not enough valid reflections ({num_refl_valid}) to fit, ' f'exiting',
138140
RuntimeWarning,
139141
)
140-
return grain_id, completeness, np.inf, grain_params
142+
result = (grain_id, completeness, np.inf, grain_params)
143+
if return_pull_spots_data:
144+
result += ((complvec, results),)
145+
return result
141146
else:
142147
grain_params = fitGrain(
143148
grain_params,
@@ -204,7 +209,10 @@ def fit_grain_FF_reduced(grain_id):
204209
simOnly=False,
205210
return_value_flag=2,
206211
)
207-
return grain_id, completeness, chisq, grain_params
212+
result = (grain_id, completeness, chisq, grain_params)
213+
if return_pull_spots_data:
214+
result += ((complvec, results),)
215+
return result
208216

209217

210218
def determine_valid_reflections(results, instrument, analysis_dirname):
@@ -370,6 +378,7 @@ def fit_grains(
370378
ids_to_refine=None,
371379
write_spots_files=True,
372380
check_if_canceled_func=None,
381+
return_pull_spots_data: bool = False,
373382
):
374383
"""
375384
Performs optimization of grain parameters.
@@ -428,6 +437,7 @@ def fit_grains(
428437
ome_period=ome_period,
429438
analysis_dirname=cfg.analysis_dir,
430439
spots_filename=spots_filename,
440+
return_pull_spots_data=return_pull_spots_data,
431441
)
432442

433443
# =====================================================================
@@ -482,4 +492,8 @@ def fit_grains(
482492
pool.join()
483493
elapsed = timeit.default_timer() - start
484494
logger.info("fitting took %f seconds", elapsed)
495+
if return_pull_spots_data:
496+
spots_data = {result[0]: result[4] for result in fit_results}
497+
fit_results = [result[:4] for result in fit_results]
498+
return fit_results, spots_data
485499
return fit_results

tests/test_fit-grains.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,85 @@ def test_fit_grains(
101101
)
102102

103103
assert cresult
104+
105+
106+
def test_fit_grains_return_pull_spots_data(
107+
single_ge_include_path: Path,
108+
test_config: config.root.RootConfig,
109+
grains_reference_file_path: Path,
110+
) -> None:
111+
os.chdir(str(single_ge_include_path))
112+
113+
grains_table: np.ndarray = np.loadtxt(grains_reference_file_path, ndmin=2)
114+
115+
result = fit_grains(
116+
test_config,
117+
grains_table,
118+
show_progress=False,
119+
ids_to_refine=None,
120+
write_spots_files=False,
121+
return_pull_spots_data=True,
122+
)
123+
124+
# Should return a (fit_results, spots_data) tuple
125+
assert isinstance(result, tuple)
126+
assert len(result) == 2
127+
128+
fit_results, spots_data = result
129+
130+
# fit_results should be a list of 4-element tuples
131+
assert isinstance(fit_results, list)
132+
assert len(fit_results) > 0
133+
for grain_result in fit_results:
134+
assert len(grain_result) == 4
135+
grain_id, completeness, chisq, grain_params = grain_result
136+
assert isinstance(grain_id, (int, np.integer))
137+
assert isinstance(completeness, float)
138+
assert isinstance(grain_params, np.ndarray)
139+
assert grain_params.shape == (12,)
140+
141+
# spots_data should be a dict keyed by grain_id
142+
assert isinstance(spots_data, dict)
143+
assert len(spots_data) == len(fit_results)
144+
145+
for grain_id, (complvec, results) in spots_data.items():
146+
# complvec is a list of booleans
147+
assert isinstance(complvec, list)
148+
149+
# results is a dict keyed by detector name
150+
assert isinstance(results, dict)
151+
assert len(results) > 0
152+
153+
for det_key, det_results in results.items():
154+
assert isinstance(det_key, str)
155+
assert isinstance(det_results, list)
156+
assert len(det_results) > 0
157+
158+
for spot in det_results:
159+
# Each spot should have 9 elements (including pred_xy)
160+
assert len(spot) == 9, (
161+
f'Expected 9 elements per spot, got {len(spot)}'
162+
)
163+
164+
peak_id = spot[0]
165+
hkl = spot[2]
166+
pred_angs = spot[5]
167+
meas_angs = spot[6]
168+
meas_xy = spot[7]
169+
pred_xy = spot[8]
170+
171+
assert isinstance(peak_id, (int, np.integer))
172+
assert isinstance(hkl, np.ndarray)
173+
assert hkl.shape == (3,)
174+
assert isinstance(pred_angs, np.ndarray)
175+
assert pred_angs.shape == (3,)
176+
177+
# meas_angs/meas_xy may be None for invalid spots
178+
if peak_id >= 0:
179+
assert isinstance(meas_angs, np.ndarray)
180+
assert meas_angs.shape == (3,)
181+
assert isinstance(meas_xy, np.ndarray)
182+
assert meas_xy.shape == (2,)
183+
184+
assert isinstance(pred_xy, np.ndarray)
185+
assert pred_xy.shape == (2,)

0 commit comments

Comments
 (0)