1414import ray
1515
1616from trinity .algorithm import SAMPLE_STRATEGY
17+ from trinity .algorithm .sample_strategy .sample_strategy import SampleStrategy
1718from trinity .common .config import Config
1819from trinity .common .constants import RunningStatus , SyncMethod , SyncStyle
1920from trinity .common .experience import Experiences
@@ -38,9 +39,6 @@ def __init__(self, config: Config) -> None:
3839 path = config .checkpoint_job_dir , trainer_name = config .trainer .name , config = config
3940 )
4041 trainer_state = self .state .load_trainer ()
41- config .buffer .trainer_input .experience_buffer .index = trainer_state .get (
42- "latest_exp_index" , 0
43- )
4442 self .last_trainer_sync_step = 0
4543 self .monitor = MONITOR .get (config .monitor .monitor_type )(
4644 project = config .project ,
@@ -50,10 +48,17 @@ def __init__(self, config: Config) -> None:
5048 config = config ,
5149 )
5250 self ._sample_exps_to_log = []
53- self .sample_strategy = SAMPLE_STRATEGY .get (config .algorithm .sample_strategy )(
51+ self .sample_strategy : SampleStrategy = SAMPLE_STRATEGY .get (
52+ config .algorithm .sample_strategy
53+ )(
5454 buffer_config = config .buffer ,
5555 ** config .algorithm .sample_strategy_args ,
5656 )
57+ if "latest_exp_index" in trainer_state :
58+ sample_strategy_state = {"current_index" : trainer_state ["latest_exp_index" ]}
59+ else :
60+ sample_strategy_state = trainer_state .get ("sample_strategy_state" , {})
61+ self .sample_strategy .load_state_dict (sample_strategy_state )
5762 self .save_interval = config .trainer .save_interval
5863 self .last_sync_step = None
5964 self .last_sync_time = None
@@ -190,8 +195,8 @@ def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = Fa
190195 self .logger .info (f"Saving checkpoint at step { self .train_step_num } ..." )
191196 self .engine .save_checkpoint (block_until_saved = block_until_saved , save_as_hf = save_as_hf )
192197 self .state .save_trainer (
193- current_exp_index = self .engine .train_step_num * self .config .buffer .train_batch_size ,
194198 current_step = self .train_step_num ,
199+ sample_strategy_state = self .sample_strategy .state_dict (),
195200 )
196201 return metrics
197202
0 commit comments