diff --git a/xllm/core/framework/batch/batch_input_builder.cpp b/xllm/core/framework/batch/batch_input_builder.cpp index 2ff34176..cf92c9e5 100644 --- a/xllm/core/framework/batch/batch_input_builder.cpp +++ b/xllm/core/framework/batch/batch_input_builder.cpp @@ -89,8 +89,8 @@ ForwardInput BatchInputBuilder::build_forward_input( uint32_t num_decoding_tokens, uint32_t min_decoding_batch_size) { process_sequences(0, static_cast(num_sequences_)); + process_batch_forward_type(); padding_decode_batch_size(num_decoding_tokens, min_decoding_batch_size); - return state_to_forward_input(); } @@ -102,6 +102,7 @@ RawForwardInput BatchInputBuilder::build_raw_forward_input(uint32_t start_idx, } else { process_sequences_multithreaded(start_idx, end_idx); } + process_batch_forward_type(); return state_to_raw_forward_input(); } @@ -207,7 +208,6 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx, state_.unique_token_lens_vec.insert(state_.unique_token_lens_vec.end(), state.unique_token_lens_vec.begin(), state.unique_token_lens_vec.end()); - state_.empty_kv_cache = state_.empty_kv_cache && state.empty_kv_cache; state_.max_seq_len = std::max(state_.max_seq_len, state.max_seq_len); state_.q_max_seq_len = std::max(state_.q_max_seq_len, state.q_max_seq_len); #if defined(USE_NPU) @@ -282,7 +282,6 @@ void BatchInputBuilder::process_single_sequence( << allowed_max_tokens_[seq_index]; // Update state - state.empty_kv_cache = state.empty_kv_cache && (n_kv_cache_tokens == 0); state.max_seq_len = std::max(state.max_seq_len, seq_len); state.q_max_seq_len = std::max(state.q_max_seq_len, q_seq_len); #if defined(USE_NPU) @@ -496,12 +495,7 @@ void BatchInputBuilder::padding_decode_batch_size( uint32_t min_decoding_batch_size) { if (num_sequences_ < min_decoding_batch_size) { const uint32_t n_tokens = state_.flatten_tokens_vec.size(); - // kv_cache is not empty in decoding phase - const bool in_decoding_phase = !state_.empty_kv_cache; - const bool same_num_decoding_tokens = - state_.q_max_seq_len == num_decoding_tokens && - n_tokens == num_sequences_ * num_decoding_tokens; - if (in_decoding_phase && same_num_decoding_tokens) { + if (state_.batch_forward_type.is_decode()) { // add padding tokens to the batch for (int32_t i = num_sequences_; i < min_decoding_batch_size; ++i) { for (int32_t k = 0; k < num_decoding_tokens; ++k) { @@ -547,7 +541,7 @@ ForwardInput BatchInputBuilder::state_to_forward_input() { } auto& input_params = forward_input.input_params; - input_params.empty_kv_cache = state_.empty_kv_cache; + input_params.batch_forward_type = state_.batch_forward_type; input_params.num_sequences = state_.block_tables_vec.size(); input_params.kv_max_seq_len = state_.max_seq_len; input_params.q_max_seq_len = state_.q_max_seq_len; @@ -557,8 +551,6 @@ ForwardInput BatchInputBuilder::state_to_forward_input() { input_params.q_seq_lens_vec = std::move(state_.q_seq_lens); input_params.new_cache_slots = torch::tensor(state_.new_token_slot_ids, torch::kInt); - input_params.decode_seq_range = - util::find_ones_indices(input_params.q_seq_lens_vec); // Setup multimodal data input_params.mm_data = MMData::batch(mm_data_vec_); @@ -621,14 +613,13 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() { std::move(state_.unique_token_counts_vec); raw_forward_input.unique_token_lens_vec = std::move(state_.unique_token_lens_vec); - raw_forward_input.empty_kv_cache = state_.empty_kv_cache; - // raw_forward_input.global_empty_kv_cache = ; raw_forward_input.max_seq_len = state_.max_seq_len; raw_forward_input.q_max_seq_len = state_.q_max_seq_len; raw_forward_input.seq_lens = std::move(state_.seq_lens); raw_forward_input.q_seq_lens = std::move(state_.q_seq_lens); raw_forward_input.new_token_slot_ids = std::move(state_.new_token_slot_ids); raw_forward_input.block_tables_vec = std::move(state_.block_tables_vec); + raw_forward_input.batch_forward_type = std::move(state_.batch_forward_type); raw_forward_input.num_sequences = num_sequences_; // raw_forward_input.dp_global_token_nums = ; raw_forward_input.transfer_kv_infos = std::move(state_.transfer_kv_infos); @@ -726,4 +717,69 @@ void BatchInputBuilder::process_swap_block_infos( swap_cache_block_infos_->end()); } } + +void BatchInputBuilder::process_batch_forward_type() { + CHECK_EQ(state_.seq_lens.size(), state_.q_seq_lens.size()) + << "seq_lens size must be equal to q_seq_lens size"; + + if (state_.q_max_seq_len == 1) { + state_.batch_forward_type = BatchForwardType::DECODE; + return; + } + + bool empty_kv_cache = true; + bool all_decode = true; + bool all_prefill = true; + +#if defined(USE_NPU) + if (state_.seq_lens.size() == 0) { + state_.batch_forward_type = BatchForwardType::IDLE; + return; + } + for (size_t i = 0; i < state_.seq_lens.size(); ++i) { + auto q_len = state_.q_seq_lens[i]; + auto kv_len = state_.seq_lens[i]; + auto cache_len = kv_len - q_len; + if (cache_len > 0) { + empty_kv_cache = false; + } + if (q_len > 1) { + all_decode = false; + } + if (q_len == 1) { + all_prefill = false; + } + } +#elif defined(USE_MLU) + if (state_.seq_lens.size() == 1) { + state_.batch_forward_type = BatchForwardType::IDLE; + return; + } + for (size_t i = 1; i < state_.seq_lens.size(); ++i) { + auto q_len = state_.q_seq_lens[i] - state_.q_seq_lens[i - 1]; + auto kv_len = state_.seq_lens[i] - state_.seq_lens[i - 1]; + auto cache_len = kv_len - q_len; + if (cache_len > 0) { + empty_kv_cache = false; + } + if (q_len > 1) { + all_decode = false; + } + if (q_len == 1) { + all_prefill = false; + } + } +#endif + if (empty_kv_cache) { + state_.batch_forward_type = BatchForwardType::PREFILL; + } else { + if (all_prefill) { + state_.batch_forward_type = BatchForwardType::CHUNKED_PREFILL; + } else if (all_decode) { + state_.batch_forward_type = BatchForwardType::DECODE; + } else { + state_.batch_forward_type = BatchForwardType::MIXED; + } + } +} } // namespace xllm diff --git a/xllm/core/framework/batch/batch_input_builder.h b/xllm/core/framework/batch/batch_input_builder.h index 9b76bfb1..a301dd65 100644 --- a/xllm/core/framework/batch/batch_input_builder.h +++ b/xllm/core/framework/batch/batch_input_builder.h @@ -62,6 +62,8 @@ class BatchInputBuilder { void process_swap_block_infos(RawForwardInput& raw_forward_input); + void process_batch_forward_type(); + // State management struct BuilderState { // Token and position data @@ -80,7 +82,7 @@ class BatchInputBuilder { std::vector unique_token_lens_vec; // Sequence metadata - bool empty_kv_cache = true; + BatchForwardType batch_forward_type; uint32_t max_seq_len = 0; uint32_t q_max_seq_len = 0; #if defined(USE_NPU) diff --git a/xllm/core/framework/batch/batch_test.cpp b/xllm/core/framework/batch/batch_test.cpp index b79f7b6d..083d427f 100644 --- a/xllm/core/framework/batch/batch_test.cpp +++ b/xllm/core/framework/batch/batch_test.cpp @@ -145,7 +145,7 @@ TEST(BatchTest, Basic) { // check the input parameters const ModelInputParams& input_params = forward_input.input_params; - EXPECT_FALSE(input_params.empty_kv_cache); + EXPECT_TRUE(input_params.batch_forward_type.is_mixed()); EXPECT_EQ(input_params.num_sequences, 4); EXPECT_EQ(input_params.q_max_seq_len, 9); EXPECT_EQ(input_params.kv_max_seq_len, 16); diff --git a/xllm/core/framework/model/CMakeLists.txt b/xllm/core/framework/model/CMakeLists.txt index 9bdd452d..9559ec60 100644 --- a/xllm/core/framework/model/CMakeLists.txt +++ b/xllm/core/framework/model/CMakeLists.txt @@ -33,6 +33,7 @@ cc_library( embedding_lm.h model_args.h npu_dp_ep_padding.h + batch_forward_type.h model_input_params.h SRCS npu_dp_ep_padding.cpp diff --git a/xllm/core/framework/model/batch_forward_type.h b/xllm/core/framework/model/batch_forward_type.h new file mode 100644 index 00000000..1197ee0a --- /dev/null +++ b/xllm/core/framework/model/batch_forward_type.h @@ -0,0 +1,81 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. +Copyright 2024 The ScaleLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +namespace xllm { + +class BatchForwardType { + public: + enum Value : int32_t { + // Prefill without using kv cache. + PREFILL = 0, + // Chunked prefill using kv cache. + // No decode sequence in this type. + CHUNKED_PREFILL = 1, + // Decode one token. + // No prefill sequence in this type. + DECODE = 2, + // Mixed prefill and decode in one batch when doing chunked prefill. + MIXED = 3, + // No sequence to forward. + IDLE = 4, + }; + + BatchForwardType() : value_(IDLE) {} + + BatchForwardType(int32_t v) : value_(static_cast(v)) {} + + constexpr BatchForwardType(Value v) : value_(v) {} + + BatchForwardType& operator=(Value v) { + value_ = v; + return *this; + } + + int32_t value() const { return value_; } + + bool is_prefill() const { return (value_ == PREFILL); } + + bool is_chunked_prefill() const { return (value_ == CHUNKED_PREFILL); } + + bool is_decode() const { return (value_ == DECODE); } + + bool is_mixed() const { return (value_ == MIXED); } + + bool is_idle() const { return (value_ == IDLE); } + + const char* to_string() const { + switch (value_) { + case PREFILL: + return "PREFILL"; + case CHUNKED_PREFILL: + return "CHUNKED_PREFILL"; + case DECODE: + return "DECODE"; + case MIXED: + return "MIXED"; + case IDLE: + return "IDLE"; + default: + return "UNKNOWN"; + } + } + + private: + Value value_; +}; +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index aaaae36d..60b472f4 100644 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -21,6 +21,7 @@ limitations under the License. #if defined(USE_NPU) #include "platform/npu/npu_layer_synchronizer.h" #endif +#include "framework/model/batch_forward_type.h" #include "framework/request/mm_data.h" #include "npu_dp_ep_padding.h" #include "util/tensor_helper.h" @@ -50,8 +51,7 @@ struct CacheBlockInfo { struct ModelInputParams { ModelInputParams to(const torch::Device& device) const { ModelInputParams params; - params.empty_kv_cache = empty_kv_cache; - params.global_empty_kv_cache = global_empty_kv_cache; + params.batch_forward_type = batch_forward_type; params.num_sequences = num_sequences; params.kv_max_seq_len = kv_max_seq_len; params.q_max_seq_len = q_max_seq_len; @@ -63,7 +63,6 @@ struct ModelInputParams { params.block_tables = safe_to(block_tables, device, true); params.kv_seq_lens_vec = kv_seq_lens_vec; params.q_seq_lens_vec = q_seq_lens_vec; - params.decode_seq_range = decode_seq_range; params.input_embedding = safe_to(input_embedding, device); @@ -98,15 +97,13 @@ struct ModelInputParams { } void print() const { - LOG(INFO) << "ModelInputParams: empty_kv_cache is " << empty_kv_cache - << " , global_empty_kv_cache is " << global_empty_kv_cache - << " , num_sequences is " << num_sequences - << " , kv_max_seq_len is " << kv_max_seq_len + LOG(INFO) << "ModelInputParams: batch_forward_type is " + << batch_forward_type.to_string() << " , num_sequences is " + << num_sequences << " , kv_max_seq_len is " << kv_max_seq_len << " , q_max_seq_len is " << q_max_seq_len << " , prefill_seq_len is " << prefill_seq_len; LOG(INFO) << "ModelInputParams: kv_seq_lens_vec is " << kv_seq_lens_vec; LOG(INFO) << "ModelInputParams: q_seq_lens_vec is " << q_seq_lens_vec; - LOG(INFO) << "ModelInputParams: decode_seq_range is " << decode_seq_range; print_tensor(kv_seq_lens, "ModelInputParams: kv_seq_lens", 4); print_tensor(q_seq_lens, "ModelInputParams: q_seq_lens", 4); print_tensor(new_cache_slots, "ModelInputParams: new_cache_slots", 4); @@ -114,8 +111,8 @@ struct ModelInputParams { LOG(INFO) << "ModelInputParams: dp_global_token_nums is " << dp_global_token_nums; } - // whether the kv-cache is empty for all sequences. - bool empty_kv_cache = true; + // forward type of the batch, used by worker/kernel. + BatchForwardType batch_forward_type; // total number of sequences in the batch int32_t num_sequences = 0; @@ -124,15 +121,7 @@ struct ModelInputParams { torch::Tensor kv_seq_lens; std::vector kv_seq_lens_vec; std::vector q_seq_lens_vec; - // Range of decode sequence indices in the batch [start, end]. - // Decode sequences are identified by q_seq_lens == 1, - // prefill sequences by q_seq_lens > 1 . - // Used to determine whether to use prefill_node_ or - // decode_node_ in NPU layers - // Values: {-1, -1} if no decode requests (all prefill), - // {0, batch_size-1} if all decode requests, - // {start_idx, end_idx} if mixed prefill/decode requests - std::pair decode_seq_range; + // max length for qkv. int32_t kv_max_seq_len = 0; int32_t q_max_seq_len = 0; @@ -151,8 +140,6 @@ struct ModelInputParams { // num tokens of all workers,mainly used for dp case std::vector dp_global_token_nums; - // whether the kv-cache is empty for all sequences,mainly used for dp case - bool global_empty_kv_cache = true; // num of prefill sequence in chunked prefill case uint32_t prefill_seq_len = 0; diff --git a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp index 0ae6cedb..759e0769 100644 --- a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp @@ -1520,8 +1520,8 @@ torch::Tensor NpuDeepseekV2DecoderLayerImpl::forward( int node_id) { atb::Status st; // all micro batches are in same prefill/decode stage, - // so, to judge empty_kv_cache, use input_params[0] here - if (input_params[0].global_empty_kv_cache) { + // deepseek dont support chunked prefill, so only check is_prefill. + if (input_params[0].batch_forward_type.is_prefill()) { build_node_variant_pack(prefill_node_, x, cos_pos, diff --git a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp index 927b9806..20429a83 100644 --- a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp +++ b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp @@ -1085,8 +1085,7 @@ torch::Tensor Glm4MoeDecoderImpl::forward( std::vector*> event_flag, int node_id) { atb::Status st; - if (input_params.decode_seq_range.second != - input_params.q_seq_lens.size(0) - 1) { + if (!input_params.batch_forward_type.is_decode()) { build_node_variant_pack(prefill_node_, x, cos_pos, diff --git a/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp index 9696353c..eb7794c2 100644 --- a/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp @@ -277,8 +277,7 @@ torch::Tensor NpuLlamaDecoderLayerImpl::forward(torch::Tensor& x, int node_id) { atb::Status st; - if (input_params.decode_seq_range.second != - input_params.q_seq_lens.size(0) - 1) { + if (!input_params.batch_forward_type.is_decode()) { build_node_variant_pack(prefill_node_, x, cos_pos, diff --git a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp index 0f788904..92f1c671 100644 --- a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp @@ -405,9 +405,7 @@ torch::Tensor NpuQwen2DecoderLayerImpl::forward( std::vector*> event_flag, int node_id) { atb::Status st; - if (input_params[0].decode_seq_range.second != - input_params[0].q_seq_lens.size(0) - 1) { - // mstxRangeId id = mstxRangeStartA("prefill build variant", nullptr); + if (!input_params[0].batch_forward_type.is_decode()) { build_node_variant_pack(prefill_node_, x[0], cos_pos[0], @@ -416,7 +414,6 @@ torch::Tensor NpuQwen2DecoderLayerImpl::forward( kv_cache, input_params[0], true); - // mstxRangeEnd(id); st = execute_node(prefill_node_, node_id, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ << "excute prefill layer fail, error code: " << st; diff --git a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp index 0067e38a..89887c05 100644 --- a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp @@ -485,10 +485,7 @@ torch::Tensor NpuQwen3DecoderLayerImpl::forward( std::vector*> event_flag, int node_id) { atb::Status st; - if (input_params[0].decode_seq_range.second != - input_params[0].q_seq_lens.size(0) - 1) { - // if (input_params.empty_kv_cache) { - // mstxRangeId id = mstxRangeStartA("prefill build variant", nullptr); + if (!input_params[0].batch_forward_type.is_decode()) { build_node_variant_pack(prefill_node_, x, cos_pos, @@ -497,7 +494,6 @@ torch::Tensor NpuQwen3DecoderLayerImpl::forward( kv_cache, input_params, true); - // mstxRangeEnd(id); st = execute_node(prefill_node_, node_id, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ << "excute prefill layer fail, error code: " << st; diff --git a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp index 3aefc3a3..ed13ad3e 100644 --- a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp @@ -886,7 +886,7 @@ torch::Tensor NpuQwen3MoeDecoderLayerImpl::forward( std::atomic* event_flag, int node_id) { atb::Status st; - if (input_params.global_empty_kv_cache) { + if (input_params.batch_forward_type.is_prefill()) { build_node_variant_pack(prefill_node_, x, cos_pos, diff --git a/xllm/core/runtime/acl_graph_executor_impl.cpp b/xllm/core/runtime/acl_graph_executor_impl.cpp index 4a603945..23988fda 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.cpp +++ b/xllm/core/runtime/acl_graph_executor_impl.cpp @@ -196,16 +196,9 @@ torch::Tensor AclGraphExecutorImpl::run( const torch::Tensor& tokens_tensor = tokens[0]; const torch::Tensor& positions_tensor = positions[0]; const ModelInputParams& params_single = params[0]; - // Identify decode phase using q_max_seq_len for precise detection - // Decode phase: all sequences have q_seq_len == 1 (generating one token at a - // time) Prefill phase: sequences have q_seq_len > 1 (processing multiple - // prompt tokens) We also check empty_kv_cache to ensure KV cache is not empty - // (not first forward pass) - const bool in_decoding_phase = - (params_single.q_max_seq_len == 1) && !params_single.empty_kv_cache; // If not in decode phase, use eager mode directly without acl graph - if (!in_decoding_phase) { + if (!params_single.batch_forward_type.is_decode()) { COUNTER_INC(num_model_execution_total_eager); return model_->forward(tokens, positions, kv_caches, params); } diff --git a/xllm/core/runtime/forward_params.h b/xllm/core/runtime/forward_params.h index dd4a3d8f..e7c3ef11 100644 --- a/xllm/core/runtime/forward_params.h +++ b/xllm/core/runtime/forward_params.h @@ -144,8 +144,7 @@ struct RawForwardInput { std::vector> unique_token_ids_vec; std::vector> unique_token_counts_vec; std::vector unique_token_lens_vec; - bool empty_kv_cache = true; - bool global_empty_kv_cache = true; + BatchForwardType batch_forward_type; uint32_t max_seq_len; uint32_t q_max_seq_len; std::vector seq_lens; diff --git a/xllm/core/runtime/forward_shared_memory_manager.cpp b/xllm/core/runtime/forward_shared_memory_manager.cpp index cd3ee268..186c6060 100644 --- a/xllm/core/runtime/forward_shared_memory_manager.cpp +++ b/xllm/core/runtime/forward_shared_memory_manager.cpp @@ -144,7 +144,7 @@ INLINE size_t calculate_raw_forward_input_size(const RawForwardInput& input) { total += type_size * 4 + cache_block_size * cache_block_info_fixed_size(); - total += type_size * 2 // empty_kv_cache + global_empty_kv_cache + total += type_size // batch_forward_type + type_size * 3 // max_seq_len + q_max_seq_len + prefill_seq_len + type_size // num_sequences @@ -462,8 +462,9 @@ INLINE void deserialize_raw_forward_input( read_cache_blocks(buffer, input.copy_in_blocks); read_cache_blocks(buffer, input.swap_blocks); - read_data(buffer, input.empty_kv_cache); - read_data(buffer, input.global_empty_kv_cache); + int32_t batch_forward_type; + read_data(buffer, batch_forward_type); + input.batch_forward_type = BatchForwardType(batch_forward_type); read_data(buffer, input.max_seq_len); read_data(buffer, input.q_max_seq_len); read_data(buffer, input.num_sequences); @@ -514,10 +515,8 @@ INLINE void serialize_raw_forward_input(const RawForwardInput& input, write_cache_blocks(buffer, input.copy_in_blocks); write_cache_blocks(buffer, input.swap_blocks); - *reinterpret_cast(buffer) = input.empty_kv_cache; - buffer += 1; - *reinterpret_cast(buffer) = input.global_empty_kv_cache; - buffer += 1; + *reinterpret_cast(buffer) = input.batch_forward_type.value(); + buffer += 4; *reinterpret_cast(buffer) = input.max_seq_len; buffer += 4; *reinterpret_cast(buffer) = input.q_max_seq_len; @@ -707,15 +706,8 @@ void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input, forward_input.positions = torch::tensor(std::move(raw_input.flatten_positions_vec), tensor_options); - std::pair decode_seq_range{0, 0}; -#if defined(USE_NPU) - if (raw_input.q_seq_lens.size() >= 1) { - decode_seq_range = util::find_ones_indices(raw_input.q_seq_lens); - } -#endif auto& input_params = forward_input.input_params; - input_params.empty_kv_cache = raw_input.empty_kv_cache; - input_params.global_empty_kv_cache = raw_input.global_empty_kv_cache; + input_params.batch_forward_type = raw_input.batch_forward_type; input_params.num_sequences = raw_input.num_sequences; input_params.kv_max_seq_len = raw_input.max_seq_len; input_params.q_max_seq_len = raw_input.q_max_seq_len; @@ -732,7 +724,6 @@ void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input, input_params.new_cache_slots = torch::tensor(std::move(raw_input.new_token_slot_ids), tensor_options); - input_params.decode_seq_range = decode_seq_range; util::pad_2d_vector(raw_input.block_tables_vec, 0); input_params.block_tables = create_2d_tensor(std::move(raw_input.block_tables_vec), torch::kInt); diff --git a/xllm/core/runtime/llm_engine.cpp b/xllm/core/runtime/llm_engine.cpp index a9111b2a..4984252c 100644 --- a/xllm/core/runtime/llm_engine.cpp +++ b/xllm/core/runtime/llm_engine.cpp @@ -815,7 +815,9 @@ std::vector> LLMEngine::prepare_inputs( std::vector> dp_global_token_nums; dp_global_token_nums.resize(micro_batches_num, std::vector(dp_size_)); - bool global_empty_kv_cache = true; + + // All idle batches use the first non-idle batch's forward type. + BatchForwardType batch_forward_type; // eplb related EplbInfo eplb_info; @@ -831,8 +833,9 @@ std::vector> LLMEngine::prepare_inputs( split_seq_index[i], split_seq_index[i + 1], threadpool_.get()))); dp_global_token_nums[i][dp_rank] = batched_inputs[dp_rank][i].flatten_tokens_vec.size(); - global_empty_kv_cache = - batched_inputs[dp_rank][i].empty_kv_cache && global_empty_kv_cache; + if (batch_forward_type.is_idle()) { + batch_forward_type = batched_inputs[dp_rank][i].batch_forward_type; + } } } @@ -840,11 +843,15 @@ std::vector> LLMEngine::prepare_inputs( eplb_info = eplb_manager_->get_eplb_info(); } - // update dp_global_token_nums and global_empty_kv_cache + // update dp_global_token_nums and batch_forward_type + CHECK(!batch_forward_type.is_idle()) + << "Forward types of all batches are idle."; for (auto dp_rank = 0; dp_rank < dp_size_; ++dp_rank) { for (auto i = 0; i < micro_batches_num; ++i) { batched_inputs[dp_rank][i].dp_global_token_nums = dp_global_token_nums[i]; - batched_inputs[dp_rank][i].global_empty_kv_cache = global_empty_kv_cache; + if (batched_inputs[dp_rank][i].batch_forward_type.is_idle()) { + batched_inputs[dp_rank][i].batch_forward_type = batch_forward_type; + } if (FLAGS_enable_eplb) { batched_inputs[dp_rank][i].eplb_info = eplb_info; } diff --git a/xllm/core/runtime/llm_worker_impl.cpp b/xllm/core/runtime/llm_worker_impl.cpp index 820bb9cc..a2ac1bb6 100644 --- a/xllm/core/runtime/llm_worker_impl.cpp +++ b/xllm/core/runtime/llm_worker_impl.cpp @@ -181,31 +181,14 @@ std::optional LLMWorkerImpl::step( } // if running in multi_stream_parallel step, all micro batches - // should be in same prefill stage, so, to judge empty_kv_cache, + // should be in same prefill stage, so, to judge forward_type, // just use micro batch 0 here if (options_.enable_speculative_decode() && !is_spec_draft_) { - if (input_params_micro_batches[0].q_seq_lens_vec[0] > 1) { + CHECK_EQ(input_params_micro_batches[0].batch_forward_type.is_mixed(), + false); + if (!input_params_micro_batches[0].batch_forward_type.is_decode()) { output.sample_output.embeddings = hidden_states; } else if (concated_sampling_params.sample_idxes.defined()) { - // auto sample_idxes = - // concated_sampling_params.selected_token_idxes.index_select( - // /*dim=*/0, concated_sampling_params.sample_idxes); - auto embeddings = hidden_states.index_select( - /*dim=*/0, concated_sampling_params.sample_idxes); - output.sample_output.embeddings = embeddings; - } - } - - // if running in multi_stream_parallel step, all micro batches - // should be in same prefill stage, so, to judge empty_kv_cache, - // just use micro batch 0 here - if (options_.enable_speculative_decode() && !is_spec_draft_) { - if (input_params_micro_batches[0].q_seq_lens_vec[0] > 1) { - output.sample_output.embeddings = hidden_states; - } else if (concated_sampling_params.sample_idxes.defined()) { - // auto sample_idxes = - // concated_sampling_params.selected_token_idxes.index_select( - // /*dim=*/0, concated_sampling_params.sample_idxes); auto embeddings = hidden_states.index_select( /*dim=*/0, concated_sampling_params.sample_idxes); output.sample_output.embeddings = embeddings; diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index 428c0c3e..a2565247 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -193,16 +193,10 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, forward_inputs.acc_logprob = torch::tensor( acc_logprob_vec, torch::dtype(torch::kFloat32).device(torch::kCPU).pinned_memory(true)); - std::pair decode_seq_range{0, 0}; -#if defined(USE_NPU) - if (q_seq_lens.size() >= 1) { - decode_seq_range = util::find_ones_indices(q_seq_lens); - } -#endif + auto& input_params = forward_inputs.input_params; - input_params.empty_kv_cache = pb_forward_input->empty_kv_cache(); - input_params.global_empty_kv_cache = - pb_forward_input->global_empty_kv_cache(); + input_params.batch_forward_type = + BatchForwardType(pb_forward_input->batch_forward_type()); input_params.num_sequences = block_tables_vec.size(); assert(input_params.num_sequences == pb_forward_input->num_sequences()); input_params.prefill_seq_len = pb_forward_input->prefill_seq_len(); @@ -215,7 +209,6 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, input_params.new_cache_slots = torch::tensor(new_token_slot_ids, tensor_options); - input_params.decode_seq_range = decode_seq_range; util::pad_2d_vector(block_tables_vec, /*pad_value=*/0); input_params.block_tables = @@ -389,8 +382,7 @@ void forward_input_to_proto(const RawForwardInput& inputs, } ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_unique_token_lens_vec(), inputs.unique_token_lens_vec); - pb_forward_input->set_empty_kv_cache(inputs.empty_kv_cache); - pb_forward_input->set_global_empty_kv_cache(inputs.global_empty_kv_cache); + pb_forward_input->set_batch_forward_type(inputs.batch_forward_type.value()); pb_forward_input->set_max_seq_len(inputs.max_seq_len); pb_forward_input->set_q_max_seq_len(inputs.q_max_seq_len); ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_seq_lens(), inputs.seq_lens); diff --git a/xllm/core/runtime/speculative_worker_impl.cpp b/xllm/core/runtime/speculative_worker_impl.cpp index 8e10f5b8..dfef0b6b 100644 --- a/xllm/core/runtime/speculative_worker_impl.cpp +++ b/xllm/core/runtime/speculative_worker_impl.cpp @@ -168,12 +168,12 @@ std::optional SpeculativeWorkerImpl::step( const BatchedForwardInputs& inputs) { // all micro batches in multi stream parallel share the same // prefill/decode stage, use inputs[0] here - if (inputs.micro_inputs[0].token_ids.numel() == 0) { + if (inputs.micro_inputs[0].input_params.batch_forward_type.is_idle()) { return step_empty(inputs); } // TODO: support data parallel case - if (inputs.micro_inputs[0].input_params.q_seq_lens_vec[0] > 1) { + if (!inputs.micro_inputs[0].input_params.batch_forward_type.is_decode()) { return step_prefill(inputs); } else { return step_decode(inputs); @@ -182,7 +182,7 @@ std::optional SpeculativeWorkerImpl::step( std::optional SpeculativeWorkerImpl::step_empty( const BatchedForwardInputs& inputs) { - if (inputs.micro_inputs[0].input_params.q_seq_lens_vec[0] > 1) { + if (!inputs.micro_inputs[0].input_params.batch_forward_type.is_decode()) { auto output = impl_->step(inputs); auto draft_output = draft_impl_->step(inputs); return output; @@ -614,7 +614,6 @@ void SpeculativeWorkerImpl::prepare_validate_inputs( input_params.block_tables = create_2d_tensor(block_tables_vec, torch::kInt).to(device_); } - input_params.decode_seq_range.second = input_params.num_sequences - 1; // update the sampling_params update_sampling_params( @@ -824,7 +823,7 @@ void SpeculativeWorkerImpl::update_sampling_params( void SpeculativeWorkerImpl::prepare_work_before_execute( const BatchedForwardInputs& inputs, BatchedForwardInputs& processed_inputs) { - if (inputs.micro_inputs[0].input_params.q_seq_lens_vec[0] > 1) { + if (!inputs.micro_inputs[0].input_params.batch_forward_type.is_decode()) { WorkerImpl::prepare_work_before_execute(inputs, processed_inputs); } else { if (enable_schedule_overlap()) { diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index 3fad5fc7..be01a33f 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -433,9 +433,10 @@ void WorkerImpl::prepare_work_before_execute( .device(torch::kCPU) .dtype(torch::kInt32) .pinned_memory(true)); - bool is_prefill = fwd_inputs_on_device.input_params.global_empty_kv_cache - ? true - : false; + bool is_prefill = + fwd_inputs_on_device.input_params.batch_forward_type.is_prefill() + ? true + : false; DpEpPadding dp_ep_padding(token_size_per_dp_group, context_.get_model_args().num_experts_per_tok(), context_.get_parallel_args().mapping_data(), @@ -523,7 +524,8 @@ folly::SemiFuture> WorkerImpl::step_async( } else { for (auto i = 0; i < inputs.micro_inputs.size(); ++i) { if (last_step_output_valid_ && - !inputs.micro_inputs[i].input_params.empty_kv_cache) { + !inputs.micro_inputs[i] + .input_params.batch_forward_type.is_prefill()) { // replace step i model input with true output of step i-1 inputs.micro_inputs[i] = update_input_by_last_step_output(inputs.micro_inputs[i]); diff --git a/xllm/models/llm/deepseek_v2.h b/xllm/models/llm/deepseek_v2.h index 010993a4..3410ec6d 100644 --- a/xllm/models/llm/deepseek_v2.h +++ b/xllm/models/llm/deepseek_v2.h @@ -185,7 +185,7 @@ class DeepseekV2ModelImpl : public torch::nn::Module { torch::Tensor attn_mask; if (num_speculative_tokens_ == 0 || - input_params[i].global_empty_kv_cache) { + input_params[i].batch_forward_type.is_prefill()) { attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); } else { attn_mask = attn_mask_.gen_free_mask( diff --git a/xllm/models/llm/glm4_moe.h b/xllm/models/llm/glm4_moe.h index 79dbefd7..79a6e4c5 100644 --- a/xllm/models/llm/glm4_moe.h +++ b/xllm/models/llm/glm4_moe.h @@ -162,7 +162,8 @@ class Glm4MoeModelImpl : public torch::nn::Module { attn_mask = torch::cat(req_mask_vec, 0); } } else { - if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache) { + if (num_speculative_tokens_ == 0 || + input_params.batch_forward_type.is_prefill()) { attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); } else { attn_mask = attn_mask_.gen_free_mask( diff --git a/xllm/models/llm/glm4_moe_mtp.h b/xllm/models/llm/glm4_moe_mtp.h index 5c005a24..56c7dea5 100644 --- a/xllm/models/llm/glm4_moe_mtp.h +++ b/xllm/models/llm/glm4_moe_mtp.h @@ -132,7 +132,8 @@ class Glm4MoeMtpModelImpl : public torch::nn::Module { attn_mask = torch::cat(req_mask_vec, 0); } } else { - if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache) { + if (num_speculative_tokens_ == 0 || + input_params.batch_forward_type.is_prefill()) { attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); } else { attn_mask = attn_mask_.gen_free_mask( diff --git a/xllm/models/llm/llm_model_base.h b/xllm/models/llm/llm_model_base.h index 7b4212be..fc9ed5fd 100644 --- a/xllm/models/llm/llm_model_base.h +++ b/xllm/models/llm/llm_model_base.h @@ -191,7 +191,7 @@ class LlmModelImplBase : public torch::nn::Module { const_cast&>(input_params); for (auto i = 0; i < micro_batch_num; ++i) { - if (tokens[i].numel() == 0) { + if (input_params[0].batch_forward_type.is_idle()) { tokens[i] = torch::tensor({1}).to(torch::kInt32).to(tokens[0].device()); positions[i] = torch::tensor({0}).to(torch::kInt32).to(tokens[0].device()); diff --git a/xllm/models/llm/qwen3_moe.h b/xllm/models/llm/qwen3_moe.h index 16771fb9..7093f19e 100644 --- a/xllm/models/llm/qwen3_moe.h +++ b/xllm/models/llm/qwen3_moe.h @@ -178,7 +178,8 @@ class Qwen3MoeModelImpl : public torch::nn::Module { auto sin_pos = target_cos_sin_chunks[1].contiguous(); torch::Tensor attn_mask; - if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache) { + if (num_speculative_tokens_ == 0 || + input_params.batch_forward_type.is_prefill()) { attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); } else { attn_mask = attn_mask_.gen_free_mask( diff --git a/xllm/proto/worker.proto b/xllm/proto/worker.proto index 5344a2e8..6142b7a0 100644 --- a/xllm/proto/worker.proto +++ b/xllm/proto/worker.proto @@ -146,6 +146,14 @@ message BlockPair { int32 host_block_id = 2; } +enum BatchForwardType { + PREFILL = 0; + CHUNKED_PREFILL = 1; + DECODE = 2; + MIXED = 3; + IDLE = 4; +} + message ForwardInput { // flatten the token ids and positions repeated int32 flatten_tokens_vec = 1; @@ -160,11 +168,11 @@ message ForwardInput { repeated UniqueTokenIds unique_token_ids_vec = 6; repeated UniqueTokenCounts unique_token_counts_vec = 7; repeated int32 unique_token_lens_vec = 8; - bool empty_kv_cache = 9; + int32 batch_forward_type = 9; uint32 max_seq_len = 10; uint32 q_max_seq_len = 11; - repeated int32 seq_lens = 12; - repeated int32 q_seq_lens = 13; + repeated int32 seq_lens = 12; + repeated int32 q_seq_lens = 13; repeated int32 paged_kv_indptr = 14; repeated int32 paged_kv_indices = 15; repeated int32 paged_kv_last_page_len = 16; @@ -173,26 +181,25 @@ message ForwardInput { repeated BlockTables block_tables_vec = 18; int32 num_sequences = 19; repeated int32 dp_global_token_nums = 20; - bool global_empty_kv_cache = 21; - repeated TransferKVInfo transfer_kv_infos = 22; - repeated Embeddings embeds = 23; - uint32 prefill_seq_len = 24; - repeated int32 embedding_ids = 25; - repeated int32 extra_token_ids = 26; - EplbInfo eplb_info =27; - repeated CacheBlockInfo async_copy_out_blocks = 28; - repeated CacheBlockInfo copy_out_blocks = 29; - repeated CacheBlockInfo copy_in_blocks = 30; - repeated CacheBlockInfo swap_blocks = 31; + repeated TransferKVInfo transfer_kv_infos = 21; + repeated Embeddings embeds = 22; + uint32 prefill_seq_len = 23; + repeated int32 embedding_ids = 24; + repeated int32 extra_token_ids = 25; + EplbInfo eplb_info = 26; + repeated CacheBlockInfo async_copy_out_blocks = 27; + repeated CacheBlockInfo copy_out_blocks = 28; + repeated CacheBlockInfo copy_in_blocks = 29; + repeated CacheBlockInfo swap_blocks = 30; // block copy kernel - repeated int32 src_block_indices = 32; - repeated int32 dst_block_indices = 33; - repeated int32 cum_sum = 34; + repeated int32 src_block_indices = 31; + repeated int32 dst_block_indices = 32; + repeated int32 cum_sum = 33; // for continuous kvcache - repeated int64 new_cache_slot_offsets = 35; - repeated int64 kv_cache_start_offsets = 36; + repeated int64 new_cache_slot_offsets = 34; + repeated int64 kv_cache_start_offsets = 35; // beam search kernel - repeated float acc_logprob_vec = 37; + repeated float acc_logprob_vec = 36; } message BatchedForwardInputs {