|
| 1 | +# Integrate An New Algorithm |
| 2 | + |
| 3 | + |
| 4 | +This guide introduces how to integrate a new algorithm to Trinity-RFT. |
| 5 | +As an example, we incorporate some "expert" data generated by a more advanced LLM and propose an algorithm named MIX , which optimizes the following policy objective: |
| 6 | + |
| 7 | +$$ |
| 8 | +\mathcal{J}_{\text{Mix}}(\theta) = |
| 9 | +\mathcal{J}_{\text{GRPO}}(\theta) |
| 10 | ++ |
| 11 | +\mu \cdot \underbrace{\frac{1}{B'} \sum_{b=1}^{B'} |
| 12 | +\left[ |
| 13 | + \frac{1}{T'_b} \sum_{t=1}^{T'_b} |
| 14 | + \log \pi_\theta(o'_{b,t} \mid q'_b, o'_{b,<t}) |
| 15 | +\right]}_{\text{Auxiliary Loss on Expert Data}}. |
| 16 | +$$ |
| 17 | +The first term corresponds to the standard GRPO objective, which aims to maximize the expected reward. The last term is an auxiliary loss defined on expert data, encouraging the policy to imitate expert behavior. $\mu$ is a weighting factor that controls the relative importance of the two terms. |
| 18 | + |
| 19 | + |
| 20 | +## Step 0: Prepare the Expert Data |
| 21 | + |
| 22 | +We prompt a powerful LLM to generate responses with the CoT process for some pre-defined questions. The collected dta are viewed as some experiences from an expert. We store them in a JSON file `expert_data.json` with the following format: |
| 23 | + |
| 24 | +```json |
| 25 | +{ |
| 26 | + "question": "What is the average of 4, 6, and 8?", |
| 27 | + "response": "I add the numbers together and divide by the count: 4 + 6 + 8 = 18, divided by 3 gives 6. The answer is 6." |
| 28 | +} |
| 29 | +... |
| 30 | +``` |
| 31 | + |
| 32 | + |
| 33 | +## Step 1: Define the Algorithm |
| 34 | + |
| 35 | +In `trinity/algorithm/algorithm.py`, we introduce a new algorithm type `MIX`. |
| 36 | + |
| 37 | +```python |
| 38 | +@ALGORITHM_TYPE.register_module("mix") |
| 39 | +class MIXAlgorithm(AlgorithmType): |
| 40 | + """MIX algorithm.""" |
| 41 | + |
| 42 | + use_critic: bool = False |
| 43 | + use_reference: bool = True |
| 44 | + use_advantage: bool = True |
| 45 | + use_rollout: bool = True |
| 46 | + can_balance_batch: bool = True |
| 47 | + schema: type = ExperienceModel |
| 48 | + |
| 49 | + @classmethod |
| 50 | + def get_default_config(cls) -> Dict: |
| 51 | + return { |
| 52 | + "repeat_times": 8, |
| 53 | + "policy_loss_fn": "mix", |
| 54 | + "advantage_fn": "grpo", |
| 55 | + "sample_strategy": "mix", |
| 56 | + } |
| 57 | +``` |
| 58 | + |
| 59 | + |
| 60 | +## Step 2: Define the Sampling Strategy |
| 61 | + |
| 62 | +We need to read two kinds of experiences: usual experiences and expert experiences in each step. For this purpose, we define a new experience sampling strategy named `MixSampleStrategy`. |
| 63 | + |
| 64 | + |
| 65 | +```python |
| 66 | +class MixSampleStrategy(SampleStrategy): |
| 67 | + """The default sample strategy.""" |
| 68 | + |
| 69 | + def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): |
| 70 | + super().__init__(buffer_config, trainer_type) |
| 71 | + self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5) |
| 72 | + tot_batch_size = buffer_config.read_batch_size |
| 73 | + expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) |
| 74 | + |
| 75 | + # experience buffer |
| 76 | + usual_buffer_config = copy.deepcopy(buffer_config) |
| 77 | + usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size |
| 78 | + self.usual_exp_buffer = get_buffer_reader( |
| 79 | + buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore |
| 80 | + ) |
| 81 | + |
| 82 | + if buffer_config.trainer_input.sft_warmup_dataset is None: |
| 83 | + raise ValueError( |
| 84 | + "`buffer_config.trainer_input.sft_warmup_dataset` is required in MIX algorithm" |
| 85 | + ) |
| 86 | + |
| 87 | + # expert experience buffer |
| 88 | + expert_buffer_config = copy.deepcopy(buffer_config) |
| 89 | + expert_buffer_config.read_batch_size = expert_batch_size |
| 90 | + self.expert_exp_buffer = get_buffer_reader( |
| 91 | + buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config |
| 92 | + ) |
| 93 | + |
| 94 | + def sample(self, step: int) -> Tuple[Any, Dict, List]: |
| 95 | + metrics = {} |
| 96 | + with Timer(metrics, "read_time"): |
| 97 | + usual_exp_list = self.usual_exp_buffer.read() |
| 98 | + for exp in usual_exp_list: |
| 99 | + if exp.info is None: |
| 100 | + exp.info = {} |
| 101 | + exp.info["is_expert"] = False |
| 102 | + |
| 103 | + expert_exp_list = self.expert_exp_buffer.read() |
| 104 | + for exp in expert_exp_list: |
| 105 | + exp.reward = 0.0 |
| 106 | + exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32) |
| 107 | + if exp.info is None: |
| 108 | + exp.info = {} |
| 109 | + exp.info["is_expert"] = True |
| 110 | + |
| 111 | + exp_list = usual_exp_list + expert_exp_list |
| 112 | + repr_samples = representative_sample(exp_list) |
| 113 | + |
| 114 | + is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool) |
| 115 | + |
| 116 | + with Timer(metrics, "gather_time"): |
| 117 | + exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore |
| 118 | + |
| 119 | + if self.trainer_type == "verl": |
| 120 | + with Timer(metrics, "convert_time"): |
| 121 | + data = to_data_proto_mix(exps, is_expert_mask) |
| 122 | + return data, metrics, repr_samples |
| 123 | + else: |
| 124 | + raise NotImplementedError(f"backend {self.trainer_type} is not supported") |
| 125 | +``` |
| 126 | + |
| 127 | +We also need to add an `is_expert_mask` field when transforming to DataProto to indicate the data type. |
| 128 | + |
| 129 | +```diff |
| 130 | ++ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto: |
| 131 | + attention_mask = experiences.attention_masks |
| 132 | + cumsum = torch.cumsum(attention_mask, dim=-1) |
| 133 | + position_ids = torch.clip(cumsum - 1, 0, None).long() |
| 134 | + batch_dict = { |
| 135 | + "uid": np.array(experiences.run_ids), |
| 136 | + "position_ids": position_ids, |
| 137 | + "input_ids": experiences.tokens.long(), |
| 138 | + "responses": experiences.tokens[:, experiences.prompt_length :].long(), |
| 139 | + "attention_mask": attention_mask.long(), |
| 140 | + "response_mask": ( |
| 141 | + experiences.action_masks[:, experiences.prompt_length :].long() |
| 142 | + if hasattr(experiences, "action_masks") and experiences.action_masks is not None |
| 143 | + else attention_mask[:, experiences.prompt_length :].long() |
| 144 | + ), |
| 145 | ++ "is_expert_mask": is_expert_mask, |
| 146 | + } |
| 147 | + if experiences.rewards is not None: |
| 148 | + token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) |
| 149 | + eos_mask_idx = cumsum.argmax(dim=-1) |
| 150 | + token_level_rewards[ |
| 151 | + torch.arange(experiences.batch_size), eos_mask_idx |
| 152 | + ] = experiences.rewards |
| 153 | + token_level_rewards = token_level_rewards[:, experiences.prompt_length :] |
| 154 | + batch_dict.update( |
| 155 | + { |
| 156 | + "token_level_scores": token_level_rewards, |
| 157 | + "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore |
| 158 | + } |
| 159 | + ) |
| 160 | + return DataProto.from_single_dict(batch_dict) |
| 161 | +``` |
| 162 | + |
| 163 | + |
| 164 | +## Step 3: Define the Policy Loss Function |
| 165 | + |
| 166 | +We define a `MixPolicyLoss` class in `trinity/algorithm/policy_loss_fn/mix_policy_loss.py`, which computes the sum of two loss terms regarding usual and expert experiences, respectively. |
| 167 | + |
| 168 | +```python |
| 169 | +@POLICY_LOSS_FN.register_module("mix") |
| 170 | +class MIXPolicyLossFn(PolicyLossFn): |
| 171 | + def __init__( |
| 172 | + self, |
| 173 | + mu: float = 0.1, |
| 174 | + clip_range: Optional[float] = None, |
| 175 | + clip_range_low: Optional[float] = None, |
| 176 | + clip_range_high: Optional[float] = None, |
| 177 | + use_dynamic_bsz: Optional[bool] = None, |
| 178 | + repeat_times: Optional[int] = None, |
| 179 | + ppo_mini_batch_size: Optional[int] = None, |
| 180 | + ppo_micro_batch_size_per_gpu: Optional[int] = None, |
| 181 | + ngpus_trainer: Optional[int] = None, |
| 182 | + read_batch_size_usual: Optional[int] = None, |
| 183 | + read_batch_size_expert: Optional[int] = None, |
| 184 | + use_token_level_loss_in_sft: bool = True, |
| 185 | + ) -> None: |
| 186 | + self.mu = mu |
| 187 | + self.use_dynamic_bsz = use_dynamic_bsz |
| 188 | + self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore |
| 189 | + self.gradient_accumulation = ( |
| 190 | + ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu # type: ignore |
| 191 | + ) |
| 192 | + self.read_batch_size_usual = read_batch_size_usual |
| 193 | + self.read_batch_size_expert = read_batch_size_expert |
| 194 | + self.grpo_loss_fn = PPOPolicyLossFn( |
| 195 | + clip_range=clip_range, |
| 196 | + clip_range_low=clip_range_low, |
| 197 | + clip_range_high=clip_range_high, |
| 198 | + ) |
| 199 | + self.sft_loss_fn = SFTLossFn(use_token_level_loss=use_token_level_loss_in_sft) |
| 200 | + |
| 201 | + def __call__( # type: ignore |
| 202 | + self, |
| 203 | + logprob: torch.Tensor, |
| 204 | + old_logprob: torch.Tensor, |
| 205 | + action_mask: torch.Tensor, |
| 206 | + advantages: torch.Tensor, |
| 207 | + **kwargs, |
| 208 | + ) -> Tuple[torch.Tensor, Dict]: |
| 209 | + is_expert_mask = kwargs.get("is_expert_mask", None) |
| 210 | + if is_expert_mask is None: |
| 211 | + raise ValueError("is_expert_mask is required in MIX") |
| 212 | + assert ( |
| 213 | + len(is_expert_mask) == logprob.shape[0] |
| 214 | + ), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}" |
| 215 | + |
| 216 | + n_usual_exp = torch.sum(~is_expert_mask).item() |
| 217 | + n_expert_exp = torch.sum(is_expert_mask).item() |
| 218 | + |
| 219 | + if self.use_dynamic_bsz: |
| 220 | + per_micro_batch_weight_usual = self.experience_per_gpu / ( |
| 221 | + logprob.shape[0] * self.read_batch_size_usual |
| 222 | + ) |
| 223 | + per_micro_batch_weight_expert = self.experience_per_gpu / ( |
| 224 | + logprob.shape[0] * self.read_batch_size_expert |
| 225 | + ) |
| 226 | + else: |
| 227 | + per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual # type: ignore |
| 228 | + per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert # type: ignore |
| 229 | + |
| 230 | + if n_usual_exp > 0: |
| 231 | + grpo_loss, grpo_metrics = self.grpo_loss_fn( |
| 232 | + logprob[~is_expert_mask], |
| 233 | + old_logprob[~is_expert_mask], |
| 234 | + action_mask[~is_expert_mask], |
| 235 | + advantages[~is_expert_mask], |
| 236 | + **kwargs, |
| 237 | + ) |
| 238 | + grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual |
| 239 | + grpo_metrics = { |
| 240 | + k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items() |
| 241 | + } |
| 242 | + else: |
| 243 | + grpo_loss = torch.tensor(0.0, device=logprob.device) |
| 244 | + grpo_metrics = {} |
| 245 | + |
| 246 | + # SFT Loss (expert) |
| 247 | + if n_expert_exp > 0: |
| 248 | + sft_loss, sft_metrics = self.sft_loss_fn( |
| 249 | + logprob[is_expert_mask], |
| 250 | + action_mask[is_expert_mask], |
| 251 | + ) |
| 252 | + sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert |
| 253 | + sft_metrics = { |
| 254 | + k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items() |
| 255 | + } |
| 256 | + else: |
| 257 | + sft_loss = torch.tensor(0.0, device=logprob.device) |
| 258 | + sft_metrics = {} |
| 259 | + |
| 260 | + loss = (1 - self.mu) * grpo_loss + self.mu * sft_loss |
| 261 | + |
| 262 | + metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()} |
| 263 | + metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) |
| 264 | + metrics.update({"loss": loss.item()}) |
| 265 | + |
| 266 | + return loss, metrics |
| 267 | + |
| 268 | + @classmethod |
| 269 | + def default_args(cls) -> Dict: |
| 270 | + return { |
| 271 | + "mu": 0.1, |
| 272 | + "clip_range": 0.2, |
| 273 | + } |
| 274 | + |
| 275 | + @property |
| 276 | + def select_keys(self) -> List[str]: |
| 277 | + return ["old_logprob", "action_mask", "advantages", "is_expert_mask"] |
| 278 | +``` |
| 279 | + |
| 280 | +## Step 4: Run the Experiment |
| 281 | + |
| 282 | +With the above newly-defined classes and functions, we can run the experiments without modifying other process. |
| 283 | +An example showing some important configurations is shown below, including the weighting factor $\mu$ as `algorithm.policy_loss_fn_args['mu']` and the batch size of expert experiences $B'$, calculated as the product of `buffer.batch_size`, `algorithm.sample_strategy_args['expert_data_ratio']` and `algorithm.repeat_times`. |
| 284 | +For the full configuration, please refer to [`mix_math.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/mix_math/mix_math.yaml) and [`train_mix_math.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/mix_math/train_mix_math.yaml). |
| 285 | + |
| 286 | +```yaml |
| 287 | +algorithm: |
| 288 | + algorithm_type: mix |
| 289 | + repeat_times: 8 |
| 290 | + sample_strategy_args: |
| 291 | + expert_data_ratio: 0.25 |
| 292 | + policy_loss_fn_args: |
| 293 | + mu: 0.1 |
| 294 | + clip_range: 0.2 |
| 295 | + use_token_level_loss_in_sft: False |
| 296 | + use_dynamic_bsz: False |
| 297 | + repeat_times: 8 |
| 298 | + ppo_mini_batch_size: 32 |
| 299 | + ppo_micro_batch_size_per_gpu: 4 |
| 300 | + ngpus_trainer: 4 |
| 301 | + read_batch_size_expert: 64 |
| 302 | + read_batch_size_usual: 192 |
| 303 | +``` |
0 commit comments