Skip to content

Commit c6f2dac

Browse files
committed
twist the code to get better performance
1 parent 2d6aff9 commit c6f2dac

File tree

1 file changed

+52
-43
lines changed

1 file changed

+52
-43
lines changed

pyxtal/XRD_indexer.py

Lines changed: 52 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def calc_two_theta_from_cell(spg, hkls, cells, wave_length=1.54184):
227227
h, k, l = hkls[:, 0], hkls[:, 1], hkls[:, 2]
228228
if spg >= 195: # cubic
229229
a = cells[0]
230-
d = a / np.sqrt(h**2 + k**2 + l**2)
230+
d = a / np.sqrt(h**2 + k**2 + l**2)#; print('ddddd', d)
231231
elif spg >= 143: # hexagonal
232232
a, c = cells[0], cells[1]
233233
d = 1 / np.sqrt((4/3) * (h**2 + h*k + k**2) / a**2 + l**2 / c**2)
@@ -246,12 +246,11 @@ def calc_two_theta_from_cell(spg, hkls, cells, wave_length=1.54184):
246246
raise NotImplementedError("triclinic systems are not supported.")
247247
sin_theta = wave_length / (2 * d)
248248
# Handle cases where sin_theta > 1
249-
valid = sin_theta <= 1
250-
thetas = np.zeros_like(sin_theta)
251-
thetas[valid] = np.arcsin(sin_theta[valid])
252-
two_thetas = 2 * np.degrees(thetas)
253-
two_thetas[~valid] = np.nan # Mark invalid values as NaN
254-
return two_thetas
249+
valid = sin_theta <= 1#; print(d[~valid])
250+
two_thetas = 2 * np.degrees(np.arcsin(sin_theta[valid]))
251+
two_thetas = np.round(two_thetas, decimals=3)
252+
two_thetas, ids = np.unique(two_thetas, return_index=True)
253+
return two_thetas, hkls[valid][ids]
255254

