Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
303 changes: 303 additions & 0 deletions docs/sphinx_doc/source/tutorial/example_mix_algo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
# Integrate An New Algorithm


This guide introduces how to integrate a new algorithm to Trinity-RFT.
As an example, we incorporate some "expert" data generated by a more advanced LLM and propose an algorithm named MIX , which optimizes the following policy objective:

$$
\mathcal{J}_{\text{Mix}}(\theta) =
\mathcal{J}_{\text{GRPO}}(\theta)
+
\mu \cdot \underbrace{\frac{1}{B'} \sum_{b=1}^{B'}
\left[
\frac{1}{T'_b} \sum_{t=1}^{T'_b}
\log \pi_\theta(o'_{b,t} \mid q'_b, o'_{b,<t})
\right]}_{\text{Auxiliary Loss on Expert Data}}.
$$
The first term corresponds to the standard GRPO objective, which aims to maximize the expected reward. The last term is an auxiliary loss defined on expert data, encouraging the policy to imitate expert behavior. $\mu$ is a weighting factor that controls the relative importance of the two terms.


## Step 0: Prepare the Expert Data

We prompt a powerful LLM to generate responses with the CoT process for some pre-defined questions. The collected dta are viewed as some experiences from an expert. We store them in a JSON file `expert_data.json` with the following format:

```json
{
"question": "What is the average of 4, 6, and 8?",
"response": "I add the numbers together and divide by the count: 4 + 6 + 8 = 18, divided by 3 gives 6. The answer is 6."
}
...
```


## Step 1: Define the Algorithm

In `trinity/algorithm/algorithm.py`, we introduce a new algorithm type `MIX`.

```python
@ALGORITHM_TYPE.register_module("mix")
class MIXAlgorithm(AlgorithmType):
"""MIX algorithm."""

use_critic: bool = False
use_reference: bool = True
use_advantage: bool = True
use_rollout: bool = True
can_balance_batch: bool = True
schema: type = ExperienceModel

@classmethod
def get_default_config(cls) -> Dict:
return {
"repeat_times": 8,
"policy_loss_fn": "mix",
"advantage_fn": "grpo",
"sample_strategy": "mix",
}
```


## Step 2: Define the Sampling Strategy

We need to read two kinds of experiences: usual experiences and expert experiences in each step. For this purpose, we define a new experience sampling strategy named `MixSampleStrategy`.


```python
class MixSampleStrategy(SampleStrategy):
"""The default sample strategy."""

def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
super().__init__(buffer_config, trainer_type)
self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5)
tot_batch_size = buffer_config.read_batch_size
expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size)

# experience buffer
usual_buffer_config = copy.deepcopy(buffer_config)
usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size
self.usual_exp_buffer = get_buffer_reader(
buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore
)

if buffer_config.trainer_input.sft_warmup_dataset is None:
raise ValueError(
"`buffer_config.trainer_input.sft_warmup_dataset` is required in MIX algorithm"
)

# expert experience buffer
expert_buffer_config = copy.deepcopy(buffer_config)
expert_buffer_config.read_batch_size = expert_batch_size
self.expert_exp_buffer = get_buffer_reader(
buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config
)

def sample(self, step: int) -> Tuple[Any, Dict, List]:
metrics = {}
with Timer(metrics, "read_time"):
usual_exp_list = self.usual_exp_buffer.read()
for exp in usual_exp_list:
if exp.info is None:
exp.info = {}
exp.info["is_expert"] = False

expert_exp_list = self.expert_exp_buffer.read()
for exp in expert_exp_list:
exp.reward = 0.0
exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32)
if exp.info is None:
exp.info = {}
exp.info["is_expert"] = True

exp_list = usual_exp_list + expert_exp_list
repr_samples = representative_sample(exp_list)

is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool)

with Timer(metrics, "gather_time"):
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore

if self.trainer_type == "verl":
with Timer(metrics, "convert_time"):
data = to_data_proto_mix(exps, is_expert_mask)
return data, metrics, repr_samples
else:
raise NotImplementedError(f"backend {self.trainer_type} is not supported")
```

We also need to add an `is_expert_mask` field when transforming to DataProto to indicate the data type.

```diff
+ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto:
attention_mask = experiences.attention_masks
cumsum = torch.cumsum(attention_mask, dim=-1)
position_ids = torch.clip(cumsum - 1, 0, None).long()
batch_dict = {
"uid": np.array(experiences.run_ids),
"position_ids": position_ids,
"input_ids": experiences.tokens.long(),
"responses": experiences.tokens[:, experiences.prompt_length :].long(),
"attention_mask": attention_mask.long(),
"response_mask": (
experiences.action_masks[:, experiences.prompt_length :].long()
if hasattr(experiences, "action_masks") and experiences.action_masks is not None
else attention_mask[:, experiences.prompt_length :].long()
),
+ "is_expert_mask": is_expert_mask,
}
if experiences.rewards is not None:
token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype)
eos_mask_idx = cumsum.argmax(dim=-1)
token_level_rewards[
torch.arange(experiences.batch_size), eos_mask_idx
] = experiences.rewards
token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
batch_dict.update(
{
"token_level_scores": token_level_rewards,
"old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore
}
)
return DataProto.from_single_dict(batch_dict)
```


## Step 3: Define the Policy Loss Function

We define a `MixPolicyLoss` class in `trinity/algorithm/policy_loss_fn/mix_policy_loss.py`, which computes the sum of two loss terms regarding usual and expert experiences, respectively.

```python
@POLICY_LOSS_FN.register_module("mix")
class MIXPolicyLossFn(PolicyLossFn):
def __init__(
self,
mu: float = 0.1,
clip_range: Optional[float] = None,
clip_range_low: Optional[float] = None,
clip_range_high: Optional[float] = None,
use_dynamic_bsz: Optional[bool] = None,
repeat_times: Optional[int] = None,
ppo_mini_batch_size: Optional[int] = None,
ppo_micro_batch_size_per_gpu: Optional[int] = None,
ngpus_trainer: Optional[int] = None,
read_batch_size_usual: Optional[int] = None,
read_batch_size_expert: Optional[int] = None,
use_token_level_loss_in_sft: bool = True,
) -> None:
self.mu = mu
self.use_dynamic_bsz = use_dynamic_bsz
self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore
self.gradient_accumulation = (
ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu # type: ignore
)
self.read_batch_size_usual = read_batch_size_usual
self.read_batch_size_expert = read_batch_size_expert
self.grpo_loss_fn = PPOPolicyLossFn(
clip_range=clip_range,
clip_range_low=clip_range_low,
clip_range_high=clip_range_high,
)
self.sft_loss_fn = SFTLossFn(use_token_level_loss=use_token_level_loss_in_sft)

def __call__( # type: ignore
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
is_expert_mask = kwargs.get("is_expert_mask", None)
if is_expert_mask is None:
raise ValueError("is_expert_mask is required in MIX")
assert (
len(is_expert_mask) == logprob.shape[0]
), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}"

n_usual_exp = torch.sum(~is_expert_mask).item()
n_expert_exp = torch.sum(is_expert_mask).item()

if self.use_dynamic_bsz:
per_micro_batch_weight_usual = self.experience_per_gpu / (
logprob.shape[0] * self.read_batch_size_usual
)
per_micro_batch_weight_expert = self.experience_per_gpu / (
logprob.shape[0] * self.read_batch_size_expert
)
else:
per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual # type: ignore
per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert # type: ignore

if n_usual_exp > 0:
grpo_loss, grpo_metrics = self.grpo_loss_fn(
logprob[~is_expert_mask],
old_logprob[~is_expert_mask],
action_mask[~is_expert_mask],
advantages[~is_expert_mask],
**kwargs,
)
grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual
grpo_metrics = {
k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items()
}
else:
grpo_loss = torch.tensor(0.0, device=logprob.device)
grpo_metrics = {}

# SFT Loss (expert)
if n_expert_exp > 0:
sft_loss, sft_metrics = self.sft_loss_fn(
logprob[is_expert_mask],
action_mask[is_expert_mask],
)
sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert
sft_metrics = {
k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items()
}
else:
sft_loss = torch.tensor(0.0, device=logprob.device)
sft_metrics = {}

loss = (1 - self.mu) * grpo_loss + self.mu * sft_loss

metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()}
metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()})
metrics.update({"loss": loss.item()})

return loss, metrics

@classmethod
def default_args(cls) -> Dict:
return {
"mu": 0.1,
"clip_range": 0.2,
}

@property
def select_keys(self) -> List[str]:
return ["old_logprob", "action_mask", "advantages", "is_expert_mask"]
```

## Step 4: Run the Experiment

With the above newly-defined classes and functions, we can run the experiments without modifying other process.
An example showing some important configurations is shown below, including the weighting factor $\mu$ as `algorithm.policy_loss_fn_args['mu']` and the batch size of expert experiences $B'$, calculated as the product of `buffer.batch_size`, `algorithm.sample_strategy_args['expert_data_ratio']` and `algorithm.repeat_times`.
For the full configuration, please refer to [`mix_math.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/mix_math/mix_math.yaml) and [`train_mix_math.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/mix_math/train_mix_math.yaml).

```yaml
algorithm:
algorithm_type: mix
repeat_times: 8
sample_strategy_args:
expert_data_ratio: 0.25
policy_loss_fn_args:
mu: 0.1
clip_range: 0.2
use_token_level_loss_in_sft: False
use_dynamic_bsz: False
repeat_times: 8
ppo_mini_batch_size: 32
ppo_micro_batch_size_per_gpu: 4
ngpus_trainer: 4
read_batch_size_expert: 64
read_batch_size_usual: 192
```
7 changes: 7 additions & 0 deletions examples/mix_math/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Example: MIX on MATH dataset

This example shows the usage of a new algorithm MIX on the MATH dataset.

For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_mix_algo.md).

The config files are located in [`mix_math.yaml`](mix.yaml) and [`train_mix_math.yaml`](train_mix_math.yaml).
Loading