@@ -1130,6 +1130,128 @@ def test_async_mcp_tools(self):
1130
1130
env_pool .close ()
1131
1131
1132
1132
1133
+ class TestThinkingPrompt :
1134
+ @pytest .fixture (autouse = True , scope = "class" )
1135
+ def base_env (self ):
1136
+ from transformers import AutoTokenizer
1137
+
1138
+ tokenizer = AutoTokenizer .from_pretrained ("Qwen/Qwen2.5-3B" )
1139
+ env = GSM8KEnv (shuffle = False , tokenizer = tokenizer , max_steps = 10 )
1140
+ return env
1141
+
1142
+ @pytest .mark .skipif (not _has_transformers , reason = "requires transformers" )
1143
+ @pytest .mark .skipif (not _has_datasets , reason = "requires gsm8k" )
1144
+ @pytest .mark .parametrize (
1145
+ "role,edit_last_turn" ,
1146
+ [("assistant" , True ), ("assistant" , False ), ("user" , False )],
1147
+ )
1148
+ @pytest .mark .parametrize ("zero_reward" , [True , False ])
1149
+ @pytest .mark .parametrize ("undo_done" , [True , False ])
1150
+ @pytest .mark .parametrize ("random_prompt" , [True , False ])
1151
+ def test_thinking_prompt_wrong_answer (
1152
+ self ,
1153
+ role ,
1154
+ edit_last_turn ,
1155
+ zero_reward ,
1156
+ undo_done ,
1157
+ random_prompt ,
1158
+ tmp_path ,
1159
+ base_env ,
1160
+ ):
1161
+ from torchrl .envs .llm .transforms import AddThinkingPrompt
1162
+
1163
+ if isinstance (base_env .transform [- 1 ], AddThinkingPrompt ):
1164
+ base_env .transform .pop ()
1165
+ env = base_env .reset_dataloader ()
1166
+ env = base_env .append_transform (
1167
+ AddThinkingPrompt (
1168
+ cond = lambda td : td ["reward" ] < 50 ,
1169
+ role = role ,
1170
+ edit_last_turn = edit_last_turn ,
1171
+ zero_reward = zero_reward ,
1172
+ undo_done = undo_done ,
1173
+ random_prompt = random_prompt ,
1174
+ )
1175
+ )
1176
+ reset = env .reset ()
1177
+ assert reset [0 ]["history" ][- 1 ].content .startswith (
1178
+ "Natalia sold clips to 48 of her friends in April"
1179
+ )
1180
+ policy_anser = (
1181
+ "<think>Let me solve this step by step. Natalia sold clips to 48 friends in April. Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
1182
+ "To find the total, I need to add April and May: 48 + 24 = 72. Therefore, Natalia sold 72 clips altogether in April and May.</think>\n <answer>322 clips</answer><|im_end|>"
1183
+ )
1184
+ reset ["text_response" ] = [policy_anser ]
1185
+ s = env .step (reset )
1186
+ if zero_reward :
1187
+ assert (s ["next" , "reward" ] == 0 ).all ()
1188
+ else :
1189
+ assert (s ["next" , "reward" ] != 0 ).all ()
1190
+ if undo_done :
1191
+ assert (s ["next" , "done" ] == 0 ).all ()
1192
+ else :
1193
+ assert (s ["next" , "done" ] != 0 ).all ()
1194
+ if edit_last_turn :
1195
+ assert s ["next" , "history" ].shape == (1 , 3 )
1196
+ else :
1197
+ assert s ["next" , "history" ].shape == (1 , 4 )
1198
+ if role == "assistant" :
1199
+ assert s [0 ]["next" , "history" , "role" ][- 1 ] == "assistant"
1200
+ else :
1201
+ assert s [0 ]["next" , "history" , "role" ][- 1 ] == "user"
1202
+
1203
+ @pytest .mark .skipif (not _has_transformers , reason = "requires transformers" )
1204
+ @pytest .mark .skipif (not _has_datasets , reason = "requires gsm8k" )
1205
+ @pytest .mark .parametrize (
1206
+ "role,edit_last_turn" ,
1207
+ [("assistant" , True ), ("assistant" , False ), ("user" , False )],
1208
+ )
1209
+ @pytest .mark .parametrize ("zero_reward" , [True , False ])
1210
+ @pytest .mark .parametrize ("undo_done" , [True , False ])
1211
+ @pytest .mark .parametrize ("random_prompt" , [True , False ])
1212
+ def test_thinking_prompt_correct_answer (
1213
+ self ,
1214
+ role ,
1215
+ edit_last_turn ,
1216
+ zero_reward ,
1217
+ undo_done ,
1218
+ random_prompt ,
1219
+ tmp_path ,
1220
+ base_env ,
1221
+ ):
1222
+ # checks that if cond returns False, nothing is changed
1223
+ from torchrl .envs .llm .transforms import AddThinkingPrompt
1224
+
1225
+ if isinstance (base_env .transform [- 1 ], AddThinkingPrompt ):
1226
+ base_env .transform .pop ()
1227
+ env = base_env
1228
+ env = env .reset_dataloader ()
1229
+ env = env .append_transform (
1230
+ AddThinkingPrompt (
1231
+ cond = lambda td : td ["reward" ] < 50 ,
1232
+ role = role ,
1233
+ edit_last_turn = edit_last_turn ,
1234
+ zero_reward = zero_reward ,
1235
+ undo_done = undo_done ,
1236
+ random_prompt = random_prompt ,
1237
+ )
1238
+ )
1239
+ reset = env .reset ()
1240
+ assert reset [0 ]["history" ][- 1 ].content .startswith (
1241
+ "Natalia sold clips to 48 of her friends in April"
1242
+ )
1243
+ policy_anser = (
1244
+ "<think>Let me solve this step by step. Natalia sold clips to 48 friends in April. Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
1245
+ "To find the total, I need to add April and May: 48 + 24 = 72. Therefore, Natalia sold 72 clips altogether in April and May.</think>\n <answer>72</answer><|im_end|>"
1246
+ )
1247
+ reset ["text_response" ] = [policy_anser ]
1248
+ s = env .step (reset )
1249
+ assert (s ["next" , "reward" ] != 0 ).all (), s ["next" , "reward" ]
1250
+ assert s [0 ]["next" , "history" , "role" ][- 1 ] == "assistant"
1251
+ assert s ["next" , "done" ].all ()
1252
+ assert len (s [0 ]["next" , "history" , "content" ]) == 3
1253
+
1254
+
1133
1255
if __name__ == "__main__" :
1134
1256
args , unknown = argparse .ArgumentParser ().parse_known_args ()
1135
1257
pytest .main ([__file__ , "--capture" , "no" , "--exitfirst" ] + unknown )
0 commit comments