Skip to content

Commit ac0e115

Browse files
authored
Performances (aiidateam#86)
* adding key sorting in npz writing for hashing * improving performances symbols and kinds names and site_indices now are in the repository. the node size in the db does not change increasing the number of sites.
1 parent f42f0f5 commit ac0e115

File tree

4 files changed

+61
-12
lines changed

4 files changed

+61
-12
lines changed

src/aiida_atomistic/data/structure/models.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class StructureBaseModel(BaseModel):
4040
sites: list[Site] = Field(
4141
default=[],
4242
description="List of sites in the structure",
43-
json_schema_extra={"store_in": "db", "property_type": "global"},
43+
json_schema_extra={"property_type": "global"},
4444
)
4545

4646
# global and more specific properties
@@ -197,7 +197,7 @@ def positions(self) -> np.ndarray:
197197
return None
198198
return np.array([site.position for site in self.sites])
199199

200-
@computed_field(json_schema_extra={"store_in": "db","singular_form": "kind_name"})
200+
@computed_field(json_schema_extra={"store_in": "repository","singular_form": "kind_name"})
201201
@property
202202
def kind_names(self) -> t.List[str]:
203203
"""
@@ -210,7 +210,7 @@ def kind_names(self) -> t.List[str]:
210210
return None
211211
return FrozenList([site.kind_name if site.kind_name is not None else site.symbol for site in self.sites])
212212

213-
@computed_field(json_schema_extra={"store_in": "db","singular_form": "symbol"})
213+
@computed_field(json_schema_extra={"store_in": "repository","singular_form": "symbol"})
214214
@property
215215
def symbols(self) -> t.List[str]:
216216
"""
@@ -393,6 +393,12 @@ def min_magnetization(self) -> t.Optional[float]:
393393
def n_sites(self) -> int:
394394
"""Total number of sites in the structure."""
395395
return len(self.sites)
396+
397+
@computed_field(json_schema_extra={"store_in": "db"})
398+
@property
399+
def n_kinds(self) -> int:
400+
"""Total number of sites in the structure."""
401+
return len(self.kinds) if self.kinds is not None else 0
396402

397403
def __repr__(self) -> str:
398404
from pprint import pformat

src/aiida_atomistic/data/structure/structure.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,15 @@ def _is_numeric_array(value) -> bool:
8383
return arr.dtype.kind in {"i", "f", "u", "c"}
8484
return False
8585

86+
@staticmethod
87+
def _is_string_array(value) -> bool:
88+
"""Return True if the value is a string array (list of strings or numpy unicode array)."""
89+
if isinstance(value, np.ndarray):
90+
return value.dtype.kind in {"U", "S"}
91+
if isinstance(value, list) and value and all(isinstance(v, str) for v in value):
92+
return True
93+
return False
94+
8695
@classmethod
8796
def get_queryable_properties(cls, include_internal: bool = False) -> dict:
8897
"""
@@ -272,7 +281,7 @@ def detect_storage_backend(cls, prop_name: str) -> str:
272281
store_in = json_schema_extra.get(cls._storage_metadata_key, '').lower()
273282
else:
274283
store_in = 'db' # default to db for unknown properties
275-
284+
276285
return store_in
277286

278287
def _store_properties(self):
@@ -309,6 +318,23 @@ def _store_properties(self):
309318
repository_dict[prop_name] = arr
310319
if self._store_shape_metadata:
311320
database_dict[f'shape|{prop_name}'] = list(arr.shape)
321+
elif target == "repository" and self._is_string_array(value):
322+
arr = np.asarray(value, dtype=str)
323+
repository_dict[prop_name] = arr
324+
if self._store_shape_metadata:
325+
database_dict[f'shape|{prop_name}'] = list(arr.shape)
326+
elif prop_name == 'site_indices':
327+
# site_indices is a ragged list-of-lists (one per kind, variable length because of different number of sites per kind).
328+
# Encode as two flat 1D int arrays using CSR format so the npz stays
329+
# homogeneous (allow_pickle=False compatible):
330+
# site_indices_flat : all indices concatenated, shape (total_sites,)
331+
# site_indices_offsets : cumulative start positions, shape (n_kinds + 1,)
332+
# e.g. [[0,1],[2],[3,4,5]] → flat=[0,1,2,3,4,5], offsets=[0,2,3,6]
333+
flat = np.array([idx for sublist in value for idx in sublist], dtype=np.int64)
334+
lengths = np.array([len(sublist) for sublist in value], dtype=np.int64)
335+
offsets = np.concatenate([[0], np.cumsum(lengths)]).astype(np.int64)
336+
repository_dict['site_indices_flat'] = flat
337+
repository_dict['site_indices_offsets'] = offsets
312338
else:
313339
database_dict[prop_name] = value
314340

@@ -343,7 +369,23 @@ def _load_properties_from_npz(self) -> dict:
343369
with self.base.repository.open(self._properties_filename, mode='rb') as handle:
344370
npz_data = np.load(handle, allow_pickle=False)
345371
# Convert to regular dict (npz returns NpzFile object)
346-
properties = {key: npz_data[key] for key in npz_data.files}
372+
# String arrays (dtype 'U' or 'S') are converted back to Python lists
373+
properties = {}
374+
for key in npz_data.files:
375+
arr = npz_data[key]
376+
if arr.dtype.kind in {"U", "S"}:
377+
properties[key] = arr.tolist()
378+
else:
379+
properties[key] = arr
380+
381+
# Decode CSR-encoded site_indices back into list-of-lists
382+
if 'site_indices_flat' in properties and 'site_indices_offsets' in properties:
383+
flat = properties.pop('site_indices_flat')
384+
offsets = properties.pop('site_indices_offsets')
385+
properties['site_indices'] = [
386+
flat[offsets[i]:offsets[i + 1]].tolist()
387+
for i in range(len(offsets) - 1)
388+
]
347389

348390
# Cache if stored
349391
if self.is_stored:

tests/data/test_kinds.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,16 +178,16 @@ def test_kinds_compression_storage(self, complex_example_structure_dict_for_kind
178178
structure = StructureData(**complex_example_structure_dict_for_kinds)
179179

180180
# Check that kind_names is in attributes
181-
assert "kind_names" in structure.base.attributes.all
181+
assert "n_kinds" in structure.base.attributes.all
182182

183183
# Check that sites maintain their structure (not compressed, site-based model)
184-
stored_symbols = structure.base.attributes.get("symbols")
185-
assert len(stored_symbols) == 8 # 8 sites (site-based, not kind-compressed)
184+
n_sites = structure.base.attributes.get("n_sites")
185+
assert n_sites == 8 # 8 sites (site-based, not kind-compressed)
186186

187187
# But kind_names should have been assigned
188-
kind_names = structure.base.attributes.get("kind_names")
189-
assert kind_names is not None
190-
assert len(set(kind_names)) == 4 # 4 unique kinds
188+
n_kinds = structure.base.attributes.get("n_kinds")
189+
assert n_kinds is not None
190+
assert n_kinds == 4 # 4 unique kinds
191191

192192

193193
class TestKindsWorkflow:

tests/data/test_structure.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,8 @@ def test_store_properties_with_kind_compression(aiida_profile_clean):
644644
)
645645

646646
# Should have kind_names in attributes
647-
assert 'kind_names' in structure.base.attributes.all
647+
assert 'n_kinds' in structure.base.attributes.all
648+
assert structure.base.attributes.all['n_kinds'] == 1
648649

649650

650651
def test_load_properties_from_npz(aiida_profile_clean):

0 commit comments

Comments
 (0)