Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,6 @@ modules.rst

# wandb
wandb/

# checkpoints
checkpoints/
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ You can set more config items for this OP (e.g. notification when annotation is

When you start running with the RFT config, the data module will start the OP `human_preference_annotation_mapper`, and then you can find a new project on the "Projects" page of the label-studio server.

![]("../../assets/data-projects.png")
![](../../assets/data-projects.png)

You can click and enter into this project, and all the samples that need to be annotated are listed on the page.

Expand Down
4 changes: 2 additions & 2 deletions docs/sphinx_doc/source/tutorial/example_multi_turn.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,5 +122,5 @@ and include them in the init files in `trinity/common/workflows/__init__.py`

Then you are all set! It should be pretty simple😄, and both environments converge.

![]("../../assets/alfworld_reward_curve.png")
![]("../../assets/webshop_reward_curve.png")
![](../../assets/alfworld_reward_curve.png)
![](../../assets/webshop_reward_curve.png)
1 change: 1 addition & 0 deletions examples/grpo_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ trainer:
trainer_config_path: 'examples/grpo_gsm8k/train_gsm8k.yaml'
sft_warmup_iteration: 0 # Set to integer to enable sft warmup
eval_interval: 50
# get_exp_strategy: 'LFU'
monitor:
cache_root_dir: ""
project: "Trinity-RFT-gsm8k"
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,6 @@ exclude = '''
| dist
)/
'''

[tool.isort]
known_third_party = ["wandb"]
2 changes: 1 addition & 1 deletion trinity/buffer/reader/sql_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
sortOrder = (desc(self.table_model_cls.priority), desc(self.table_model_cls.id))

else:
raise NotImplementedError("Unsupported strategy by SQLStorage")
raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage")

exp_list = []
while len(exp_list) < self.batch_size:
Expand Down
1 change: 1 addition & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class TrainerConfig:

# train algorithm
algorithm_type: AlgorithmType = AlgorithmType.PPO
get_exp_strategy: Optional[str] = None

# warmup config
sft_warmup_iteration: int = 0
Expand Down
1 change: 0 additions & 1 deletion trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ class Trainer:
training_rollout_mode: str = "parallel"
enable_exp_buffer: bool = True
steps_per_epoch: int = 1280
get_exp_strategy: Optional[str] = None
sync_freq: int = 0
sft_warmup_iteration: int = 0
max_actor_ckpt_to_keep: Optional[int] = None
Expand Down
6 changes: 4 additions & 2 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,10 @@ def get_weight(self, name: str) -> torch.Tensor:

def explore(self) -> None:
"""Explore the entire dataset."""
explore_status, _ = self.explore_step()
while explore_status:
while True:
explore_status, _ = self.explore_step()
if not explore_status:
break
self.sync_weight()
self.logger.info("Explorer finished.")

Expand Down
38 changes: 21 additions & 17 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from trinity.buffer import get_buffer_reader
from trinity.common.config import Config
from trinity.common.constants import AlgorithmType
from trinity.common.constants import AlgorithmType, ReadStrategy
from trinity.common.experience import Experiences
from trinity.utils.log import get_logger

Expand Down Expand Up @@ -81,24 +81,28 @@ def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple
pad_token_id=self.config.buffer.pad_token_id, # type: ignore
)
)
else:
exps = self.train_buffer.read()
if algo_type.is_rft():
return self.engine.train_rft_iteration(
Experiences.gather_experiences(
exps,
pad_token_id=self.config.buffer.pad_token_id, # type: ignore
)
elif algo_type.is_rft():
if self.config.trainer.get_exp_strategy:
strategy = ReadStrategy(self.config.trainer.get_exp_strategy)
else:
strategy = None
exps = self.train_buffer.read(strategy=strategy)
return self.engine.train_rft_iteration(
Experiences.gather_experiences(
exps,
pad_token_id=self.config.buffer.pad_token_id, # type: ignore
)
elif algo_type.is_dpo():
return self.engine.train_dpo_iteration(
Experiences.gather_dpo_experiences(
exps,
pad_token_id=self.config.buffer.pad_token_id, # type: ignore
)
)
elif algo_type.is_dpo():
exps = self.train_buffer.read()
return self.engine.train_dpo_iteration(
Experiences.gather_dpo_experiences(
exps,
pad_token_id=self.config.buffer.pad_token_id, # type: ignore
)
else:
raise ValueError(f"Unsupported algorithm type: {algo_type}")
)
else:
raise ValueError(f"Unsupported algorithm type: {algo_type}")

def sync_weight(self) -> None:
"""Sync the model weight."""
Expand Down
Loading