@@ -603,23 +603,22 @@ using torch_ext::trtllm::attention::Runner;
603603using torch_ext::trtllm::attention::AttentionInputType;
604604
605605void attention (torch::Tensor q, std::optional<torch::Tensor> k, std::optional<torch::Tensor> v, torch::Tensor& output,
606- std::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype,
607- std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths,
608- torch::Tensor host_total_kv_lens, torch::Tensor context_lengths, torch::Tensor host_context_lengths,
609- torch::Tensor host_request_types, std::optional<torch::Tensor> kv_cache_block_offsets,
610- std::optional<torch::Tensor> host_kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_pool_pointers,
611- std::optional<torch::Tensor> host_kv_cache_pool_mapping, std::optional<torch::Tensor> cache_indirection,
612- std::optional<torch::Tensor> kv_scale_orig_quant, std::optional<torch::Tensor> kv_scale_quant_orig,
613- std::optional<torch::Tensor> out_scale, std::optional<torch::Tensor> rotary_inv_freq,
614- std::optional<torch::Tensor> rotary_cos_sin, std::optional<torch::Tensor> latent_cache,
615- std::optional<torch::Tensor> q_pe, std::optional<torch::Tensor> block_ids_per_seq,
616- std::optional<torch::Tensor> attention_sinks, bool const is_fused_qkv, bool const update_kv_cache,
617- int64_t const predicted_tokens_per_seq, int64_t const layer_idx, int64_t const num_heads,
618- int64_t const num_kv_heads, int64_t const head_size, std::optional<int64_t > const tokens_per_block,
619- int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size,
620- int64_t const sink_token_length, int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode,
621- double const q_scaling, int64_t const position_embedding_type, int64_t const rotary_embedding_dim,
622- double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
606+ std::optional<torch::Tensor> output_sf, std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length,
607+ torch::Tensor host_past_key_value_lengths, torch::Tensor host_total_kv_lens, torch::Tensor context_lengths,
608+ torch::Tensor host_context_lengths, torch::Tensor host_request_types,
609+ std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_block_offsets,
610+ std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
611+ std::optional<torch::Tensor> cache_indirection, std::optional<torch::Tensor> kv_scale_orig_quant,
612+ std::optional<torch::Tensor> kv_scale_quant_orig, std::optional<torch::Tensor> out_scale,
613+ std::optional<torch::Tensor> rotary_inv_freq, std::optional<torch::Tensor> rotary_cos_sin,
614+ std::optional<torch::Tensor> latent_cache, std::optional<torch::Tensor> q_pe,
615+ std::optional<torch::Tensor> block_ids_per_seq, std::optional<torch::Tensor> attention_sinks,
616+ bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq,
617+ int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size,
618+ std::optional<int64_t > const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length,
619+ int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
620+ int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type,
621+ int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
623622 std::vector<double > rotary_embedding_scales, std::vector<int64_t > rotary_embedding_max_position_info,
624623 bool const use_paged_context_fmha, std::optional<int64_t > attention_input_type, bool is_mla_enable,
625624 std::optional<int64_t > chunked_prefill_buffer_batch_size, std::optional<int64_t > q_lora_rank,
@@ -658,8 +657,10 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
658657 }
659658
660659 auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType (qkv_or_q.scalar_type ());
661- bool const is_fp8_out = out_dtype.has_value () && out_dtype.value () == torch::kFloat8_e4m3fn ;
662- bool const is_fp4_out = out_dtype.has_value () && out_dtype.value () == torch::kUInt8 ;
660+ auto const out_dtype = output.scalar_type ();
661+ bool const is_fp8_out = out_dtype == torch::kFloat8_e4m3fn ;
662+ // Torch does not support native nvfp4 type.
663+ bool const is_fp4_out = out_dtype == torch::kUInt8 ;
663664
664665 RunnerPtr runner;
665666 if (dtype == nvinfer1::DataType::kHALF )
@@ -674,13 +675,13 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
674675 }
675676 else
676677 {
677- TLLM_CHECK (! out_dtype. has_value () || out_dtype. value () == torch::kFloat16 );
678+ TLLM_CHECK (out_dtype == torch::kFloat16 );
678679 runner = std::make_shared<Runner<half>>();
679680 }
680681 }
681682 else if (dtype == nvinfer1::DataType::kFLOAT )
682683 {
683- TLLM_CHECK (! out_dtype. has_value () || out_dtype. value () == torch::kFloat32 );
684+ TLLM_CHECK (out_dtype == torch::kFloat32 );
684685 runner = std::make_shared<Runner<float >>();
685686 }
686687#ifdef ENABLE_BF16
@@ -696,7 +697,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
696697 }
697698 else
698699 {
699- TLLM_CHECK (! out_dtype. has_value () || out_dtype. value () == torch::kBFloat16 );
700+ TLLM_CHECK (out_dtype == torch::kBFloat16 );
700701 runner = std::make_shared<Runner<__nv_bfloat16>>();
701702 }
702703 }
0 commit comments