Skip to content

Commit f1d2547

Browse files
committed
feat: add BLEU metrics and translate RL training comments to English
- Add BLEU metric implementations (BLEU, Sentence-BLEU, Self-BLEU) - Translate Chinese comments to English in RL training examples - Update alignment_reward_fn.py, alignment_rl_dataset.py, base_dataset.py - Update reward_manager.py and grpo_training.sh configurations
1 parent 26f56af commit f1d2547

File tree

6 files changed

+373
-109
lines changed

6 files changed

+373
-109
lines changed

examples/train/rl_training/alignment_reward_fn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@
4747
MAX_WORKERS = int(os.getenv("MAX_WORKERS", "10"))
4848
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
4949
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "2048"))
50-
LLM_TIMEOUT = float(os.getenv("LLM_TIMEOUT", "30.0")) # LLM调用超时(秒)
51-
MAX_RETRIES = int(os.getenv("MAX_RETRIES", "2")) # LLM调用失败重试次数
50+
LLM_TIMEOUT = float(os.getenv("LLM_TIMEOUT", "30.0")) # LLM call timeout (seconds)
51+
MAX_RETRIES = int(os.getenv("MAX_RETRIES", "2")) # LLM call retry attempts on failure
5252

5353
# Debug Configuration
5454
VERBOSE = os.getenv("VERBOSE", "false").lower() == "true"
@@ -127,8 +127,8 @@ def custom_alignment_prompt(user_query, response_a, response_b, reference="", **
127127
max_tokens=MAX_TOKENS,
128128
max_workers=MAX_WORKERS,
129129
verbose=VERBOSE,
130-
llm_timeout=LLM_TIMEOUT, # 添加超时配置
131-
max_retries=MAX_RETRIES, # 添加重试配置
130+
llm_timeout=LLM_TIMEOUT, # Add timeout configuration
131+
max_retries=MAX_RETRIES, # Add retry configuration
132132
)
133133

134134
# Log configuration

examples/train/rl_training/alignment_rl_dataset.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -57,23 +57,23 @@ class DataKeys:
5757

5858

5959
class AlignmentChatRLDataset(BaseChatRLDataset):
60-
"""Alignment任务聊天强化学习数据集
60+
"""Alignment Chat RL Dataset
6161
62-
专门处理包含chosen/rejected格式的alignment数据
62+
Specialized for handling alignment data with chosen/rejected format
6363
"""
6464

6565
def __init__(self, data_files, tokenizer, config, processor=None):
6666
super().__init__(data_files, tokenizer, config, processor)
67-
print("使用 Alignment 模式")
67+
print("Using Alignment mode")
6868

6969
def _build_messages(self, example: dict) -> List[dict]:
70-
"""从样本构建聊天消息 - Alignment模式"""
70+
"""Build chat messages from sample - Alignment mode"""
7171
messages = []
7272

73-
# 优先从x字段构建消息
73+
# Priority: build messages from x field
7474
if "x" in example and example["x"] is not None:
7575
x_data = example["x"]
76-
# 处理numpy数组格式
76+
# Handle numpy array format
7777
if hasattr(x_data, "tolist"):
7878
x_data = x_data.tolist()
7979
elif not isinstance(x_data, (list, tuple)):
@@ -83,26 +83,26 @@ def _build_messages(self, example: dict) -> List[dict]:
8383
if isinstance(msg, dict) and msg.get("role") and msg.get("content"):
8484
messages.append({"role": msg["role"], "content": msg["content"]})
8585

86-
# 如果x字段为空,从chosen字段构建消息(取前面的对话,不包括最后的assistant回复)
86+
# If x field is empty, build messages from chosen field (excluding last assistant reply)
8787
elif DataKeys.CHOSEN in example and example[DataKeys.CHOSEN]:
8888
chosen_messages = example[DataKeys.CHOSEN]
89-
# 处理numpy数组格式
89+
# Handle numpy array format
9090
if hasattr(chosen_messages, "tolist"):
9191
chosen_messages = chosen_messages.tolist()
9292

9393
if isinstance(chosen_messages, (list, tuple)):
9494
for msg in chosen_messages:
9595
if isinstance(msg, dict) and msg.get("role") and msg.get("content"):
96-
# 只添加user消息,不添加assistant消息(因为那是要预测的目标)
96+
# Only add user messages, not assistant messages (as they are the prediction target)
9797
if msg.get("role") == "user":
9898
messages.append(
9999
{"role": msg["role"], "content": msg["content"]}
100100
)
101101

102-
# 如果还是没有找到消息,尝试从rejected字段构建
102+
# If still no messages found, try building from rejected field
103103
elif DataKeys.REJECTED in example and example[DataKeys.REJECTED]:
104104
rejected_messages = example[DataKeys.REJECTED]
105-
# 处理numpy数组格式
105+
# Handle numpy array format
106106
if hasattr(rejected_messages, "tolist"):
107107
rejected_messages = rejected_messages.tolist()
108108

@@ -114,39 +114,39 @@ def _build_messages(self, example: dict) -> List[dict]:
114114
{"role": msg["role"], "content": msg["content"]}
115115
)
116116

117-
# 如果还是没有消息,创建一个默认的用户消息
117+
# If still no messages, create a default user message
118118
if len(messages) == 0:
119-
messages = [{"role": "user", "content": "请协助完成这个任务。"}]
119+
messages = [{"role": "user", "content": "Please help complete this task."}]
120120

121121
return messages
122122

123123
def _format_template(self, messages: List[dict], example: dict) -> str:
124-
"""格式化alignment模板"""
124+
"""Format alignment template"""
125125
return messages
126126

127127
def _extract_ground_truth(self, row_dict):
128-
"""提取alignment真实标签
128+
"""Extract alignment ground truth
129129
130-
对于alignment数据,chosen可以作为一种"更好"的参考
130+
For alignment data, chosen can serve as a "better" reference
131131
"""
132132
try:
133133
ground_truth_info = {}
134134

