11import copy
22import functools
33import re
4+ import numpy as np
45
56import typing as t
6- import numpy as np
7+
8+ from scipy .spatial import cKDTree
79
810from aiida .common .constants import elements
911from aiida .common .exceptions import UnsupportedSpeciesError
@@ -98,6 +100,7 @@ def efficient_copy(obj):
98100
99101 Handles both dictionaries and lists, as well as other types.
100102 """
103+
101104 if obj is None :
102105 return None
103106 elif isinstance (obj , dict ):
@@ -167,30 +170,27 @@ def _get_valid_pbc(inputpbc):
167170
168171 return the_pbc
169172
173+ def any_close_pairs (points , eps ):
174+ tree = cKDTree (points )
175+ pairs = tree .query_pairs (r = eps )
176+ return len (pairs ) > 0 ,pairs
177+
170178def _check_valid_sites (sites ):
171179 """Check that no two sites have positions that are too close to each other."""
172180
173181 positions = np .array ([site ['position' ] for site in sites ])
182+
174183 n_sites = len (positions )
175184
176185 if n_sites <= 1 :
177186 return
178187
179- # Calculate pairwise distances using broadcasting (this is much more efficient than loops...)
180- diff = positions [:, np .newaxis , :] - positions [np .newaxis , :, :] # Shape: (n_sites, n_sites, 3)
181- distances = np .linalg .norm (diff , axis = 2 ) # Shape: (n_sites, n_sites)
182-
183- # Set diagonal to large value to ignore self-comparisons
184- np .fill_diagonal (distances , np .inf )
185-
186188 # Check if any distance is below threshold
187189 min_distance = 1e-3 # You can adjust this threshold
188- close_pairs = np . where ( distances < min_distance )
190+ close_pairs = any_close_pairs ( positions , eps = min_distance )
189191
190- if len (close_pairs [0 ]) > 0 :
191- i , j = close_pairs [0 ][0 ], close_pairs [1 ][0 ] # Get first problematic pair
192- raise ValueError (f"Sites { i } and { j } have positions that are too close: "
193- f"{ positions [i ]} and { positions [j ]} (distance: { distances [i ,j ]:.6f} )" )
192+ if close_pairs [0 ]:
193+ raise ValueError (f"The following sites have positions that are too close (less than { min_distance } ): { close_pairs [1 ]} ." )
194194 return
195195
196196
@@ -390,7 +390,7 @@ def group_symbols(_list):
390390 :param _list: a list of elements representing a chemical formula
391391 :return: a list of length-2 lists of the form [ multiplicity , element ]
392392 """
393- the_list = efficient_copy (_list )
393+ the_list = copy . deepcopy (_list )
394394 the_list .reverse ()
395395 grouped_list = [[1 , the_list .pop ()]]
396396 while the_list :
@@ -468,7 +468,7 @@ def group_together(_list, group_size, offset):
468468 ``group_together(['O','Ba','Ti','Ba','Ti'],2,1) =
469469 ['O',['Ba','Ti'],['Ba','Ti']]``
470470 """
471- the_list = efficient_copy (_list )
471+ the_list = copy . deepcopy (_list )
472472 the_list .reverse ()
473473 grouped_list = []
474474 for _ in range (offset ):
@@ -507,14 +507,14 @@ def group_together_symbols(_list, group_size):
507507 :return the_symbol_list: the new grouped symbol list
508508 :return has_grouped: True if we grouped something
509509 """
510- the_symbol_list = efficient_copy (_list )
510+ the_symbol_list = copy . deepcopy (_list )
511511 has_grouped = False
512512 offset = 0
513513 while not has_grouped and offset < group_size :
514514 grouped_list = group_together (the_symbol_list , group_size , offset )
515515 new_symbol_list = group_symbols (grouped_list )
516516 if len (new_symbol_list ) < len (grouped_list ):
517- the_symbol_list = efficient_copy (new_symbol_list )
517+ the_symbol_list = copy . deepcopy (new_symbol_list )
518518 the_symbol_list = cleanout_symbol_list (the_symbol_list )
519519 has_grouped = True
520520 # print get_formula_from_symbol_list(the_symbol_list)
@@ -530,7 +530,7 @@ def group_all_together_symbols(_list):
530530 """
531531 has_finished = False
532532 group_size = 2
533- the_symbol_list = efficient_copy (_list )
533+ the_symbol_list = copy . deepcopy (_list )
534534
535535 while not has_finished and group_size <= len (_list ) // 2 :
536536 # try to group as much as possible by groups of size group_size
@@ -551,7 +551,7 @@ def group_all_together_symbols(_list):
551551 # successively apply the grouping procedure until the symbol list does not
552552 # change anymore
553553 while new_symbol_list != old_symbol_list :
554- old_symbol_list = efficient_copy (new_symbol_list )
554+ old_symbol_list = copy . deepcopy (new_symbol_list )
555555 new_symbol_list = group_all_together_symbols (old_symbol_list )
556556
557557 return get_formula_from_symbol_list (new_symbol_list , separator = separator )
@@ -879,7 +879,7 @@ def check_is_alloy(data):
879879 :param data: the data to check. The dict of the SiteCore model.
880880 :return: True if the data is an alloy, False otherwise.
881881 """
882- new_data = efficient_copy (data )
882+ new_data = copy . deepcopy (data )
883883 if "weight" not in new_data .keys () or new_data .get ("weight" , None ) is None :
884884 if isinstance (new_data ["symbol" ], list ) or re .search (r'[A-Z][a-z]*[A-Z]' , new_data ["symbol" ]):
885885 return new_data
@@ -1003,20 +1003,23 @@ def build_sites_from_expanded_properties(expanded):
10031003 # Use all keys except positions if you want to exclude arrays, or specify your own
10041004 site_props = set (expanded .keys ()).difference (_GLOBAL_PROPERTIES + _COMPUTED_PROPERTIES + ["sites" , "site_indices" ])
10051005
1006- n_sites = len (expanded .get ("positions" ,[]))
1007- sites = []
1008- for i in range (n_sites ):
1009- site = {}
1010- site ["position" ] = expanded ["positions" ][i ]
1011- for prop in site_props :
1012- site [_CONVERSION_PLURAL_SINGULAR [prop ]] = expanded [prop ][i ]
1013- sites .append (site )
1006+ # Pre-compute the conversion mapping to avoid repeated dict lookups
1007+ prop_conversions = [(prop , _CONVERSION_PLURAL_SINGULAR [prop ], expanded [prop ])
1008+ for prop in site_props ]
10141009
1015- structure_dict = {}
1016- for prop in _GLOBAL_PROPERTIES :
1017- if expanded .get (prop , None ) is not None :
1018- structure_dict [prop ] = expanded [prop ]
1010+ positions = expanded .get ("positions" , [])
1011+ n_sites = len (positions )
10191012
1013+ sites = [
1014+ {"position" : positions [i ], ** {singular : values [i ] for _ , singular , values in prop_conversions }}
1015+ for i in range (n_sites )
1016+ ]
1017+
1018+ structure_dict = {
1019+ prop : expanded [prop ]
1020+ for prop in _GLOBAL_PROPERTIES
1021+ if prop in expanded and expanded [prop ] is not None
1022+ }
10201023 structure_dict ["sites" ] = sites
10211024
10221025 return structure_dict
@@ -1145,7 +1148,7 @@ def sites_from_kinds(kinds):
11451148 positions += list (kind ['positions' ])
11461149 num_sites = len (sites_list )
11471150 for i in range (num_sites ):
1148- sites_list [i ] = efficient_copy (kinds [sites_list [i ]])
1151+ sites_list [i ] = copy . deepcopy (kinds [sites_list [i ]])
11491152 sites_list [i ].pop ('site_indices' , None )
11501153 sites_list [i ].pop ('positions' , None )
11511154 sites_list [i ]['position' ] = positions [i ]
0 commit comments