@@ -49,6 +49,9 @@ class PrunerConfig:
4949
5050 def __post_init__ (self ) -> None :
5151 """Validate inputs and initialize computed fields."""
52+ # validate input types
53+ assert type (self .structures ) is np .ndarray
54+
5255 self .mask = np .ones (shape = (self .structures .shape [0 ],), dtype = np .bool_ )
5356
5457 if len (self .energies ) != 0 :
@@ -58,6 +61,7 @@ def __post_init__(self) -> None:
5861
5962 # Set defaults for optional parameters
6063 if len (self .energies ) == 0 :
64+ assert type (self .energies ) is np .ndarray
6165 self .energies = np .zeros (self .structures .shape [0 ], dtype = float )
6266
6367 assert len (self .energies ) == len (self .structures ), (
@@ -85,6 +89,14 @@ class RMSDRotCorrPrunerConfig(PrunerConfig):
8589 graph : Graph = field (kw_only = True )
8690 heavy_atoms_only : bool = True
8791
92+ def __post_init__ (self ) -> None :
93+ """Add type enforcing to the parent's __post_init__."""
94+ super ().__post_init__ ()
95+
96+ # validate input types
97+ assert type (self .atoms ) is np .ndarray
98+ assert type (self .graph ) is Graph
99+
88100 def evaluate_sim (self , i1 : int , i2 : int ) -> bool :
89101 """Return whether the structures are similar."""
90102 rmsd , max_dev = rotationally_corrected_rmsd_and_max (
@@ -116,6 +128,13 @@ class RMSDPrunerConfig(PrunerConfig):
116128 max_dev : float = field (kw_only = True )
117129 heavy_atoms_only : bool = True
118130
131+ def __post_init__ (self ) -> None :
132+ """Add type enforcing to the parent's __post_init__."""
133+ super ().__post_init__ ()
134+
135+ # validate input types
136+ assert type (self .atoms ) is np .ndarray
137+
119138 def evaluate_sim (self , i1 : int , i2 : int ) -> bool :
120139 """Return whether the structures are similar."""
121140 if self .heavy_atoms_only :
@@ -146,8 +165,12 @@ class MOIPrunerConfig(PrunerConfig):
146165 max_dev : float = 0.01
147166
148167 def __post_init__ (self ) -> None :
149- """Add the moi_vecs calc to the parent's __post_init__."""
168+ """Add type enforcing and moi_vecs to the parent's __post_init__."""
150169 super ().__post_init__ ()
170+
171+ # validate input types
172+ assert type (self .masses ) is np .ndarray
173+
151174 self .moi_vecs = {
152175 c : get_inertia_moments (
153176 coord ,
0 commit comments