Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
02a54e1
append to trajectory
Nov 25, 2025
db4cfa2
test append to trajectory
Nov 25, 2025
6f78f54
style
Nov 25, 2025
b378c6e
revert small change
Nov 25, 2025
1feff96
maintain step per system
Nov 26, 2025
eed1fc1
format
Nov 26, 2025
031fc56
integrate only for (n_steps - initial_step) steps when continuing
Nov 26, 2025
5b2469f
Merge branch 'main' into append-to-trajectory
danielzuegner Nov 26, 2025
6fd5cc9
change `load_new_trajectories` behavior
Nov 26, 2025
4e820ee
back to `step`
Nov 27, 2025
8b8c449
format
Nov 27, 2025
0a230bf
truncate trajectories
Nov 27, 2025
e138a33
fix tests
Nov 27, 2025
82059d5
fix style
Nov 27, 2025
2e4bb06
style
Nov 27, 2025
65e61d0
fix type hint
Nov 27, 2025
c1065e7
style
Nov 27, 2025
38619af
fix kT indexing in integrate, step counting in optimize
Nov 28, 2025
dde853c
style
Nov 28, 2025
6a1114c
Merge branch 'main' into append-to-trajectory
orionarcher Nov 28, 2025
6cba46f
Merge branch 'main' into append-to-trajectory
orionarcher Dec 9, 2025
787a967
rename variable
Dec 10, 2025
9c637f5
return positions last step
Dec 10, 2025
0031296
truncate to positions last step
Dec 10, 2025
1874dc5
extract methods
Dec 10, 2025
51ca30d
fix tests
Dec 10, 2025
dfd39f0
format
Dec 10, 2025
2aed1bb
Merge branch 'main' into append-to-trajectory
danielzuegner Dec 30, 2025
36d915a
prek
Dec 30, 2025
2470b46
disable auto truncating
Dec 30, 2025
28bb660
style
Dec 30, 2025
6de046a
rename
Dec 30, 2025
e801817
prek
Dec 30, 2025
318f79c
remove unused variable
Jan 7, 2026
6dbd9e3
fix trajectory reopen
Jan 7, 2026
2c10ea3
Merge branch 'main' into append-to-trajectory
orionarcher Jan 14, 2026
3871feb
Merge branch 'main' into append-to-trajectory
danielzuegner Jan 19, 2026
96097ad
use info log
Jan 19, 2026
5fb664d
fix style
Jan 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions tests/test_trajectory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sys
import tempfile
from collections.abc import Callable, Generator
from pathlib import Path

