2020from megatron .core .datasets .multimodal_dataset import MockMultimodalDataset , MultimodalDatasetConfig
2121from megatron .core .optimizer import OptimizerConfig
2222
23- from nemo .collections import vlm
23+ from nemo .collections import llm , vlm
2424from nemo .collections .vlm .neva .model .base import neva_data_step , neva_forward_step
2525from nemo .tron .api import megatron_pretrain
2626from nemo .tron .config import CheckpointConfig , ConfigContainer , LoggerConfig , SchedulerConfig , TrainingConfig
27+ from nemo .tron .losses import masked_next_token_loss
2728from nemo .tron .state import GlobalState
2829from nemo .tron .utils .common_utils import print_rank_0
2930
3031
31- def loss_func (loss_mask : torch .Tensor , output_tensor : torch .Tensor ):
32- """Loss function.
33-
34- Args:
35- loss_mask (torch.Tensor): Used to mask out some portions of the loss
36- output_tensor (torch.Tensor): The tensor with the losses
37-
38- Returns:
39- the loss scalar for this micro-batch
40- the number of non-padded tokens in this microbatch
41- a dict containing reporting metrics on the loss and number of tokens across
42- the data parallel ranks
43- """
44- state = GlobalState ()
45- losses = output_tensor .float ()
46- loss_mask = loss_mask .view (- 1 ).float ()
47- total_tokens = loss_mask .sum ()
48- loss = torch .cat ([torch .sum (losses .view (- 1 ) * loss_mask ).view (1 ), total_tokens .view (1 )])
49-
50- if state .cfg .model_config .context_parallel_size > 1 :
51- torch .distributed .all_reduce (loss , group = parallel_state .get_context_parallel_group ())
52-
53- # Reduce loss for logging.
54- reporting_loss = loss .clone ().detach ()
55- torch .distributed .all_reduce (reporting_loss , group = parallel_state .get_data_parallel_group ())
56-
57- local_num_tokens = loss [1 ].clone ().detach ().to (torch .int )
58- return (
59- loss [0 ] * state .cfg .model_config .context_parallel_size ,
60- local_num_tokens ,
61- {"lm loss" : (reporting_loss [0 ], reporting_loss [1 ])},
62- )
63-
64-
65- def forward_step (data_iterator , model ):
66- """Forward training step.
67-
68- Args:
69- data_iterator : Input data iterator
70- model (NevaModel): The NeVA Model
71- """
72-
73- timers = GlobalState ().timers
32+ def forward_step (state : GlobalState , data_iterator , model ):
33+ timers = state .timers
7434
7535 # Get the batch.
7636 timers ("batch-generator" , log_level = 2 ).start ()
@@ -80,7 +40,7 @@ def forward_step(data_iterator, model):
8040
8141 output_tensor = neva_forward_step (model , batch )
8242
83- return output_tensor , partial (loss_func , loss_mask )
43+ return output_tensor , partial (masked_next_token_loss , loss_mask )
8444
8545
8646def neva_dataset_provider (train_val_test_num_samples : list [int ], dataset_config : MultimodalDatasetConfig ):
0 commit comments