Skip to content

Commit 34c2c5e

Browse files
committed
optimize and simplify sym code
1 parent cf91a3e commit 34c2c5e

File tree

3 files changed

+252
-146
lines changed

3 files changed

+252
-146
lines changed

pyxtal/XRD_indexer.py

Lines changed: 81 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,14 @@ def get_cell_params(spg, hkls, two_thetas, wave_length=1.54184):
101101
B = np.reshape(ds, [len_solutions, 2])
102102
A = np.reshape(A, [len_solutions, 2, 2])#; print(A.shape, B.shape)
103103
xs = np.linalg.solve(A, B)#; print(xs); import sys; sys.exit()
104-
xs = xs[xs[:, 0] > 0]
105-
xs = xs[xs[:, 1] > 0]
104+
mask1 = np.all(xs[:, :] > 0, axis=1)
105+
hkls = np.reshape(hkls, (len_solutions, 6))
106+
hkls = hkls[mask1]
107+
xs = xs[mask1]
106108
cells = np.sqrt(1/xs)
107-
cells = cells[cells[:, 0] < 50.0]
109+
mask2 = np.all(cells[:, :2] < 50.0, axis=1)
110+
cells = cells[mask2]
111+
hkls = hkls[mask2]
108112

109113
elif 16 <= spg <= 74: # orthorhombic, need a, b, c
110114
# need three hkls to determine a, b, c
@@ -117,11 +121,14 @@ def get_cell_params(spg, hkls, two_thetas, wave_length=1.54184):
117121
B = np.reshape(ds, [len_solutions, 3])
118122
A = np.reshape(A, [len_solutions, 3, 3])#; print(A.shape, B.shape)
119123
xs = np.linalg.solve(A, B)#; print(xs); import sys; sys.exit()
120-
xs = xs[xs[:, 0] > 0]
121-
xs = xs[xs[:, 1] > 0]
122-
xs = xs[xs[:, 2] > 0]
124+
mask1 = np.all(xs[:, :] > 0, axis=1)
125+
hkls_out = np.reshape(hkls, (len_solutions, 9))
126+
hkls_out = hkls_out[mask1]
127+
xs = xs[mask1]
123128
cells = np.sqrt(1/xs)
124-
cells = cells[np.all(cells[:, :3] < 50.0, axis=1)]
129+
mask2 = np.all(cells[:, :3] < 50.0, axis=1)
130+
cells = cells[mask2]
131+
hkls_out = hkls_out[mask2]
125132

126133
elif 3 <= spg <= 15: # monoclinic, need a, b, c, beta
127134
# need four hkls to determine a, b, c, beta
@@ -137,12 +144,16 @@ def get_cell_params(spg, hkls, two_thetas, wave_length=1.54184):
137144
B = np.reshape(ds, [len_solutions, 4])
138145
A = np.reshape(A, [len_solutions, 4, 4])#; print(A.shape, B.shape)
139146
xs = np.linalg.solve(A, B)#; print(xs); import sys; sys.exit()
140-
xs = xs[xs[:, 0] > 0]
141-
xs = xs[xs[:, 1] > 0]
142-
xs = xs[xs[:, 2] > 0]
147+
mask1 = np.all(xs[:, :3] > 0, axis=1)
148+
hkls_out = np.reshape(hkls, (len_solutions, 12))#;print(hkls.shape, mask1.shape, A.shape)
149+
hkls_out = hkls_out[mask1]
150+
xs = xs[mask1]
151+
143152
cos_betas = -xs[:, 3] / (2 * np.sqrt(xs[:, 0] * xs[:, 2]))
144153
masks = np.abs(cos_betas) <= 1/np.sqrt(2)
145154
xs = xs[masks]
155+
hkls_out = hkls_out[masks]
156+
146157
cos_betas = cos_betas[masks]
147158
sin_betas = np.sqrt(1 - cos_betas ** 2)
148159
cells = np.zeros([len(xs), 4])
@@ -154,13 +165,16 @@ def get_cell_params(spg, hkls, two_thetas, wave_length=1.54184):
154165
# force angle to be less than 90
155166
mask = cells[:, 3] > 90.0
156167
cells[mask, 3] = 180.0 - cells[mask, 3]
157-
cells = cells[np.all(cells[:, :3] < 50.0, axis=1)]
168+
169+
mask2 = np.all(cells[:, :3] < 50.0, axis=1)
170+
cells = cells[mask2]
171+
hkls_out = hkls_out[mask2]
158172
#print(cells)
159173
else:
160174
msg = "Only cubic, tetragonal, hexagonal, and orthorhombic systems are supported."
161175
raise NotImplementedError(msg)
162176

