@@ -57,23 +57,23 @@ class DataKeys:
5757
5858
5959class AlignmentChatRLDataset (BaseChatRLDataset ):
60- """Alignment任务聊天强化学习数据集
60+ """Alignment Chat RL Dataset
6161
62- 专门处理包含chosen/rejected格式的alignment数据
62+ Specialized for handling alignment data with chosen/rejected format
6363 """
6464
6565 def __init__ (self , data_files , tokenizer , config , processor = None ):
6666 super ().__init__ (data_files , tokenizer , config , processor )
67- print ("使用 Alignment 模式 " )
67+ print ("Using Alignment mode " )
6868
6969 def _build_messages (self , example : dict ) -> List [dict ]:
70- """从样本构建聊天消息 - Alignment模式 """
70+ """Build chat messages from sample - Alignment mode """
7171 messages = []
7272
73- # 优先从x字段构建消息
73+ # Priority: build messages from x field
7474 if "x" in example and example ["x" ] is not None :
7575 x_data = example ["x" ]
76- # 处理numpy数组格式
76+ # Handle numpy array format
7777 if hasattr (x_data , "tolist" ):
7878 x_data = x_data .tolist ()
7979 elif not isinstance (x_data , (list , tuple )):
@@ -83,26 +83,26 @@ def _build_messages(self, example: dict) -> List[dict]:
8383 if isinstance (msg , dict ) and msg .get ("role" ) and msg .get ("content" ):
8484 messages .append ({"role" : msg ["role" ], "content" : msg ["content" ]})
8585
86- # 如果x字段为空,从chosen字段构建消息(取前面的对话,不包括最后的assistant回复)
86+ # If x field is empty, build messages from chosen field (excluding last assistant reply)
8787 elif DataKeys .CHOSEN in example and example [DataKeys .CHOSEN ]:
8888 chosen_messages = example [DataKeys .CHOSEN ]
89- # 处理numpy数组格式
89+ # Handle numpy array format
9090 if hasattr (chosen_messages , "tolist" ):
9191 chosen_messages = chosen_messages .tolist ()
9292
9393 if isinstance (chosen_messages , (list , tuple )):
9494 for msg in chosen_messages :
9595 if isinstance (msg , dict ) and msg .get ("role" ) and msg .get ("content" ):
96- # 只添加user消息,不添加assistant消息(因为那是要预测的目标)
96+ # Only add user messages, not assistant messages (as they are the prediction target)
9797 if msg .get ("role" ) == "user" :
9898 messages .append (
9999 {"role" : msg ["role" ], "content" : msg ["content" ]}
100100 )
101101
102- # 如果还是没有找到消息,尝试从rejected字段构建
102+ # If still no messages found, try building from rejected field
103103 elif DataKeys .REJECTED in example and example [DataKeys .REJECTED ]:
104104 rejected_messages = example [DataKeys .REJECTED ]
105- # 处理numpy数组格式
105+ # Handle numpy array format
106106 if hasattr (rejected_messages , "tolist" ):
107107 rejected_messages = rejected_messages .tolist ()
108108
@@ -114,39 +114,39 @@ def _build_messages(self, example: dict) -> List[dict]:
114114 {"role" : msg ["role" ], "content" : msg ["content" ]}
115115 )
116116
117- # 如果还是没有消息,创建一个默认的用户消息
117+ # If still no messages, create a default user message
118118 if len (messages ) == 0 :
119- messages = [{"role" : "user" , "content" : "请协助完成这个任务。 " }]
119+ messages = [{"role" : "user" , "content" : "Please help complete this task. " }]
120120
121121 return messages
122122
123123 def _format_template (self , messages : List [dict ], example : dict ) -> str :
124- """格式化alignment模板 """
124+ """Format alignment template """
125125 return messages
126126
127127 def _extract_ground_truth (self , row_dict ):
128- """提取alignment真实标签
128+ """Extract alignment ground truth
129129
130- 对于alignment数据,chosen可以作为一种"更好"的参考
130+ For alignment data, chosen can serve as a "better" reference
131131 """
132132 try :
133133 ground_truth_info = {}
134134
135- # 将chosen和rejected都保存到ground_truth中,供奖励函数使用
135+ # Save both chosen and rejected to ground_truth for reward function use
136136 chosen_key = DataKeys .CHOSEN
137137 rejected_key = DataKeys .REJECTED
138138 source_key = DataKeys .SOURCE
139139
140140 if chosen_key in row_dict and row_dict [chosen_key ] is not None :
141141 chosen_data = row_dict [chosen_key ]
142- # 处理numpy数组格式
142+ # Handle numpy array format
143143 if hasattr (chosen_data , "tolist" ):
144144 chosen_data = chosen_data .tolist ()
145145 ground_truth_info [chosen_key ] = chosen_data
146146
147147 if rejected_key in row_dict and row_dict [rejected_key ] is not None :
148148 rejected_data = row_dict [rejected_key ]
149- # 处理numpy数组格式
149+ # Handle numpy array format
150150 if hasattr (rejected_data , "tolist" ):
151151 rejected_data = rejected_data .tolist ()
152152 ground_truth_info [rejected_key ] = rejected_data
@@ -161,26 +161,26 @@ def _extract_ground_truth(self, row_dict):
161161 return {}
162162
163163 def __getitem__ (self , item ):
164- """获取数据集中的一个项目 """
164+ """Get an item from the dataset """
165165 row_dict = dict (self .dataframe [item ])
166166 messages = self ._build_messages (row_dict )
167167
168- # 格式化提示
168+ # Format prompt
169169 raw_prompt_messages = self ._format_template (messages , row_dict )
170170
171- # 应用聊天模板
171+ # Apply chat template
172172 raw_prompt = self .tokenizer .apply_chat_template (
173173 raw_prompt_messages , add_generation_prompt = True , tokenize = False
174174 )
175175
176- # 分词
176+ # Tokenize
177177 model_inputs = self .tokenizer (
178178 raw_prompt , return_tensors = "pt" , add_special_tokens = False
179179 )
180180 input_ids = model_inputs ["input_ids" ]
181181 attention_mask = model_inputs ["attention_mask" ]
182182
183- # 后处理
183+ # Post-process
184184 input_ids , attention_mask = verl_F .postprocess_data (
185185 input_ids = input_ids ,
186186 attention_mask = attention_mask ,
@@ -190,10 +190,10 @@ def __getitem__(self, item):
190190 truncation = self .truncation ,
191191 )
192192
193- # 计算位置ID
193+ # Compute position IDs
194194 position_ids = compute_position_id_with_mask (attention_mask )
195195
196- # 准备原始提示ID
196+ # Prepare raw prompt IDs
197197 raw_prompt_ids = self .tokenizer .encode (raw_prompt , add_special_tokens = False )
198198 if len (raw_prompt_ids ) > self .max_prompt_length :
199199 if self .truncation == "left" :
@@ -202,18 +202,18 @@ def __getitem__(self, item):
202202 raw_prompt_ids = raw_prompt_ids [: self .max_prompt_length ]
203203 elif self .truncation == "error" :
204204 raise RuntimeError (
205- f"提示长度 { len (raw_prompt_ids )} 超过 { self .max_prompt_length } "
205+ f"Prompt length { len (raw_prompt_ids )} exceeds { self .max_prompt_length } "
206206 )
207207
208- # 构建x字段(用于传递给奖励函数)
208+ # Build x field (for passing to reward function)
209209 x_messages = []
210210
211- # 从原始数据构建完整的对话上下文
211+ # Build complete conversation context from raw data
212212 chosen_key = DataKeys .CHOSEN
213213 if chosen_key in row_dict and row_dict [chosen_key ]:
214- # 使用chosen作为基础构建对话上下文
214+ # Use chosen as base for conversation context
215215 chosen_messages = row_dict [chosen_key ]
216- # 处理numpy数组格式
216+ # Handle numpy array format
217217 if hasattr (chosen_messages , "tolist" ):
218218 chosen_messages = chosen_messages .tolist ()
219219
@@ -224,11 +224,11 @@ def __getitem__(self, item):
224224 {"role" : msg ["role" ], "content" : msg ["content" ]}
225225 )
226226
227- # 如果没有从chosen获取到消息,使用我们构建的messages
227+ # If no messages obtained from chosen, use our built messages
228228 if not x_messages :
229229 x_messages = messages
230230
231- # 构建结果
231+ # Build result
232232 result = {
233233 "input_ids" : input_ids [0 ],
234234 "attention_mask" : attention_mask [0 ],
@@ -240,7 +240,7 @@ def __getitem__(self, item):
240240 "data_source" : row_dict .get ("source" , "alignment" ),
241241 }
242242
243- # 添加x字段,包含对话上下文
243+ # Add x field with conversation context
244244 result ["extra_info" ]["x" ] = x_messages
245245
246246 if self .return_raw_chat :
0 commit comments