Skip to content

Commit 8795012

Browse files
authored
Make 'get_exp_strategy' effective in sql (#26)
1 parent d66b3de commit 8795012

File tree

10 files changed

+37
-24
lines changed

10 files changed

+37
-24
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,6 @@ modules.rst
9494

9595
# wandb
9696
wandb/
97+
98+
# checkpoints
99+
checkpoints/

docs/sphinx_doc/source/tutorial/example_data_functionalities.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ You can set more config items for this OP (e.g. notification when annotation is
244244

245245
When you start running with the RFT config, the data module will start the OP `human_preference_annotation_mapper`, and then you can find a new project on the "Projects" page of the label-studio server.
246246

247-
![]("../../assets/data-projects.png")
247+
![](../../assets/data-projects.png)
248248

249249
You can click and enter into this project, and all the samples that need to be annotated are listed on the page.
250250

docs/sphinx_doc/source/tutorial/example_multi_turn.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,5 +122,5 @@ and include them in the init files in `trinity/common/workflows/__init__.py`
122122

123123
Then you are all set! It should be pretty simple😄, and both environments converge.
124124

125-
![]("../../assets/alfworld_reward_curve.png")
126-
![]("../../assets/webshop_reward_curve.png")
125+
![](../../assets/alfworld_reward_curve.png)
126+
![](../../assets/webshop_reward_curve.png)

examples/grpo_gsm8k/gsm8k.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ trainer:
6969
trainer_config_path: 'examples/grpo_gsm8k/train_gsm8k.yaml'
7070
sft_warmup_iteration: 0 # Set to integer to enable sft warmup
7171
eval_interval: 50
72+
# get_exp_strategy: 'LFU'
7273
monitor:
7374
cache_root_dir: ""
7475
project: "Trinity-RFT-gsm8k"

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,6 @@ exclude = '''
9393
| dist
9494
)/
9595
'''
96+
97+
[tool.isort]
98+
known_third_party = ["wandb"]

trinity/buffer/reader/sql_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
4949
sortOrder = (desc(self.table_model_cls.priority), desc(self.table_model_cls.id))
5050

5151
else:
52-
raise NotImplementedError("Unsupported strategy by SQLStorage")
52+
raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage")
5353

5454
exp_list = []
5555
while len(exp_list) < self.batch_size:

trinity/common/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ class TrainerConfig:
181181

182182
# train algorithm
183183
algorithm_type: AlgorithmType = AlgorithmType.PPO
184+
get_exp_strategy: Optional[str] = None
184185

185186
# warmup config
186187
sft_warmup_iteration: int = 0

trinity/common/verl_config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,6 @@ class Trainer:
244244
training_rollout_mode: str = "parallel"
245245
enable_exp_buffer: bool = True
246246
steps_per_epoch: int = 1280
247-
get_exp_strategy: Optional[str] = None
248247
sync_freq: int = 0
249248
sft_warmup_iteration: int = 0
250249
max_actor_ckpt_to_keep: Optional[int] = None

trinity/explorer/explorer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,10 @@ def get_weight(self, name: str) -> torch.Tensor:
149149

150150
def explore(self) -> None:
151151
"""Explore the entire dataset."""
152-
explore_status, _ = self.explore_step()
153-
while explore_status:
152+
while True:
153+
explore_status, _ = self.explore_step()
154+
if not explore_status:
155+
break
154156
self.sync_weight()
155157
self.logger.info("Explorer finished.")
156158

trinity/trainer/trainer.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from trinity.buffer import get_buffer_reader
1515
from trinity.common.config import Config
16-
from trinity.common.constants import AlgorithmType
16+
from trinity.common.constants import AlgorithmType, ReadStrategy
1717
from trinity.common.experience import Experiences
1818
from trinity.utils.log import get_logger
1919

@@ -81,24 +81,28 @@ def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple
8181
pad_token_id=self.config.buffer.pad_token_id, # type: ignore
8282
)
8383
)
84-
else:
85-
exps = self.train_buffer.read()
86-
if algo_type.is_rft():
87-
return self.engine.train_rft_iteration(
88-
Experiences.gather_experiences(
89-
exps,
90-
pad_token_id=self.config.buffer.pad_token_id, # type: ignore
91-
)
84+
elif algo_type.is_rft():
85+
if self.config.trainer.get_exp_strategy:
86+
strategy = ReadStrategy(self.config.trainer.get_exp_strategy)
87+
else:
88+
strategy = None
89+
exps = self.train_buffer.read(strategy=strategy)
90+
return self.engine.train_rft_iteration(
91+
Experiences.gather_experiences(
92+
exps,
93+
pad_token_id=self.config.buffer.pad_token_id, # type: ignore
9294
)
93-
elif algo_type.is_dpo():
94-
return self.engine.train_dpo_iteration(
95-
Experiences.gather_dpo_experiences(
96-
exps,
97-
pad_token_id=self.config.buffer.pad_token_id, # type: ignore
98-
)
95+
)
96+
elif algo_type.is_dpo():
97+
exps = self.train_buffer.read()
98+
return self.engine.train_dpo_iteration(
99+
Experiences.gather_dpo_experiences(
100+
exps,
101+
pad_token_id=self.config.buffer.pad_token_id, # type: ignore
99102
)
100-
else:
101-
raise ValueError(f"Unsupported algorithm type: {algo_type}")
103+
)
104+
else:
105+
raise ValueError(f"Unsupported algorithm type: {algo_type}")
102106

103107
def sync_weight(self) -> None:
104108
"""Sync the model weight."""

0 commit comments

Comments
 (0)