Skip to content

Commit 0585acf

Browse files
committed
add test that is fixed by change
1 parent f65c434 commit 0585acf

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

tests/test_autobatching.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,19 @@ def test_calculate_scaling_metric(si_sim_state: ts.SimState) -> None:
107107
calculate_memory_scaler(si_sim_state, "invalid_metric")
108108

109109

110+
def test_calculate_scaling_metric_non_periodic(benzene_sim_state: ts.SimState) -> None:
111+
"""Test calculation of scaling metrics for a non-periodic state."""
112+
# Test that calculate passes
113+
n_atoms_metric = calculate_memory_scaler(benzene_sim_state, "n_atoms")
114+
assert n_atoms_metric == benzene_sim_state.n_atoms
115+
116+
# Test n_atoms_x_density metric works for non-periodic systems
117+
n_atoms_x_density_metric = calculate_memory_scaler(
118+
benzene_sim_state, "n_atoms_x_density"
119+
)
120+
assert n_atoms_x_density_metric > 0
121+
122+
110123
def test_split_state(si_double_sim_state: ts.SimState) -> None:
111124
"""Test splitting a batched state into individual states."""
112125
split_states = si_double_sim_state.split()

torch_sim/autobatching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def calculate_memory_scaler(
367367
volume = torch.abs(torch.linalg.det(state.cell[0])) / 1000
368368
else:
369369
bbox = state.positions.max(dim=0).values - state.positions.min(dim=0).values
370-
volume = bbox.prod() / 1000
370+
volume = bbox.clamp(min=1.0).prod() / 1000 # min 1 Å for planar molecules
371371
number_density = state.n_atoms / volume.item()
372372
return state.n_atoms * number_density
373373
raise ValueError(

0 commit comments

Comments
 (0)