Skip to content

Commit 6556453

Browse files
Merge branch 'main' into fix_symmetry_constraint
2 parents fb6e234 + c5d4c30 commit 6556453

File tree

14 files changed

+2663
-17
lines changed

14 files changed

+2663
-17
lines changed

examples/scripts/2_structural_optimization.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,72 @@
386386
print(f"Initial pressure: {initial_pressure} GPa")
387387
print(f"Final pressure: {final_pressure} GPa")
388388

389+
# ============================================================================
390+
# SECTION 7: Batched MACE L-BFGS
391+
# ============================================================================
392+
print("\n" + "=" * 70)
393+
print("SECTION 7: Batched MACE L-BFGS")
394+
print("=" * 70)
395+
396+
# Recreate structures with perturbations
397+
si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2))
398+
si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape)
399+
400+
cu_dc = bulk("Cu", "fcc", a=3.85).repeat((2, 2, 2))
401+
cu_dc.positions += 0.2 * rng.standard_normal(cu_dc.positions.shape)
402+
403+
fe_dc = bulk("Fe", "bcc", a=2.95).repeat((2, 2, 2))
404+
fe_dc.positions += 0.2 * rng.standard_normal(fe_dc.positions.shape)
405+
406+
atoms_list = [si_dc, cu_dc, fe_dc]
407+
408+
state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype)
409+
results = model(state)
410+
state = ts.lbfgs_init(state=state, model=model, alpha=70.0, step_size=1.0)
411+
412+
print("\nRunning L-BFGS:")
413+
for step in range(N_steps):
414+
if step % 20 == 0:
415+
print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}")
416+
state = ts.lbfgs_step(state=state, model=model, max_history=100)
417+
418+
print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV")
419+
print(f"Final energies: {[energy.item() for energy in state.energy]} eV")
420+
421+
422+
# ============================================================================
423+
# SECTION 8: Batched MACE BFGS
424+
# ============================================================================
425+
print("\n" + "=" * 70)
426+
print("SECTION 8: Batched MACE BFGS")
427+
print("=" * 70)
428+
429+
# Recreate structures with perturbations
430+
si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2))
431+
si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape)
432+
433+
cu_dc = bulk("Cu", "fcc", a=3.85).repeat((2, 2, 2))
434+
cu_dc.positions += 0.2 * rng.standard_normal(cu_dc.positions.shape)
435+
436+
fe_dc = bulk("Fe", "bcc", a=2.95).repeat((2, 2, 2))
437+
fe_dc.positions += 0.2 * rng.standard_normal(fe_dc.positions.shape)
438+
439+
atoms_list = [si_dc, cu_dc, fe_dc]
440+
441+
state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype)
442+
results = model(state)
443+
state = ts.bfgs_init(state=state, model=model, alpha=70.0)
444+
445+
print("\nRunning BFGS:")
446+
for step in range(N_steps):
447+
if step % 20 == 0:
448+
print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}")
449+
state = ts.bfgs_step(state=state, model=model)
450+
451+
print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV")
452+
print(f"Final energies: {[energy.item() for energy in state.energy]} eV")
453+
454+
389455
print("\n" + "=" * 70)
390456
print("Structural optimization examples completed!")
391457
print("=" * 70)