Expand Down Expand Up @@ -834,3 +835,99 @@ def test_write_ase_trajectory_importerror(
with pytest.raises(ImportError, match="ASE is required to convert to ASE trajectory"):
traj.write_ase_trajectory(tmp_path / "dummy.traj")
traj.close()


def test_optimize_append_to_trajectory(
si_double_sim_state: SimState, lj_model: LennardJonesModel
) -> None:
"""Test appending to an existing trajectory when running ts.optimize."""

# Create a temporary trajectory file
with tempfile.TemporaryDirectory() as temp_dir:
traj_files = [f"{temp_dir}/optimize_trajectory_{idx}.h5" for idx in range(2)]

# Initialize model and state
trajectory_reporter = ts.TrajectoryReporter(
traj_files,
state_frequency=1,
)

# First optimization run
opt_state = ts.optimize(
system=si_double_sim_state,
model=lj_model,
max_steps=5,
optimizer=ts.Optimizer.fire,
trajectory_reporter=trajectory_reporter,
steps_between_swaps=100,
)

for traj in trajectory_reporter.trajectories:
with TorchSimTrajectory(traj._file.filename, mode="r") as traj:
# Check that the trajectory file has 5 frames
np.testing.assert_allclose(traj.get_steps("positions"), range(1, 6))

trajectory_reporter_2 = ts.TrajectoryReporter(
traj_files, state_frequency=1, trajectory_kwargs=dict(mode="a")
)
_ = ts.optimize(
system=opt_state,
model=lj_model,
max_steps=7,
optimizer=ts.Optimizer.fire,
trajectory_reporter=trajectory_reporter_2,
steps_between_swaps=100,
)
for traj in trajectory_reporter_2.trajectories:
with TorchSimTrajectory(traj._file.filename, mode="r") as traj:
# Check that the trajectory file now has 12 frames
np.testing.assert_allclose(traj.get_steps("positions"), range(1, 13))


def test_integrate_append_to_trajectory(
si_double_sim_state: SimState, lj_model: LennardJonesModel
) -> None:
"""Test appending to an existing trajectory when running ts.integrate."""

# Create a temporary trajectory file
with tempfile.TemporaryDirectory() as temp_dir:
traj_files = [f"{temp_dir}/integrate_trajectory_{idx}.h5" for idx in range(2)]

# Initialize model and state
trajectory_reporter = ts.TrajectoryReporter(
traj_files,
state_frequency=1,
)

# First integration run
int_state = ts.integrate(
system=si_double_sim_state,
model=lj_model,
timestep=0.1,
n_steps=5,
temperature=300.0,
integrator=ts.Integrator.nvt_langevin,
trajectory_reporter=trajectory_reporter,
)

for traj in trajectory_reporter.trajectories:
with TorchSimTrajectory(traj._file.filename, mode="r") as traj:
# Check that the trajectory file has 5 frames
np.testing.assert_allclose(traj.get_steps("positions"), range(1, 6))

trajectory_reporter_2 = ts.TrajectoryReporter(
traj_files, state_frequency=1, trajectory_kwargs=dict(mode="a")
)
_ = ts.integrate(
system=int_state,
model=lj_model,
timestep=0.1,
temperature=300.0,
n_steps=7,
integrator=ts.Integrator.nvt_langevin,
trajectory_reporter=trajectory_reporter_2,
)
for traj in trajectory_reporter_2.trajectories:
with TorchSimTrajectory(traj._file.filename, mode="r") as traj:
# Check that the trajectory file now has 12 frames
np.testing.assert_allclose(traj.get_steps("positions"), range(1, 13))
30 changes: 25 additions & 5 deletions torch_sim/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,12 @@ def integrate[T: SimState]( # noqa: C901
trajectory_reporter,
properties=["kinetic_energy", "potential_energy", "temperature"],
)

# Auto-detect initial step from trajectory files for resuming integration
initial_step: int = 1
if trajectory_reporter is not None and trajectory_reporter.mode == "a":
last_step = trajectory_reporter.last_step
if last_step > 0:
initial_step = last_step + 1
final_states: list[T] = []
og_filenames = trajectory_reporter.filenames if trajectory_reporter else None

Expand All @@ -199,9 +204,13 @@ def integrate[T: SimState]( # noqa: C901
)

# run the simulation
for step in range(1, n_steps + 1):
for step in range(initial_step, initial_step + n_steps):
state = step_func(
state=state, model=model, dt=dt, kT=kTs[step - 1], **integrator_kwargs
state=state,
model=model,
dt=dt,
kT=kTs[step - initial_step],
**integrator_kwargs,
)

if trajectory_reporter:
Expand Down Expand Up @@ -457,7 +466,14 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915
trajectory_reporter, properties=["potential_energy"]
)

step: int = 1
# Auto-detect initial step from trajectory files for resuming optimizations
initial_step: int = 1
if trajectory_reporter is not None and trajectory_reporter.mode == "a":
last_step = trajectory_reporter.last_step
if last_step > 0:
initial_step = last_step + 1
step: int = initial_step

