@@ -253,6 +253,88 @@ def has_data(modality_list):
253253 return batch
254254
255255
256+ def nemotron_parse_collate_fn (
257+ examples : Sequence [Dict [str , Any ]],
258+ processor ,
259+ task_prompt : str = "</s><s><predict_bbox><predict_classes><output_markdown>" ,
260+ ) -> Dict [str , torch .Tensor ]:
261+ """
262+ Collate function for NVIDIA Nemotron-Parse models.
263+
264+ The Nemotron-Parse processor does not expose a chat template, so we build the
265+ prompt + answer string manually, mask the prompt tokens, and keep the
266+ image preprocessing handled by the processor.
267+ """
268+
269+ conversations = [example ["conversation" ] for example in examples ]
270+
271+ images : List [Any ] = []
272+ targets : List [str ] = []
273+ for conversation in conversations :
274+ image = None
275+ assistant_text = ""
276+
277+ for message in conversation :
278+ role = message .get ("role" )
279+ content = message .get ("content" )
280+
281+ if role == "user" :
282+ if isinstance (content , list ):
283+ for item in content :
284+ if isinstance (item , dict ) and item .get ("type" ) == "image" :
285+ image = item .get ("image" )
286+ break
287+ elif role == "assistant" and not assistant_text :
288+ assistant_text = _extract_assistant_text (message )
289+
290+ if image is not None and assistant_text :
291+ break
292+
293+ images .append (image )
294+ targets .append (assistant_text )
295+
296+ texts = [f"{ task_prompt } { target } " for target in targets ]
297+
298+ batch = processor (images = images , text = texts , padding = True , return_tensors = "pt" )
299+
300+ if "pixel_values" in batch :
301+ batch ["pixel_values" ] = batch ["pixel_values" ].to (torch .bfloat16 )
302+
303+ labels = build_labels (
304+ batch ["input_ids" ],
305+ conversations ,
306+ processor ,
307+ )
308+
309+ batch ["labels" ] = labels [:, 1 :]
310+
311+ tokenizer = getattr (processor , "tokenizer" , processor )
312+ pad_token_id = getattr (tokenizer , "pad_token_id" , None )
313+ decoder_start_token_id = getattr (tokenizer , "decoder_start_token_id" , None ) or getattr (
314+ tokenizer , "bos_token_id" , None
315+ )
316+ if decoder_start_token_id is None :
317+ decoder_start_token_id = getattr (tokenizer , "eos_token_id" , None )
318+ if pad_token_id is None or decoder_start_token_id is None :
319+ raise ValueError ("Nemotron-Parse collate_fn requires pad_token_id and decoder_start_token_id." )
320+
321+ decoder_input_ids = batch ["input_ids" ].clone ()
322+ decoder_input_ids [:, 0 ] = decoder_start_token_id
323+ decoder_input_ids [:, 1 :] = batch ["input_ids" ][:, :- 1 ]
324+
325+ decoder_attention_mask = (decoder_input_ids != pad_token_id ).long ()
326+
327+ batch ["decoder_input_ids" ] = decoder_input_ids [:, 1 :]
328+ batch ["decoder_attention_mask" ] = decoder_attention_mask [:, 1 :]
329+
330+ input_shape = batch ["input_ids" ].shape
331+ for key , value in list (batch .items ()):
332+ if isinstance (value , torch .Tensor ) and value .shape == input_shape :
333+ batch [key ] = value [:, :- 1 ]
334+
335+ return batch
336+
337+
256338def default_collate_fn (
257339 examples : Sequence [Dict [str , Any ]],
258340 processor ,
@@ -297,5 +379,6 @@ def default_collate_fn(
297379COLLATE_FNS = {
298380 "Qwen2_5_VLProcessor" : qwen2_5_collate_fn ,
299381 "Qwen3OmniMoeProcessor" : qwen3_omni_collate_fn ,
382+ "NemotronParseProcessor" : nemotron_parse_collate_fn ,
300383 "default" : default_collate_fn ,
301384}
0 commit comments