Skip to content

Commit 59ebca6

Browse files
committed
Improving performances
1 parent 4cac54b commit 59ebca6

File tree

4 files changed

+68
-51
lines changed

4 files changed

+68
-51
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ dependencies = [
2626
"voluptuous",
2727
"ase",
2828
"pymatgen>=2022.1.20",
29-
"pydantic>=2.6"
29+
"pydantic>=2.6",
30+
"scipy>=1.13.1",
3031
]
3132

3233
[project.urls]

src/aiida_atomistic/data/structure/getter_mixin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
HubbardGetterMixin,
1414
)
1515

16-
from aiida_atomistic.data.structure.utils import classify_site_kinds, check_kinds_match, efficient_copy
16+
from aiida_atomistic.data.structure.utils import classify_site_kinds, check_kinds_match
1717

1818
try:
1919
import ase # noqa: F401
@@ -397,7 +397,7 @@ def to_dict(self, exclude_kinds=False):
397397
:return: The structure as a dictionary.
398398
:rtype: dict
399399
"""
400-
dict_repr = efficient_copy(self.properties.model_dump(exclude_unset=True, exclude_none=True, warnings=False, exclude={'kinds'} if exclude_kinds else {}))
400+
dict_repr = copy.deepcopy(self.properties.model_dump(exclude_unset=True, exclude_none=True, warnings=False, exclude={'kinds'} if exclude_kinds else {}))
401401

402402
return dict_repr
403403

src/aiida_atomistic/data/structure/models.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ def validate_sites(cls, v):
116116

117117
if v is None:
118118
return v
119-
else:
120-
# test if they can be converted to Site
121-
sites = [Site.model_validate(site) if not isinstance(site, Site) else site for site in v]
119+
# else:
120+
# # test if they can be converted to Site
121+
# sites = [Site.model_validate(site) if not isinstance(site, Site) else site for site in v]
122122

123123
_check_valid_sites(v)
124124

@@ -291,20 +291,33 @@ def kinds(self) -> list[Kind]:
291291
#raise ValueError("Kind names must be defined to access kinds.")
292292
return None
293293

294+
# Mapping of kind_name -> site indices
295+
kind_to_indices = defaultdict(list)
296+
for i, name in enumerate(self.kind_names):
297+
kind_to_indices[name].append(i)
298+
299+
positions_array = self.positions
300+
294301
kinds_list = []
295-
kind_name_set = set(self.kind_names)
296-
for idx, site in enumerate(self.sites):
302+
seen_kinds = set()
303+
304+
for site in self.sites:
297305
kind_name = site.kind_name if site.kind_name else site.symbol
298-
if kind_name in kind_name_set:
299-
site_indices = [i for i, name in enumerate(self.kind_names) if name == kind_name]
300-
positions=np.array([self.positions[i] for i in site_indices])
301-
kind = Kind(
302-
**site.model_dump(exclude={'position'}),
303-
site_indices=site_indices,
304-
positions=positions,
305-
)
306-
kinds_list.append(kind)
307-
kind_name_set.remove(kind_name) # Ensure we don't add the same kind multiple
306+
307+
# Skip if we've already processed this kind
308+
if kind_name in seen_kinds:
309+
continue
310+
311+
seen_kinds.add(kind_name)
312+
site_indices = kind_to_indices[kind_name]
313+
positions = positions_array[site_indices]
314+
315+
kind = Kind(
316+
**site.model_dump(exclude={'position'}),
317+
site_indices=site_indices,
318+
positions=positions,
319+
)
320+
kinds_list.append(kind)
308321

309322
return FrozenList(kinds_list)
310323

src/aiida_atomistic/data/structure/utils.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import copy
22
import functools
33
import re
4+
import numpy as np
45

56
import typing as t
6-
import numpy as np
7+
8+
from scipy.spatial import cKDTree
79

810
from aiida.common.constants import elements
911
from aiida.common.exceptions import UnsupportedSpeciesError
@@ -98,6 +100,7 @@ def efficient_copy(obj):
98100
99101
Handles both dictionaries and lists, as well as other types.
100102
"""
103+
101104
if obj is None:
102105
return None
103106
elif isinstance(obj, dict):
@@ -167,30 +170,27 @@ def _get_valid_pbc(inputpbc):
167170

168171
return the_pbc
169172

173+
def any_close_pairs(points, eps):
174+
tree = cKDTree(points)
175+
pairs = tree.query_pairs(r=eps)
176+
return len(pairs) > 0,pairs
177+
170178
def _check_valid_sites(sites):
171179
"""Check that no two sites have positions that are too close to each other."""
172180

173181
positions = np.array([site['position'] for site in sites])
182+
174183
n_sites = len(positions)
175184

176185
if n_sites <= 1:
177186
return
178187

179-
# Calculate pairwise distances using broadcasting (this is much more efficient than loops...)
180-
diff = positions[:, np.newaxis, :] - positions[np.newaxis, :, :] # Shape: (n_sites, n_sites, 3)
181-
distances = np.linalg.norm(diff, axis=2) # Shape: (n_sites, n_sites)
182-
183-
# Set diagonal to large value to ignore self-comparisons
184-
np.fill_diagonal(distances, np.inf)
185-
186188
# Check if any distance is below threshold
187189
min_distance = 1e-3 # You can adjust this threshold
188-
close_pairs = np.where(distances < min_distance)
190+
close_pairs = any_close_pairs(positions, eps=min_distance)
189191

190-
if len(close_pairs[0]) > 0:
191-
i, j = close_pairs[0][0], close_pairs[1][0] # Get first problematic pair
192-
raise ValueError(f"Sites {i} and {j} have positions that are too close: "
193-
f"{positions[i]} and {positions[j]} (distance: {distances[i,j]:.6f})")
192+
if close_pairs[0]:
193+
raise ValueError(f"The following sites have positions that are too close (less than {min_distance}): {close_pairs[1]}.")
194194
return
195195

