@@ -92,7 +92,6 @@ def forward(
9292 dpo_forward : Optional [bool ] = False ,
9393 cache_position = None ,
9494 ) -> Union [Tuple , CausalLMOutputWithPast ]:
95- print ('actions' , actions )
9695 if inputs_embeds is None :
9796 (input_ids , position_ids , attention_mask , past_key_values , inputs_embeds , labels , action_idx ) = self .prepare_inputs_labels_for_multimodal (input_ids , position_ids , attention_mask , past_key_values , labels , images , modalities , image_sizes )
9897
@@ -233,8 +232,7 @@ def forward(
233232 action_logits = action_logits .to (device )
234233 actions = actions .to (device )
235234
236- vision_supervision_loss = 0.0
237-
235+ vision_supervision_loss = 0.0
238236
239237 triples = list (zip (other_verb_logits_list , other_noun_logits_list , other_action_logits_list ))
240238
@@ -245,39 +243,39 @@ def forward(
245243 elif getattr (self .config , 'vision_token_training' , None ) and self .config .vision_token_training == 'all_layers' :
246244 pass
247245 # by default, distilaltion uses all layers
248- # First check if any process has valid examples across all triples
249- world_has_valid = torch .tensor (actions [:, 0 ].any () > 0 , device = actions .device )
250- torch .distributed .all_reduce (world_has_valid , op = torch .distributed .ReduceOp .MAX )
251-
252- if world_has_valid : # If any process has valid examples
253- for other_verb_logits , other_noun_logits , other_action_logits in triples :
254- valid_mask = actions [:, 0 ] > 0
255-
256- if valid_mask .any (): # This process has valid examples
257- valid_verb_logits = other_verb_logits [valid_mask ]
258- valid_noun_logits = other_noun_logits [valid_mask ]
259- valid_action_logits = other_action_logits [valid_mask ]
260-
261- valid_verb_targets = actions [valid_mask , 0 ]
262- valid_noun_targets = actions [valid_mask , 1 ]
263- valid_action_targets = actions [valid_mask , 2 ]
246+ # First check if any process has valid examples across all triples
247+ world_has_valid = torch .tensor (actions [:, 0 ].any () > 0 , device = actions .device )
248+ torch .distributed .all_reduce (world_has_valid , op = torch .distributed .ReduceOp .MAX )
264249
265- other_verb_loss = loss_fct ( valid_verb_logits , valid_verb_targets )
266- other_noun_loss = loss_fct ( valid_noun_logits , valid_noun_targets )
267- other_action_loss = loss_fct ( valid_action_logits , valid_action_targets )
250+ if world_has_valid : # If any process has valid examples
251+ for other_verb_logits , other_noun_logits , other_action_logits in triples :
252+ valid_mask = actions [:, 0 ] > 0
268253
269- vision_supervision_loss += 0.5 * other_verb_loss + 0.5 * other_noun_loss + 0.1 * other_action_loss
270- else : # This process has no valid examples but others do
271- # Add dummy loss to maintain gradient flow
272- vision_supervision_loss += 0.0 * (other_verb_logits .sum () + other_noun_logits .sum () + other_action_logits .sum ())
273-
274- vision_supervision_loss /= (len (triples ) + 1 )
275- loss += vision_supervision_loss * 0.1
276- else :
277- # If no process has valid examples, add dummy loss to prevent hanging
278- dummy_loss = sum (sum (t .sum () * 0.0 for t in triple ) for triple in triples )
279- vision_supervision_loss = dummy_loss / (len (triples ) + 1 )
280- loss += vision_supervision_loss * 0.1
254+ if valid_mask .any (): # This process has valid examples
255+ valid_verb_logits = other_verb_logits [valid_mask ]
256+ valid_noun_logits = other_noun_logits [valid_mask ]
257+ valid_action_logits = other_action_logits [valid_mask ]
258+
259+ valid_verb_targets = actions [valid_mask , 0 ]
260+ valid_noun_targets = actions [valid_mask , 1 ]
261+ valid_action_targets = actions [valid_mask , 2 ]
262+
263+ other_verb_loss = loss_fct (valid_verb_logits , valid_verb_targets )
264+ other_noun_loss = loss_fct (valid_noun_logits , valid_noun_targets )
265+ other_action_loss = loss_fct (valid_action_logits , valid_action_targets )
266+
267+ vision_supervision_loss += 0.5 * other_verb_loss + 0.5 * other_noun_loss + 0.1 * other_action_loss
268+ else : # This process has no valid examples but others do
269+ # Add dummy loss to maintain gradient flow
270+ vision_supervision_loss += 0.0 * (other_verb_logits .sum () + other_noun_logits .sum () + other_action_logits .sum ())
271+
272+ vision_supervision_loss /= (len (triples ) + 1 )
273+ loss += vision_supervision_loss * 0.1
274+ else :
275+ # If no process has valid examples, add dummy loss to prevent hanging
276+ dummy_loss = sum (sum (t .sum () * 0.0 for t in triple ) for triple in triples )
277+ vision_supervision_loss = dummy_loss / (len (triples ) + 1 )
278+ loss += vision_supervision_loss * 0.1
281279
282280 if getattr (self .config , 'vision_token_training' , None ) and 'distillation' in self .config .vision_token_training :
283281
0 commit comments