Skip to content

Commit c35d2a7

Browse files
authored
test: Get Eagle tests working (NVIDIA#3593)
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent e70961f commit c35d2a7

File tree

5 files changed

+79
-12
lines changed

5 files changed

+79
-12
lines changed

examples/eagle/convert_checkpoint.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,14 @@ def copy(tensors):
295295
args.n_positions = hf_config.max_position_embeddings
296296
args.dtype = str(
297297
hf_config.torch_dtype)[6:] if args.dtype == 'auto' else args.dtype
298+
if 'head_dim' in hf_config:
299+
args.head_dim = hf_config.head_dim
300+
else:
301+
args.head_dim = args.n_embd // args.n_head
302+
if 'head_size' in hf_config:
303+
args.head_size = hf_config.head_size
304+
else:
305+
args.head_size = args.head_dim
298306

299307
if args.eagle_model_dir is None:
300308
hf_config_eagle = hf_config.eagle
@@ -305,6 +313,14 @@ def copy(tensors):
305313
args.n_kv_head_eagle = hf_config_eagle['num_key_value_heads']
306314
args.rms_norm_eps_eagle = hf_config_eagle['rms_norm_eps']
307315
args.n_positions_eagle = hf_config_eagle['max_position_embeddings']
316+
if 'head_dim' in hf_config_eagle:
317+
args.head_dim_eagle = hf_config_eagle['head_dim']
318+
else:
319+
args.head_dim_eagle = args.n_embd_eagle // args.n_head_eagle
320+
if 'head_size' in hf_config_eagle:
321+
args.head_size_eagle = hf_config_eagle['head_size']
322+
else:
323+
args.head_size_eagle = args.head_dim_eagle
308324
else:
309325
hf_config_eagle = LlamaConfig.from_pretrained(args.eagle_model_dir)
310326
args.n_head_eagle = hf_config_eagle.num_attention_heads
@@ -314,6 +330,14 @@ def copy(tensors):
314330
args.n_kv_head_eagle = hf_config_eagle.num_key_value_heads
315331
args.rms_norm_eps_eagle = hf_config_eagle.rms_norm_eps
316332
args.n_positions_eagle = hf_config_eagle.max_position_embeddings
333+
if 'head_dim' in hf_config_eagle:
334+
args.head_dim_eagle = hf_config_eagle.head_dim
335+
else:
336+
args.head_dim_eagle = args.n_embd_eagle // args.n_head_eagle
337+
if 'head_size' in hf_config_eagle:
338+
args.head_size_eagle = hf_config_eagle.head_size
339+
else:
340+
args.head_size_eagle = args.head_dim_eagle
317341

318342
elif args.meta_ckpt_dir is not None:
319343
assert False, "meta ckpt is not supported yet"
@@ -370,6 +394,8 @@ def copy(tensors):
370394
},
371395
'use_parallel_embedding': args.use_parallel_embedding,
372396
'embedding_sharding_dim': args.embedding_sharding_dim,
397+
'head_dim': args.head_dim_eagle,
398+
'head_size': args.head_size_eagle
373399
}
374400

375401
config = {
@@ -402,7 +428,9 @@ def copy(tensors):
402428
'max_draft_len': args.max_draft_len,
403429
'num_eagle_layers': args.num_eagle_layers,
404430
'max_non_leaves_per_layer': args.max_non_leaves_per_layer,
405-
'eagle_net_config': eagle_net_config
431+
'eagle_net_config': eagle_net_config,
432+
'head_dim': args.head_dim,
433+
'head_size': args.head_size
406434
}
407435

408436
assert args.max_draft_len <= 256, "args.max_draft_len > 256 is not supported"

tensorrt_llm/models/eagle/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ def from_hugging_face(
8888
n_positions = hf_config.max_position_embeddings
8989
hidden_act = hf_config.hidden_act
9090
dtype = str(hf_config.torch_dtype)[6:] if dtype == 'auto' else dtype
91+
if hasattr(hf_config, 'head_dim'):
92+
head_dim = hf_config.head_dim
93+
else:
94+
head_dim = hf_config.n_embd // hf_config.n_head
95+
if hasattr(hf_config, 'head_size'):
96+
head_size = hf_config.head_size
97+
else:
98+
head_size = head_dim
9199

92100
if speculative_config_or_dir is None:
93101
hf_config_eagle = hf_config.eagle
@@ -143,6 +151,8 @@ def from_hugging_face(
143151
},
144152
'use_parallel_embedding': kwargs['use_parallel_embedding'],
145153
'embedding_sharding_dim': kwargs['embedding_sharding_dim'],
154+
'head_dim': head_dim,
155+
'head_size': head_size
146156
}
147157

