@@ -101,10 +101,14 @@ def get_cell_params(spg, hkls, two_thetas, wave_length=1.54184):
101101 B = np .reshape (ds , [len_solutions , 2 ])
102102 A = np .reshape (A , [len_solutions , 2 , 2 ])#; print(A.shape, B.shape)
103103 xs = np .linalg .solve (A , B )#; print(xs); import sys; sys.exit()
104- xs = xs [xs [:, 0 ] > 0 ]
105- xs = xs [xs [:, 1 ] > 0 ]
104+ mask1 = np .all (xs [:, :] > 0 , axis = 1 )
105+ hkls = np .reshape (hkls , (len_solutions , 6 ))
106+ hkls = hkls [mask1 ]
107+ xs = xs [mask1 ]
106108 cells = np .sqrt (1 / xs )
107- cells = cells [cells [:, 0 ] < 50.0 ]
109+ mask2 = np .all (cells [:, :2 ] < 50.0 , axis = 1 )
110+ cells = cells [mask2 ]
111+ hkls = hkls [mask2 ]
108112
109113 elif 16 <= spg <= 74 : # orthorhombic, need a, b, c
110114 # need three hkls to determine a, b, c
@@ -117,11 +121,14 @@ def get_cell_params(spg, hkls, two_thetas, wave_length=1.54184):
117121 B = np .reshape (ds , [len_solutions , 3 ])
118122 A = np .reshape (A , [len_solutions , 3 , 3 ])#; print(A.shape, B.shape)
119123 xs = np .linalg .solve (A , B )#; print(xs); import sys; sys.exit()
120- xs = xs [xs [:, 0 ] > 0 ]
121- xs = xs [xs [:, 1 ] > 0 ]
122- xs = xs [xs [:, 2 ] > 0 ]
124+ mask1 = np .all (xs [:, :] > 0 , axis = 1 )
125+ hkls_out = np .reshape (hkls , (len_solutions , 9 ))
126+ hkls_out = hkls_out [mask1 ]
127+ xs = xs [mask1 ]
123128 cells = np .sqrt (1 / xs )
124- cells = cells [np .all (cells [:, :3 ] < 50.0 , axis = 1 )]
129+ mask2 = np .all (cells [:, :3 ] < 50.0 , axis = 1 )
130+ cells = cells [mask2 ]
131+ hkls_out = hkls_out [mask2 ]
125132
126133 elif 3 <= spg <= 15 : # monoclinic, need a, b, c, beta
127134 # need four hkls to determine a, b, c, beta
@@ -137,12 +144,16 @@ def get_cell_params(spg, hkls, two_thetas, wave_length=1.54184):
137144 B = np .reshape (ds , [len_solutions , 4 ])
138145 A = np .reshape (A , [len_solutions , 4 , 4 ])#; print(A.shape, B.shape)
139146 xs = np .linalg .solve (A , B )#; print(xs); import sys; sys.exit()
140- xs = xs [xs [:, 0 ] > 0 ]
141- xs = xs [xs [:, 1 ] > 0 ]
142- xs = xs [xs [:, 2 ] > 0 ]
147+ mask1 = np .all (xs [:, :3 ] > 0 , axis = 1 )
148+ hkls_out = np .reshape (hkls , (len_solutions , 12 ))#;print(hkls.shape, mask1.shape, A.shape)
149+ hkls_out = hkls_out [mask1 ]
150+ xs = xs [mask1 ]
151+
143152 cos_betas = - xs [:, 3 ] / (2 * np .sqrt (xs [:, 0 ] * xs [:, 2 ]))
144153 masks = np .abs (cos_betas ) <= 1 / np .sqrt (2 )
145154 xs = xs [masks ]
155+ hkls_out = hkls_out [masks ]
156+
146157 cos_betas = cos_betas [masks ]
147158 sin_betas = np .sqrt (1 - cos_betas ** 2 )
148159 cells = np .zeros ([len (xs ), 4 ])
@@ -154,13 +165,16 @@ def get_cell_params(spg, hkls, two_thetas, wave_length=1.54184):
154165 # force angle to be less than 90
155166 mask = cells [:, 3 ] > 90.0
156167 cells [mask , 3 ] = 180.0 - cells [mask , 3 ]
157- cells = cells [np .all (cells [:, :3 ] < 50.0 , axis = 1 )]
168+
169+ mask2 = np .all (cells [:, :3 ] < 50.0 , axis = 1 )
170+ cells = cells [mask2 ]
171+ hkls_out = hkls_out [mask2 ]
158172 #print(cells)
159173 else :
160174 msg = "Only cubic, tetragonal, hexagonal, and orthorhombic systems are supported."
161175 raise NotImplementedError (msg )
162176
163- return cells
177+ return cells , hkls_out
164178
165179def get_d_hkl_from_cell (spg , cells , h , k , l ):
166180 """
@@ -349,17 +363,20 @@ def get_cell_from_multi_hkls(spg, hkls, two_thetas, long_thetas=None, wave_lengt
349363
350364 if use_seed :
351365 seed_hkls , seed_thetas = get_seeds (spg , hkls , two_thetas )#; print(seed_hkls, seed_thetas)
352- cells = get_cell_params (spg , seed_hkls , seed_thetas , wave_length )#; print(cells)
366+ cells , hkls = get_cell_params (spg , seed_hkls , seed_thetas , wave_length )#; print(cells)
353367 else :
354- cells = get_cell_params (spg , hkls , two_thetas , wave_length )#; print(cells)
368+ cells , hkls = get_cell_params (spg , hkls , two_thetas , wave_length )#; print(cells)
355369
356- cells = np .array (cells )
370+ # cells = np.array(cells)
357371 if len (cells ) == 0 : return []
358372 # keep cells up to 4 decimal places
359373 if spg < 16 :
360374 cells [:, - 1 ] = np .round (cells [:, - 1 ], decimals = 2 )
361375 cells [:, :3 ] = np .round (cells [:, :3 ], decimals = 4 )
362- cells = np .unique (cells , axis = 0 )#; print(cells) # remove duplicates
376+
377+ _ , unique_ids = np .unique (cells , axis = 0 , return_index = True )
378+ hkls = hkls [unique_ids ]#; print(cells) # remove duplicates
379+ cells = cells [unique_ids ]
363380
364381 # get the maximum h from assuming the cell[-1] is (h00)
365382 d_100s = get_d_hkl_from_cell (spg , cells , 1 , 0 , 0 )
@@ -382,7 +399,6 @@ def get_cell_from_multi_hkls(spg, hkls, two_thetas, long_thetas=None, wave_lengt
382399 # Filter out None values
383400 valid_mask = expected_thetas != None
384401 valid_thetas = expected_thetas [valid_mask ]
385- # Now try to index all other peaks using this 'a'
386402
387403 if len (valid_thetas ) > 0 :
388404 valid_thetas = np .array (valid_thetas , dtype = float )
@@ -417,8 +433,8 @@ def get_cell_from_multi_hkls(spg, hkls, two_thetas, long_thetas=None, wave_lengt
417433 solutions .append ({
418434 'cell' : cell ,
419435 'n_matched' : n_matched ,
420- 'n_total' : len (long_thetas ),
421436 'score' : score ,
437+ 'id' : [i , hkls [i ]],
422438 })
423439
424440 return solutions
@@ -430,10 +446,12 @@ def get_cell_from_multi_hkls(spg, hkls, two_thetas, long_thetas=None, wave_lengt
430446 np .set_printoptions (precision = 4 , suppress = True )
431447
432448 xtal = pyxtal ()
433- xtal .from_seed ('pyxtal/database/cifs/JVASP-62168.cif' )
434- #xtal.from_seed('pyxtal/database/cifs/JVASP-98225.cif')
435- #xtal.from_seed('pyxtal/database/cifs/JVASP-50935.cif')
436- #xtal.from_seed('pyxtal/database/cifs/JVASP-28565.cif')
449+ #xtal.from_seed('pyxtal/database/cifs/JVASP-62168.cif') # 52s -> 16s
450+ #xtal.from_seed('pyxtal/database/cifs/JVASP-98225.cif') # P21/c -> 33s -> 12s
451+ #xtal.from_seed('pyxtal/database/cifs/JVASP-50935.cif') # Pm -> 45s -> 7.6s
452+ #xtal.from_seed('pyxtal/database/cifs/JVASP-28565.cif') # 207s -> 91s -> 80s -> 72s
453+ xtal .from_seed ('pyxtal/database/cifs/JVASP-47532.cif' ) #
454+
437455 xrd = xtal .get_XRD (thetas = [0 , 120 ], SCALED_INTENSITY_TOL = 0.5 )
438456 cell_ref = np .sort (np .array (xtal .lattice .encode ()))
439457 long_thetas = xrd .pxrd [:15 , 0 ]
@@ -445,50 +463,70 @@ def get_cell_from_multi_hkls(spg, hkls, two_thetas, long_thetas=None, wave_lengt
445463 if spg >= 16 :
446464 guesses = xtal .group .generate_hkl_guesses (2 , 3 , 5 , max_square = 29 , total_square = 40 , verbose = True )
447465 else :
448- guesses = xtal .group .generate_hkl_guesses (3 , 3 , 6 , max_square = 38 , total_square = 48 , verbose = True )
466+ if spg in [5 , 8 , 12 , 15 ]:
467+ guesses = xtal .group .generate_hkl_guesses (3 , 3 , 6 , max_square = 38 , total_square = 40 , verbose = True )
468+ else :
469+ guesses = xtal .group .generate_hkl_guesses (3 , 3 , 4 , max_square = 29 , total_square = 35 , verbose = True )
449470 guesses = np .array (guesses )
450471 print ("Total guesses:" , len (guesses ))
451472 sum_squares = np .sum (guesses ** 2 , axis = (1 ,2 ))
452473 sorted_indices = np .argsort (sum_squares )
453474 guesses = guesses [sorted_indices ]
475+ if len (guesses ) > 200000 : guesses = guesses [:200000 ]
454476 #guesses = np.array([[[2, 0, 0], [1, 1, 0], [0, 1, 1], [0, 0, 2]]])
455477 #guesses = np.array([[[2, 0, 0], [1, 1, 0], [0, 0, 2], [2, 0, -2]]])
456- #guesses = np.array([[[0, 2, 0 ], [0, 0, 1 ], [1, 0, 0 ], [1, 0 , -1 ]]])
478+ #guesses = np.array([[[0, 0, -1 ], [1, 1, 0 ], [1, 1, -1 ], [0, 2 , -5 ]]])
457479
458480 # Check the quality of each (hkl, 2theta) solutions
481+ N_add = 5
482+ N_batch = 20
459483 cell2 = np .sort (np .array (xtal .lattice .encode ()))
460484 if spg <= 15 and cell2 [3 ] > 90 : cell2 [3 ] = 180 - cell2 [3 ]
461485 cells_all = np .reshape (cell2 , (1 , len (cell2 )))
462-
463- for guess in guesses [:]:
464- #print('New guess', guess.flatten())
465- found = False
466- n_peaks = len (guess )
467-
468- # Try each combination of n peaks from the first n+1 peaks
469- N_add = 5
470- available_peaks = xrd .pxrd [:n_peaks + N_add , 0 ]
471-
472- thetas = []
473- for peak_combo in combinations (range (n_peaks + N_add ), n_peaks ):
474- thetas .extend (available_peaks [list (peak_combo )])
475- hkls_t = np .tile (guess , (int (len (thetas )/ len (guess )), 1 ))
476-
486+ # Try each combination of n peaks from the first n+1 peaks
487+ n_peaks = len (guesses [0 ])
488+ available_peaks = xrd .pxrd [:n_peaks + N_add , 0 ]
489+ thetas = []
490+ for peak_combo in combinations (range (n_peaks + N_add ), n_peaks ):
491+ thetas .extend (available_peaks [list (peak_combo )])
492+ N_thetas = len (thetas ) // n_peaks
493+ thetas = np .array (thetas )
494+ thetas = np .tile (thetas , N_batch )
495+
496+ found = False
497+ d2 = 0
498+ for i in range (len (guesses )// N_batch + 1 ):
499+ if i == len (guesses )// N_batch :
500+ N_batch = len (guesses ) - N_batch * i
501+ if N_batch == 0 :
502+ break
503+ else :
504+ thetas = thetas [:N_thetas * n_peaks * N_batch ]
505+ hkls_b = np .reshape (guesses [N_batch * i :N_batch * (i + 1 )], [N_batch * n_peaks , 3 ])
506+ hkls_t = np .tile (hkls_b , (N_thetas , 1 ))
477507 solutions = get_cell_from_multi_hkls (spg , hkls_t , thetas , long_thetas , use_seed = False )
508+ if i % 1000 == 0 :
509+ print (f"Processed { N_batch * (i )} /{ d2 } , found { len (cells_all )- 1 } cells so far." )
510+
478511 for sol in solutions :
479512 cell1 = np .sort (np .array (sol ['cell' ]))
480- if spg <= 15 and cell1 [3 ] > 90 : cell1 [3 ] = 180 - cell1 [3 ]
481513
482514 # Check if it is a new solution
483515 diffs = np .sum ((cells_all - cell1 )** 2 , axis = 1 )
516+ guess = sol ['id' ][1 ]
517+ score = sol ['score' ]
518+ d2 = np .sum (guess ** 2 )
484519 if len (cells_all [diffs < 0.1 ]) == 0 :
485- print ("Guess:" , guess . flatten (), np . sum ( guess ** 2 ), "->" , cell1 , sol [ ' score' ] )
520+ print (f "Guess: { guess } , { d2 } / { len ( cells_all ) - 1 } -> { cell1 } , { score :.6f } " )
486521 cells_all = np .vstack ((cells_all , cell1 ))
487522
488- if diffs [0 ] < 0.1 : #result['score'] > 0.9999:
489- print ("Guess:" , guess .flatten (), np .sum (guess ** 2 ), "->" , cell1 , sol ['score' ])
523+ # Early stopping for getting high-quality solutions
524+ if diffs [0 ] < 0.1 :
525+ print (f"Guess: { guess } , { d2 } /{ len (cells_all )- 1 } -> { cell1 } , { score :.6f} " )
490526 print ("High score, exiting early." )
491527 found = True
492528 break
493529 if found :
494530 break
531+
532+
0 commit comments