Skip to content

Commit c045636

Browse files
authored
[TRTLLM-7735][feat] Attention NVFP4 out support for torch compile (#9740)
Signed-off-by: Jin Li <[email protected]>
1 parent d70aedd commit c045636

File tree

17 files changed

+328
-267
lines changed

17 files changed

+328
-267
lines changed

cpp/tensorrt_llm/nanobind/thop/bindings.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ void initBindings(nb::module_& m)
3939
m.def("attention", &torch_ext::attention,
4040
// Parameters with default values using std::nullopt for optional arguments
4141
nb::arg("q"), nb::arg("k") = std::nullopt, nb::arg("v") = std::nullopt, nb::arg("output"),
42-
nb::arg("output_sf") = std::nullopt, nb::arg("out_dtype") = std::nullopt, nb::arg("workspace_") = std::nullopt,
43-
nb::arg("sequence_length"), nb::arg("host_past_key_value_lengths"), nb::arg("host_total_kv_lens"),
44-
nb::arg("context_lengths"), nb::arg("host_context_lengths"), nb::arg("host_request_types"),
42+
nb::arg("output_sf") = std::nullopt, nb::arg("workspace_") = std::nullopt, nb::arg("sequence_length"),
43+
nb::arg("host_past_key_value_lengths"), nb::arg("host_total_kv_lens"), nb::arg("context_lengths"),
44+
nb::arg("host_context_lengths"), nb::arg("host_request_types"),
4545
nb::arg("kv_cache_block_offsets") = std::nullopt, nb::arg("host_kv_cache_block_offsets") = std::nullopt,
4646
nb::arg("host_kv_cache_pool_pointers") = std::nullopt, nb::arg("host_kv_cache_pool_mapping") = std::nullopt,
4747
nb::arg("cache_indirection") = std::nullopt, nb::arg("kv_scale_orig_quant") = std::nullopt,

cpp/tensorrt_llm/pybind/thop/bindings.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ void initBindings(pybind11::module_& m)
3939
m.def("attention", &torch_ext::attention,
4040
// Parameters with default values using std::nullopt for optional arguments
4141
py::arg("q"), py::arg("k") = std::nullopt, py::arg("v") = std::nullopt, py::arg("output"),
42-
py::arg("output_sf") = std::nullopt, py::arg("out_dtype") = std::nullopt, py::arg("workspace_") = std::nullopt,
43-
py::arg("sequence_length"), py::arg("host_past_key_value_lengths"), py::arg("host_total_kv_lens"),
44-
py::arg("context_lengths"), py::arg("host_context_lengths"), py::arg("host_request_types"),
42+
py::arg("output_sf") = std::nullopt, py::arg("workspace_") = std::nullopt, py::arg("sequence_length"),
43+
py::arg("host_past_key_value_lengths"), py::arg("host_total_kv_lens"), py::arg("context_lengths"),
44+
py::arg("host_context_lengths"), py::arg("host_request_types"),
4545
py::arg("kv_cache_block_offsets") = std::nullopt, py::arg("host_kv_cache_block_offsets") = std::nullopt,
4646
py::arg("host_kv_cache_pool_pointers") = std::nullopt, py::arg("host_kv_cache_pool_mapping") = std::nullopt,
4747
py::arg("cache_indirection") = std::nullopt, py::arg("kv_scale_orig_quant") = std::nullopt,

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -603,23 +603,22 @@ using torch_ext::trtllm::attention::Runner;
603603
using torch_ext::trtllm::attention::AttentionInputType;
604604

605605
void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<torch::Tensor> v, torch::Tensor& output,
606-
std::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype,
607-
std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths,
608-
torch::Tensor host_total_kv_lens, torch::Tensor context_lengths, torch::Tensor host_context_lengths,
609-
torch::Tensor host_request_types, std::optional<torch::Tensor> kv_cache_block_offsets,
610-
std::optional<torch::Tensor> host_kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_pool_pointers,
611-
std::optional<torch::Tensor> host_kv_cache_pool_mapping, std::optional<torch::Tensor> cache_indirection,
612-
std::optional<torch::Tensor> kv_scale_orig_quant, std::optional<torch::Tensor> kv_scale_quant_orig,
613-
std::optional<torch::Tensor> out_scale, std::optional<torch::Tensor> rotary_inv_freq,
614-
std::optional<torch::Tensor> rotary_cos_sin, std::optional<torch::Tensor> latent_cache,
615-
std::optional<torch::Tensor> q_pe, std::optional<torch::Tensor> block_ids_per_seq,
616-
std::optional<torch::Tensor> attention_sinks, bool const is_fused_qkv, bool const update_kv_cache,
617-
int64_t const predicted_tokens_per_seq, int64_t const layer_idx, int64_t const num_heads,
618-
int64_t const num_kv_heads, int64_t const head_size, std::optional<int64_t> const tokens_per_block,
619-
int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size,
620-
int64_t const sink_token_length, int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode,
621-
double const q_scaling, int64_t const position_embedding_type, int64_t const rotary_embedding_dim,
622-
double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
606+
std::optional<torch::Tensor> output_sf, std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length,
607+
torch::Tensor host_past_key_value_lengths, torch::Tensor host_total_kv_lens, torch::Tensor context_lengths,
608+
torch::Tensor host_context_lengths, torch::Tensor host_request_types,
609+
std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_block_offsets,
610+
std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
611+
std::optional<torch::Tensor> cache_indirection, std::optional<torch::Tensor> kv_scale_orig_quant,
612+
std::optional<torch::Tensor> kv_scale_quant_orig, std::optional<torch::Tensor> out_scale,
613+
std::optional<torch::Tensor> rotary_inv_freq, std::optional<torch::Tensor> rotary_cos_sin,
614+
std::optional<torch::Tensor> latent_cache, std::optional<torch::Tensor> q_pe,
615+
std::optional<torch::Tensor> block_ids_per_seq, std::optional<torch::Tensor> attention_sinks,
616+
bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq,
617+
int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size,
618+
std::optional<int64_t> const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length,
619+
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
620+
int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type,
621+
int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
623622
std::vector<double> rotary_embedding_scales, std::vector<int64_t> rotary_embedding_max_position_info,
624623
bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable,
625624
std::optional<int64_t> chunked_prefill_buffer_batch_size, std::optional<int64_t> q_lora_rank,
@@ -658,8 +657,10 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
658657
}
659658

660659
auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(qkv_or_q.scalar_type());
661-
bool const is_fp8_out = out_dtype.has_value() && out_dtype.value() == torch::kFloat8_e4m3fn;
662-
bool const is_fp4_out = out_dtype.has_value() && out_dtype.value() == torch::kUInt8;
660+
auto const out_dtype = output.scalar_type();
661+
bool const is_fp8_out = out_dtype == torch::kFloat8_e4m3fn;
662+
// Torch does not support native nvfp4 type.
663+
bool const is_fp4_out = out_dtype == torch::kUInt8;
663664

664665
RunnerPtr runner;
665666
if (dtype == nvinfer1::DataType::kHALF)
@@ -674,13 +675,13 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
674675
}
675676
else
676677
{
677-
TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kFloat16);
678+
TLLM_CHECK(out_dtype == torch::kFloat16);
678679
runner = std::make_shared<Runner<half>>();
679680
}
680681
}
681682
else if (dtype == nvinfer1::DataType::kFLOAT)
682683
{
683-
TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kFloat32);
684+
TLLM_CHECK(out_dtype == torch::kFloat32);
684685
runner = std::make_shared<Runner<float>>();
685686
}
686687
#ifdef ENABLE_BF16
@@ -696,7 +697,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
696697
}
697698
else
698699
{
699-
TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kBFloat16);
700+
TLLM_CHECK(out_dtype == torch::kBFloat16);
700701
runner = std::make_shared<Runner<__nv_bfloat16>>();
701702
}
702703
}

