Skip to content

Commit 1a89a88

Browse files
authored
Feat: Eagle3 HF Online - support nemotron and VLMs (#463)
## What does this PR do? **Type of change:** New feature <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** - Support the nano and nano-VL in eagle3 online mode: - Added submodule path detection for `base model`, `lm_head`, and `embeddings` to adapt different base model naming structure; - Refactored data loading/preprocessing to support VLM; - Attn backend improvement: - Added option of `sdpa` in case `flex_attn` doesn't work. - Added a unified TTT mask function that produce either `BlockMask` for flex_attn or tensor masks for regular attn. - Logging improvements: - Added estimated AR validation during training. This is available for both online and offline. - Plot estimated AR and training acc to wandb for better training visualization; - Fix: PTQ on speculative decoding model. ## Usage <!-- You can potentially add a usage example below. --> For VLM as base model, pass in extra arguments `--vlm_processor <hf_model_path> --vlm_img_dir <path to images>` in original launching commands. Other usage unchanged. E.g. ```bash ./launch_train.sh --model $MODEL \ --output_dir $OUTPUT_DIR \ --data $DATA \ --num_gpu 1 \ --num_epochs 2 \ --train_bs 2 \ --lr 3e-5 \ --eagle_config eagle_config.json \ --training_seq_len 4096 \ --vlm_processor $MODEL \ --vlm_img_dir <path to images> ``` ## Testing <!-- Mention how have you tested your change if applicable. --> Tested short training with HF Online training on following models: - `llama-3.2-1b` - data: daring-anteater - The new nano (Hyrbid LLM) - data: daring-anteater - The nano-VL - data: Llama-Nemotron-VLM-Dataset-v1/ocr_1 See loss decreasing and AR > 1. ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: h-guo18 <[email protected]>
1 parent f161794 commit 1a89a88

File tree

5 files changed

+434
-145
lines changed

5 files changed

+434
-145
lines changed

examples/speculative_decoding/eagle_config.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
"original_max_position_embeddings": 8192,
77
"rope_type": "llama3"
88
},
9-
"initializer_range": 0.02
9+
"initializer_range": 0.02,
10+
"_attn_implementation": "sdpa"
1011
}

0 commit comments

Comments
 (0)