Skip to content

Commit 4fbe9bd

Browse files
committed
bug fix for trainer_state saving
1 parent 01423d6 commit 4fbe9bd

File tree

7 files changed

+72
-13
lines changed

7 files changed

+72
-13
lines changed

tests/template/config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ model:
2020
max_response_tokens: 2048
2121
max_model_len: 4096
2222
cluster: # 2 for explorer, 2 for trainer
23-
node_num: 2
24-
gpu_per_node: 2
23+
node_num: ${oc.env:NODE_NUM,2}
24+
gpu_per_node: ${oc.env:GPU_PER_NODE,2}
2525
buffer:
2626
total_epochs: 1
2727
batch_size: 4

tests/trainer/trainer_test.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for trainer."""
22

3+
import json
34
import multiprocessing
45
import os
56
import shutil
@@ -809,7 +810,7 @@ def test_trainer(self):
809810
self.config.algorithm.policy_loss_fn = "mix"
810811
self.config.buffer.batch_size = 4
811812
self.config.buffer.train_batch_size = 32
812-
self.config.buffer.total_epochs = 1
813+
self.config.buffer.total_steps = 2
813814
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
814815
self.config.synchronizer.sync_interval = 1
815816
self.config.trainer.save_interval = 1
@@ -823,6 +824,31 @@ def test_trainer(self):
823824
self.config.buffer.trainer_input.experience_buffer.max_read_timeout = 20
824825
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
825826
both(self.config)
827+
ray.shutdown(_exiting_interpreter=True)
828+
829+
# check trainer resume metadata
830+
trainer_meta_file = os.path.join(self.config.checkpoint_job_dir, "trainer_meta.json")
831+
with open(trainer_meta_file) as f:
832+
trainer_meta = json.load(f)
833+
self.assertEqual(trainer_meta["latest_iteration"], 2)
834+
self.assertEqual(
835+
trainer_meta["sample_strategy_state"]["expert_buffer"]["current_index"], 32
836+
)
837+
838+
self.config.buffer.total_steps = None
839+
self.config.buffer.total_epochs = 1
840+
self.config.check_and_update()
841+
ray.init(ignore_reinit_error=True, namespace=self.config.ray_namespace)
842+
both(self.config)
843+
844+
# check trainer resume metadata
845+
with open(trainer_meta_file) as f:
846+
trainer_meta = json.load(f)
847+
self.assertEqual(trainer_meta["latest_iteration"], 4)
848+
self.assertEqual(
849+
trainer_meta["sample_strategy_state"]["expert_buffer"]["current_index"], 64
850+
)
851+
826852
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
827853

828854
# test rollout metrics
@@ -852,7 +878,8 @@ def test_trainer(self):
852878
self.assertTrue(len(os.listdir(os.path.join(checkpoint_dir, "actor"))) > 0)
853879

854880
def tearDown(self):
855-
shutil.rmtree(self.config.checkpoint_job_dir)
881+
# shutil.rmtree(self.config.checkpoint_job_dir)
882+
pass
856883

857884

858885
class TestMultiModalGRPO(BaseTrainerCase):

trinity/algorithm/sample_strategy/mix_sample_strategy.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,15 @@ def default_args(cls) -> Dict:
111111
"expert_data_ratio": 0.5,
112112
"sft_dataset_name": "sft_dataset",
113113
}
114+
115+
def state_dict(self) -> dict:
116+
return {
117+
"usal_buffer": self.usual_exp_buffer.state_dict(),
118+
"expert_buffer": self.expert_exp_buffer.state_dict(),
119+
}
120+
121+
def load_state_dict(self, state_dict: dict) -> None:
122+
if state_dict.get("usal_buffer", None):
123+
self.usual_exp_buffer.load_state_dict(state_dict["usal_buffer"])
124+
if state_dict.get("expert_buffer", None):
125+
self.expert_exp_buffer.load_state_dict(state_dict["expert_buffer"])

trinity/algorithm/sample_strategy/sample_strategy.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
4343
def default_args(cls) -> dict:
4444
"""Get the default arguments of the sample strategy."""
4545

46+
@abstractmethod
47+
def state_dict(self) -> dict:
48+
"""Get the state dict of the sample strategy."""
49+
50+
@abstractmethod
51+
def load_state_dict(self, state_dict: dict) -> None:
52+
"""Load the state dict of the sample strategy."""
53+
4654

4755
@SAMPLE_STRATEGY.register_module("default")
4856
class DefaultSampleStrategy(SampleStrategy):
@@ -64,6 +72,13 @@ async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]:
6472
def default_args(cls) -> dict:
6573
return {}
6674

75+
def state_dict(self) -> dict:
76+
return self.exp_buffer.state_dict()
77+
78+
def load_state_dict(self, state_dict: dict) -> None:
79+
if state_dict:
80+
self.exp_buffer.load_state_dict(state_dict)
81+
6782

6883
@Deprecated
6984
@SAMPLE_STRATEGY.register_module("warmup")

trinity/buffer/reader/queue_reader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ async def read_async(self, batch_size: Optional[int] = None) -> List:
4343
return exps
4444

4545
def state_dict(self) -> Dict:
46-
# SQL Not supporting state dict yet
46+
# Queue Not supporting state dict yet
4747
return {"current_index": 0}
4848

4949
def load_state_dict(self, state_dict):
50-
# SQL Not supporting state dict yet
50+
# Queue Not supporting state dict yet
5151
return None

trinity/manager/state_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,14 @@ def load_explorer_server_url(self) -> Optional[str]:
101101

102102
def save_trainer(
103103
self,
104-
current_exp_index: int,
105104
current_step: int,
105+
sample_strategy_state: dict,
106106
) -> None:
107107
with open(self.trainer_state_path, "w", encoding="utf-8") as f:
108108
json.dump(
109109
{
110-
"latest_exp_index": current_exp_index,
111110
"latest_iteration": current_step,
111+
"sample_strategy_state": sample_strategy_state,
112112
},
113113
f,
114114
indent=2,

trinity/trainer/trainer.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import ray
1515

1616
from trinity.algorithm import SAMPLE_STRATEGY
17+
from trinity.algorithm.sample_strategy.sample_strategy import SampleStrategy
1718
from trinity.common.config import Config
1819
from trinity.common.constants import RunningStatus, SyncMethod, SyncStyle
1920
from trinity.common.experience import Experiences
@@ -38,9 +39,6 @@ def __init__(self, config: Config) -> None:
3839
path=config.checkpoint_job_dir, trainer_name=config.trainer.name, config=config
3940
)
4041
trainer_state = self.state.load_trainer()
41-
config.buffer.trainer_input.experience_buffer.index = trainer_state.get(
42-
"latest_exp_index", 0
43-
)
4442
self.last_trainer_sync_step = 0
4543
self.monitor = MONITOR.get(config.monitor.monitor_type)(
4644
project=config.project,
@@ -50,10 +48,17 @@ def __init__(self, config: Config) -> None:
5048
config=config,
5149
)
5250
self._sample_exps_to_log = []
53-
self.sample_strategy = SAMPLE_STRATEGY.get(config.algorithm.sample_strategy)(
51+
self.sample_strategy: SampleStrategy = SAMPLE_STRATEGY.get(
52+
config.algorithm.sample_strategy
53+
)(
5454
buffer_config=config.buffer,
5555
**config.algorithm.sample_strategy_args,
5656
)
57+
if "latest_exp_index" in trainer_state:
58+
sample_strategy_state = {"current_index": trainer_state["latest_exp_index"]}
59+
else:
60+
sample_strategy_state = trainer_state.get("sample_strategy_state", {})
61+
self.sample_strategy.load_state_dict(sample_strategy_state)
5762
self.save_interval = config.trainer.save_interval
5863
self.last_sync_step = None
5964
self.last_sync_time = None
@@ -190,8 +195,8 @@ def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = Fa
190195
self.logger.info(f"Saving checkpoint at step {self.train_step_num}...")
191196
self.engine.save_checkpoint(block_until_saved=block_until_saved, save_as_hf=save_as_hf)
192197
self.state.save_trainer(
193-
current_exp_index=self.engine.train_step_num * self.config.buffer.train_batch_size,
194198
current_step=self.train_step_num,
199+
sample_strategy_state=self.sample_strategy.state_dict(),
195200
)
196201
return metrics
197202

0 commit comments

Comments
 (0)