196196

@@ -390,7 +390,7 @@ def group_symbols(_list):
390390
:param _list: a list of elements representing a chemical formula
391391
:return: a list of length-2 lists of the form [ multiplicity , element ]
392392
"""
393-
the_list = efficient_copy(_list)
393+
the_list = copy.deepcopy(_list)
394394
the_list.reverse()
395395
grouped_list = [[1, the_list.pop()]]
396396
while the_list:
@@ -468,7 +468,7 @@ def group_together(_list, group_size, offset):
468468
``group_together(['O','Ba','Ti','Ba','Ti'],2,1) =
469469
['O',['Ba','Ti'],['Ba','Ti']]``
470470
"""
471-
the_list = efficient_copy(_list)
471+
the_list = copy.deepcopy(_list)
472472
the_list.reverse()
473473
grouped_list = []
474474
for _ in range(offset):
@@ -507,14 +507,14 @@ def group_together_symbols(_list, group_size):
507507
:return the_symbol_list: the new grouped symbol list
508508
:return has_grouped: True if we grouped something
509509
"""
510-
the_symbol_list = efficient_copy(_list)
510+
the_symbol_list = copy.deepcopy(_list)
511511
has_grouped = False
512512
offset = 0
513513
while not has_grouped and offset < group_size:
514514
grouped_list = group_together(the_symbol_list, group_size, offset)
515515
new_symbol_list = group_symbols(grouped_list)
516516
if len(new_symbol_list) < len(grouped_list):
517-
the_symbol_list = efficient_copy(new_symbol_list)
517+
the_symbol_list = copy.deepcopy(new_symbol_list)
518518
the_symbol_list = cleanout_symbol_list(the_symbol_list)
519519
has_grouped = True
520520
# print get_formula_from_symbol_list(the_symbol_list)
@@ -530,7 +530,7 @@ def group_all_together_symbols(_list):
530530
"""
531531
has_finished = False
532532
group_size = 2
533-
the_symbol_list = efficient_copy(_list)
533+
the_symbol_list = copy.deepcopy(_list)
534534

535535
while not has_finished and group_size <= len(_list) // 2:
536536
# try to group as much as possible by groups of size group_size
@@ -551,7 +551,7 @@ def group_all_together_symbols(_list):
551551
# successively apply the grouping procedure until the symbol list does not
552552
# change anymore
553553
while new_symbol_list != old_symbol_list:
554-
old_symbol_list = efficient_copy(new_symbol_list)
554+
old_symbol_list = copy.deepcopy(new_symbol_list)
555555
new_symbol_list = group_all_together_symbols(old_symbol_list)
556556

557557
return get_formula_from_symbol_list(new_symbol_list, separator=separator)
@@ -879,7 +879,7 @@ def check_is_alloy(data):
879879
:param data: the data to check. The dict of the SiteCore model.
880880
:return: True if the data is an alloy, False otherwise.
881881
"""
882-
new_data = efficient_copy(data)
882+
new_data = copy.deepcopy(data)
883883
if "weight" not in new_data.keys() or new_data.get("weight", None) is None:
884884
if isinstance(new_data["symbol"], list) or re.search(r'[A-Z][a-z]*[A-Z]', new_data["symbol"]):
885885
return new_data
@@ -1003,20 +1003,23 @@ def build_sites_from_expanded_properties(expanded):
10031003
# Use all keys except positions if you want to exclude arrays, or specify your own
10041004
site_props = set(expanded.keys()).difference(_GLOBAL_PROPERTIES + _COMPUTED_PROPERTIES + ["sites", "site_indices"])
10051005

1006-
n_sites = len(expanded.get("positions",[]))
1007-
sites = []
1008-
for i in range(n_sites):
1009-
site = {}
1010-
site["position"] = expanded["positions"][i]
1011-
for prop in site_props:
1012-
site[_CONVERSION_PLURAL_SINGULAR[prop]] = expanded[prop][i]
1013-
sites.append(site)
1006+
# Pre-compute the conversion mapping to avoid repeated dict lookups
1007+
prop_conversions = [(prop, _CONVERSION_PLURAL_SINGULAR[prop], expanded[prop])
1008+
for prop in site_props]
10141009

1015-
structure_dict = {}
1016-
for prop in _GLOBAL_PROPERTIES:
1017-
if expanded.get(prop, None) is not None:
1018-
structure_dict[prop] = expanded[prop]
1010+
positions = expanded.get("positions", [])
1011+
n_sites = len(positions)
10191012

1013+
sites = [
1014+
{"position": positions[i], **{singular: values[i] for _, singular, values in prop_conversions}}
1015+
for i in range(n_sites)
1016+
]
1017+
1018+
structure_dict = {
1019+
prop: expanded[prop]
1020+
for prop in _GLOBAL_PROPERTIES
1021+
if prop in expanded and expanded[prop] is not None
1022+
}
10201023
structure_dict["sites"] = sites
10211024

10221025
return structure_dict
@@ -1145,7 +1148,7 @@ def sites_from_kinds(kinds):
11451148
positions += list(kind['positions'])
11461149
num_sites = len(sites_list)
11471150
for i in range(num_sites):
1148-
sites_list[i] = efficient_copy(kinds[sites_list[i]])
1151+
sites_list[i] = copy.deepcopy(kinds[sites_list[i]])
11491152
sites_list[i].pop('site_indices', None)
11501153
sites_list[i].pop('positions', None)
11511154
sites_list[i]['position'] = positions[i]

0 commit comments

Comments
 (0)