Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hexrd/core/instrument/hedm_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -2082,6 +2082,7 @@ def pull_spots(
ang_centers[patch_id],
meas_angs,
meas_xy,
xy_centers[patch_id],
]
)
if write_text:
Expand Down
18 changes: 16 additions & 2 deletions hexrd/hedm/fitgrains.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def fit_grain_FF_reduced(grain_id):
prefix = paramMP['spots_filename']
spots_filename = None if prefix is None else prefix % grain_id

return_pull_spots_data: bool = paramMP.get('return_pull_spots_data', False)

grain = grains_table[grain_id]
grain_params = grain[3:15]

Expand Down Expand Up @@ -137,7 +139,10 @@ def fit_grain_FF_reduced(grain_id):
f'Not enough valid reflections ({num_refl_valid}) to fit, ' f'exiting',
RuntimeWarning,
)
return grain_id, completeness, np.inf, grain_params
result = (grain_id, completeness, np.inf, grain_params)
if return_pull_spots_data:
result += ((complvec, results),)
return result
else:
grain_params = fitGrain(
grain_params,
Expand Down Expand Up @@ -204,7 +209,10 @@ def fit_grain_FF_reduced(grain_id):
simOnly=False,
return_value_flag=2,
)
return grain_id, completeness, chisq, grain_params
result = (grain_id, completeness, chisq, grain_params)
if return_pull_spots_data:
result += ((complvec, results),)
return result


def determine_valid_reflections(results, instrument, analysis_dirname):
Expand Down Expand Up @@ -370,6 +378,7 @@ def fit_grains(
ids_to_refine=None,
write_spots_files=True,
check_if_canceled_func=None,
return_pull_spots_data: bool = False,
):
"""
Performs optimization of grain parameters.
Expand Down Expand Up @@ -428,6 +437,7 @@ def fit_grains(
ome_period=ome_period,
analysis_dirname=cfg.analysis_dir,
spots_filename=spots_filename,
return_pull_spots_data=return_pull_spots_data,
)

# =====================================================================
Expand Down Expand Up @@ -482,4 +492,8 @@ def fit_grains(
pool.join()
elapsed = timeit.default_timer() - start
logger.info("fitting took %f seconds", elapsed)
if return_pull_spots_data:
spots_data = {result[0]: result[4] for result in fit_results}
fit_results = [result[:4] for result in fit_results]
return fit_results, spots_data
return fit_results
82 changes: 82 additions & 0 deletions tests/test_fit-grains.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,85 @@ def test_fit_grains(
)

assert cresult


def test_fit_grains_return_pull_spots_data(
single_ge_include_path: Path,
test_config: config.root.RootConfig,
grains_reference_file_path: Path,
) -> None:
os.chdir(str(single_ge_include_path))

grains_table: np.ndarray = np.loadtxt(grains_reference_file_path, ndmin=2)

result = fit_grains(
test_config,
grains_table,
show_progress=False,
ids_to_refine=None,
write_spots_files=False,
return_pull_spots_data=True,
)

# Should return a (fit_results, spots_data) tuple
assert isinstance(result, tuple)
assert len(result) == 2

fit_results, spots_data = result

# fit_results should be a list of 4-element tuples
assert isinstance(fit_results, list)
assert len(fit_results) > 0
for grain_result in fit_results:
assert len(grain_result) == 4
grain_id, completeness, chisq, grain_params = grain_result
assert isinstance(grain_id, (int, np.integer))
assert isinstance(completeness, float)
assert isinstance(grain_params, np.ndarray)
assert grain_params.shape == (12,)

# spots_data should be a dict keyed by grain_id
assert isinstance(spots_data, dict)
assert len(spots_data) == len(fit_results)

for grain_id, (complvec, results) in spots_data.items():
# complvec is a list of booleans
assert isinstance(complvec, list)

# results is a dict keyed by detector name
assert isinstance(results, dict)
assert len(results) > 0

for det_key, det_results in results.items():
assert isinstance(det_key, str)
assert isinstance(det_results, list)
assert len(det_results) > 0

for spot in det_results:
# Each spot should have 9 elements (including pred_xy)
assert len(spot) == 9, (
f'Expected 9 elements per spot, got {len(spot)}'
)

peak_id = spot[0]
hkl = spot[2]
pred_angs = spot[5]
meas_angs = spot[6]
meas_xy = spot[7]
pred_xy = spot[8]

assert isinstance(peak_id, (int, np.integer))
assert isinstance(hkl, np.ndarray)
assert hkl.shape == (3,)
assert isinstance(pred_angs, np.ndarray)
assert pred_angs.shape == (3,)

# meas_angs/meas_xy may be None for invalid spots
if peak_id >= 0:
assert isinstance(meas_angs, np.ndarray)
assert meas_angs.shape == (3,)
assert isinstance(meas_xy, np.ndarray)
assert meas_xy.shape == (2,)

assert isinstance(pred_xy, np.ndarray)
assert pred_xy.shape == (2,)
Loading