cpp/tensorrt_llm/thop/attentionOp.h

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,22 @@ namespace torch_ext
3939
* - Speculative decoding
4040
*/
4141
void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<torch::Tensor> v, torch::Tensor& output,
42-
std::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype,
43-
std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths,
44-
torch::Tensor host_total_kv_lens, torch::Tensor context_lengths, torch::Tensor host_context_lengths,
45-
torch::Tensor host_request_types, std::optional<torch::Tensor> kv_cache_block_offsets,
46-
std::optional<torch::Tensor> host_kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_pool_pointers,
47-
std::optional<torch::Tensor> host_kv_cache_pool_mapping, std::optional<torch::Tensor> cache_indirection,
48-
std::optional<torch::Tensor> kv_scale_orig_quant, std::optional<torch::Tensor> kv_scale_quant_orig,
49-
std::optional<torch::Tensor> out_scale, std::optional<torch::Tensor> rotary_inv_freq,
50-
std::optional<torch::Tensor> rotary_cos_sin, std::optional<torch::Tensor> latent_cache,
51-
std::optional<torch::Tensor> q_pe, std::optional<torch::Tensor> block_ids_per_seq,
52-
std::optional<torch::Tensor> attention_sinks, bool const is_fused_qkv, bool const update_kv_cache,
53-
int64_t const predicted_tokens_per_seq, int64_t const layer_idx, int64_t const num_heads,
54-
int64_t const num_kv_heads, int64_t const head_size, std::optional<int64_t> const tokens_per_block,
55-
int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size,
56-
int64_t const sink_token_length, int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode,
57-
double const q_scaling, int64_t const position_embedding_type, int64_t const rotary_embedding_dim,
58-
double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
42+
std::optional<torch::Tensor> output_sf, std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length,
43+
torch::Tensor host_past_key_value_lengths, torch::Tensor host_total_kv_lens, torch::Tensor context_lengths,
44+
torch::Tensor host_context_lengths, torch::Tensor host_request_types,
45+
std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_block_offsets,
46+
std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
47+
std::optional<torch::Tensor> cache_indirection, std::optional<torch::Tensor> kv_scale_orig_quant,
48+
std::optional<torch::Tensor> kv_scale_quant_orig, std::optional<torch::Tensor> out_scale,
49+
std::optional<torch::Tensor> rotary_inv_freq, std::optional<torch::Tensor> rotary_cos_sin,
50+
std::optional<torch::Tensor> latent_cache, std::optional<torch::Tensor> q_pe,
51+
std::optional<torch::Tensor> block_ids_per_seq, std::optional<torch::Tensor> attention_sinks,
52+
bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq,
53+
int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size,
54+
std::optional<int64_t> const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length,
55+
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
56+
int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type,
57+
int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
5958
std::vector<double> rotary_embedding_scales, std::vector<int64_t> rotary_embedding_max_position_info,
6059
bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable,
6160
std::optional<int64_t> chunked_prefill_buffer_batch_size, std::optional<int64_t> q_lora_rank,

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -705,9 +705,13 @@ def support_fused_qkv(cls) -> bool:
705705
def support_mla(cls) -> bool:
706706
return False
707707

708-
@classmethod
709-
def support_nvfp4_output(cls) -> bool:
710-
return False
708+
def create_output(self, q: torch.Tensor, **kwargs) -> List[torch.Tensor]:
709+
"""
710+
Create the output tensors for the attention operation.
711+
"""
712+
num_tokens = q.shape[0]
713+
hidden_size = self.num_heads * self.head_dim
714+
return [q.new_empty([num_tokens, hidden_size], dtype=q.dtype)]
711715

712716

713717
@dataclass(kw_only=True, unsafe_hash=True)

0 commit comments

Comments
 (0)