Skip to content

Commit 84abe2c

Browse files
fix: assert no duplicate starting bos (#835)
Signed-off-by: Zhiyu Li <[email protected]>
1 parent 024d173 commit 84abe2c

File tree

3 files changed

+10
-17
lines changed

3 files changed

+10
-17
lines changed

examples/run_grpo_math.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,11 @@ def hf_data_processor(
8989
add_generation_prompt=True,
9090
add_special_tokens=False,
9191
)
92-
# add bos token if not in chat template
93-
bos_token_in_chat_template = tokenizer.chat_template.startswith(
94-
"{{- bos_token }}"
95-
) or tokenizer.chat_template.startswith("{{ bos_token }}")
9692

9793
user_message["token_ids"] = tokenizer(
9894
message,
9995
return_tensors="pt",
100-
add_special_tokens=not bos_token_in_chat_template,
96+
add_special_tokens=False,
10197
)["input_ids"][0]
10298
user_message["content"] = message
10399
message_log.append(user_message)

nemo_rl/data/datasets.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def __getitem__(self, idx: int) -> DatumSpec:
125125
assert isinstance(token_ids, torch.Tensor), (
126126
f"token_ids must be a torch.Tensor, got {type(token_ids)}"
127127
)
128-
assert_start_with_single_bos(token_ids, self.tokenizer)
128+
assert_no_double_bos(token_ids, self.tokenizer)
129129
self._bos_checked = True
130130

131131
return datum_spec
@@ -301,25 +301,20 @@ def dpo_collate_fn(
301301
return train_data
302302

303303

304-
def assert_start_with_single_bos(
305-
token_ids: torch.Tensor, tokenizer: TokenizerType
306-
) -> None:
307-
"""Assert that the first token is a BOS token and the second token is not a BOS token.
304+
def assert_no_double_bos(token_ids: torch.Tensor, tokenizer: TokenizerType) -> None:
305+
"""Assert that there are no double starting BOS tokens in the message.
308306
309307
Args:
310308
token_ids: List of token IDs
311309
tokenizer: Tokenizer
312310
"""
313311
if tokenizer.bos_token_id is not None:
314312
token_ids_list = token_ids.tolist()
315-
if len(token_ids_list) > 0:
316-
assert token_ids_list[0] == tokenizer.bos_token_id, (
317-
f"Expected BOS token at the start of the message, but got {token_ids_list[0]}"
318-
)
319313
if len(token_ids_list) > 1:
320-
assert token_ids_list[1] != tokenizer.bos_token_id, (
321-
f"Expected non-BOS token at the second position of the message, but got {token_ids_list[1]}"
322-
)
314+
assert not (
315+
token_ids_list[0] == tokenizer.bos_token_id
316+
and token_ids_list[1] == tokenizer.bos_token_id
317+
), "Found double BOS token in the first two positions of the message."
323318
else:
324319
print(
325320
f"skip assert_start_single_bos since Tokenizer {tokenizer.name_or_path} has no BOS token"

tests/unit/data/test_data_processor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ def test_math_data_processor():
7373
"Qwen/Qwen2.5-1.5B-Instruct", # no bos token
7474
"google/gemma-3-1b-it",
7575
"Qwen/Qwen3-0.6B", # no bos token
76+
"deepseek-ai/DeepSeek-V3",
77+
"moonshotai/Moonlight-16B-A3B-Instruct",
7678
],
7779
)
7880
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)