Skip to content

Commit 8579a98

Browse files
committed
Apply fixes used during training
1 parent c4b4824 commit 8579a98

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed

nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ def _get_header_conversation_type_mask_role(source, special_tokens):
5151
if TYPE_INSTRUCTION[data_type] != '':
5252
conversation = conversation + '\n' + TYPE_INSTRUCTION[data_type]
5353
mask_role = source.get('mask', 'User')
54-
header = f"{special_tokens['system_turn_start']}{SYSTEM_TOKEN}{END_NAME_SIGNAL}{conversation}{END_SIGNAL}"
54+
system_token = source.get("system_token", SYSTEM_TOKEN)
55+
header = f"{special_tokens['system_turn_start']}{system_token}{END_NAME_SIGNAL}{conversation}{END_SIGNAL}"
56+
# logging.info(f"DBG HEADER:\n```{header}```")
5557
conversation = _add_speaker_and_signal(header, source['conversations'], mask_role, data_type, special_tokens)
5658
return header, conversation, data_type, mask_role
5759

@@ -60,13 +62,14 @@ def get_prompt_template_example(special_tokens):
6062
source = {
6163
'system': '{system message}',
6264
'conversations': [
63-
{'from': 'User', 'value': '{turn 1 user message}', 'label': None},
64-
{'from': 'Assistant', 'value': '{turn 1 assistant message}', 'label': '{turn 1 assistant label}'},
65-
{'from': 'User', 'value': '{turn 2 user message}', 'label': None},
66-
{'from': 'Assistant', 'value': '{turn 2 assistant message}', 'label': '{turn 2 assistant label}'},
65+
{'from': '{user role}', 'value': '{turn 1 user message}', 'label': None},
66+
{'from': '{assistant role}', 'value': '{turn 1 assistant message}', 'label': '{turn 1 assistant label}'},
67+
{'from': '{user role}', 'value': '{turn 2 user message}', 'label': None},
68+
{'from': '{assistant role}', 'value': '{turn 2 assistant message}', 'label': '{turn 2 assistant label}'},
6769
],
68-
"mask": "User",
70+
"mask": "{user role}",
6971
"type": "VALUE_TO_TEXT",
72+
"system_token": '{system token}',
7073
}
7174
_, conversation, _, _ = _get_header_conversation_type_mask_role(source, special_tokens)
7275
return conversation
@@ -273,6 +276,7 @@ def preprocess(
273276
id1 = tokenizer.text_to_ids(PREFIX_STR + s["value"])
274277
id2 = tokenizer.text_to_ids(PREFIX_STR)
275278
tokenized_sentence = id1[len(id2) :]
279+
# logging.info(f"CONV DBG: {tokenized_sentence[0:20]} ... {tokenized_sentence[-20:]}")
276280
ids.append(torch.tensor(tokenized_sentence))
277281
tokenized_lens.append(len(tokenized_sentence))
278282
speakers = [sentence["from"] for sentence in source['conversations']]
@@ -326,6 +330,8 @@ def _build_samples_mapping(self):
326330
id2 = self.tokenizer.text_to_ids(PREFIX_STR)
327331
self.num_turn_start_tokens = len(id1) - len(id2)
328332

333+
# logging.info(f"DATASET DBG:\n{self.special_tokens=}\n{self.label_start_tokens=}, {self.name_end_token_ids=}, {self.num_turn_start_tokens=}")
334+
329335
def _process_example(self, example):
330336
"""
331337
Create an example by concatenating text and answer.

nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def __getitem__(self, idx):
228228
else:
229229
auto_gen_idx = False
230230
try:
231+
idx = int(idx)
231232
example = self.indexed_dataset[idx]
232233
if auto_gen_idx:
233234
example['__AUTOGENERATED__'] = True
@@ -542,6 +543,7 @@ def __getitem__(self, idx):
542543
# assert idx < len(self.samples_mapping)
543544
idx = self.samples_mapping[idx]
544545

546+
idx = int(idx)
545547
input_ids = self.indexed_dataset[idx]['input_ids']
546548
seq_boundaries = self.indexed_dataset[idx]['seq_start_id'] + [len(input_ids)]
547549
loss_mask = self.indexed_dataset[idx]['loss_mask']

nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,9 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
12961296
'loss_mask': batch['loss_mask'],
12971297
}
12981298

1299+
# if torch.distributed.get_rank() == 0:
1300+
# logging.info(f"*****DEBUG OUTPUT*****\nTOKENS:\n{batch['tokens'][0].tolist()}\nPOSITION_IDS:\n{batch['position_ids'][0].tolist()}\nLABELS:\n{batch['labels'][0].tolist()}\nLOSS_MASK:\n{batch['loss_mask'][0].tolist()}\nATTENTION_MASK:\n{None if batch['attention_mask'] is None else batch['attention_mask'][0].tolist()}\n")
1301+
12991302
if not self.mcore_gpt:
13001303
forward_args['checkpoint_activations_all_layers'] = checkpoint_activations_all_layers
13011304
if not self.use_loss_mask:
@@ -1592,7 +1595,7 @@ def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor):
15921595
losses = output_tensor.float()
15931596
loss_mask = loss_mask.view(-1).float()
15941597
# TODO: add nemo version here
1595-
loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens_in_ub # sequence level nll
1598+
loss = torch.sum(losses.view(-1) * loss_mask) / max(1, num_valid_tokens_in_ub) # sequence level nll
15961599
if parallel_state.get_context_parallel_world_size() > 1:
15971600
torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group())
15981601
return loss

0 commit comments

Comments
 (0)