Skip to content

Commit b8d1faa

Browse files
authored
[Feature] Add MIX algorithm (#83)
1 parent aeabfe5 commit b8d1faa

File tree

10 files changed

+750
-1
lines changed

10 files changed

+750
-1
lines changed
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
# Integrate An New Algorithm
2+
3+
4+
This guide introduces how to integrate a new algorithm to Trinity-RFT.
5+
As an example, we incorporate some "expert" data generated by a more advanced LLM and propose an algorithm named MIX , which optimizes the following policy objective:
6+
7+
$$
8+
\mathcal{J}_{\text{Mix}}(\theta) =
9+
\mathcal{J}_{\text{GRPO}}(\theta)
10+
+
11+
\mu \cdot \underbrace{\frac{1}{B'} \sum_{b=1}^{B'}
12+
\left[
13+
\frac{1}{T'_b} \sum_{t=1}^{T'_b}
14+
\log \pi_\theta(o'_{b,t} \mid q'_b, o'_{b,<t})
15+
\right]}_{\text{Auxiliary Loss on Expert Data}}.
16+
$$
17+
The first term corresponds to the standard GRPO objective, which aims to maximize the expected reward. The last term is an auxiliary loss defined on expert data, encouraging the policy to imitate expert behavior. $\mu$ is a weighting factor that controls the relative importance of the two terms.
18+
19+
20+
## Step 0: Prepare the Expert Data
21+
22+
We prompt a powerful LLM to generate responses with the CoT process for some pre-defined questions. The collected dta are viewed as some experiences from an expert. We store them in a JSON file `expert_data.json` with the following format:
23+
24+
```json
25+
{
26+
"question": "What is the average of 4, 6, and 8?",
27+
"response": "I add the numbers together and divide by the count: 4 + 6 + 8 = 18, divided by 3 gives 6. The answer is 6."
28+
}
29+
...
30+
```
31+
32+
33+
## Step 1: Define the Algorithm
34+
35+
In `trinity/algorithm/algorithm.py`, we introduce a new algorithm type `MIX`.
36+
37+
```python
38+
@ALGORITHM_TYPE.register_module("mix")
39+
class MIXAlgorithm(AlgorithmType):
40+
"""MIX algorithm."""
41+
42+
use_critic: bool = False
43+
use_reference: bool = True
44+
use_advantage: bool = True
45+
use_rollout: bool = True
46+
can_balance_batch: bool = True
47+
schema: type = ExperienceModel
48+
49+
@classmethod
50+
def get_default_config(cls) -> Dict:
51+
return {
52+
"repeat_times": 8,
53+
"policy_loss_fn": "mix",
54+
"advantage_fn": "grpo",
55+
"sample_strategy": "mix",
56+
}
57+
```
58+
59+
60+
## Step 2: Define the Sampling Strategy
61+
62+
We need to read two kinds of experiences: usual experiences and expert experiences in each step. For this purpose, we define a new experience sampling strategy named `MixSampleStrategy`.
63+
64+
65+
```python
66+
class MixSampleStrategy(SampleStrategy):
67+
"""The default sample strategy."""
68+
69+
def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
70+
super().__init__(buffer_config, trainer_type)
71+
self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5)
72+
tot_batch_size = buffer_config.read_batch_size
73+
expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size)
74+
75+
# experience buffer
76+
usual_buffer_config = copy.deepcopy(buffer_config)
77+
usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size
78+
self.usual_exp_buffer = get_buffer_reader(
79+
buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore
80+
)
81+
82+
if buffer_config.trainer_input.sft_warmup_dataset is None:
83+
raise ValueError(
84+
"`buffer_config.trainer_input.sft_warmup_dataset` is required in MIX algorithm"
85+
)
86+
87+
# expert experience buffer
88+
expert_buffer_config = copy.deepcopy(buffer_config)
89+
expert_buffer_config.read_batch_size = expert_batch_size
90+
self.expert_exp_buffer = get_buffer_reader(
91+
buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config
92+
)
93+
94+
def sample(self, step: int) -> Tuple[Any, Dict, List]:
95+
metrics = {}
96+
with Timer(metrics, "read_time"):
97+
usual_exp_list = self.usual_exp_buffer.read()
98+
for exp in usual_exp_list:
99+
if exp.info is None:
100+
exp.info = {}
101+
exp.info["is_expert"] = False
102+
103+
expert_exp_list = self.expert_exp_buffer.read()
104+
for exp in expert_exp_list:
105+
exp.reward = 0.0
106+
exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32)
107+
if exp.info is None:
108+
exp.info = {}
109+
exp.info["is_expert"] = True
110+
111+
exp_list = usual_exp_list + expert_exp_list
112+
repr_samples = representative_sample(exp_list)
113+
114+
is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool)
115+
116+
with Timer(metrics, "gather_time"):
117+
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
118+
119+
if self.trainer_type == "verl":
120+
with Timer(metrics, "convert_time"):
121+
data = to_data_proto_mix(exps, is_expert_mask)
122+
return data, metrics, repr_samples
123+
else:
124+
raise NotImplementedError(f"backend {self.trainer_type} is not supported")
125+
```
126+
127+
We also need to add an `is_expert_mask` field when transforming to DataProto to indicate the data type.
128+
129+
```diff
130+
+ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto:
131+
attention_mask = experiences.attention_masks
132+
cumsum = torch.cumsum(attention_mask, dim=-1)
133+
position_ids = torch.clip(cumsum - 1, 0, None).long()
134+
batch_dict = {
135+
"uid": np.array(experiences.run_ids),
136+
"position_ids": position_ids,
137+
"input_ids": experiences.tokens.long(),
138+
"responses": experiences.tokens[:, experiences.prompt_length :].long(),
139+
"attention_mask": attention_mask.long(),
140+
"response_mask": (
141+
experiences.action_masks[:, experiences.prompt_length :].long()
142+
if hasattr(experiences, "action_masks") and experiences.action_masks is not None
143+
else attention_mask[:, experiences.prompt_length :].long()
144+
),
145+
+ "is_expert_mask": is_expert_mask,
146+
}
147+
if experiences.rewards is not None:
148+
token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype)
149+
eos_mask_idx = cumsum.argmax(dim=-1)
150+
token_level_rewards[
151+
torch.arange(experiences.batch_size), eos_mask_idx
152+
] = experiences.rewards
153+
token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
154+
batch_dict.update(
155+
{
156+
"token_level_scores": token_level_rewards,
157+
"old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore
158+
}
159+
)
160+
return DataProto.from_single_dict(batch_dict)
161+
```
162+
163+
164+
## Step 3: Define the Policy Loss Function
165+
166+
We define a `MixPolicyLoss` class in `trinity/algorithm/policy_loss_fn/mix_policy_loss.py`, which computes the sum of two loss terms regarding usual and expert experiences, respectively.
167+
168+
```python
169+
@POLICY_LOSS_FN.register_module("mix")
170+
class MIXPolicyLossFn(PolicyLossFn):
171+
def __init__(
172+
self,
173+
mu: float = 0.1,
174+
clip_range: Optional[float] = None,
175+
clip_range_low: Optional[float] = None,
176+
clip_range_high: Optional[float] = None,
177+
use_dynamic_bsz: Optional[bool] = None,
178+
repeat_times: Optional[int] = None,
179+
ppo_mini_batch_size: Optional[int] = None,
180+
ppo_micro_batch_size_per_gpu: Optional[int] = None,
181+
ngpus_trainer: Optional[int] = None,
182+
read_batch_size_usual: Optional[int] = None,
183+
read_batch_size_expert: Optional[int] = None,
184+
use_token_level_loss_in_sft: bool = True,
185+
) -> None:
186+
self.mu = mu
187+
self.use_dynamic_bsz = use_dynamic_bsz
188+
self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore
189+
self.gradient_accumulation = (
190+
ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu # type: ignore
191+
)
192+
self.read_batch_size_usual = read_batch_size_usual
193+
self.read_batch_size_expert = read_batch_size_expert
194+
self.grpo_loss_fn = PPOPolicyLossFn(
195+
clip_range=clip_range,
196+
clip_range_low=clip_range_low,
197+
clip_range_high=clip_range_high,
198+
)
199+
self.sft_loss_fn = SFTLossFn(use_token_level_loss=use_token_level_loss_in_sft)
200+
201+
def __call__( # type: ignore
202+
self,
203+
logprob: torch.Tensor,
204+
old_logprob: torch.Tensor,
205+
action_mask: torch.Tensor,
206+
advantages: torch.Tensor,
207+
**kwargs,
208+
) -> Tuple[torch.Tensor, Dict]:
209+
is_expert_mask = kwargs.get("is_expert_mask", None)
210+
if is_expert_mask is None:
211+
raise ValueError("is_expert_mask is required in MIX")
212+
assert (
213+
len(is_expert_mask) == logprob.shape[0]
214+
), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}"
215+
216+
n_usual_exp = torch.sum(~is_expert_mask).item()
217+
n_expert_exp = torch.sum(is_expert_mask).item()
218+
219+
if self.use_dynamic_bsz:
220+
per_micro_batch_weight_usual = self.experience_per_gpu / (
221+
logprob.shape[0] * self.read_batch_size_usual
222+
)
223+
per_micro_batch_weight_expert = self.experience_per_gpu / (
224+
logprob.shape[0] * self.read_batch_size_expert
225+
)
226+
else:
227+
per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual # type: ignore
228+
per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert # type: ignore
229+
230+
if n_usual_exp > 0:
231+
grpo_loss, grpo_metrics = self.grpo_loss_fn(
232+
logprob[~is_expert_mask],
233+
old_logprob[~is_expert_mask],
234+
action_mask[~is_expert_mask],
235+
advantages[~is_expert_mask],
236+
**kwargs,
237+
)
238+
grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual
239+
grpo_metrics = {
240+
k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items()
241+
}
242+
else:
243+
grpo_loss = torch.tensor(0.0, device=logprob.device)
244+
grpo_metrics = {}
245+
246+
# SFT Loss (expert)
247+
if n_expert_exp > 0:
248+
sft_loss, sft_metrics = self.sft_loss_fn(
249+
logprob[is_expert_mask],
250+
action_mask[is_expert_mask],
251+
)
252+
sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert
253+
sft_metrics = {
254+
k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items()
255+
}
256+
else:
257+
sft_loss = torch.tensor(0.0, device=logprob.device)
258+
sft_metrics = {}
259+
260+
loss = (1 - self.mu) * grpo_loss + self.mu * sft_loss
261+
262+
metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()}
263+
metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()})
264+
metrics.update({"loss": loss.item()})
265+
266+
return loss, metrics
267+
268+
@classmethod
269+
def default_args(cls) -> Dict:
270+
return {
271+
"mu": 0.1,
272+
"clip_range": 0.2,
273+
}
274+
275+
@property
276+
def select_keys(self) -> List[str]:
277+
return ["old_logprob", "action_mask", "advantages", "is_expert_mask"]
278+
```
279+
280+
## Step 4: Run the Experiment
281+
282+
With the above newly-defined classes and functions, we can run the experiments without modifying other process.
283+
An example showing some important configurations is shown below, including the weighting factor $\mu$ as `algorithm.policy_loss_fn_args['mu']` and the batch size of expert experiences $B'$, calculated as the product of `buffer.batch_size`, `algorithm.sample_strategy_args['expert_data_ratio']` and `algorithm.repeat_times`.
284+
For the full configuration, please refer to [`mix_math.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/mix_math/mix_math.yaml) and [`train_mix_math.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/mix_math/train_mix_math.yaml).
285+
286+
```yaml
287+
algorithm:
288+
algorithm_type: mix
289+
repeat_times: 8
290+
sample_strategy_args:
291+
expert_data_ratio: 0.25
292+
policy_loss_fn_args:
293+
mu: 0.1
294+
clip_range: 0.2
295+
use_token_level_loss_in_sft: False
296+
use_dynamic_bsz: False
297+
repeat_times: 8
298+
ppo_mini_batch_size: 32
299+
ppo_micro_batch_size_per_gpu: 4
300+
ngpus_trainer: 4
301+
read_batch_size_expert: 64
302+
read_batch_size_usual: 192
303+
```

examples/mix_math/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Example: MIX on MATH dataset
2+
3+
This example shows the usage of a new algorithm MIX on the MATH dataset.
4+
5+
For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_mix_algo.md).
6+
7+
The config files are located in [`mix_math.yaml`](mix.yaml) and [`train_mix_math.yaml`](train_mix_math.yaml).

0 commit comments

Comments
 (0)