-
Notifications
You must be signed in to change notification settings - Fork 192
Feat: Eagle3 HF Online - support nemotron models #463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
| device = self.model.layers[-1].self_attn.q_proj.weight.device | ||
| elif hasattr(self.model.layers[-1].self_attn, "qkv_proj"): | ||
| device = self.model.layers[-1].self_attn.qkv_proj.weight.device | ||
| self.eagle_module.to(self.dtype).to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: confirm this device detection with @yeyu-nvidia
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
8eb6abf to
a85d473
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #463 +/- ##
=======================================
Coverage 73.38% 73.38%
=======================================
Files 180 180
Lines 18110 18110
=======================================
Hits 13290 13290
Misses 4820 4820 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: h-guo18 <[email protected]>
| input_ids = output.input_ids[0] | ||
| attention_mask = output.attention_mask[0] | ||
| loss_mask = torch.ones_like(input_ids) | ||
| labels = torch.full_like(input_ids, IGNORE_TOKEN_ID) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So all labels are IGNORE_TOKEN_ID?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. This is aligned with previous behavior:
| labels = torch.full_like(input_ids, IGNORE_TOKEN_ID) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but previously we will update labels here
| labels[indices] = input_ids[indices] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Do you think we will use labels in the future? Otherwise we can get rid of loss_mask and labels and simplify data loading.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Labels should only be useful when we want to tune the base model. it's currently not used in our code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will need labels (even if it's dummy) as HF trainer needs this for training. We definitely need loss_mask as we need to exclude padded tokens. I would say we keep them.
| return ret | ||
|
|
||
|
|
||
| class OfflineSupervisedDataset(Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this support VLM data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet. Will update to this PR later.
| if wandb and is_master(): | ||
| wandb.init() | ||
|
|
||
| def on_log(self, args, state, control, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain how you estimate AR? I'm not sure it's a good idea to expose "estimated AR" as it may mislead users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is calculated by 1 + step_1_acc + step_1_accstep_2_acc + step_1_accstep_2_acc*step_3_acc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we have discussed this and agree on not to use estimated AR. We can either just use acc or run real AR validation.
| metadata={"help": "Path to the d2t cache directory."}, | ||
| ) | ||
| vlm_img_dir: str = field(default=None, metadata={"help": "Path to the VLM image directory."}) | ||
| vlm_processor: str = field(default=None, metadata={"help": "Path to the VLM processor."}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is VLM processor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a processor that nano-vl use to pre-process the image and the text into tokens. It is defined here:
https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16/blob/main/processing.py#L43
| for param in self.model.embed_tokens.parameters(): | ||
| # find base model, lm head, and embeddings paths | ||
| self._find_base_model_parts() | ||
| self.eagle_module.to(self._base_model.dtype).to(self._base_model_lm_head.weight.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to check if ptq/inference fails. We want to make sure eagle_module.device is the same as last base model decoder layer, but this is not necessarily the same as lm_head.device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you confirm on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the idea. I think it make sense to put eagle on last layer's device.
|
|
||
| dtypemin = torch.finfo(self._base_llm_config.dtype).min | ||
| q_len = seq_length | ||
| kv_len = seq_length * (2 + ttt_step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why 2 + ttt_step?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At 0th ttt step, we have kv_len = 2*seq_len
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is that?
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
What does this PR do?
Type of change: New feature
Overview:
base model,lm_head, andembeddingsto adapt different base model naming structure;sdpain caseflex_attndoesn't work.BlockMaskfor flex_attn or tensor masks for regular attn.Usage
For VLM as base model, pass in extra arguments
--vlm_processor <hf_model_path> --vlm_img_dir <path to images>in original launching commands. Other usage unchanged.E.g.
Testing
Tested short training with HF Online training on following models:
llama-3.2-1b- data: daring-anteaterSee loss decreasing and AR > 1.
Before your PR is "Ready for review"
Additional Information