Skip to content

Commit 638843d

Browse files
authored
H5md refactor (#196)
* add velocities and particle_state labels * working h5md basic system * added some ModelSystem functions for molecules * add some system types * fixed tests * cleaned comments * returned current docstrings that diverged somehow * ruff * refined is_molecule, updated tests * remove duplicate is_molecule * fix editing mistakes * switched to get_symbols() but left the old function commented for now * re-uncomment get_chemical_symbols() for now
1 parent 5a57ef0 commit 638843d

File tree

3 files changed

+361
-7
lines changed

3 files changed

+361
-7
lines changed

src/nomad_simulations/schema_packages/atoms_state.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,12 @@ class ParticleState(Entity):
551551
This can be extended to include any common quantities in the future.
552552
"""
553553

554-
pass
554+
label = Quantity(
555+
type=str,
556+
description="""
557+
User- or program-package-defined identifier for this particle.
558+
""",
559+
)
555560

556561

557562
class AtomsState(ParticleState):
@@ -660,5 +665,83 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
660665
self.atomic_number = self.resolve_atomic_number(logger=logger)
661666

662667

663-
class CoarseGrainedState(ParticleState):
664-
pass
668+
class CGBeadState(ParticleState):
669+
"""
670+
A section to define coarse-grained bead state information.
671+
"""
672+
673+
# ? What do we want to qualify as type identifier? What safety checks do we need?
674+
bead_symbol = Quantity(
675+
type=str,
676+
description="""
677+
Symbol(s) describing the (base) CG particle type. Equivalent to chemical_symbol
678+
for atomic elements.
679+
""",
680+
)
681+
682+
label = Quantity(
683+
type=str,
684+
description="""
685+
User- or program-package-defined identifier for this bead site.
686+
This could be used to store primary FF labels in cases where only a
687+
secondary specification is required. Otherwise, `alt_labels` are
688+
used to document more complex bead identifiers, e.g., bead interactions based
689+
on connectivity.
690+
""",
691+
)
692+
693+
alt_labels = Quantity(
694+
type=str,
695+
shape=['*'],
696+
description="""
697+
A list of bead labels for multifaceted bead characterization.
698+
""",
699+
)
700+
701+
mass = Quantity(
702+
type=np.float64,
703+
unit='kg',
704+
description="""
705+
Total mass of the particle.
706+
""",
707+
)
708+
709+
charge = Quantity(
710+
type=np.float64,
711+
unit='coulomb',
712+
description="""
713+
Total charge of the particle.
714+
""",
715+
)
716+
717+
# Other possible quantities
718+
# diameter: float
719+
# The diameter of each particle.
720+
# Default: 1.0
721+
# body: int
722+
# The composite body associated with each particle. The value -1
723+
# indicates no body.
724+
# Default: -1
725+
# moment_inertia: float
726+
# The moment_inertia of each particle (I_xx, I_yy, I_zz).
727+
# This inertia tensor is diagonal in the body frame of the particle.
728+
# The default value is for point particles.
729+
# Default: 0, 0, 0
730+
# scaled_positions: list of scaled-positions #! for cell if relevant
731+
# Like positions, but given in units of the unit cell.
732+
# Can not be set at the same time as positions.
733+
# Default: 0, 0, 0
734+
# orientation: float
735+
# The orientation of each particle. In scalar + vector notation,
736+
# this is (r, a_x, a_y, a_z), where the quaternion is q = r + a_xi + a_yj + a_zk.
737+
# A unit quaternion has the property: sqrt(r^2 + a_x^2 + a_y^2 + a_z^2) = 1.
738+
# Default: 0, 0, 0, 0
739+
# angmom: float #? for cell or here?
740+
# The angular momentum of each particle as a quaternion.
741+
# Default: 0, 0, 0, 0
742+
# image: int #! advance PBC stuff would go in cell I guess
743+
# The number of times each particle has wrapped around the box (i_x, i_y, i_z).
744+
# Default: 0, 0, 0
745+
746+
def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
747+
super().normalize(archive, logger)

src/nomad_simulations/schema_packages/model_system.py

Lines changed: 165 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import ase
2525
import numpy as np
26+
from ase.symbols import symbols2numbers
2627
from matid import Classifier, SymmetryAnalyzer # pylint: disable=import-error
2728
from matid.classification.classifications import (
2829
Atom,
@@ -51,6 +52,7 @@
5152

5253
from nomad_simulations.schema_packages.atoms_state import (
5354
AtomsState,
55+
CGBeadState,
5456
ParticleState,
5557
)
5658
from nomad_simulations.schema_packages.utils import (
@@ -844,7 +846,9 @@ class ModelSystem(System):
844846
type=MEnum(
845847
'atom',
846848
'active_atom',
847-
'molecule / cluster',
849+
'molecule',
850+
'monomer',
851+
'cluster',
848852
'1D',
849853
'surface',
850854
'2D',
@@ -911,7 +915,7 @@ class ModelSystem(System):
911915
type=np.int32,
912916
shape=['*'],
913917
description="""
914-
Global indices of the particles that belong to this subsystem,
918+
Global indices of the particles that belong to this subsystem,
915919
counted from the representative (top-level) ModelSystem.
916920
917921
**Example (SrTiO_3 primitive cell)**
@@ -938,6 +942,16 @@ class ModelSystem(System):
938942
""",
939943
)
940944

945+
velocities = Quantity(
946+
type=np.float64,
947+
shape=['*', 3],
948+
unit='meter / second',
949+
description="""
950+
Velocities of the particles: I.e., the change in cartesian coordinates of the
951+
particle position with time.
952+
""",
953+
)
954+
941955
# TODO improve description and add an example
942956
bond_list = Quantity(
943957
type=np.int32,
@@ -993,7 +1007,7 @@ class ModelSystem(System):
9931007
section_def=ParticleState.m_def,
9941008
repeats=True,
9951009
description="""
996-
Particle state of each of the particles conforming the ModelSystem.
1010+
Particle state of each of the particles conforming the ModelSystem.
9971011
This is a list of `n_particles` elements and the order matches that of `positions`.
9981012
9991013
Example
@@ -1010,6 +1024,7 @@ class ModelSystem(System):
10101024

10111025
sub_systems = SubSection(sub_section=SectionProxy('ModelSystem'), repeats=True)
10121026

1027+
# TODO Will remove this after developing CGBeadState functionality further
10131028
def get_chemical_symbols(self, logger: 'BoundLogger') -> list[str]:
10141029
"""
10151030
Gets the chemical symbols from the particle_states that are AtomsState instances.
@@ -1028,6 +1043,55 @@ def get_chemical_symbols(self, logger: 'BoundLogger') -> list[str]:
10281043
chemical_symbols.append(particle_state.chemical_symbol)
10291044
return chemical_symbols
10301045

1046+
# TODO: symbols should be a property right?
1047+
# ? To replace get_chemical_symbols
1048+
def get_symbols(self, logger: 'BoundLogger') -> list[str]:
1049+
"""
1050+
Gets the symbols from the particle_states.
1051+
Args:
1052+
logger (BoundLogger): The logger to log messages.
1053+
Returns:
1054+
list: The list of symbols of the particles.
1055+
"""
1056+
symbols = []
1057+
for particle_state in self.particle_states:
1058+
symbol = None
1059+
if isinstance(particle_state, AtomsState):
1060+
symbol = particle_state.chemical_symbol
1061+
elif isinstance(particle_state, CGBeadState):
1062+
symbol = particle_state.bead_symbol
1063+
if not symbol:
1064+
logger.warning('missing symbol in ParticleState.')
1065+
symbols.append(symbol)
1066+
return symbols
1067+
1068+
def are_valid_chemical_symbols(self, logger: 'BoundLogger') -> bool:
1069+
"""
1070+
Validate that ASE can map all element symbols in the particle_states
1071+
to atomic numbers.
1072+
Args:
1073+
logger (BoundLogger): The logger to log messages.
1074+
Returns:
1075+
bool: True if all chemical symbols are valid, False otherwise.
1076+
"""
1077+
symbols = self.get_symbols(logger)
1078+
if not symbols:
1079+
return False
1080+
1081+
try:
1082+
symbols2numbers(symbols)
1083+
except KeyError as e:
1084+
logger.error(f'Invalid chemical symbol found: {e}')
1085+
return False
1086+
return True
1087+
1088+
# atom_labels = self.traj_parser.get_atom_labels(n)
1089+
# if atom_labels is not None:
1090+
# try:
1091+
# symbols2numbers(atom_labels)
1092+
# except KeyError:
1093+
# atom_labels = ['X'] * len(atom_labels)
1094+
10311095
def to_ase_atoms(self, logger: 'BoundLogger') -> 'Optional[ase.Atoms]':
10321096
"""
10331097
Generates an ASE Atoms object from ModelSystem data.
@@ -1036,7 +1100,7 @@ def to_ase_atoms(self, logger: 'BoundLogger') -> 'Optional[ase.Atoms]':
10361100
- positions from the top-level positions quantity,
10371101
- periodic boundary conditions and lattice vectors from the first cell.
10381102
"""
1039-
symbols = self.get_chemical_symbols(logger)
1103+
symbols = self.get_symbols(logger)
10401104
if not symbols:
10411105
logger.error('Cannot generate ASE Atoms without chemical symbols.')
10421106
return None
@@ -1207,3 +1271,100 @@ def is_ge_structure(self, other: 'ModelSystem') -> bool:
12071271

12081272
def is_ne_structure(self, other: 'ModelSystem') -> bool:
12091273
return not self.is_equal_structure(other)
1274+
1275+
# functions for traversing the ModelSystem hierarchy
1276+
def get_root_system(self) -> 'ModelSystem':
1277+
"""
1278+
Traverses up the hierarchy to find the root ModelSystem.
1279+
1280+
Returns:
1281+
ModelSystem: The top-level (root) ModelSystem.
1282+
"""
1283+
system = self
1284+
while isinstance(system.m_parent, ModelSystem):
1285+
system = system.m_parent
1286+
return system
1287+
1288+
# functions for working with molecules
1289+
def get_bond_list(self, set_local: bool = False) -> np.ndarray:
1290+
"""
1291+
Retrieves the bond list for this subsystem by filtering the root bond_list
1292+
using the subsystem's `particle_indices`. The bond indices remain in root-level
1293+
coordinates (no reindexing).
1294+
1295+
Args:
1296+
set_local (bool): If True, sets `self.bond_list` to the filtered bonds.
1297+
1298+
Returns:
1299+
np.ndarray: Filtered bond list for this subsystem (root-level indices).
1300+
"""
1301+
1302+
if not isinstance(self.m_parent, ModelSystem): # this is the root system
1303+
return self.bond_list
1304+
1305+
if self.particle_indices is None:
1306+
return np.array([])
1307+
1308+
root = self.get_root_system()
1309+
if root.bond_list is None:
1310+
return np.array([])
1311+
1312+
indices_set = set(self.particle_indices.tolist())
1313+
bond_list = np.array(
1314+
[
1315+
(i, j)
1316+
for i, j in root.bond_list
1317+
if i in indices_set and j in indices_set
1318+
],
1319+
dtype=np.int32,
1320+
)
1321+
1322+
if set_local:
1323+
self.bond_list = bond_list
1324+
1325+
return bond_list
1326+
1327+
def is_molecule(self) -> bool:
1328+
"""
1329+
Checks if the current subsystem forms a contiguous and isolated molecule:
1330+
- All particles are connected (single connected component).
1331+
- No bonds connect particles inside this subsystem to particles outside it.
1332+
1333+
Returns:
1334+
bool: True if the subsystem is an isolated molecule, False otherwise.
1335+
"""
1336+
import networkx as nx
1337+
1338+
# Internal bonds for this subsystem
1339+
bonds = self.get_bond_list(set_local=False)
1340+
1341+
# Handle case: no bonds
1342+
if bonds.size == 0:
1343+
return False
1344+
1345+
# Determine particle indices (fallback to range if None)
1346+
particle_indices = self.particle_indices
1347+
if particle_indices is None:
1348+
n_particles = (
1349+
len(self.positions) if self.positions is not None else self.n_particles
1350+
)
1351+
particle_indices = np.arange(n_particles, dtype=np.int32)
1352+
1353+
# --- 1. Connectivity check ---
1354+
graph = nx.Graph()
1355+
graph.add_nodes_from(particle_indices)
1356+
graph.add_edges_from(bonds)
1357+
1358+
if not nx.is_connected(graph):
1359+
return False
1360+
1361+
# --- 2. Isolation check: ensure no bonds cross subsystem boundary ---
1362+
root = self.get_root_system()
1363+
if root.bond_list is not None:
1364+
indices_set = set(particle_indices.tolist())
1365+
for i, j in root.bond_list:
1366+
# If exactly one endpoint is inside → cross-boundary bond
1367+
if (i in indices_set) ^ (j in indices_set):
1368+
return False
1369+
1370+
return True

0 commit comments

Comments
 (0)