Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/template/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ model:
max_response_tokens: 2048
max_model_len: 4096
cluster: # 2 for explorer, 2 for trainer
node_num: 2
gpu_per_node: 2
node_num: ${oc.env:NODE_NUM,2}
gpu_per_node: ${oc.env:GPU_PER_NODE,2}
buffer:
total_epochs: 1
batch_size: 4
Expand Down
28 changes: 27 additions & 1 deletion tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for trainer."""

import json
import multiprocessing
import os
import shutil
Expand Down Expand Up @@ -809,7 +810,7 @@ def test_trainer(self):
self.config.algorithm.policy_loss_fn = "mix"
self.config.buffer.batch_size = 4
self.config.buffer.train_batch_size = 32
self.config.buffer.total_epochs = 1
self.config.buffer.total_steps = 2
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.synchronizer.sync_interval = 1
self.config.trainer.save_interval = 1
Expand All @@ -823,6 +824,31 @@ def test_trainer(self):
self.config.buffer.trainer_input.experience_buffer.max_read_timeout = 20
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
both(self.config)
ray.shutdown(_exiting_interpreter=True)

# check trainer resume metadata
trainer_meta_file = os.path.join(self.config.checkpoint_job_dir, "trainer_meta.json")
with open(trainer_meta_file) as f:
trainer_meta = json.load(f)
self.assertEqual(trainer_meta["latest_iteration"], 2)
self.assertEqual(
trainer_meta["sample_strategy_state"]["expert_buffer"]["current_index"], 32
)

self.config.buffer.total_steps = None
self.config.buffer.total_epochs = 1
self.config.check_and_update()
ray.init(ignore_reinit_error=True, namespace=self.config.ray_namespace)
both(self.config)

# check trainer resume metadata
with open(trainer_meta_file) as f:
trainer_meta = json.load(f)
self.assertEqual(trainer_meta["latest_iteration"], 4)
self.assertEqual(
trainer_meta["sample_strategy_state"]["expert_buffer"]["current_index"], 64
)

parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))

# test rollout metrics
Expand Down
12 changes: 12 additions & 0 deletions trinity/algorithm/sample_strategy/mix_sample_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,15 @@ def default_args(cls) -> Dict:
"expert_data_ratio": 0.5,
"sft_dataset_name": "sft_dataset",
}

def state_dict(self) -> dict:
return {
"usal_buffer": self.usual_exp_buffer.state_dict(),
"expert_buffer": self.expert_exp_buffer.state_dict(),
}

def load_state_dict(self, state_dict: dict) -> None:
if state_dict.get("usal_buffer", None):
self.usual_exp_buffer.load_state_dict(state_dict["usal_buffer"])
if state_dict.get("expert_buffer", None):
self.expert_exp_buffer.load_state_dict(state_dict["expert_buffer"])
15 changes: 15 additions & 0 deletions trinity/algorithm/sample_strategy/sample_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
def default_args(cls) -> dict:
"""Get the default arguments of the sample strategy."""

@abstractmethod
def state_dict(self) -> dict:
"""Get the state dict of the sample strategy."""

@abstractmethod
def load_state_dict(self, state_dict: dict) -> None:
"""Load the state dict of the sample strategy."""


@SAMPLE_STRATEGY.register_module("default")
class DefaultSampleStrategy(SampleStrategy):
Expand All @@ -64,6 +72,13 @@ async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]:
def default_args(cls) -> dict:
return {}

def state_dict(self) -> dict:
return self.exp_buffer.state_dict()

def load_state_dict(self, state_dict: dict) -> None:
if state_dict:
self.exp_buffer.load_state_dict(state_dict)


@Deprecated
@SAMPLE_STRATEGY.register_module("warmup")
Expand Down
4 changes: 2 additions & 2 deletions trinity/buffer/reader/queue_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ async def read_async(self, batch_size: Optional[int] = None) -> List:
return exps

def state_dict(self) -> Dict:
# SQL Not supporting state dict yet
# Queue Not supporting state dict yet
return {"current_index": 0}

def load_state_dict(self, state_dict):
# SQL Not supporting state dict yet
# Queue Not supporting state dict yet
return None
4 changes: 2 additions & 2 deletions trinity/manager/state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ def load_explorer_server_url(self) -> Optional[str]:

def save_trainer(
self,
current_exp_index: int,
current_step: int,
sample_strategy_state: dict,
) -> None:
with open(self.trainer_state_path, "w", encoding="utf-8") as f:
json.dump(
{
"latest_exp_index": current_exp_index,
"latest_iteration": current_step,
"sample_strategy_state": sample_strategy_state,
},
f,
indent=2,
Expand Down
15 changes: 10 additions & 5 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import ray

from trinity.algorithm import SAMPLE_STRATEGY
from trinity.algorithm.sample_strategy.sample_strategy import SampleStrategy
from trinity.common.config import Config
from trinity.common.constants import RunningStatus, SyncMethod, SyncStyle
from trinity.common.experience import Experiences
Expand All @@ -38,9 +39,6 @@ def __init__(self, config: Config) -> None:
path=config.checkpoint_job_dir, trainer_name=config.trainer.name, config=config
)
trainer_state = self.state.load_trainer()
config.buffer.trainer_input.experience_buffer.index = trainer_state.get(
"latest_exp_index", 0
)
self.last_trainer_sync_step = 0
self.monitor = MONITOR.get(config.monitor.monitor_type)(
project=config.project,
Expand All @@ -50,10 +48,17 @@ def __init__(self, config: Config) -> None:
config=config,
)
self._sample_exps_to_log = []
self.sample_strategy = SAMPLE_STRATEGY.get(config.algorithm.sample_strategy)(
self.sample_strategy: SampleStrategy = SAMPLE_STRATEGY.get(
config.algorithm.sample_strategy
)(
buffer_config=config.buffer,
**config.algorithm.sample_strategy_args,
)
if "latest_exp_index" in trainer_state:
sample_strategy_state = {"current_index": trainer_state["latest_exp_index"]}
else:
sample_strategy_state = trainer_state.get("sample_strategy_state", {})
self.sample_strategy.load_state_dict(sample_strategy_state)
self.save_interval = config.trainer.save_interval
self.last_sync_step = None
self.last_sync_time = None
Expand Down Expand Up @@ -190,8 +195,8 @@ def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = Fa
self.logger.info(f"Saving checkpoint at step {self.train_step_num}...")
self.engine.save_checkpoint(block_until_saved=block_until_saved, save_as_hf=save_as_hf)
self.state.save_trainer(
current_exp_index=self.engine.train_step_num * self.config.buffer.train_batch_size,
current_step=self.train_step_num,
sample_strategy_state=self.sample_strategy.state_dict(),
)
return metrics

Expand Down