Commit 1a89a88
authored
Feat: Eagle3 HF Online - support nemotron and VLMs (#463)
## What does this PR do?
**Type of change:** New feature <!-- Use one of the following: Bug fix,
new feature, new example, new tests, documentation. -->
**Overview:**
- Support the nano and nano-VL in eagle3 online mode:
- Added submodule path detection for `base model`, `lm_head`, and
`embeddings` to adapt different base model naming structure;
- Refactored data loading/preprocessing to support VLM;
- Attn backend improvement:
- Added option of `sdpa` in case `flex_attn` doesn't work.
- Added a unified TTT mask function that produce either `BlockMask` for
flex_attn or tensor masks for regular attn.
- Logging improvements:
- Added estimated AR validation during training. This is available for
both online and offline.
- Plot estimated AR and training acc to wandb for better training
visualization;
- Fix: PTQ on speculative decoding model.
## Usage
<!-- You can potentially add a usage example below. -->
For VLM as base model, pass in extra arguments `--vlm_processor
<hf_model_path> --vlm_img_dir <path to images>` in original launching
commands. Other usage unchanged.
E.g.
```bash
./launch_train.sh --model $MODEL \
--output_dir $OUTPUT_DIR \
--data $DATA \
--num_gpu 1 \
--num_epochs 2 \
--train_bs 2 \
--lr 3e-5 \
--eagle_config eagle_config.json \
--training_seq_len 4096 \
--vlm_processor $MODEL \
--vlm_img_dir <path to images>
```
## Testing
<!-- Mention how have you tested your change if applicable. -->
Tested short training with HF Online training on following models:
- `llama-3.2-1b` - data: daring-anteater
- The new nano (Hyrbid LLM) - data: daring-anteater
- The nano-VL - data: Llama-Nemotron-VLM-Dataset-v1/ocr_1
See loss decreasing and AR > 1.
## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->
- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes/No <!--- If No, explain
why. -->
- **Did you write any new necessary tests?**: Yes/No
- **Did you add or update any necessary documentation?**: Yes/No
- **Did you update
[Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Yes/No <!--- Only for new features, API changes, critical bug fixes or
bw breaking changes. -->
## Additional Information
<!-- E.g. related issue. -->
---------
Signed-off-by: h-guo18 <[email protected]>1 parent f161794 commit 1a89a88
File tree
5 files changed
+434
-145
lines changed- examples/speculative_decoding
- modelopt/torch/speculative/plugins
5 files changed
+434
-145
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
9 | | - | |
| 9 | + | |
| 10 | + | |
10 | 11 | | |
0 commit comments