@@ -232,8 +232,7 @@ def forward(
232232 action_logits = action_logits .to (device )
233233 actions = actions .to (device )
234234
235- vision_supervision_loss = 0.0
236-
235+ vision_supervision_loss = 0.0
237236
238237 triples = list (zip (other_verb_logits_list , other_noun_logits_list , other_action_logits_list ))
239238
@@ -244,39 +243,39 @@ def forward(
244243 elif getattr (self .config , 'vision_token_training' , None ) and self .config .vision_token_training == 'all_layers' :
245244 pass
246245 # by default, distilaltion uses all layers
247- # First check if any process has valid examples across all triples
248- world_has_valid = torch .tensor (actions [:, 0 ].any () > 0 , device = actions .device )
249- torch .distributed .all_reduce (world_has_valid , op = torch .distributed .ReduceOp .MAX )
250-
251- if world_has_valid : # If any process has valid examples
252- for other_verb_logits , other_noun_logits , other_action_logits in triples :
253- valid_mask = actions [:, 0 ] > 0
254-
255- if valid_mask .any (): # This process has valid examples
256- valid_verb_logits = other_verb_logits [valid_mask ]
257- valid_noun_logits = other_noun_logits [valid_mask ]
258- valid_action_logits = other_action_logits [valid_mask ]
259-
260- valid_verb_targets = actions [valid_mask , 0 ]
261- valid_noun_targets = actions [valid_mask , 1 ]
262- valid_action_targets = actions [valid_mask , 2 ]
246+ # First check if any process has valid examples across all triples
247+ world_has_valid = torch .tensor (actions [:, 0 ].any () > 0 , device = actions .device )
248+ torch .distributed .all_reduce (world_has_valid , op = torch .distributed .ReduceOp .MAX )
263249
264- other_verb_loss = loss_fct ( valid_verb_logits , valid_verb_targets )
265- other_noun_loss = loss_fct ( valid_noun_logits , valid_noun_targets )
266- other_action_loss = loss_fct ( valid_action_logits , valid_action_targets )
250+ if world_has_valid : # If any process has valid examples
251+ for other_verb_logits , other_noun_logits , other_action_logits in triples :
252+ valid_mask = actions [:, 0 ] > 0
267253
268- vision_supervision_loss += 0.5 * other_verb_loss + 0.5 * other_noun_loss + 0.1 * other_action_loss
269- else : # This process has no valid examples but others do
270- # Add dummy loss to maintain gradient flow
271- vision_supervision_loss += 0.0 * (other_verb_logits .sum () + other_noun_logits .sum () + other_action_logits .sum ())
272-
273- vision_supervision_loss /= (len (triples ) + 1 )
274- loss += vision_supervision_loss * 0.1
275- else :
276- # If no process has valid examples, add dummy loss to prevent hanging
277- dummy_loss = sum (sum (t .sum () * 0.0 for t in triple ) for triple in triples )
278- vision_supervision_loss = dummy_loss / (len (triples ) + 1 )
279- loss += vision_supervision_loss * 0.1
254+ if valid_mask .any (): # This process has valid examples
255+ valid_verb_logits = other_verb_logits [valid_mask ]
256+ valid_noun_logits = other_noun_logits [valid_mask ]
257+ valid_action_logits = other_action_logits [valid_mask ]
258+
259+ valid_verb_targets = actions [valid_mask , 0 ]
260+ valid_noun_targets = actions [valid_mask , 1 ]
261+ valid_action_targets = actions [valid_mask , 2 ]
262+
263+ other_verb_loss = loss_fct (valid_verb_logits , valid_verb_targets )
264+ other_noun_loss = loss_fct (valid_noun_logits , valid_noun_targets )
265+ other_action_loss = loss_fct (valid_action_logits , valid_action_targets )
266+
267+ vision_supervision_loss += 0.5 * other_verb_loss + 0.5 * other_noun_loss + 0.1 * other_action_loss
268+ else : # This process has no valid examples but others do
269+ # Add dummy loss to maintain gradient flow
270+ vision_supervision_loss += 0.0 * (other_verb_logits .sum () + other_noun_logits .sum () + other_action_logits .sum ())
271+
272+ vision_supervision_loss /= (len (triples ) + 1 )
273+ loss += vision_supervision_loss * 0.1
274+ else :
275+ # If no process has valid examples, add dummy loss to prevent hanging
276+ dummy_loss = sum (sum (t .sum () * 0.0 for t in triple ) for triple in triples )
277+ vision_supervision_loss = dummy_loss / (len (triples ) + 1 )
278+ loss += vision_supervision_loss * 0.1
280279
281280 if getattr (self .config , 'vision_token_training' , None ) and 'distillation' in self .config .vision_token_training :
282281
0 commit comments