163-
return cells
177+
return cells, hkls_out
164178

165179
def get_d_hkl_from_cell(spg, cells, h, k, l):
166180
"""
@@ -349,17 +363,20 @@ def get_cell_from_multi_hkls(spg, hkls, two_thetas, long_thetas=None, wave_lengt
349363

350364
if use_seed:
351365
seed_hkls, seed_thetas = get_seeds(spg, hkls, two_thetas)#; print(seed_hkls, seed_thetas)
352-
cells = get_cell_params(spg, seed_hkls, seed_thetas, wave_length)#; print(cells)
366+
cells, hkls = get_cell_params(spg, seed_hkls, seed_thetas, wave_length)#; print(cells)
353367
else:
354-
cells = get_cell_params(spg, hkls, two_thetas, wave_length)#; print(cells)
368+
cells, hkls = get_cell_params(spg, hkls, two_thetas, wave_length)#; print(cells)
355369

356-
cells = np.array(cells)
370+
#cells = np.array(cells)
357371
if len(cells) == 0: return []
358372
# keep cells up to 4 decimal places
359373
if spg < 16:
360374
cells[:, -1] = np.round(cells[:, -1], decimals=2)
361375
cells[:, :3] = np.round(cells[:, :3], decimals=4)
362-
cells = np.unique(cells, axis=0)#; print(cells) # remove duplicates
376+
377+
_, unique_ids = np.unique(cells, axis=0, return_index=True)
378+
hkls = hkls[unique_ids]#; print(cells) # remove duplicates
379+
cells = cells[unique_ids]
363380

364381
# get the maximum h from assuming the cell[-1] is (h00)
365382
d_100s = get_d_hkl_from_cell(spg, cells, 1, 0, 0)
@@ -382,7 +399,6 @@ def get_cell_from_multi_hkls(spg, hkls, two_thetas, long_thetas=None, wave_lengt
382399
# Filter out None values
383400
valid_mask = expected_thetas != None
384401
valid_thetas = expected_thetas[valid_mask]
385-
# Now try to index all other peaks using this 'a'
386402

387403
if len(valid_thetas) > 0:
388404
valid_thetas = np.array(valid_thetas, dtype=float)
@@ -417,8 +433,8 @@ def get_cell_from_multi_hkls(spg, hkls, two_thetas, long_thetas=None, wave_lengt
417433
solutions.append({
418434
'cell': cell,
419435
'n_matched': n_matched,
420-
'n_total': len(long_thetas),
421436
'score': score,
437+
'id': [i, hkls[i]],
422438
})
423439

424440
return solutions
@@ -430,10 +446,12 @@ def get_cell_from_multi_hkls(spg, hkls, two_thetas, long_thetas=None, wave_lengt
430446
np.set_printoptions(precision=4, suppress=True)
431447

432448
xtal = pyxtal()
433-
xtal.from_seed('pyxtal/database/cifs/JVASP-62168.cif')
434-
#xtal.from_seed('pyxtal/database/cifs/JVASP-98225.cif')
435-
#xtal.from_seed('pyxtal/database/cifs/JVASP-50935.cif')
436-
#xtal.from_seed('pyxtal/database/cifs/JVASP-28565.cif')
449+
#xtal.from_seed('pyxtal/database/cifs/JVASP-62168.cif') # 52s -> 16s
450+
#xtal.from_seed('pyxtal/database/cifs/JVASP-98225.cif') # P21/c -> 33s -> 12s
451+
#xtal.from_seed('pyxtal/database/cifs/JVASP-50935.cif') # Pm -> 45s -> 7.6s
452+
#xtal.from_seed('pyxtal/database/cifs/JVASP-28565.cif') # 207s -> 91s -> 80s -> 72s
453+
xtal.from_seed('pyxtal/database/cifs/JVASP-47532.cif') #
454+
437455
xrd = xtal.get_XRD(thetas=[0, 120], SCALED_INTENSITY_TOL=0.5)
438456
cell_ref = np.sort(np.array(xtal.lattice.encode()))
439457
long_thetas = xrd.pxrd[:15, 0]
@@ -445,50 +463,70 @@ def get_cell_from_multi_hkls(spg, hkls, two_thetas, long_thetas=None, wave_lengt
445463
if spg >= 16:
446464
guesses = xtal.group.generate_hkl_guesses(2, 3, 5, max_square=29, total_square=40, verbose=True)
447465
else:
448-
guesses = xtal.group.generate_hkl_guesses(3, 3, 6, max_square=38, total_square=48, verbose=True)
466+
if spg in [5, 8, 12, 15]:
467+
guesses = xtal.group.generate_hkl_guesses(3, 3, 6, max_square=38, total_square=40, verbose=True)
468+
else:
469+
guesses = xtal.group.generate_hkl_guesses(3, 3, 4, max_square=29, total_square=35, verbose=True)
449470
guesses = np.array(guesses)
450471
print("Total guesses:", len(guesses))
451472
sum_squares = np.sum(guesses**2, axis=(1,2))
452473
sorted_indices = np.argsort(sum_squares)
453474
guesses = guesses[sorted_indices]
475+
if len(guesses) > 200000: guesses = guesses[:200000]
454476
#guesses = np.array([[[2, 0, 0], [1, 1, 0], [0, 1, 1], [0, 0, 2]]])
455477
#guesses = np.array([[[2, 0, 0], [1, 1, 0], [0, 0, 2], [2, 0, -2]]])
456-
#guesses = np.array([[[0, 2, 0], [0, 0, 1], [1, 0, 0], [1, 0, -1]]])
478+
#guesses = np.array([[[0, 0, -1], [1, 1, 0], [1, 1, -1], [0, 2, -5]]])
457479

458480
# Check the quality of each (hkl, 2theta) solutions
481+
N_add = 5
482+
N_batch = 20
459483
cell2 = np.sort(np.array(xtal.lattice.encode()))
460484
if spg <= 15 and cell2[3] > 90: cell2[3] = 180 - cell2[3]
461485
cells_all = np.reshape(cell2, (1, len(cell2)))
462-
463-
for guess in guesses[:]:
464-
#print('New guess', guess.flatten())
465-
found = False
466-
n_peaks = len(guess)
467-
468-
# Try each combination of n peaks from the first n+1 peaks
469-
N_add = 5
470-
available_peaks = xrd.pxrd[:n_peaks + N_add, 0]
471-
472-
thetas = []
473-
for peak_combo in combinations(range(n_peaks + N_add), n_peaks):
474-
thetas.extend(available_peaks[list(peak_combo)])
475-
hkls_t = np.tile(guess, (int(len(thetas)/len(guess)), 1))
476-
486+
# Try each combination of n peaks from the first n+1 peaks
487+
n_peaks = len(guesses[0])
488+
available_peaks = xrd.pxrd[:n_peaks + N_add, 0]
489+
thetas = []
490+
for peak_combo in combinations(range(n_peaks + N_add), n_peaks):
491+
thetas.extend(available_peaks[list(peak_combo)])
492+
N_thetas = len(thetas) // n_peaks
493+
thetas = np.array(thetas)
494+
thetas = np.tile(thetas, N_batch)
495+
496+
found = False
497+
d2 = 0
498+
for i in range(len(guesses)//N_batch + 1):
499+
if i == len(guesses)//N_batch:
500+
N_batch = len(guesses) - N_batch * i
501+
if N_batch == 0:
502+
break
503+
else:
504+
thetas = thetas[:N_thetas * n_peaks * N_batch]
505+
hkls_b = np.reshape(guesses[N_batch*i:N_batch*(i+1)], [N_batch*n_peaks, 3])
506+
hkls_t = np.tile(hkls_b, (N_thetas, 1))
477507
solutions = get_cell_from_multi_hkls(spg, hkls_t, thetas, long_thetas, use_seed=False)
508+
if i % 1000 == 0:
509+
print(f"Processed {N_batch*(i)}/{d2}, found {len(cells_all)-1} cells so far.")
510+
478511
for sol in solutions:
479512
cell1 = np.sort(np.array(sol['cell']))
480-
if spg <= 15 and cell1[3] > 90: cell1[3] = 180 - cell1[3]
481513

482514
# Check if it is a new solution
483515
diffs = np.sum((cells_all - cell1)**2, axis=1)
516+
guess = sol['id'][1]
517+
score = sol['score']
518+
d2 = np.sum(guess**2)
484519
if len(cells_all[diffs < 0.1]) == 0:
485-
print("Guess:", guess.flatten(), np.sum(guess**2), "->", cell1, sol['score'])
520+
print(f"Guess: {guess}, {d2}/{len(cells_all)-1} -> {cell1}, {score:.6f}")
486521
cells_all = np.vstack((cells_all, cell1))
487522

488-
if diffs[0] < 0.1: #result['score'] > 0.9999:
489-
print("Guess:", guess.flatten(), np.sum(guess**2), "->", cell1, sol['score'])
523+
# Early stopping for getting high-quality solutions
524+
if diffs[0] < 0.1:
525+
print(f"Guess: {guess}, {d2}/{len(cells_all)-1} -> {cell1}, {score:.6f}")
490526
print("High score, exiting early.")
491527
found = True
492528
break
493529
if found:
494530
break
531+
532+
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#############################################################
2+
# ______ _ _ _ #
3+
# (_____ \ \ \ / / | | #
4+
# _____) ) _ \ \/ / |_ ____| | #
5+
# | ____/ | | | ) (| _)/ _ | | #
6+
# | | | |_| |/ /\ \ |_( (_| | |___ #
7+
# |_| \__ /_/ \_\___)__|_|_____) #
8+
# (____/ #
9+
#---------------------(version 1.1.1)--------------------#
10+
# A Python package for random crystal generation #
11+
# url: https://github.com/qzhu2017/pyxtal #
12+
# @Zhu's group at U. North Carolina at Charlotte #
13+
#############################################################
14+
data_from_pyxtal
15+
16+
_symmetry_space_group_name_H-M 'P2/m'
17+
_symmetry_Int_Tables_number 10
18+
_symmetry_cell_setting monoclinic
19+
_cell_length_a 5.343986
20+
_cell_length_b 3.078251
21+
_cell_length_c 7.531201
22+
_cell_angle_alpha 90.000000
23+
_cell_angle_beta 90.116449
24+
_cell_angle_gamma 90.000000
25+
_cell_volume 123.889003
26+
27+
loop_
28+
_symmetry_equiv_pos_site_id
29+
_symmetry_equiv_pos_as_xyz
30+
1 'x, y, z'
31+
2 '-x, y, -z'
32+
3 '-x, -y, -z'
33+
4 'x, -y, z'
34+
35+
loop_
36+
_atom_site_label
37+
_atom_site_type_symbol
38+
_atom_site_symmetry_multiplicity
39+
_atom_site_fract_x
40+
_atom_site_fract_y
41+
_atom_site_fract_z
42+
_atom_site_occupancy
43+
Mn Mn 2 0.330707 0.000000 0.662283 1
44+
Mn Mn 1 0.500000 0.500000 0.000000 1
45+
Mn Mn 1 0.000000 0.000000 0.000000 1
46+
Fe Fe 2 0.835235 0.500000 0.664951 1
47+
O O 2 0.667897 0.000000 0.829994 1
48+
O O 2 0.166981 0.500000 0.832570 1
49+
O O 1 0.500000 0.500000 0.500000 1
50+
O O 1 0.000000 0.000000 0.500000 1
51+
#END
52+

0 commit comments

Comments
 (0)