tests/test_autobatching.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,3 +605,179 @@ def test_in_flight_max_iterations(
605605
# Verify iteration_count tracking
606606
for idx in range(len(states)):
607607
assert batcher.iteration_count[idx] == max_iterations
608+
609+
610+
@pytest.mark.parametrize(
611+
"num_steps_per_batch",
612+
[
613+
5, # At 5 steps, not every state will converge before the next batch.
614+
10, # At 10 steps, all states will converge before the next batch
615+
],
616+
)
617+
def test_in_flight_with_bfgs(
618+
si_sim_state: ts.SimState,
619+
fe_supercell_sim_state: ts.SimState,
620+
lj_model: LennardJonesModel,
621+
num_steps_per_batch: int,
622+
) -> None:
623+
"""Test InFlightAutoBatcher with BFGS optimizer."""
624+
si_bfgs_state = ts.bfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit)
625+
fe_bfgs_state = ts.bfgs_init(
626+
fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit
627+
)
628+
629+
bfgs_states = [si_bfgs_state, fe_bfgs_state] * 5
630+
bfgs_states = [state.clone() for state in bfgs_states]
631+
for state in bfgs_states:
632+
state.positions += torch.randn_like(state.positions) * 0.01
633+
634+
batcher = InFlightAutoBatcher(
635+
model=lj_model,
636+
memory_scales_with="n_atoms",
637+
max_memory_scaler=6000,
638+
)
639+
batcher.load_states(bfgs_states)
640+
641+
def convergence_fn(state: ts.BFGSState) -> torch.Tensor:
642+
system_wise_max_force = torch.zeros(
643+
state.n_systems, device=state.device, dtype=torch.float64
644+
)
645+
max_forces = state.forces.norm(dim=1)
646+
system_wise_max_force = system_wise_max_force.scatter_reduce(
647+
dim=0, index=state.system_idx, src=max_forces, reduce="amax"
648+
)
649+
return system_wise_max_force < 5e-1
650+
651+
all_completed_states, convergence_tensor = [], None
652+
while True:
653+
state, completed_states = batcher.next_batch(state, convergence_tensor)
654+
655+
all_completed_states.extend(completed_states)
656+
if state is None:
657+
break
658+
659+
for _ in range(num_steps_per_batch):
660+
state = ts.bfgs_step(state=state, model=lj_model)
661+
convergence_tensor = convergence_fn(state)
662+
663+
assert len(all_completed_states) == len(bfgs_states)
664+
665+
666+
def test_binning_auto_batcher_with_bfgs(
667+
si_sim_state: ts.SimState,
668+
fe_supercell_sim_state: ts.SimState,
669+
lj_model: LennardJonesModel,
670+
) -> None:
671+
"""Test BinningAutoBatcher with BFGS optimizer."""
672+
si_bfgs_state = ts.bfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit)
673+
fe_bfgs_state = ts.bfgs_init(
674+
fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit
675+
)
676+
677+
bfgs_states = [si_bfgs_state, fe_bfgs_state] * 5
678+
bfgs_states = [state.clone() for state in bfgs_states]
679+
for state in bfgs_states:
680+
state.positions += torch.randn_like(state.positions) * 0.01
681+
682+
batcher = BinningAutoBatcher(
683+
model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=6000
684+
)
685+
batcher.load_states(bfgs_states)
686+
687+
all_finished_states: list[ts.SimState] = []
688+
total_batches = 0
689+
for batch, _ in batcher:
690+
total_batches += 1 # noqa: SIM113
691+
for _ in range(5):
692+
batch = ts.bfgs_step(state=batch, model=lj_model)
693+
all_finished_states.extend(batch.split())
694+
695+
assert len(all_finished_states) == len(bfgs_states)
696+
697+
698+
@pytest.mark.parametrize(
699+
"num_steps_per_batch",
700+
[
701+
5, # At 5 steps, not every state will converge before the next batch.
702+
10, # At 10 steps, all states will converge before the next batch
703+
],
704+
)
705+
def test_in_flight_with_lbfgs(
706+
si_sim_state: ts.SimState,
707+
fe_supercell_sim_state: ts.SimState,
708+
lj_model: LennardJonesModel,
709+
num_steps_per_batch: int,
710+
) -> None:
711+
"""Test InFlightAutoBatcher with L-BFGS optimizer."""
712+
si_lbfgs_state = ts.lbfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit)
713+
fe_lbfgs_state = ts.lbfgs_init(
714+
fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit
715+
)
716+
717+
lbfgs_states = [si_lbfgs_state, fe_lbfgs_state] * 5
718+
lbfgs_states = [state.clone() for state in lbfgs_states]
719+
for state in lbfgs_states:
720+
state.positions += torch.randn_like(state.positions) * 0.01
721+
722+
batcher = InFlightAutoBatcher(
723+
model=lj_model,
724+
memory_scales_with="n_atoms",
725+
max_memory_scaler=6000,
726+
)
727+
batcher.load_states(lbfgs_states)
728+
729+
def convergence_fn(state: ts.LBFGSState) -> torch.Tensor:
730+
system_wise_max_force = torch.zeros(
731+
state.n_systems, device=state.device, dtype=torch.float64
732+
)
733+
max_forces = state.forces.norm(dim=1)
734+
system_wise_max_force = system_wise_max_force.scatter_reduce(
735+
dim=0, index=state.system_idx, src=max_forces, reduce="amax"
736+
)
737+
return system_wise_max_force < 5e-1
738+
739+
all_completed_states, convergence_tensor = [], None
740+
while True:
741+
state, completed_states = batcher.next_batch(state, convergence_tensor)
742+
743+
all_completed_states.extend(completed_states)
744+
if state is None:
745+
break
746+
747+
for _ in range(num_steps_per_batch):
748+
state = ts.lbfgs_step(state=state, model=lj_model)
749+
convergence_tensor = convergence_fn(state)
750+
751+
assert len(all_completed_states) == len(lbfgs_states)
752+
753+
754+
def test_binning_auto_batcher_with_lbfgs(
755+
si_sim_state: ts.SimState,
756+
fe_supercell_sim_state: ts.SimState,
757+
lj_model: LennardJonesModel,
758+
) -> None:
759+
"""Test BinningAutoBatcher with L-BFGS optimizer."""
760+
si_lbfgs_state = ts.lbfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit)
761+
fe_lbfgs_state = ts.lbfgs_init(
762+
fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit
763+
)
764+
765+
lbfgs_states = [si_lbfgs_state, fe_lbfgs_state] * 5
766+
lbfgs_states = [state.clone() for state in lbfgs_states]
767+
for state in lbfgs_states:
768+
state.positions += torch.randn_like(state.positions) * 0.01
769+
770+
batcher = BinningAutoBatcher(
771+
model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=6000
772+
)
773+
batcher.load_states(lbfgs_states)
774+
775+
all_finished_states: list[ts.SimState] = []
776+
total_batches = 0
777+
for batch, _ in batcher:
778+
total_batches += 1 # noqa: SIM113
779+
for _ in range(5):
780+
batch = ts.lbfgs_step(state=batch, model=lj_model)
781+
all_finished_states.extend(batch.split())
782+
783+
assert len(all_finished_states) == len(lbfgs_states)

0 commit comments

Comments
 (0)