Skip to content

Commit 909b52b

Browse files
author
Haozhe Qi
committed
fixed a bug
1 parent 2401861 commit 909b52b

File tree

1 file changed

+32
-34
lines changed

1 file changed

+32
-34
lines changed

llava/model/language_model/llava_qwen.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def forward(
9292
dpo_forward: Optional[bool] = False,
9393
cache_position=None,
9494
) -> Union[Tuple, CausalLMOutputWithPast]:
95-
print ('actions', actions)
9695
if inputs_embeds is None:
9796
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels, action_idx) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes)
9897

@@ -233,8 +232,7 @@ def forward(
233232
action_logits = action_logits.to(device)
234233
actions = actions.to(device)
235234

236-
vision_supervision_loss = 0.0
237-
235+
vision_supervision_loss = 0.0
238236

239237
triples = list(zip(other_verb_logits_list, other_noun_logits_list, other_action_logits_list))
240238

@@ -245,39 +243,39 @@ def forward(
245243
elif getattr(self.config, 'vision_token_training', None) and self.config.vision_token_training == 'all_layers':
246244
pass
247245
# by default, distilaltion uses all layers
248-
# First check if any process has valid examples across all triples
249-
world_has_valid = torch.tensor(actions[:, 0].any() > 0, device=actions.device)
250-
torch.distributed.all_reduce(world_has_valid, op=torch.distributed.ReduceOp.MAX)
251-
252-
if world_has_valid: # If any process has valid examples
253-
for other_verb_logits, other_noun_logits, other_action_logits in triples:
254-
valid_mask = actions[:, 0] > 0
255-
256-
if valid_mask.any(): # This process has valid examples
257-
valid_verb_logits = other_verb_logits[valid_mask]
258-
valid_noun_logits = other_noun_logits[valid_mask]
259-
valid_action_logits = other_action_logits[valid_mask]
260-
261-
valid_verb_targets = actions[valid_mask, 0]
262-
valid_noun_targets = actions[valid_mask, 1]
263-
valid_action_targets = actions[valid_mask, 2]
246+
# First check if any process has valid examples across all triples
247+
world_has_valid = torch.tensor(actions[:, 0].any() > 0, device=actions.device)
248+
torch.distributed.all_reduce(world_has_valid, op=torch.distributed.ReduceOp.MAX)
264249

265-
other_verb_loss = loss_fct(valid_verb_logits, valid_verb_targets)
266-
other_noun_loss = loss_fct(valid_noun_logits, valid_noun_targets)
267-
other_action_loss = loss_fct(valid_action_logits, valid_action_targets)
250+
if world_has_valid: # If any process has valid examples
251+
for other_verb_logits, other_noun_logits, other_action_logits in triples:
252+
valid_mask = actions[:, 0] > 0
268253

269-
vision_supervision_loss += 0.5 * other_verb_loss + 0.5 * other_noun_loss + 0.1 * other_action_loss
270-
else: # This process has no valid examples but others do
271-
# Add dummy loss to maintain gradient flow
272-
vision_supervision_loss += 0.0 * (other_verb_logits.sum() + other_noun_logits.sum() + other_action_logits.sum())
273-
274-
vision_supervision_loss /= (len(triples) + 1)
275-
loss += vision_supervision_loss * 0.1
276-
else:
277-
# If no process has valid examples, add dummy loss to prevent hanging
278-
dummy_loss = sum(sum(t.sum() * 0.0 for t in triple) for triple in triples)
279-
vision_supervision_loss = dummy_loss / (len(triples) + 1)
280-
loss += vision_supervision_loss * 0.1
254+
if valid_mask.any(): # This process has valid examples
255+
valid_verb_logits = other_verb_logits[valid_mask]
256+
valid_noun_logits = other_noun_logits[valid_mask]
257+
valid_action_logits = other_action_logits[valid_mask]
258+
259+
valid_verb_targets = actions[valid_mask, 0]
260+
valid_noun_targets = actions[valid_mask, 1]
261+
valid_action_targets = actions[valid_mask, 2]
262+
263+
other_verb_loss = loss_fct(valid_verb_logits, valid_verb_targets)
264+
other_noun_loss = loss_fct(valid_noun_logits, valid_noun_targets)
265+
other_action_loss = loss_fct(valid_action_logits, valid_action_targets)
266+
267+
vision_supervision_loss += 0.5 * other_verb_loss + 0.5 * other_noun_loss + 0.1 * other_action_loss
268+
else: # This process has no valid examples but others do
269+
# Add dummy loss to maintain gradient flow
270+
vision_supervision_loss += 0.0 * (other_verb_logits.sum() + other_noun_logits.sum() + other_action_logits.sum())
271+
272+
vision_supervision_loss /= (len(triples) + 1)
273+
loss += vision_supervision_loss * 0.1
274+
else:
275+
# If no process has valid examples, add dummy loss to prevent hanging
276+
dummy_loss = sum(sum(t.sum() * 0.0 for t in triple) for triple in triples)
277+
vision_supervision_loss = dummy_loss / (len(triples) + 1)
278+
loss += vision_supervision_loss * 0.1
281279

282280
if getattr(self.config, 'vision_token_training', None) and 'distillation' in self.config.vision_token_training:
283281

0 commit comments

Comments
 (0)