Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions finetune_t0_non_causal_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ def model_provider(pre_process=True, post_process=True):
see_memory_usage(f"After Building Model", force=True)
return model

def visualize_model_inputs(tokens, attention_mask, labels, loss_mask):
import os
if os.path.exists("batchoutput.json"):
return
out = {
"tokens": tokens[0,:].tolist(),
"labels": labels[0,:].tolist(),
"attention_mask": attention_mask[0,:].tolist(),
"loss_mask": loss_mask[0,:].tolist(),
}
import json
with open('batchoutput.json', 'w') as fp:
json.dump(out, fp)


def get_batch_pipe(data):
"""
Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator` & in packed fashion
Expand Down Expand Up @@ -83,17 +98,20 @@ def get_batch_pipe(data):
)
# Only compute loss over causal target tokens, i.e. ignore input_tokens & padding
loss_on_targets_only = ~data_c["decoder_is_inputs"][:, 1:]
loss_on_non_pad_only = (tokens != tokenizer.pad)
loss_on_non_pad_only = (labels != tokenizer.pad)
loss_mask *= loss_on_targets_only * loss_on_non_pad_only

attention_mask = get_packed_attention_mask(
# Run non-causal decoder
is_causal=False,
causal_mask=~(causal_mask.bool()),
is_causal=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's rename this file finetune_t0_causal_decoder then

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about just finetune_t0.py?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right but do we hardcode this everytime? I'd rather have this one be the script for causal decoder.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an argument prefixlm

causal_mask=~(causal_mask.bool()), # Turn back into tril being ones
decoder_is_inputs=decoder_is_inputs.bool(),
segment_ids=segment_ids.long(),
)

# Helper script
# visualize_model_inputs(tokens, attention_mask, labels, loss_mask)

if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]:
raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.")

Expand Down
2 changes: 1 addition & 1 deletion megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True

if args.deepspeed:
load_optimizer_states = False if args.no_load_optim else True
loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_optimizer_states=load_optimizer_states)
loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_optimizer_states=load_optimizer_states, load_lr_scheduler_states=load_optimizer_states)
if loaded_dir is None:
print_rank_0('WARNING: could not find the metadata file {} '.format(
load_dir))
Expand Down
16 changes: 15 additions & 1 deletion megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,16 @@ def load_state_dict(self, state_dict, strict=True):
state_dict = state_dict[self._language_model_key]
self.language_model.load_state_dict(state_dict, strict=strict)

def visualize_outputs(losses):
import os
if os.path.exists("losses.json"):
return
out = {
"losses": losses[0,:].tolist(),
}
import json
with open('losses.json', 'w') as fp:
json.dump(out, fp)

def get_cross_entropy(is_prefix: bool):
def CrossEntropy(output, labels):
Expand All @@ -167,6 +177,9 @@ def CrossEntropy(output, labels):

losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels)

# Helper script
# visualize_outputs(losses)

if is_prefix:
micro_batch_size, sequence_length = loss_mask.shape
average_tokens_per_sample: torch.Tensor
Expand Down Expand Up @@ -252,7 +265,8 @@ def _to_float16(inputs):
args.num_layers),
layer_number=layer_idx,
# TODO: Change naming of class from GPT to something that encapsulate prefix lm.
self_attn_mask_type=attn_mask_type))
self_attn_mask_type=attn_mask_type)
)

# Undo data format change
def undo(x):
Expand Down
30 changes: 23 additions & 7 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,18 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode
- segment_ids: torch.IntTensor [batch_size, sequence_length]
Returns:
- attention_mask: torch.BoolTensor [batch_size, 1, sequence_length, sequence_length]

Input example for the mask examples:
att_mask_batch = 1
seq_length = 7
decoder_is_inputs = torch.tensor([[1, 1, 0, 1, 1, 0, 0]])
segment_ids = torch.tensor([[1, 1, 1, 2, 2, 2, 0]])
causal_mask = torch.tril(torch.ones(att_mask_batch, seq_length, seq_length)).view(att_mask_batch, 1, seq_length, seq_length)
"""

"""Causal Inputs Mask:
mask = [[[[1, 1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
mask = [[[[1, 1, 0, 1, 1, 0, 0],
[1, 1, 0, 1, 1, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0, 0],
Expand Down Expand Up @@ -299,7 +306,7 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0]]]]
[0, 0, 0, 0, 0, 0, 1]]]]
"""
segment_mask = segment_ids[:, None, :, None] == segment_ids[:, None, None, :]

Expand All @@ -311,13 +318,22 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode
[0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0]]]]

If is_causal=True:
mask = [[[[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0]]]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[0, 0, 0, 0, 0, 0, 0]]]]
[0, 0, 0, 0, 0, 0, 1]]]]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is a 1 , because the last row & column is 100% padding

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum I'm wondering if this doesn't screw something up. Essentially you're going to compute softmax on a row with only zeros ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last row & last col are the attention scores of the last token with respect to the last token. Since the last token is masked out in our loss_mask it doesn't matter I think.
Also it's a row with only -inf, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No you compute softmax, what should be the result of the softmax of a row full of masked out values .... It feels like that would return lots of Nans.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we fill it with -inf?
And the softmax of a row where all values are the same is just 1/n, no? Where would it cause NaNs?

Copy link
Member

@thomasw21 thomasw21 Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can try writing a test but I would be pretty sure that the actual results are 0. (with current kernel)


"""
attention_mask = causal_inputs_mask * padding_mask * segment_mask

# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
attention_mask = causal_inputs_mask * padding_mask * segment_mask

return attention_mask
# True for places we do not want to attend to
return ~attention_mask

def param_size(parameter):
return parameter.ds_numel if hasattr(parameter, 'ds_id') else parameter.nelement()
Expand Down
4 changes: 4 additions & 0 deletions tools/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def encode(self, json_line):
ids = {}
for key in self.args.json_keys:
text = data[key]
if self.args.prepend_space:
text = f" {text}"
doc_ids = []
for sentence in Encoder.splitter.tokenize(text):
sentence_ids = Encoder.tokenizer.tokenize(sentence)
Expand Down Expand Up @@ -117,6 +119,8 @@ def get_args():
help='Path to the BPE merge file (if necessary).')
group.add_argument('--append-eod', action='store_true',
help='Append an <eod> token to the end of a document.')
group.add_argument('--prepend-space', action='store_true',
help='Prepends a space to the beginning of a document')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a mention in which context it's useful, typically it is when you compute targets.

group.add_argument("--tokenizer-name-or-path", type=str, default=None,
help="Name or path of the huggingface tokenizer.")
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
Expand Down