@@ -605,3 +605,179 @@ def test_in_flight_max_iterations(
605605 # Verify iteration_count tracking
606606 for idx in range (len (states )):
607607 assert batcher .iteration_count [idx ] == max_iterations
608+
609+
610+ @pytest .mark .parametrize (
611+ "num_steps_per_batch" ,
612+ [
613+ 5 , # At 5 steps, not every state will converge before the next batch.
614+ 10 , # At 10 steps, all states will converge before the next batch
615+ ],
616+ )
617+ def test_in_flight_with_bfgs (
618+ si_sim_state : ts .SimState ,
619+ fe_supercell_sim_state : ts .SimState ,
620+ lj_model : LennardJonesModel ,
621+ num_steps_per_batch : int ,
622+ ) -> None :
623+ """Test InFlightAutoBatcher with BFGS optimizer."""
624+ si_bfgs_state = ts .bfgs_init (si_sim_state , lj_model , cell_filter = ts .CellFilter .unit )
625+ fe_bfgs_state = ts .bfgs_init (
626+ fe_supercell_sim_state , lj_model , cell_filter = ts .CellFilter .unit
627+ )
628+
629+ bfgs_states = [si_bfgs_state , fe_bfgs_state ] * 5
630+ bfgs_states = [state .clone () for state in bfgs_states ]
631+ for state in bfgs_states :
632+ state .positions += torch .randn_like (state .positions ) * 0.01
633+
634+ batcher = InFlightAutoBatcher (
635+ model = lj_model ,
636+ memory_scales_with = "n_atoms" ,
637+ max_memory_scaler = 6000 ,
638+ )
639+ batcher .load_states (bfgs_states )
640+
641+ def convergence_fn (state : ts .BFGSState ) -> torch .Tensor :
642+ system_wise_max_force = torch .zeros (
643+ state .n_systems , device = state .device , dtype = torch .float64
644+ )
645+ max_forces = state .forces .norm (dim = 1 )
646+ system_wise_max_force = system_wise_max_force .scatter_reduce (
647+ dim = 0 , index = state .system_idx , src = max_forces , reduce = "amax"
648+ )
649+ return system_wise_max_force < 5e-1
650+
651+ all_completed_states , convergence_tensor = [], None
652+ while True :
653+ state , completed_states = batcher .next_batch (state , convergence_tensor )
654+
655+ all_completed_states .extend (completed_states )
656+ if state is None :
657+ break
658+
659+ for _ in range (num_steps_per_batch ):
660+ state = ts .bfgs_step (state = state , model = lj_model )
661+ convergence_tensor = convergence_fn (state )
662+
663+ assert len (all_completed_states ) == len (bfgs_states )
664+
665+
666+ def test_binning_auto_batcher_with_bfgs (
667+ si_sim_state : ts .SimState ,
668+ fe_supercell_sim_state : ts .SimState ,
669+ lj_model : LennardJonesModel ,
670+ ) -> None :
671+ """Test BinningAutoBatcher with BFGS optimizer."""
672+ si_bfgs_state = ts .bfgs_init (si_sim_state , lj_model , cell_filter = ts .CellFilter .unit )
673+ fe_bfgs_state = ts .bfgs_init (
674+ fe_supercell_sim_state , lj_model , cell_filter = ts .CellFilter .unit
675+ )
676+
677+ bfgs_states = [si_bfgs_state , fe_bfgs_state ] * 5
678+ bfgs_states = [state .clone () for state in bfgs_states ]
679+ for state in bfgs_states :
680+ state .positions += torch .randn_like (state .positions ) * 0.01
681+
682+ batcher = BinningAutoBatcher (
683+ model = lj_model , memory_scales_with = "n_atoms" , max_memory_scaler = 6000
684+ )
685+ batcher .load_states (bfgs_states )
686+
687+ all_finished_states : list [ts .SimState ] = []
688+ total_batches = 0
689+ for batch , _ in batcher :
690+ total_batches += 1 # noqa: SIM113
691+ for _ in range (5 ):
692+ batch = ts .bfgs_step (state = batch , model = lj_model )
693+ all_finished_states .extend (batch .split ())
694+
695+ assert len (all_finished_states ) == len (bfgs_states )
696+
697+
698+ @pytest .mark .parametrize (
699+ "num_steps_per_batch" ,
700+ [
701+ 5 , # At 5 steps, not every state will converge before the next batch.
702+ 10 , # At 10 steps, all states will converge before the next batch
703+ ],
704+ )
705+ def test_in_flight_with_lbfgs (
706+ si_sim_state : ts .SimState ,
707+ fe_supercell_sim_state : ts .SimState ,
708+ lj_model : LennardJonesModel ,
709+ num_steps_per_batch : int ,
710+ ) -> None :
711+ """Test InFlightAutoBatcher with L-BFGS optimizer."""
712+ si_lbfgs_state = ts .lbfgs_init (si_sim_state , lj_model , cell_filter = ts .CellFilter .unit )
713+ fe_lbfgs_state = ts .lbfgs_init (
714+ fe_supercell_sim_state , lj_model , cell_filter = ts .CellFilter .unit
715+ )
716+
717+ lbfgs_states = [si_lbfgs_state , fe_lbfgs_state ] * 5
718+ lbfgs_states = [state .clone () for state in lbfgs_states ]
719+ for state in lbfgs_states :
720+ state .positions += torch .randn_like (state .positions ) * 0.01
721+
722+ batcher = InFlightAutoBatcher (
723+ model = lj_model ,
724+ memory_scales_with = "n_atoms" ,
725+ max_memory_scaler = 6000 ,
726+ )
727+ batcher .load_states (lbfgs_states )
728+
729+ def convergence_fn (state : ts .LBFGSState ) -> torch .Tensor :
730+ system_wise_max_force = torch .zeros (
731+ state .n_systems , device = state .device , dtype = torch .float64
732+ )
733+ max_forces = state .forces .norm (dim = 1 )
734+ system_wise_max_force = system_wise_max_force .scatter_reduce (
735+ dim = 0 , index = state .system_idx , src = max_forces , reduce = "amax"
736+ )
737+ return system_wise_max_force < 5e-1
738+
739+ all_completed_states , convergence_tensor = [], None
740+ while True :
741+ state , completed_states = batcher .next_batch (state , convergence_tensor )
742+
743+ all_completed_states .extend (completed_states )
744+ if state is None :
745+ break
746+
747+ for _ in range (num_steps_per_batch ):
748+ state = ts .lbfgs_step (state = state , model = lj_model )
749+ convergence_tensor = convergence_fn (state )
750+
751+ assert len (all_completed_states ) == len (lbfgs_states )
752+
753+
754+ def test_binning_auto_batcher_with_lbfgs (
755+ si_sim_state : ts .SimState ,
756+ fe_supercell_sim_state : ts .SimState ,
757+ lj_model : LennardJonesModel ,
758+ ) -> None :
759+ """Test BinningAutoBatcher with L-BFGS optimizer."""
760+ si_lbfgs_state = ts .lbfgs_init (si_sim_state , lj_model , cell_filter = ts .CellFilter .unit )
761+ fe_lbfgs_state = ts .lbfgs_init (
762+ fe_supercell_sim_state , lj_model , cell_filter = ts .CellFilter .unit
763+ )
764+
765+ lbfgs_states = [si_lbfgs_state , fe_lbfgs_state ] * 5
766+ lbfgs_states = [state .clone () for state in lbfgs_states ]
767+ for state in lbfgs_states :
768+ state .positions += torch .randn_like (state .positions ) * 0.01
769+
770+ batcher = BinningAutoBatcher (
771+ model = lj_model , memory_scales_with = "n_atoms" , max_memory_scaler = 6000
772+ )
773+ batcher .load_states (lbfgs_states )
774+
775+ all_finished_states : list [ts .SimState ] = []
776+ total_batches = 0
777+ for batch , _ in batcher :
778+ total_batches += 1 # noqa: SIM113
779+ for _ in range (5 ):
780+ batch = ts .lbfgs_step (state = batch , model = lj_model )
781+ all_finished_states .extend (batch .split ())
782+
783+ assert len (all_finished_states ) == len (lbfgs_states )
0 commit comments