135-
# 将chosen和rejected都保存到ground_truth中,供奖励函数使用
135+
# Save both chosen and rejected to ground_truth for reward function use
136136
chosen_key = DataKeys.CHOSEN
137137
rejected_key = DataKeys.REJECTED
138138
source_key = DataKeys.SOURCE
139139

140140
if chosen_key in row_dict and row_dict[chosen_key] is not None:
141141
chosen_data = row_dict[chosen_key]
142-
# 处理numpy数组格式
142+
# Handle numpy array format
143143
if hasattr(chosen_data, "tolist"):
144144
chosen_data = chosen_data.tolist()
145145
ground_truth_info[chosen_key] = chosen_data
146146

147147
if rejected_key in row_dict and row_dict[rejected_key] is not None:
148148
rejected_data = row_dict[rejected_key]
149-
# 处理numpy数组格式
149+
# Handle numpy array format
150150
if hasattr(rejected_data, "tolist"):
151151
rejected_data = rejected_data.tolist()
152152
ground_truth_info[rejected_key] = rejected_data
@@ -161,26 +161,26 @@ def _extract_ground_truth(self, row_dict):
161161
return {}
162162

163163
def __getitem__(self, item):
164-
"""获取数据集中的一个项目"""
164+
"""Get an item from the dataset"""
165165
row_dict = dict(self.dataframe[item])
166166
messages = self._build_messages(row_dict)
167167

168-
# 格式化提示
168+
# Format prompt
169169
raw_prompt_messages = self._format_template(messages, row_dict)
170170

171-
# 应用聊天模板
171+
# Apply chat template
172172
raw_prompt = self.tokenizer.apply_chat_template(
173173
raw_prompt_messages, add_generation_prompt=True, tokenize=False
174174
)
175175

176-
# 分词
176+
# Tokenize
177177
model_inputs = self.tokenizer(
178178
raw_prompt, return_tensors="pt", add_special_tokens=False
179179
)
180180
input_ids = model_inputs["input_ids"]
181181
attention_mask = model_inputs["attention_mask"]
182182

183-
# 后处理
183+
# Post-process
184184
input_ids, attention_mask = verl_F.postprocess_data(
185185
input_ids=input_ids,
186186
attention_mask=attention_mask,
@@ -190,10 +190,10 @@ def __getitem__(self, item):
190190
truncation=self.truncation,
191191
)
192192

193-
# 计算位置ID
193+
# Compute position IDs
194194
position_ids = compute_position_id_with_mask(attention_mask)
195195

196-
# 准备原始提示ID
196+
# Prepare raw prompt IDs
197197
raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
198198
if len(raw_prompt_ids) > self.max_prompt_length:
199199
if self.truncation == "left":
@@ -202,18 +202,18 @@ def __getitem__(self, item):
202202
raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
203203
elif self.truncation == "error":
204204
raise RuntimeError(
205-
f"提示长度 {len(raw_prompt_ids)} 超过 {self.max_prompt_length}"
205+
f"Prompt length {len(raw_prompt_ids)} exceeds {self.max_prompt_length}"
206206
)
207207

208-
# 构建x字段(用于传递给奖励函数)
208+
# Build x field (for passing to reward function)
209209
x_messages = []
210210

211-
# 从原始数据构建完整的对话上下文
211+
# Build complete conversation context from raw data
212212
chosen_key = DataKeys.CHOSEN
213213
if chosen_key in row_dict and row_dict[chosen_key]:
214-
# 使用chosen作为基础构建对话上下文
214+
# Use chosen as base for conversation context
215215
chosen_messages = row_dict[chosen_key]
216-
# 处理numpy数组格式
216+
# Handle numpy array format
217217
if hasattr(chosen_messages, "tolist"):
218218
chosen_messages = chosen_messages.tolist()
219219

@@ -224,11 +224,11 @@ def __getitem__(self, item):
224224
{"role": msg["role"], "content": msg["content"]}
225225
)
226226

227-
# 如果没有从chosen获取到消息,使用我们构建的messages
227+
# If no messages obtained from chosen, use our built messages
228228
if not x_messages:
229229
x_messages = messages
230230

231-
# 构建结果
231+
# Build result
232232
result = {
233233
"input_ids": input_ids[0],
234234
"attention_mask": attention_mask[0],
@@ -240,7 +240,7 @@ def __getitem__(self, item):
240240
"data_source": row_dict.get("source", "alignment"),
241241
}
242242

243-
# 添加x字段,包含对话上下文
243+
# Add x field with conversation context
244244
result["extra_info"]["x"] = x_messages
245245

246246
if self.return_raw_chat:

0 commit comments

Comments
 (0)