Skip to content

Commit 866ce4b

Browse files
committed
ifx
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent 2017f5f commit 866ce4b

File tree

1 file changed

+32
-5
lines changed

1 file changed

+32
-5
lines changed

nemo_automodel/components/datasets/vlm/collate_fns.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,30 @@ def _extract_assistant_text(message: Dict[str, Any]) -> str:
6262
return ""
6363

6464

65+
def _decode_single_token(tokenizer, token_id: int) -> str:
66+
"""Decode a single token id across tokenizer implementations.
67+
68+
Some tokenizers accept an `int` token id, while others require a sequence of
69+
ids (e.g., `List[int]`). We try the common forms in order.
70+
"""
71+
try:
72+
return tokenizer.decode(token_id)
73+
except Exception:
74+
try:
75+
return tokenizer.decode([token_id])
76+
except Exception:
77+
try:
78+
return tokenizer.decode(torch.tensor([token_id]))
79+
except Exception:
80+
# Best-effort fallback; stop-token detection will likely fail.
81+
return str(token_id)
82+
83+
6584
def build_labels(
6685
input_ids_batch: torch.Tensor,
6786
conversations: Sequence[Sequence[Dict[str, Any]]],
6887
processor,
69-
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
88+
) -> torch.Tensor:
7089
"""Construct label and optional loss-mask tensors aligned to assistant responses."""
7190
tokenizer = getattr(processor, "tokenizer", processor)
7291

@@ -93,9 +112,8 @@ def build_labels(
93112
answer_start, answer_end = _find_pattern_indices(encoded, assistant_tokens, search_start_index)
94113

95114
if answer_end < len(encoded):
96-
# Convert tensor to list for tokenizers that don't accept tensors (e.g., tiktoken)
97-
next_token_id = encoded[answer_end].item() if hasattr(encoded[answer_end], 'item') else encoded[answer_end]
98-
next_token_str = tokenizer.decode([next_token_id])
115+
next_token_id = int(encoded[answer_end].item())
116+
next_token_str = _decode_single_token(tokenizer, next_token_id)
99117
if next_token_str.strip() in default_stop_tokens(processor):
100118
answer_end += 1
101119

@@ -200,6 +218,15 @@ def qwen3_omni_collate_fn(
200218
"qwen_omni_utils is required for qwen3_omni_collate_fn. Install it with: pip install qwen-omni-utils"
201219
)
202220

221+
# Import at call-time to support environments/tests that inject the module
222+
# after this file is initially imported.
223+
try:
224+
from qwen_omni_utils import process_mm_info as _process_mm_info
225+
except ImportError as exc:
226+
raise ImportError(
227+
"qwen_omni_utils is required for qwen3_omni_collate_fn. Install it with: pip install qwen-omni-utils"
228+
) from exc
229+
203230
conversations = [example["conversation"] for example in examples]
204231
texts = [
205232
processor.apply_chat_template(conversation, add_generation_prompt=False, tokenize=False)
@@ -210,7 +237,7 @@ def qwen3_omni_collate_fn(
210237
all_images = []
211238
all_videos = []
212239
for conversation in conversations:
213-
audios, images, videos = process_mm_info(conversation, use_audio_in_video=use_audio_in_video)
240+
audios, images, videos = _process_mm_info(conversation, use_audio_in_video=use_audio_in_video)
214241
all_audios.append(audios)
215242
all_images.append(images)
216243
all_videos.append(videos)

0 commit comments

Comments
 (0)