Skip to content

Commit 4a85e76

Browse files
committed
fix eagle3 fp8 target model + bf16 draft model
Signed-off-by: Dylan Chen <[email protected]>
1 parent 6ce0624 commit 4a85e76

File tree

7 files changed

+53
-34
lines changed

7 files changed

+53
-34
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ struct FusedQKVMaskedAttentionDispatchParams
120120
bool block_sparse_attention = false;
121121
BlockSparseParams block_sparse_params;
122122
int32_t const* mrope_position_deltas;
123+
bool is_eagle3 = false;
123124
};
124125

125126
template <typename T, typename KVCacheBuffer>
@@ -645,7 +646,8 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS
645646
params.ia3_key_weights = reinterpret_cast<DataType const*>(input_params.ia3_key_weights);
646647
params.ia3_value_weights = reinterpret_cast<DataType const*>(input_params.ia3_value_weights);
647648

648-
if (input_params.quant_option.hasStaticActivationScaling() || input_params.fp8_context_fmha)
649+
if ((input_params.quant_option.hasStaticActivationScaling() || input_params.fp8_context_fmha)
650+
&& !input_params.is_eagle3)
649651
{
650652
// qkv_scale_out is nullptr currently (no scale).
651653
params.qkv_scale_quant_orig = input_params.qkv_scale_out;
@@ -2407,6 +2409,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
24072409
dispatch_params.block_sparse_attention = mMaskType == AttentionMaskType::BLOCKSPARSE;
24082410
dispatch_params.block_sparse_params = mBlockSparseParams;
24092411
dispatch_params.mrope_position_deltas = params.mrope_position_deltas;
2412+
dispatch_params.is_eagle3 = mIsEagle3;
24102413

24112414
using DataType = typename SATypeConverter<T>::Type;
24122415
if (!isCrossAttention())
@@ -2614,7 +2617,7 @@ int AttentionOp::initialize() noexcept
26142617
fmhaParams.dataTypeOut = mFP8AttenOutput ? DATA_TYPE_E4M3 : data_type;
26152618

26162619
// FP8 FMHA should be used with fp8 workflow together.
2617-
if (mFP8ContextFMHA || mFP8ContextMLA)
2620+
if ((mFP8ContextFMHA || mFP8ContextMLA) && !mIsEagle3)
26182621
{
26192622
data_type = DATA_TYPE_E4M3;
26202623
}
@@ -2624,7 +2627,7 @@ int AttentionOp::initialize() noexcept
26242627
// The KV input data type. The default is same as dataType.
26252628
fmhaParams.dataTypeKv = fmhaParams.dataType;
26262629
// If the kernel must read from KV cache, set the dtype correctly.
2627-
if (mPagedKVCache && mPagedContextFMHA)
2630+
if (mPagedKVCache && mPagedContextFMHA && !mIsEagle3)
26282631
{
26292632
if (mKVCacheQuantMode.hasFp8KvCache())
26302633
{

cpp/tensorrt_llm/common/attentionOp.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ class AttentionOp
420420
bool mIsSpecDecodingEnabled = false;
421421
bool mUseSpecDecoding = false;
422422
bool mIsSpecDecTree = true;
423+
bool mIsEagle3 = false;
423424
bool mSpecDecodingIsGenerationLengthVariable = false;
424425
int32_t mSpecDecodingMaxGenerationLength = 1;
425426
bool mIsMLAEnabled = false;
@@ -470,11 +471,11 @@ class AttentionOp
470471
mUnfuseQkvGemm, (int32_t) mType, mMaxContextLength, mQKVBiasEnabled, mCrossAttention, mMaxDistance,
471472
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mChunkPrefillBufferBatchSize, mFP8AttenOutput,
472473
mDenseContextFMHA, mHasFullAttentionMask, mIsSpecDecodingEnabled, mUseSpecDecoding, mIsSpecDecTree,
473-
mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mIsGenerationMLA,
474-
mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads,
475-
mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast,
476-
mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
477-
mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
474+
mIsEagle3, mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength, mIsMLAEnabled,
475+
mIsGenerationMLA, mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads,
476+
mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank,
477+
mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache,
478+
mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
478479
};
479480

480481
private:

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -609,11 +609,13 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
609609

610610
op->mAttentionChunkSize = attention_chunk_size;
611611

612-
TORCH_CHECK(spec_decoding_bool_params.size() == 3,
613-
"Expecting 3 bools for spec-dec mode, is_spec_decoding_enabled, use_spec_decoding, and is_spec_dec_tree.");
612+
TORCH_CHECK(spec_decoding_bool_params.size() == 4,
613+
"Expecting 4 bools for spec-dec mode, is_spec_decoding_enabled, use_spec_decoding, is_spec_dec_tree, and "
614+
"is_eagle3.");
614615
op->mIsSpecDecodingEnabled = spec_decoding_bool_params[0]; // is_spec_decoding_enabled
615616
op->mUseSpecDecoding = spec_decoding_bool_params[1]; // use_spec_decoding
616617
op->mIsSpecDecTree = spec_decoding_bool_params[2]; // is_spec_dec_tree
618+
op->mIsEagle3 = spec_decoding_bool_params[3]; // is_eagle3
617619

618620
if (is_mla_enable)
619621
{

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def plan(
184184
is_spec_decoding_enabled: bool = False,
185185
use_spec_decoding: bool = False,
186186
is_spec_dec_tree: bool = False,
187+
is_eagle3: bool = False,
187188
spec_decoding_position_offsets: Optional[torch.Tensor] = None,
188189
spec_decoding_packed_mask: Optional[torch.Tensor] = None,
189190
spec_decoding_generation_lengths: Optional[torch.Tensor] = None,
@@ -271,6 +272,7 @@ def plan(
271272
self.is_spec_decoding_enabled = is_spec_decoding_enabled
272273
self.use_spec_decoding = use_spec_decoding
273274
self.is_spec_dec_tree = is_spec_dec_tree
275+
self.is_eagle3 = is_eagle3
274276
self.spec_decoding_position_offsets = spec_decoding_position_offsets
275277
self.spec_decoding_packed_mask = spec_decoding_packed_mask
276278
self.spec_decoding_generation_lengths = spec_decoding_generation_lengths
@@ -414,7 +416,7 @@ def run(
414416
]
415417
spec_decoding_bool_params = [
416418
self.is_spec_decoding_enabled, self.use_spec_decoding,
417-
self.is_spec_dec_tree
419+
self.is_spec_dec_tree, self.is_eagle3
418420
]
419421
spec_decoding_tensor_params = [
420422
self.spec_decoding_generation_lengths,
@@ -1237,6 +1239,8 @@ def forward(
12371239
# Context MLA uses separate qkv instead of paged_context_fmha
12381240
use_paged_context_fmha = False
12391241

1242+
is_eagle3 = kwargs.get("is_eagle3", False)
1243+
12401244
use_nvfp4_output = False
12411245
if enable_attn_nvfp4_output and self.has_nvfp4 and self.support_nvfp4_output(
12421246
):
@@ -1287,6 +1291,7 @@ def forward(
12871291
is_spec_decoding_enabled=metadata.is_spec_decoding_enabled,
12881292
use_spec_decoding=metadata.use_spec_decoding,
12891293
is_spec_dec_tree=metadata.is_spec_dec_tree,
1294+
is_eagle3=is_eagle3,
12901295
spec_decoding_position_offsets=metadata.
12911296
spec_decoding_position_offsets,
12921297
spec_decoding_packed_mask=metadata.spec_decoding_packed_mask,

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
dtype=config.torch_dtype,
4646
config=model_config,
4747
)
48+
self.is_eagle3 = True
4849

4950
tp_size = model_config.mapping.tp_size
5051
# Override the QKV projection. The number of input features

tensorrt_llm/_torch/modules/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,8 @@ def _attn_impl(
396396
enable_attn_nvfp4_output=enable_attn_nvfp4_output,
397397
output=output[:num_tokens, :] if output is not None else None,
398398
output_sf=output_sf,
399-
attention_sinks=attention_sinks)
399+
attention_sinks=attention_sinks,
400+
is_eagle3=getattr(self, "is_eagle3", False))
400401
if isinstance(attn_output, tuple):
401402
assert len(
402403
attn_output

tests/unittest/_torch/speculative/test_eagle3.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,34 +17,35 @@
1717

1818

1919
@pytest.mark.parametrize(
20-
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter",
20+
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,fp8_target",
2121
[
22-
[True, "TRTLLM", True, False, False, False, True],
23-
[True, "TRTLLM", True, False, False, False, False],
24-
[False, "TRTLLM", True, False, False, False, True],
25-
[False, "TRTLLM", True, False, False, False, False],
26-
[True, "FLASHINFER", True, False, False, False, True],
27-
[False, "FLASHINFER", True, False, False, False, True],
28-
[False, "TRTLLM", False, True, True, False, True],
29-
[True, "TRTLLM", False, True, True, False, True],
30-
[True, "TRTLLM", True, False, True, True, True],
31-
[True, "TRTLLM", True, False, True, False, True],
22+
[True, "TRTLLM", True, False, False, False, True, False],
23+
[True, "TRTLLM", True, False, False, False, False, False],
24+
[False, "TRTLLM", True, False, False, False, True, False],
25+
[False, "TRTLLM", True, False, False, False, False, False],
26+
[True, "FLASHINFER", True, False, False, False, True, False],
27+
[False, "FLASHINFER", True, False, False, False, True, False],
28+
[False, "TRTLLM", False, True, True, False, True, False],
29+
[True, "TRTLLM", False, True, True, False, True, False],
30+
[True, "TRTLLM", True, False, True, True, True, False],
31+
[True, "TRTLLM", True, False, True, False, True, False],
3232
# TODO: nvbugs/5461761
33-
# [True, "TRTLLM", True, False, False, True, True],
34-
[True, "TRTLLM", False, False, False, False, True],
35-
[False, "TRTLLM", False, False, False, False, True],
36-
[True, "TRTLLM", False, False, False, False, False],
37-
[False, "TRTLLM", False, False, False, False, False],
38-
[True, "TRTLLM", False, False, False, True, True],
39-
[True, "TRTLLM", False, False, False, True, False],
40-
[True, "FLASHINFER", False, False, False, False, True],
41-
[False, "FLASHINFER", False, False, False, False, True],
33+
# [True, "TRTLLM", True, False, False, True, True, False],
34+
[True, "TRTLLM", False, False, False, False, True, False],
35+
[False, "TRTLLM", False, False, False, False, True, False],
36+
[True, "TRTLLM", False, False, False, False, False, False],
37+
[False, "TRTLLM", False, False, False, False, False, False],
38+
[True, "TRTLLM", False, False, False, True, True, False],
39+
[True, "TRTLLM", False, False, False, True, False, False],
40+
[True, "FLASHINFER", False, False, False, False, True, False],
41+
[False, "FLASHINFER", False, False, False, False, True, False],
42+
[True, "TRTLLM", True, True, True, True, True, True],
4243
])
4344
@pytest.mark.high_cuda_memory
4445
def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
4546
disable_overlap_scheduler: bool, enable_block_reuse: bool,
4647
use_one_model: bool, enable_chunked_prefill: bool,
47-
use_chain_drafter: bool):
48+
use_chain_drafter: bool, fp8_target: bool):
4849
# Eagle3 one model works with overlap scheduler and block reuse.
4950
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
5051
if total_mem_gb < 35:
@@ -53,13 +54,18 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
5354
models_path = llm_models_root()
5455
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
5556
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
57+
kv_cache_dtype = 'auto'
58+
if fp8_target:
59+
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8"
60+
kv_cache_dtype = 'fp8'
5661

5762
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
5863
# that ref and spec does not match 100%
5964
max_batch_size = 1
6065
max_draft_len = 4
6166
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
62-
max_tokens=8192)
67+
max_tokens=8192,
68+
dtype=kv_cache_dtype)
6369
cuda_graph_config = CudaGraphConfig(
6470
batch_sizes=[1]) if use_cuda_graph else None
6571

0 commit comments

Comments
 (0)