|
13 | 13 |
|
14 | 14 | from trinity.buffer import get_buffer_reader |
15 | 15 | from trinity.common.config import Config |
16 | | -from trinity.common.constants import AlgorithmType |
| 16 | +from trinity.common.constants import AlgorithmType, ReadStrategy |
17 | 17 | from trinity.common.experience import Experiences |
18 | 18 | from trinity.utils.log import get_logger |
19 | 19 |
|
@@ -81,24 +81,28 @@ def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple |
81 | 81 | pad_token_id=self.config.buffer.pad_token_id, # type: ignore |
82 | 82 | ) |
83 | 83 | ) |
84 | | - else: |
85 | | - exps = self.train_buffer.read() |
86 | | - if algo_type.is_rft(): |
87 | | - return self.engine.train_rft_iteration( |
88 | | - Experiences.gather_experiences( |
89 | | - exps, |
90 | | - pad_token_id=self.config.buffer.pad_token_id, # type: ignore |
91 | | - ) |
| 84 | + elif algo_type.is_rft(): |
| 85 | + if self.config.trainer.get_exp_strategy: |
| 86 | + strategy = ReadStrategy(self.config.trainer.get_exp_strategy) |
| 87 | + else: |
| 88 | + strategy = None |
| 89 | + exps = self.train_buffer.read(strategy=strategy) |
| 90 | + return self.engine.train_rft_iteration( |
| 91 | + Experiences.gather_experiences( |
| 92 | + exps, |
| 93 | + pad_token_id=self.config.buffer.pad_token_id, # type: ignore |
92 | 94 | ) |
93 | | - elif algo_type.is_dpo(): |
94 | | - return self.engine.train_dpo_iteration( |
95 | | - Experiences.gather_dpo_experiences( |
96 | | - exps, |
97 | | - pad_token_id=self.config.buffer.pad_token_id, # type: ignore |
98 | | - ) |
| 95 | + ) |
| 96 | + elif algo_type.is_dpo(): |
| 97 | + exps = self.train_buffer.read() |
| 98 | + return self.engine.train_dpo_iteration( |
| 99 | + Experiences.gather_dpo_experiences( |
| 100 | + exps, |
| 101 | + pad_token_id=self.config.buffer.pad_token_id, # type: ignore |
99 | 102 | ) |
100 | | - else: |
101 | | - raise ValueError(f"Unsupported algorithm type: {algo_type}") |
| 103 | + ) |
| 104 | + else: |
| 105 | + raise ValueError(f"Unsupported algorithm type: {algo_type}") |
102 | 106 |
|
103 | 107 | def sync_weight(self) -> None: |
104 | 108 | """Sync the model weight.""" |
|
0 commit comments