last_energy = None
all_converged_states: list[T] = []
convergence_tensor = None
Expand Down Expand Up @@ -485,9 +501,13 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915
and og_filenames is not None
and (step == 1 or len(converged_states) > 0)
):
mode_before = trajectory_reporter.trajectory_kwargs["mode"]
# temporarily set to "append" mode to avoid overwriting existing files
trajectory_reporter.trajectory_kwargs["mode"] = "a"
trajectory_reporter.load_new_trajectories(
filenames=[og_filenames[i] for i in autobatcher.current_idx]
)
trajectory_reporter.trajectory_kwargs["mode"] = mode_before

for _step in range(steps_between_swaps):
if hasattr(state, "energy"):
Expand All @@ -498,7 +518,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915
if trajectory_reporter:
trajectory_reporter.report(state, step, model=model)
step += 1
if step > max_steps:
if step > max_steps + initial_step - 1:
# TODO: max steps should be tracked for each structure in the batch
warnings.warn(f"Optimize has reached max steps: {step}", stacklevel=2)
break
Expand Down
69 changes: 61 additions & 8 deletions torch_sim/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from torch_sim.models.interface import ModelInterface
from torch_sim.state import SimState


if TYPE_CHECKING:
from ase import Atoms
from ase.io.trajectory import TrajectoryReader
Expand Down Expand Up @@ -302,6 +301,43 @@ def close(self) -> None:
for trajectory in self.trajectories:
trajectory.close()

@property
def mode(self) -> Literal["r", "w", "a"]:
"""Get the mode of the first trajectory file.

Returns:
"r" | "w" | "a": Mode from the trajectory_kwargs used during initialization.
"""
if not self.trajectories:
raise ValueError("No trajectories loaded.")
# Key is guaranteed to exist because we set it during initialization.
return self.trajectory_kwargs["mode"]

@property
def last_step(self) -> int:
"""Get the maximum last step across all trajectory files.

Returns the highest step number found across all trajectory files.
This is useful for resuming optimizations from where they left off.

Returns:
int: The maximum last step number across all trajectories, or 0 if
no trajectories exist or all are empty
"""
if not self.trajectories:
return 0

max_step = 0
for trajectory in self.trajectories:
if trajectory._file.isopen:
last_step = trajectory.last_step
else:
with TorchSimTrajectory(trajectory._file.filename, mode="r") as traj:
last_step = traj.last_step
max_step = max(max_step, last_step)

return max_step

def __enter__(self) -> "TrajectoryReporter":
"""Support the context manager protocol.

Expand Down Expand Up @@ -594,7 +630,7 @@ def _validate_array(self, name: str, data: np.ndarray, steps: list[int]) -> None
)

# Validate step is monotonically increasing by checking HDF5 file directly
steps_node = self._file.get_node("/steps/", name=name)
steps_node = self.get_steps(name)
if len(steps_node) > 0:
last_step = steps_node[-1] # Get the last recorded step
if steps[0] <= last_step:
Expand Down Expand Up @@ -658,9 +694,6 @@ def get_array(
def get_steps(
self,
name: str,
start: int | None = None,
stop: int | None = None,
step: int = 1,
) -> np.ndarray:
"""Get the steps for an array.

Expand All @@ -675,9 +708,29 @@ def get_steps(
Returns:
np.ndarray: Array of step numbers with shape [n_selected_frames]
"""
return self._file.root.steps.__getitem__(name).read(
start=start, stop=stop, step=step
)
return self._file.get_node("/steps/", name=name).read()

@property
def last_step(self) -> int:
"""Get the last step number from the trajectory.

Retrieves the maximum step number across all arrays in the trajectory.
If the trajectory is empty or has no arrays, returns 0.

Returns:
int: The last (maximum) step number in the trajectory, or 0 if empty
"""
if not self.array_registry:
return 0

max_step = 0
for name in self.array_registry:
steps_node = self.get_steps(name)
if len(steps_node) > 0:
last_step = int(steps_node[-1])
max_step = max(max_step, last_step)

return max_step

def __str__(self) -> str:
"""Get a string representation of the trajectory.
Expand Down
Loading