148158
config = {

tests/integration/defs/common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,11 +945,22 @@ def get_dummy_spec_decoding_heads(hf_model_dir,
945945
)
946946

947947
quant_cfg = getattr(mtq, "FP8_DEFAULT_CFG")
948+
# Following quantizers are needed for KV cache quantization.
948949
quant_cfg["quant_cfg"]["*output_quantizer"] = {
949950
"num_bits": (4, 3),
950951
"axis": None,
951952
"enable": True,
952953
}
954+
quant_cfg["quant_cfg"]["*k_bmm_quantizer"] = {
955+
"num_bits": (4, 3),
956+
"axis": None,
957+
"enable": True,
958+
}
959+
quant_cfg["quant_cfg"]["*v_bmm_quantizer"] = {
960+
"num_bits": (4, 3),
961+
"axis": None,
962+
"enable": True,
963+
}
953964

954965
calibrate_loop = dataset_utils.create_forward_loop(
955966
calib_dataloader, dataloader=calib_dataloader)

tests/integration/defs/examples/test_eagle.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -270,17 +270,6 @@ def test_codellama_eagle_1gpu(code_llama_model_root,
270270
llm_datasets_root=llm_datasets_root,
271271
llm_rouge_root=llm_rouge_root)
272272

273-
test_with_dummy_eagle(hf_model_root=code_llama_model_root,
274-
eagle_example_root=eagle_example_root,
275-
llm_venv=llm_venv,
276-
cmodel_dir=cmodel_dir,
277-
engine_dir=engine_dir,
278-
batch_size=batch_size,
279-
data_type=data_type,
280-
use_dynamic_tree=use_dynamic_tree,
281-
llm_datasets_root=llm_datasets_root,
282-
llm_rouge_root=llm_rouge_root)
283-
284273

285274
@pytest.mark.parametrize("use_dynamic_tree", [False, True],
286275
ids=['eagle1', 'eagle2'])
@@ -309,6 +298,33 @@ def test_mistral_eagle_1gpu(llm_mistral_model_root,
309298
llm_rouge_root=llm_rouge_root)
310299

311300

301+
@pytest.mark.parametrize("use_dynamic_tree", [False, True],
302+
ids=['eagle1', 'eagle2'])
303+
@pytest.mark.parametrize("mistral_nemo_model_root", ['Mistral-Nemo-12b-Base'],
304+
indirect=True)
305+
def test_mistral_nemo_eagle_1gpu(mistral_nemo_model_root,
306+
eagle_example_root,
307+
llm_datasets_root,
308+
llm_rouge_root,
309+
llm_venv,
310+
cmodel_dir,
311+
engine_dir,
312+
use_dynamic_tree,
313+
batch_size=8,
314+
data_type='bfloat16'):
315+
316+
test_with_dummy_eagle(hf_model_root=mistral_nemo_model_root,
317+
eagle_example_root=eagle_example_root,
318+
llm_venv=llm_venv,
319+
cmodel_dir=cmodel_dir,
320+
engine_dir=engine_dir,
321+
batch_size=batch_size,
322+
data_type=data_type,
323+
use_dynamic_tree=use_dynamic_tree,
324+
llm_datasets_root=llm_datasets_root,
325+
llm_rouge_root=llm_rouge_root)
326+
327+
312328
@pytest.mark.parametrize("use_dynamic_tree", [False, True],
313329
ids=['eagle1', 'eagle2'])
314330
@pytest.mark.parametrize("llm_qwen_model_root", [

tests/integration/test_lists/qa/examples_test_list.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ examples/test_eagle.py::test_llama_eagle_1gpu[llama-v2-7b-hf-eagle1]
500500
examples/test_eagle.py::test_llama_eagle_1gpu[llama-3.2-1b-eagle1]
501501
examples/test_eagle.py::test_llama_eagle_1gpu[llama-3.1-8b-eagle1]
502502
examples/test_eagle.py::test_mistral_eagle_1gpu[mistral-7b-v0.1-eagle1]
503+
examples/test_eagle.py::test_mistral_nemo_eagle_1gpu[Mistral-Nemo-12b-Base-eagle1]
503504
examples/test_eagle.py::test_qwen_eagle_1gpu[qwen_7b_chat-eagle1]
504505
examples/test_eagle.py::test_qwen_eagle_1gpu[qwen1.5_7b_chat-eagle1]
505506
examples/test_eagle.py::test_qwen_eagle_1gpu[qwen2_7b_instruct-eagle1]
@@ -514,6 +515,7 @@ examples/test_eagle.py::test_llama_eagle_1gpu[llama-v2-7b-hf-eagle2]
514515
examples/test_eagle.py::test_llama_eagle_1gpu[llama-3.2-1b-eagle2]
515516
examples/test_eagle.py::test_llama_eagle_1gpu[llama-3.1-8b-eagle2]
516517
examples/test_eagle.py::test_mistral_eagle_1gpu[mistral-7b-v0.1-eagle2]
518+
examples/test_eagle.py::test_mistral_nemo_eagle_1gpu[Mistral-Nemo-12b-Base-eagle2]
517519
examples/test_eagle.py::test_qwen_eagle_1gpu[qwen_7b_chat-eagle2]
518520
examples/test_eagle.py::test_qwen_eagle_1gpu[qwen1.5_7b_chat-eagle2]
519521
examples/test_eagle.py::test_qwen_eagle_1gpu[qwen2_7b_instruct-eagle2]

0 commit comments

Comments
 (0)