Skip to content

Commit 44625f1

Browse files
committed
enforced type checking on PrunerConfig classes
1 parent de870ea commit 44625f1

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

prism_pruner/pruner.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ class PrunerConfig:
4949

5050
def __post_init__(self) -> None:
5151
"""Validate inputs and initialize computed fields."""
52+
# validate input types
53+
assert type(self.structures) is np.ndarray
54+
5255
self.mask = np.ones(shape=(self.structures.shape[0],), dtype=np.bool_)
5356

5457
if len(self.energies) != 0:
@@ -58,6 +61,7 @@ def __post_init__(self) -> None:
5861

5962
# Set defaults for optional parameters
6063
if len(self.energies) == 0:
64+
assert type(self.energies) is np.ndarray
6165
self.energies = np.zeros(self.structures.shape[0], dtype=float)
6266

6367
assert len(self.energies) == len(self.structures), (
@@ -85,6 +89,14 @@ class RMSDRotCorrPrunerConfig(PrunerConfig):
8589
graph: Graph = field(kw_only=True)
8690
heavy_atoms_only: bool = True
8791

92+
def __post_init__(self) -> None:
93+
"""Add type enforcing to the parent's __post_init__."""
94+
super().__post_init__()
95+
96+
# validate input types
97+
assert type(self.atoms) is np.ndarray
98+
assert type(self.graph) is Graph
99+
88100
def evaluate_sim(self, i1: int, i2: int) -> bool:
89101
"""Return whether the structures are similar."""
90102
rmsd, max_dev = rotationally_corrected_rmsd_and_max(
@@ -116,6 +128,13 @@ class RMSDPrunerConfig(PrunerConfig):
116128
max_dev: float = field(kw_only=True)
117129
heavy_atoms_only: bool = True
118130

131+
def __post_init__(self) -> None:
132+
"""Add type enforcing to the parent's __post_init__."""
133+
super().__post_init__()
134+
135+
# validate input types
136+
assert type(self.atoms) is np.ndarray
137+
119138
def evaluate_sim(self, i1: int, i2: int) -> bool:
120139
"""Return whether the structures are similar."""
121140
if self.heavy_atoms_only:
@@ -146,8 +165,12 @@ class MOIPrunerConfig(PrunerConfig):
146165
max_dev: float = 0.01
147166

148167
def __post_init__(self) -> None:
149-
"""Add the moi_vecs calc to the parent's __post_init__."""
168+
"""Add type enforcing and moi_vecs to the parent's __post_init__."""
150169
super().__post_init__()
170+
171+
# validate input types
172+
assert type(self.masses) is np.ndarray
173+
151174
self.moi_vecs = {
152175
c: get_inertia_moments(
153176
coord,

0 commit comments

Comments
 (0)