256255
def get_seeds(spg, hkls, two_thetas):
257256
"""
@@ -338,7 +337,7 @@ def get_seeds(spg, hkls, two_thetas):
338337

339338

340339
def get_cell_from_multi_hkls(spg, hkls, two_thetas, long_thetas=None, wave_length=1.54184,
341-
tolerance=0.05, use_seed=True, min_score=0.999):
340+
tolerance=0.1, use_seed=True, min_score=0.999):
342341
"""
343342
Estimate the cell parameters from multiple (hkl, two_theta) inputs.
344343
The idea is to use the Bragg's law and the lattice spacing formula to estimate the lattice parameters.
@@ -347,10 +346,10 @@ def get_cell_from_multi_hkls(spg, hkls, two_thetas, long_thetas=None, wave_lengt
347346
Args:
348347
hkls: list of (h, k, l) tuples
349348
two_thetas: list of 2theta values
350-
long_thetas: list of all observed 2theta values
349+
long_thetas: array of all observed 2theta values
351350
spg (int): space group number
352351
wave_length: X-ray wavelength, default is Cu K-alpha
353-
tolerance: tolerance for matching 2theta values, default is 0.05 degrees
352+
tolerance: tolerance for matching 2theta values, default is 0.1 degrees
354353
use_seed: whether to use seed hkls for initial cell estimation
355354
min_score: threshold score for consideration
356355
@@ -391,9 +390,9 @@ def get_cell_from_multi_hkls(spg, hkls, two_thetas, long_thetas=None, wave_lengt
391390
d_100s = get_d_hkl_from_cell(spg, cells, 1, 0, 0)
392391
d_010s = get_d_hkl_from_cell(spg, cells, 0, 1, 0)
393392
d_001s = get_d_hkl_from_cell(spg, cells, 0, 0, 1)
394-
theta_100s = np.degrees(np.arcsin(wave_length / (2 * d_100s)))
395-
theta_010s = np.degrees(np.arcsin(wave_length / (2 * d_010s)))
396-
theta_001s = np.degrees(np.arcsin(wave_length / (2 * d_001s)))#; print(len(cells))
393+
theta_100s = 2*np.degrees(np.arcsin(wave_length / (2 * d_100s)))
394+
theta_010s = 2*np.degrees(np.arcsin(wave_length / (2 * d_010s)))
395+
theta_001s = 2*np.degrees(np.arcsin(wave_length / (2 * d_001s)))
397396
h_maxs = np.array(long_thetas[-1] / theta_100s, dtype=int); h_maxs[h_maxs > 100] = 100
398397
k_maxs = np.array(long_thetas[-1] / theta_010s, dtype=int); k_maxs[k_maxs > 100] = 100
399398
l_maxs = np.array(long_thetas[-1] / theta_001s, dtype=int); l_maxs[l_maxs > 100] = 100
@@ -404,47 +403,45 @@ def get_cell_from_multi_hkls(spg, hkls, two_thetas, long_thetas=None, wave_lengt
404403
k_max=k_maxs[i],
405404
l_max=l_maxs[i],
406405
level=level))
407-
expected_thetas = calc_two_theta_from_cell(spg, test_hkls, cell, wave_length)
408-
# Filter out None values
409-
valid_mask = expected_thetas != None
410-
valid_thetas = expected_thetas[valid_mask]
411-
412-
if len(valid_thetas) > 0:
413-
valid_thetas = np.array(valid_thetas, dtype=float)
414-
matched_peaks = [] # (index, hkl, obs_theta, error)
415-
416-
for peak_idx, obs_theta in enumerate(long_thetas):
417-
best_match = None
418-
best_error = float('inf')
419-
errors = np.abs(obs_theta - valid_thetas)
420-
within_tolerance = errors < tolerance
421-
422-
if np.any(within_tolerance):
423-
min_idx = np.argmin(errors[within_tolerance])
424-
valid_indices = np.where(within_tolerance)[0]
425-
best_idx = valid_indices[min_idx]
426-
best_error = errors[best_idx]
427-
#best_match = (peak_idx, tuple(valid_hkls[best_idx]), obs_theta, best_error)
428-
best_match = (peak_idx, obs_theta, best_error)
429-
430-
#print(cell, peak_idx, best_match)
431-
if best_match is not None: matched_peaks.append(best_match)
406+
exp_thetas, exp_hkls = calc_two_theta_from_cell(spg, test_hkls, cell, wave_length)
407+
if len(exp_thetas) == 0: continue
408+
409+
errors_matrix = np.abs(long_thetas[:, np.newaxis] - exp_thetas[np.newaxis, :])
410+
within_tolerance = errors_matrix < tolerance
411+
has_match = np.any(within_tolerance, axis=1)
412+
best_errors = np.min(errors_matrix, axis=1)
413+
#print('best errors', exp_thetas); import sys; sys.exit()
414+
415+
# Filter to only those with valid matches
416+
valid_matches = has_match & (best_errors < tolerance)
417+
matched_peaks = [] # (index, hkl, obs_theta, error)
418+
valid_peak_indices = np.where(valid_matches)[0]
419+
420+
for peak_idx in valid_peak_indices:
421+
obs_theta = long_thetas[peak_idx]
422+
error = best_errors[peak_idx]
423+
matched_peaks.append((peak_idx, obs_theta, error))
424+
#print(error)
432425

433426
# Score this solution
434427
n_matched = len(matched_peaks)
435428
coverage = n_matched / len(long_thetas)
436429
avg_error = np.mean([match[-1] for match in matched_peaks])
437-
consistency_score = 1.0 / (1.0 + avg_error) # lower error = higher score
430+
consistency_score = 1.0 / (1.0 + avg_error)
438431
score = coverage * consistency_score
439-
#print("Cell:", cell, hkls[i], "Score:", score)
432+
#unmatches = exp_thetas[~within_tolerance.all(axis=0)]
433+
#mask = (unmatches > long_thetas[0]) & (unmatches < long_thetas[-1])
434+
#unmatches = exp_hkls[mask]
440435

441436
if score > min_score:
442437
solutions.append({
443438
'cell': cell,
444439
'n_matched': n_matched,
445440
'score': score,
446441
'id': hkls[i],
442+
#'unmatched_thetas': unmatches,
447443
})
444+
#print(cell, len(unmatches), unmatches)
448445

449446
return solutions
450447

@@ -463,12 +460,12 @@ def get_cell_from_multi_hkls(spg, hkls, two_thetas, long_thetas=None, wave_lengt
463460
#'pyxtal/database/cifs/JVASP-28634.cif', # P3m1, 0.1s
464461
#'pyxtal/database/cifs/JVASP-85365.cif', # P4/mmm, 0.6s
465462
#'pyxtal/database/cifs/JVASP-62168.cif', # Pnma, 33s
466-
'pyxtal/database/cifs/JVASP-98225.cif', # P21/c, 14s
463+
#'pyxtal/database/cifs/JVASP-98225.cif', # P21/c, 14s
467464
#'pyxtal/database/cifs/JVASP-50935.cif', # Pm, 10s
468465
#'pyxtal/database/cifs/JVASP-28565.cif', # Cm, 100s
469466
#'pyxtal/database/cifs/JVASP-36885.cif', # Cm, 100s
470467
#'pyxtal/database/cifs/JVASP-42300.cif', # C2, 178s
471-
#'pyxtal/database/cifs/JVASP-47532.cif', # P2/m,
468+
'pyxtal/database/cifs/JVASP-47532.cif', # P2/m,
472469
]:
473470
t0 = time()
474471
xtal.from_seed(cif)
@@ -560,7 +557,8 @@ def get_cell_from_multi_hkls(spg, hkls, two_thetas, long_thetas=None, wave_lengt
560557
t1 = time()
561558
data.append((cif, spg, d2, len(cells_all), score, t1-t0))
562559

563-
for d in data: print(d)
560+
for d in data:
561+
print(d)
564562
"""
565563
('pyxtal/database/cifs/JVASP-97915.cif', 225, 11, 1, 0.9944178674744656, 0.8724310398101807)
566564
('pyxtal/database/cifs/JVASP-86205.cif', 204, 4, 1, 0.9999808799880138, 0.10700225830078125)
@@ -570,4 +568,15 @@ def get_cell_from_multi_hkls(spg, hkls, two_thetas, long_thetas=None, wave_lengt
570568
('pyxtal/database/cifs/JVASP-98225.cif', 14, 14, 3, 0.9999291925203678, 35.536349296569824)
571569
('pyxtal/database/cifs/JVASP-50935.cif', 6, 6, 25, 0.9998350565457528, 11.00560998916626)
572570
('pyxtal/database/cifs/JVASP-28565.cif', 8, 35, 3824, 0.9995855322890919, 490.50669598579407)
571+
572+
('pyxtal/database/cifs/JVASP-97915.cif', 225, 11, 1, 0.9942656034459425, 0.9070873260498047)
573+
('pyxtal/database/cifs/JVASP-86205.cif', 204, 4, 1, 0.9997322519800832, 0.09690594673156738)
574+
('pyxtal/database/cifs/JVASP-28634.cif', 156, 2, 1, 0.9997539680387753, 0.07663106918334961)
575+
('pyxtal/database/cifs/JVASP-85365.cif', 123, 16, 9, 0.9997570762997596, 0.3383021354675293)
576+
('pyxtal/database/cifs/JVASP-62168.cif', 62, 34, 15, 0.9997220764747486, 13.73779296875)
577+
('pyxtal/database/cifs/JVASP-98225.cif', 14, 14, 1, 0.9997723881589121, 21.532269716262817)
578+
('pyxtal/database/cifs/JVASP-50935.cif', 6, 6, 25, 0.9997130006576738, 10.676042079925537)
579+
('pyxtal/database/cifs/JVASP-28565.cif', 8, 35, 2885, 0.9995837967820846, 218.40485000610352)
580+
('pyxtal/database/cifs/JVASP-36885.cif', 6, 5, 25, 0.9997727973911672, 10.984601259231567)
581+
('pyxtal/database/cifs/JVASP-42300.cif', 5, 25, 1, 0.9993642422227232, 84.00993585586548)
573582
"""

0 commit comments

Comments
 (0)