@@ -62,11 +62,30 @@ def _extract_assistant_text(message: Dict[str, Any]) -> str:
6262 return ""
6363
6464
65+ def _decode_single_token (tokenizer , token_id : int ) -> str :
66+ """Decode a single token id across tokenizer implementations.
67+
68+ Some tokenizers accept an `int` token id, while others require a sequence of
69+ ids (e.g., `List[int]`). We try the common forms in order.
70+ """
71+ try :
72+ return tokenizer .decode (token_id )
73+ except Exception :
74+ try :
75+ return tokenizer .decode ([token_id ])
76+ except Exception :
77+ try :
78+ return tokenizer .decode (torch .tensor ([token_id ]))
79+ except Exception :
80+ # Best-effort fallback; stop-token detection will likely fail.
81+ return str (token_id )
82+
83+
6584def build_labels (
6685 input_ids_batch : torch .Tensor ,
6786 conversations : Sequence [Sequence [Dict [str , Any ]]],
6887 processor ,
69- ) -> tuple [ torch .Tensor , Optional [ torch . Tensor ]] :
88+ ) -> torch .Tensor :
7089 """Construct label and optional loss-mask tensors aligned to assistant responses."""
7190 tokenizer = getattr (processor , "tokenizer" , processor )
7291
@@ -93,9 +112,8 @@ def build_labels(
93112 answer_start , answer_end = _find_pattern_indices (encoded , assistant_tokens , search_start_index )
94113
95114 if answer_end < len (encoded ):
96- # Convert tensor to list for tokenizers that don't accept tensors (e.g., tiktoken)
97- next_token_id = encoded [answer_end ].item () if hasattr (encoded [answer_end ], 'item' ) else encoded [answer_end ]
98- next_token_str = tokenizer .decode ([next_token_id ])
115+ next_token_id = int (encoded [answer_end ].item ())
116+ next_token_str = _decode_single_token (tokenizer , next_token_id )
99117 if next_token_str .strip () in default_stop_tokens (processor ):
100118 answer_end += 1
101119
@@ -200,6 +218,15 @@ def qwen3_omni_collate_fn(
200218 "qwen_omni_utils is required for qwen3_omni_collate_fn. Install it with: pip install qwen-omni-utils"
201219 )
202220
221+ # Import at call-time to support environments/tests that inject the module
222+ # after this file is initially imported.
223+ try :
224+ from qwen_omni_utils import process_mm_info as _process_mm_info
225+ except ImportError as exc :
226+ raise ImportError (
227+ "qwen_omni_utils is required for qwen3_omni_collate_fn. Install it with: pip install qwen-omni-utils"
228+ ) from exc
229+
203230 conversations = [example ["conversation" ] for example in examples ]
204231 texts = [
205232 processor .apply_chat_template (conversation , add_generation_prompt = False , tokenize = False )
@@ -210,7 +237,7 @@ def qwen3_omni_collate_fn(
210237 all_images = []
211238 all_videos = []
212239 for conversation in conversations :
213- audios , images , videos = process_mm_info (conversation , use_audio_in_video = use_audio_in_video )
240+ audios , images , videos = _process_mm_info (conversation , use_audio_in_video = use_audio_in_video )
214241 all_audios .append (audios )
215242 all_images .append (images )
216243 all_videos .append (videos )
0 commit comments