forked from agentscope-ai/Trinity-RFT
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
195 lines (173 loc) · 8.25 KB
/
utils.py
File metadata and controls
195 lines (173 loc) · 8.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""Utils for ccompatibility issues with verl."""
import os
from logging import Logger
import numpy as np
import torch
from verl import DataProto
from verl.trainer.ppo.metric_utils import _compute_response_info
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from trinity.common.config import Config
from trinity.common.experience import Experiences
def to_data_proto(experiences: Experiences, logger: Logger) -> DataProto: # noqa: C901
"""Convert Experiences to verl DataProto."""
attention_mask = experiences.attention_masks
cumsum = torch.cumsum(attention_mask, dim=-1)
position_ids = torch.clip(cumsum - 1, 0, None).long()
batch_dict = {
"uid": np.array([eid.tid for eid in experiences.eids]),
"unique_ids": np.array([eid.uid for eid in experiences.eids]),
"position_ids": position_ids,
"input_ids": experiences.tokens.long(),
"responses": experiences.tokens[:, experiences.prompt_length :].long(),
"attention_mask": attention_mask.long(),
"response_mask": (
experiences.action_masks.long()
if hasattr(experiences, "action_masks") and experiences.action_masks is not None
else attention_mask[:, experiences.prompt_length :].long()
),
}
if experiences.rewards is not None or experiences.token_level_rewards is not None:
assert experiences.logprobs is not None
if experiences.token_level_rewards is not None:
if experiences.rewards is not None:
logger.warning(
"Both experiences.rewards and experiences.token_level_rewards are provided. "
"Using experiences.token_level_rewards."
)
token_level_rewards = experiences.token_level_rewards
else:
token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype)
eos_mask_idx = cumsum.argmax(dim=-1)
token_level_rewards[
torch.arange(experiences.batch_size), eos_mask_idx
] = experiences.rewards
token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
batch_dict.update(
{
"token_level_scores": token_level_rewards,
"old_log_probs": experiences.logprobs, # type: ignore
}
)
if experiences.advantages is not None:
batch_dict["advantages"] = experiences.advantages
if experiences.returns is not None:
batch_dict["returns"] = experiences.returns
if experiences.multi_modal_inputs is not None:
batch_size = len(batch_dict["unique_ids"])
batch_dict["multi_modal_inputs"] = np.array(
[
{k: v[i] for k, v in experiences.multi_modal_inputs.items()}
for i in range(batch_size)
],
dtype=object,
)
if experiences.custom_fields:
for field in experiences.custom_fields:
if hasattr(experiences, field):
batch_dict[field] = getattr(experiences, field)
return DataProto.from_single_dict(batch_dict)
def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> dict:
"""
Computes various metrics from a batch of data for PPO training.
Modified from verl.trainer.ppo.metric_utils.compute_data_metrics
This function calculates metrics related to scores, rewards, advantages, returns, values,
and sequence lengths from a batch of data. It provides statistical information (mean, max, min)
for each metric category.
Args:
batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc.
use_critic: Whether to include critic-specific metrics. Defaults to True.
Returns:
A dictionary of metrics including:
- critic/score/mean, max, min: Statistics about sequence scores
- critic/rewards/mean, max, min: Statistics about sequence rewards
- critic/advantages/mean, max, min: Statistics about advantages
- critic/returns/mean, max, min: Statistics about returns
- critic/values/mean, max, min: Statistics about critic values (if use_critic=True)
- critic/vf_explained_var: Explained variance of the value function (if use_critic=True)
- response_length/mean, max, min, clip_ratio: Statistics about response lengths
- prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths
"""
metrics = {}
if "token_level_rewards" in batch.batch and "token_level_scores" in batch.batch:
sequence_score = batch.batch["token_level_scores"].sum(-1)
sequence_reward = batch.batch["token_level_rewards"].sum(-1)
metrics.update(
{
# score
"critic/score/mean": torch.mean(sequence_score).detach().item(),
"critic/score/max": torch.max(sequence_score).detach().item(),
"critic/score/min": torch.min(sequence_score).detach().item(),
# reward
"critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
"critic/rewards/max": torch.max(sequence_reward).detach().item(),
"critic/rewards/min": torch.min(sequence_reward).detach().item(),
}
)
max_response_length = batch.batch["responses"].shape[-1]
prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
max_prompt_length = prompt_mask.size(-1)
response_info = _compute_response_info(batch)
prompt_length = response_info["prompt_length"]
response_length = response_info["response_length"]
metrics.update(
{
# response length
"response_length/mean": torch.mean(response_length).detach().item(),
"response_length/max": torch.max(response_length).detach().item(),
"response_length/min": torch.min(response_length).detach().item(),
"response_length/clip_ratio": torch.mean(
torch.eq(response_length, max_response_length).float()
)
.detach()
.item(),
# prompt length
"prompt_length/mean": torch.mean(prompt_length).detach().item(),
"prompt_length/max": torch.max(prompt_length).detach().item(),
"prompt_length/min": torch.min(prompt_length).detach().item(),
"prompt_length/clip_ratio": torch.mean(
torch.eq(prompt_length, max_prompt_length).float()
)
.detach()
.item(),
}
)
if "advantages" in batch.batch:
# adv
advantages = batch.batch["advantages"]
if response_mask.numel() > 0:
valid_adv = torch.masked_select(advantages, response_mask)
else:
valid_adv = torch.zeros(1)
metrics.update(
{
# adv
"critic/advantages/mean": torch.mean(valid_adv).detach().item(),
"critic/advantages/max": torch.max(valid_adv).detach().item(),
"critic/advantages/min": torch.min(valid_adv).detach().item(),
}
)
if "returns" in batch.batch:
# returns
returns = batch.batch["returns"]
if response_mask.numel() > 0:
valid_returns = torch.masked_select(returns, response_mask)
else:
valid_returns = torch.zeros(1)
metrics.update(
{
"critic/returns/mean": torch.mean(valid_returns).detach().item(),
"critic/returns/max": torch.max(valid_returns).detach().item(),
"critic/returns/min": torch.min(valid_returns).detach().item(),
}
)
return metrics
def get_latest_hf_checkpoint_path(config: Config):
"""Get the latest huggingface checkpoint path"""
if config.trainer.trainer_type != "verl":
raise ValueError("This function is only for verl trainer.")
checkpoint_dir = find_latest_ckpt_path(config.checkpoint_job_dir)
hf_checkpoint_dir = os.path.join(checkpoint_dir, "actor", "huggingface")
if not os.path.exists(hf_checkpoint_dir):
raise ValueError(f"No huggingface checkpoint found in {hf_checkpoint_dir}")
return hf_checkpoint_dir