Skip to content

Commit 36d2796

Browse files
author
Ye Shaokai
committed
Merge branch 'shaokai/dev' of github.com:yeshaokai/LLaVA-NeXT into shaokai/dev
2 parents 7f351c5 + 909b52b commit 36d2796

File tree

1 file changed

+32
-33
lines changed

1 file changed

+32
-33
lines changed

llava/model/language_model/llava_qwen.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,7 @@ def forward(
232232
action_logits = action_logits.to(device)
233233
actions = actions.to(device)
234234

235-
vision_supervision_loss = 0.0
236-
235+
vision_supervision_loss = 0.0
237236

238237
triples = list(zip(other_verb_logits_list, other_noun_logits_list, other_action_logits_list))
239238

@@ -244,39 +243,39 @@ def forward(
244243
elif getattr(self.config, 'vision_token_training', None) and self.config.vision_token_training == 'all_layers':
245244
pass
246245
# by default, distilaltion uses all layers
247-
# First check if any process has valid examples across all triples
248-
world_has_valid = torch.tensor(actions[:, 0].any() > 0, device=actions.device)
249-
torch.distributed.all_reduce(world_has_valid, op=torch.distributed.ReduceOp.MAX)
250-
251-
if world_has_valid: # If any process has valid examples
252-
for other_verb_logits, other_noun_logits, other_action_logits in triples:
253-
valid_mask = actions[:, 0] > 0
254-
255-
if valid_mask.any(): # This process has valid examples
256-
valid_verb_logits = other_verb_logits[valid_mask]
257-
valid_noun_logits = other_noun_logits[valid_mask]
258-
valid_action_logits = other_action_logits[valid_mask]
259-
260-
valid_verb_targets = actions[valid_mask, 0]
261-
valid_noun_targets = actions[valid_mask, 1]
262-
valid_action_targets = actions[valid_mask, 2]
246+
# First check if any process has valid examples across all triples
247+
world_has_valid = torch.tensor(actions[:, 0].any() > 0, device=actions.device)
248+
torch.distributed.all_reduce(world_has_valid, op=torch.distributed.ReduceOp.MAX)
263249

264-
other_verb_loss = loss_fct(valid_verb_logits, valid_verb_targets)
265-
other_noun_loss = loss_fct(valid_noun_logits, valid_noun_targets)
266-
other_action_loss = loss_fct(valid_action_logits, valid_action_targets)
250+
if world_has_valid: # If any process has valid examples
251+
for other_verb_logits, other_noun_logits, other_action_logits in triples:
252+
valid_mask = actions[:, 0] > 0
267253

268-
vision_supervision_loss += 0.5 * other_verb_loss + 0.5 * other_noun_loss + 0.1 * other_action_loss
269-
else: # This process has no valid examples but others do
270-
# Add dummy loss to maintain gradient flow
271-
vision_supervision_loss += 0.0 * (other_verb_logits.sum() + other_noun_logits.sum() + other_action_logits.sum())
272-
273-
vision_supervision_loss /= (len(triples) + 1)
274-
loss += vision_supervision_loss * 0.1
275-
else:
276-
# If no process has valid examples, add dummy loss to prevent hanging
277-
dummy_loss = sum(sum(t.sum() * 0.0 for t in triple) for triple in triples)
278-
vision_supervision_loss = dummy_loss / (len(triples) + 1)
279-
loss += vision_supervision_loss * 0.1
254+
if valid_mask.any(): # This process has valid examples
255+
valid_verb_logits = other_verb_logits[valid_mask]
256+
valid_noun_logits = other_noun_logits[valid_mask]
257+
valid_action_logits = other_action_logits[valid_mask]
258+
259+
valid_verb_targets = actions[valid_mask, 0]
260+
valid_noun_targets = actions[valid_mask, 1]
261+
valid_action_targets = actions[valid_mask, 2]
262+
263+
other_verb_loss = loss_fct(valid_verb_logits, valid_verb_targets)
264+
other_noun_loss = loss_fct(valid_noun_logits, valid_noun_targets)
265+
other_action_loss = loss_fct(valid_action_logits, valid_action_targets)
266+
267+
vision_supervision_loss += 0.5 * other_verb_loss + 0.5 * other_noun_loss + 0.1 * other_action_loss
268+
else: # This process has no valid examples but others do
269+
# Add dummy loss to maintain gradient flow
270+
vision_supervision_loss += 0.0 * (other_verb_logits.sum() + other_noun_logits.sum() + other_action_logits.sum())
271+
272+
vision_supervision_loss /= (len(triples) + 1)
273+
loss += vision_supervision_loss * 0.1
274+
else:
275+
# If no process has valid examples, add dummy loss to prevent hanging
276+
dummy_loss = sum(sum(t.sum() * 0.0 for t in triple) for triple in triples)
277+
vision_supervision_loss = dummy_loss / (len(triples) + 1)
278+
loss += vision_supervision_loss * 0.1
280279

281280
if getattr(self.config, 'vision_token_training', None) and 'distillation' in self.config.vision_token_training:
282281

0 commit comments

Comments
 (0)