Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion finetune_t0_non_causal_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def get_batch_pipe(data):
attention_mask = get_packed_attention_mask(
# Run non-causal decoder
is_causal=False,
causal_mask=~(causal_mask.bool()),
causal_mask=~(causal_mask.bool()), # Turn back into tril being ones
decoder_is_inputs=decoder_is_inputs.bool(),
segment_ids=segment_ids.long(),
)
Expand Down
2 changes: 1 addition & 1 deletion megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True

if args.deepspeed:
load_optimizer_states = False if args.no_load_optim else True
loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_optimizer_states=load_optimizer_states)
loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_optimizer_states=load_optimizer_states, load_lr_scheduler_states=load_optimizer_states)
if loaded_dir is None:
print_rank_0('WARNING: could not find the metadata file {} '.format(
load_dir))
Expand Down
3 changes: 2 additions & 1 deletion megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ def _to_float16(inputs):
args.num_layers),
layer_number=layer_idx,
# TODO: Change naming of class from GPT to something that encapsulate prefix lm.
self_attn_mask_type=attn_mask_type))
self_attn_mask_type=attn_mask_type)
)

# Undo data format change
def undo(x):
Expand Down
30 changes: 23 additions & 7 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,18 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode
- segment_ids: torch.IntTensor [batch_size, sequence_length]
Returns:
- attention_mask: torch.BoolTensor [batch_size, 1, sequence_length, sequence_length]

Input example for the mask examples:
att_mask_batch = 1
seq_length = 7
decoder_is_inputs = torch.tensor([[1, 1, 1, 2, 2, 2, 0]])
segment_ids = torch.tensor([[1, 1, 0, 1, 1, 0, 0]])
causal_mask = torch.tril(torch.ones(att_mask_batch, seq_length, seq_length)).view(att_mask_batch, 1, seq_length, seq_length)
"""

"""Causal Inputs Mask:
mask = [[[[1, 1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
mask = [[[[1, 1, 0, 1, 1, 0, 0],
[1, 1, 0, 1, 1, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0, 0],
Expand Down Expand Up @@ -299,7 +306,7 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0]]]]
[0, 0, 0, 0, 0, 0, 1]]]]
"""
segment_mask = segment_ids[:, None, :, None] == segment_ids[:, None, None, :]

Expand All @@ -311,13 +318,22 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode
[0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0]]]]

If is_causal=True:
mask = [[[[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0]]]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[0, 0, 0, 0, 0, 0, 0]]]]
[0, 0, 0, 0, 0, 0, 1]]]]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is a 1 , because the last row & column is 100% padding

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum I'm wondering if this doesn't screw something up. Essentially you're going to compute softmax on a row with only zeros ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last row & last col are the attention scores of the last token with respect to the last token. Since the last token is masked out in our loss_mask it doesn't matter I think.
Also it's a row with only -inf, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No you compute softmax, what should be the result of the softmax of a row full of masked out values .... It feels like that would return lots of Nans.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we fill it with -inf?
And the softmax of a row where all values are the same is just 1/n, no? Where would it cause NaNs?

Copy link
Member

@thomasw21 thomasw21 Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can try writing a test but I would be pretty sure that the actual results are 0. (with current kernel)


"""
attention_mask = causal_inputs_mask * padding_mask * segment_mask

# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
attention_mask = causal_inputs_mask * padding_mask * segment_mask

return attention_mask
# True for places we do not want to attend to
return ~attention_mask

def param_size(parameter):
return parameter.ds_numel if hasattr(parameter, 'ds_id') else parameter.nelement()
Expand Down