Skip to content

Commit ab2d4d8

Browse files
authored
Fix prompt type bug in generate_with_search within examples/search-r1 (#1182)
1 parent a52ddbd commit ab2d4d8

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

examples/search-r1/generate_with_search.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import asyncio
55
import re
6+
import numpy as np
67

78
from qa_em_format import compute_score_em
89

@@ -152,7 +153,7 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
152153
# Handle partial rollout samples: continue generation from existing response
153154
prompt = sample.prompt
154155
if args.apply_chat_template:
155-
assert isinstance(prompt, list), "prompt should be a list when apply_chat_template is True"
156+
assert isinstance(prompt, np.ndarray), "prompt should be a np.ndarray when apply_chat_template is True"
156157
prompt_text = state.tokenizer.apply_chat_template(
157158
prompt,
158159
tokenize=False,
@@ -241,6 +242,7 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
241242
sample.response_length = len(response_token_ids)
242243
sample.response = response
243244
sample.loss_mask = loss_mask
245+
sample.prompt = prompt_text
244246

245247
# Store log probs if enabled
246248
if SEARCH_R1_CONFIGS["return_logprob"]:

0 commit comments

Comments
 (0)