-
Notifications
You must be signed in to change notification settings - Fork 228
Enable loading ckpt for t0 finetuning #309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
90b8f46
abdd703
0fcb19c
63daa46
89460c0
fb8ecb8
a55d2fb
2dfe5d1
ca740f1
cb0313b
b62dcaf
b15ca2d
dc8d0ab
0a32459
2699721
1e77844
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,6 +47,21 @@ def model_provider(pre_process=True, post_process=True): | |
| see_memory_usage(f"After Building Model", force=True) | ||
| return model | ||
|
|
||
| def visualize_model_inputs(tokens, attention_mask, labels, loss_mask): | ||
| import os | ||
| if os.path.exists("batchoutput.json"): | ||
| return | ||
| out = { | ||
| "tokens": tokens[0,:].tolist(), | ||
| "labels": labels[0,:].tolist(), | ||
| "attention_mask": attention_mask[0,:].tolist(), | ||
| "loss_mask": loss_mask[0,:].tolist(), | ||
| } | ||
| import json | ||
| with open('batchoutput.json', 'w') as fp: | ||
| json.dump(out, fp) | ||
|
|
||
|
|
||
| def get_batch_pipe(data): | ||
| """ | ||
| Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator` & in packed fashion | ||
|
|
@@ -83,17 +98,20 @@ def get_batch_pipe(data): | |
| ) | ||
| # Only compute loss over causal target tokens, i.e. ignore input_tokens & padding | ||
| loss_on_targets_only = ~data_c["decoder_is_inputs"][:, 1:] | ||
| loss_on_non_pad_only = (tokens != tokenizer.pad) | ||
| loss_on_non_pad_only = (labels != tokenizer.pad) | ||
| loss_mask *= loss_on_targets_only * loss_on_non_pad_only | ||
|
|
||
| attention_mask = get_packed_attention_mask( | ||
| # Run non-causal decoder | ||
| is_causal=False, | ||
| causal_mask=~(causal_mask.bool()), | ||
| is_causal=True, | ||
|
||
| causal_mask=~(causal_mask.bool()), # Turn back into tril being ones | ||
| decoder_is_inputs=decoder_is_inputs.bool(), | ||
| segment_ids=segment_ids.long(), | ||
| ) | ||
|
|
||
| # Helper script | ||
| # visualize_model_inputs(tokens, attention_mask, labels, loss_mask) | ||
|
|
||
Muennighoff marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]: | ||
| raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.") | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -261,11 +261,18 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode | |||||
| - segment_ids: torch.IntTensor [batch_size, sequence_length] | ||||||
| Returns: | ||||||
| - attention_mask: torch.BoolTensor [batch_size, 1, sequence_length, sequence_length] | ||||||
|
|
||||||
| Input example for the mask examples: | ||||||
| att_mask_batch = 1 | ||||||
| seq_length = 7 | ||||||
| decoder_is_inputs = torch.tensor([[1, 1, 0, 1, 1, 0, 0]]) | ||||||
| segment_ids = torch.tensor([[1, 1, 1, 2, 2, 2, 0]]) | ||||||
| causal_mask = torch.tril(torch.ones(att_mask_batch, seq_length, seq_length)).view(att_mask_batch, 1, seq_length, seq_length) | ||||||
| """ | ||||||
|
|
||||||
| """Causal Inputs Mask: | ||||||
| mask = [[[[1, 1, 0, 0, 0, 0, 0], | ||||||
| [1, 1, 0, 0, 0, 0, 0], | ||||||
| mask = [[[[1, 1, 0, 1, 1, 0, 0], | ||||||
| [1, 1, 0, 1, 1, 0, 0], | ||||||
Muennighoff marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| [1, 1, 1, 0, 0, 0, 0], | ||||||
| [1, 1, 1, 1, 1, 0, 0], | ||||||
| [1, 1, 1, 1, 1, 0, 0], | ||||||
|
|
@@ -299,7 +306,7 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode | |||||
| [0, 0, 0, 1, 1, 1, 0], | ||||||
| [0, 0, 0, 1, 1, 1, 0], | ||||||
| [0, 0, 0, 1, 1, 1, 0], | ||||||
| [0, 0, 0, 0, 0, 0, 0]]]] | ||||||
| [0, 0, 0, 0, 0, 0, 1]]]] | ||||||
Muennighoff marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| """ | ||||||
| segment_mask = segment_ids[:, None, :, None] == segment_ids[:, None, None, :] | ||||||
|
|
||||||
|
|
@@ -311,13 +318,22 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode | |||||
| [0, 0, 0, 1, 1, 0, 0], | ||||||
| [0, 0, 0, 1, 1, 1, 0], | ||||||
| [0, 0, 0, 0, 0, 0, 0]]]] | ||||||
|
|
||||||
| If is_causal=True: | ||||||
| mask = [[[[1, 0, 0, 0, 0, 0, 0], | ||||||
| [1, 1, 0, 0, 0, 0, 0], | ||||||
| [1, 1, 1, 0, 0, 0, 0], | ||||||
| [0, 0, 0, 1, 0, 0, 0], | ||||||
| [0, 0, 0, 1, 1, 0, 0], | ||||||
| [0, 0, 0, 1, 1, 1, 0], | ||||||
| [0, 0, 0, 0, 0, 0, 0]]]] | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think there is a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hum I'm wondering if this doesn't screw something up. Essentially you're going to compute softmax on a row with only zeros ... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The last row & last col are the attention scores of the last token with respect to the last token. Since the last token is masked out in our loss_mask it doesn't matter I think. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No you compute softmax, what should be the result of the softmax of a row full of masked out values .... It feels like that would return lots of Nans. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we fill it with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can try writing a test but I would be pretty sure that the actual results are 0. (with current kernel) |
||||||
|
|
||||||
| """ | ||||||
| attention_mask = causal_inputs_mask * padding_mask * segment_mask | ||||||
|
|
||||||
| # Convert attention mask to binary: | ||||||
| attention_mask = (attention_mask < 0.5) | ||||||
| attention_mask = causal_inputs_mask * padding_mask * segment_mask | ||||||
|
|
||||||
| return attention_mask | ||||||
| # True for places we do not want to attend to | ||||||
| return ~attention_mask | ||||||
|
|
||||||
| def param_size(parameter): | ||||||
| return parameter.ds_numel if hasattr(parameter, 'ds_id') else parameter.nelement() | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -82,6 +82,8 @@ def encode(self, json_line): | |
| ids = {} | ||
| for key in self.args.json_keys: | ||
| text = data[key] | ||
| if self.args.prepend_space: | ||
| text = f" {text}" | ||
| doc_ids = [] | ||
| for sentence in Encoder.splitter.tokenize(text): | ||
| sentence_ids = Encoder.tokenizer.tokenize(sentence) | ||
|
|
@@ -117,6 +119,8 @@ def get_args(): | |
| help='Path to the BPE merge file (if necessary).') | ||
| group.add_argument('--append-eod', action='store_true', | ||
| help='Append an <eod> token to the end of a document.') | ||
| group.add_argument('--prepend-space', action='store_true', | ||
| help='Prepends a space to the beginning of a document') | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a mention in which context it's useful, typically it is when you compute targets. |
||
| group.add_argument("--tokenizer-name-or-path", type=str, default=None, | ||
| help="Name or path of the huggingface tokenizer.") | ||
| group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.