Skip to content

Commit 35616c7

Browse files
committed
update
Signed-off-by: Maanu Grover <[email protected]>
1 parent b91c6f3 commit 35616c7

File tree

1 file changed

+5
-45
lines changed

1 file changed

+5
-45
lines changed

nemo/tron/examples/neva.py

Lines changed: 5 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,57 +20,17 @@
2020
from megatron.core.datasets.multimodal_dataset import MockMultimodalDataset, MultimodalDatasetConfig
2121
from megatron.core.optimizer import OptimizerConfig
2222

23-
from nemo.collections import vlm
23+
from nemo.collections import llm, vlm
2424
from nemo.collections.vlm.neva.model.base import neva_data_step, neva_forward_step
2525
from nemo.tron.api import megatron_pretrain
2626
from nemo.tron.config import CheckpointConfig, ConfigContainer, LoggerConfig, SchedulerConfig, TrainingConfig
27+
from nemo.tron.losses import masked_next_token_loss
2728
from nemo.tron.state import GlobalState
2829
from nemo.tron.utils.common_utils import print_rank_0
2930

3031

31-
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
32-
"""Loss function.
33-
34-
Args:
35-
loss_mask (torch.Tensor): Used to mask out some portions of the loss
36-
output_tensor (torch.Tensor): The tensor with the losses
37-
38-
Returns:
39-
the loss scalar for this micro-batch
40-
the number of non-padded tokens in this microbatch
41-
a dict containing reporting metrics on the loss and number of tokens across
42-
the data parallel ranks
43-
"""
44-
state = GlobalState()
45-
losses = output_tensor.float()
46-
loss_mask = loss_mask.view(-1).float()
47-
total_tokens = loss_mask.sum()
48-
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
49-
50-
if state.cfg.model_config.context_parallel_size > 1:
51-
torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group())
52-
53-
# Reduce loss for logging.
54-
reporting_loss = loss.clone().detach()
55-
torch.distributed.all_reduce(reporting_loss, group=parallel_state.get_data_parallel_group())
56-
57-
local_num_tokens = loss[1].clone().detach().to(torch.int)
58-
return (
59-
loss[0] * state.cfg.model_config.context_parallel_size,
60-
local_num_tokens,
61-
{"lm loss": (reporting_loss[0], reporting_loss[1])},
62-
)
63-
64-
65-
def forward_step(data_iterator, model):
66-
"""Forward training step.
67-
68-
Args:
69-
data_iterator : Input data iterator
70-
model (NevaModel): The NeVA Model
71-
"""
72-
73-
timers = GlobalState().timers
32+
def forward_step(state: GlobalState, data_iterator, model):
33+
timers = state.timers
7434

7535
# Get the batch.
7636
timers("batch-generator", log_level=2).start()
@@ -80,7 +40,7 @@ def forward_step(data_iterator, model):
8040

8141
output_tensor = neva_forward_step(model, batch)
8242

83-
return output_tensor, partial(loss_func, loss_mask)
43+
return output_tensor, partial(masked_next_token_loss, loss_mask)
8444

8545

8646
def neva_dataset_provider(train_val_test_num_samples: list[int], dataset_config: MultimodalDatasetConfig):

0 commit comments

Comments
 (0)