diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 91f5a340..81a5451f 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -406,7 +406,7 @@ def convergence_fn( """ force_conv = ts.system_wise_max_force(state) < force_tol - if include_cell_forces: + if include_cell_forces and hasattr(state, "cell_forces"): if (cell_forces := getattr(state, "cell_forces", None)) is None: raise ValueError("cell_forces not found in state") cell_forces_norm, _ = cell_forces.norm(dim=2).max(dim=1)