Skip to content

Commit 404bc13

Browse files
authored
Bug fix for trainer_state saving (#408)
1 parent ee05b3e commit 404bc13

File tree

9 files changed

+72
-14
lines changed

9 files changed

+72
-14
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # TRI
8282
- `explore`: Only launches the explorer.
8383
- `bench`: Used for benchmarking.
8484
- `checkpoint_root_dir`: Root directory where all checkpoints and logs will be saved. Checkpoints for this experiment will be stored in `<checkpoint_root_dir>/<project>/<name>/`.
85-
- `continue_from_checkpoint`: If set to `true`, the experiment will continue from the latest checkpoint in the checkpoint path (if any); otherwise, it will rename the current experiment to `<name>_<timestamp>` and start a new experiment.
85+
- `continue_from_checkpoint`: If set to `true`, the experiment will continue from the latest checkpoint in the checkpoint path (if any); otherwise, it will rename the current experiment to `<name>_<timestamp>` and start a new experiment. Due to our decoupled design, during recovery from a checkpoint, we can only guarantee that the Trainer's model parameters and its optional auxiliary buffers (`auxiliary_buffers`) are restored to their latest checkpointed states, while the Explorer and Experience Buffer cannot be guaranteed to be restored to the same point in time.
8686
- `ray_namespace`: Namespace for the modules launched in the current experiment. If not specified, it will be set to `<project>/<name>`.
8787

8888
---

docs/sphinx_doc/source_zh/tutorial/trinity_configs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # TRI
8282
- `explore`: 仅启动 explorer。
8383
- `bench`: 用于 benchmark 测试。
8484
- `checkpoint_root_dir`: 所有检查点和日志的根目录。该实验的检查点将存储在 `<checkpoint_root_dir>/<project>/<name>/` 路径下。
85-
- `continue_from_checkpoint`: 若设置为 `true`,实验将从检查点路径中的最新检查点继续;否则,会将当前实验重命名为 `<name>_<timestamp>` 并启动新实验。
85+
- `continue_from_checkpoint`: 若设置为 `true`,实验将从检查点路径中的最新检查点继续;否则,会将当前实验重命名为 `<name>_<timestamp>` 并启动新实验。由于我们的分离式设计,从检查点恢复的时候,我们只能保证Trainer的模型参数以及其使用的可选缓冲区(`auxiliary_buffers`)可以恢复到最新检查点的状态,而Explorer和Experience Buffer不能保证恢复到同一时点。
8686
- `ray_namespace`: 当前实验中启动模块的命名空间。若未指定,则默认为 `<project>/<name>`。
8787

8888
---

tests/template/config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ model:
2020
max_response_tokens: 2048
2121
max_model_len: 4096
2222
cluster: # 2 for explorer, 2 for trainer
23-
node_num: 2
24-
gpu_per_node: 2
23+
node_num: ${oc.env:NODE_NUM,2}
24+
gpu_per_node: ${oc.env:GPU_PER_NODE,2}
2525
buffer:
2626
total_epochs: 1
2727
batch_size: 4

tests/trainer/trainer_test.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for trainer."""
22

3+
import json
34
import multiprocessing
45
import os
56
import shutil
@@ -809,7 +810,7 @@ def test_trainer(self):
809810
self.config.algorithm.policy_loss_fn = "mix"
810811
self.config.buffer.batch_size = 4
811812
self.config.buffer.train_batch_size = 32
812-
self.config.buffer.total_epochs = 1
813+
self.config.buffer.total_steps = 2
813814
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
814815
self.config.synchronizer.sync_interval = 1
815816
self.config.trainer.save_interval = 1
@@ -823,6 +824,31 @@ def test_trainer(self):
823824
self.config.buffer.trainer_input.experience_buffer.max_read_timeout = 20
824825
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
825826
both(self.config)
827+
ray.shutdown(_exiting_interpreter=True)
828+
829+
# check trainer resume metadata
830+
trainer_meta_file = os.path.join(self.config.checkpoint_job_dir, "trainer_meta.json")
831+
with open(trainer_meta_file) as f:
832+
trainer_meta = json.load(f)
833+
self.assertEqual(trainer_meta["latest_iteration"], 2)
834+
self.assertEqual(
835+
trainer_meta["sample_strategy_state"]["expert_buffer"]["current_index"], 32
836+
)
837+
838+
self.config.buffer.total_steps = None
839+
self.config.buffer.total_epochs = 1
840+
self.config.check_and_update()
841+
ray.init(ignore_reinit_error=True, namespace=self.config.ray_namespace)
842+
both(self.config)
843+
844+
# check trainer resume metadata
845+
with open(trainer_meta_file) as f:
846+
trainer_meta = json.load(f)
847+
self.assertEqual(trainer_meta["latest_iteration"], 4)
848+
self.assertEqual(
849+
trainer_meta["sample_strategy_state"]["expert_buffer"]["current_index"], 64
850+
)
851+
826852
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
827853

828854
# test rollout metrics

trinity/algorithm/sample_strategy/mix_sample_strategy.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,15 @@ def default_args(cls) -> Dict:
111111
"expert_data_ratio": 0.5,
112112
"sft_dataset_name": "sft_dataset",
113113
}
114+
115+
def state_dict(self) -> dict:
116+
return {
117+
"usal_buffer": self.usual_exp_buffer.state_dict(),
118+
"expert_buffer": self.expert_exp_buffer.state_dict(),
119+
}
120+
121+
def load_state_dict(self, state_dict: dict) -> None:
122+
if state_dict.get("usal_buffer", None):
123+
self.usual_exp_buffer.load_state_dict(state_dict["usal_buffer"])
124+
if state_dict.get("expert_buffer", None):
125+
self.expert_exp_buffer.load_state_dict(state_dict["expert_buffer"])

trinity/algorithm/sample_strategy/sample_strategy.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
4343
def default_args(cls) -> dict:
4444
"""Get the default arguments of the sample strategy."""
4545

46+
@abstractmethod
47+
def state_dict(self) -> dict:
48+
"""Get the state dict of the sample strategy."""
49+
50+
@abstractmethod
51+
def load_state_dict(self, state_dict: dict) -> None:
52+
"""Load the state dict of the sample strategy."""
53+
4654

4755
@SAMPLE_STRATEGY.register_module("default")
4856
class DefaultSampleStrategy(SampleStrategy):
@@ -64,6 +72,13 @@ async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]:
6472
def default_args(cls) -> dict:
6573
return {}
6674

75+
def state_dict(self) -> dict:
76+
return self.exp_buffer.state_dict()
77+
78+
def load_state_dict(self, state_dict: dict) -> None:
79+
if state_dict:
80+
self.exp_buffer.load_state_dict(state_dict)
81+
6782

6883
@Deprecated
6984
@SAMPLE_STRATEGY.register_module("warmup")

trinity/buffer/reader/queue_reader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ async def read_async(self, batch_size: Optional[int] = None) -> List:
4343
return exps
4444

4545
def state_dict(self) -> Dict:
46-
# SQL Not supporting state dict yet
46+
# Queue Not supporting state dict yet
4747
return {"current_index": 0}
4848

4949
def load_state_dict(self, state_dict):
50-
# SQL Not supporting state dict yet
50+
# Queue Not supporting state dict yet
5151
return None

trinity/manager/state_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,14 @@ def load_explorer_server_url(self) -> Optional[str]:
101101

102102
def save_trainer(
103103
self,
104-
current_exp_index: int,
105104
current_step: int,
105+
sample_strategy_state: dict,
106106
) -> None:
107107
with open(self.trainer_state_path, "w", encoding="utf-8") as f:
108108
json.dump(
109109
{
110-
"latest_exp_index": current_exp_index,
111110
"latest_iteration": current_step,
111+
"sample_strategy_state": sample_strategy_state,
112112
},
113113
f,
114114
indent=2,

trinity/trainer/trainer.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import ray
1515

1616
from trinity.algorithm import SAMPLE_STRATEGY
17+
from trinity.algorithm.sample_strategy.sample_strategy import SampleStrategy
1718
from trinity.common.config import Config
1819
from trinity.common.constants import RunningStatus, SyncMethod, SyncStyle
1920
from trinity.common.experience import Experiences
@@ -38,9 +39,6 @@ def __init__(self, config: Config) -> None:
3839
path=config.checkpoint_job_dir, trainer_name=config.trainer.name, config=config
3940
)
4041
trainer_state = self.state.load_trainer()
41-
config.buffer.trainer_input.experience_buffer.index = trainer_state.get(
42-
"latest_exp_index", 0
43-
)
4442
self.last_trainer_sync_step = 0
4543
self.monitor = MONITOR.get(config.monitor.monitor_type)(
4644
project=config.project,
@@ -50,10 +48,17 @@ def __init__(self, config: Config) -> None:
5048
config=config,
5149
)
5250
self._sample_exps_to_log = []
53-
self.sample_strategy = SAMPLE_STRATEGY.get(config.algorithm.sample_strategy)(
51+
self.sample_strategy: SampleStrategy = SAMPLE_STRATEGY.get(
52+
config.algorithm.sample_strategy
53+
)(
5454
buffer_config=config.buffer,
5555
**config.algorithm.sample_strategy_args,
5656
)
57+
if "latest_exp_index" in trainer_state:
58+
sample_strategy_state = {"current_index": trainer_state["latest_exp_index"]}
59+
else:
60+
sample_strategy_state = trainer_state.get("sample_strategy_state", {})
61+
self.sample_strategy.load_state_dict(sample_strategy_state)
5762
self.save_interval = config.trainer.save_interval
5863
self.last_sync_step = None
5964
self.last_sync_time = None
@@ -190,8 +195,8 @@ def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = Fa
190195
self.logger.info(f"Saving checkpoint at step {self.train_step_num}...")
191196
self.engine.save_checkpoint(block_until_saved=block_until_saved, save_as_hf=save_as_hf)
192197
self.state.save_trainer(
193-
current_exp_index=self.engine.train_step_num * self.config.buffer.train_batch_size,
194198
current_step=self.train_step_num,
199+
sample_strategy_state=self.sample_strategy.state_dict(),
195200
)
196201
return metrics
197202

0 commit comments

Comments
 (0)