Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 70 additions & 14 deletions xllm/core/framework/batch/batch_input_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ ForwardInput BatchInputBuilder::build_forward_input(
uint32_t num_decoding_tokens,
uint32_t min_decoding_batch_size) {
process_sequences(0, static_cast<uint32_t>(num_sequences_));
process_batch_forward_type();
padding_decode_batch_size(num_decoding_tokens, min_decoding_batch_size);

return state_to_forward_input();
}

Expand All @@ -102,6 +102,7 @@ RawForwardInput BatchInputBuilder::build_raw_forward_input(uint32_t start_idx,
} else {
process_sequences_multithreaded(start_idx, end_idx);
}
process_batch_forward_type();
return state_to_raw_forward_input();
}

Expand Down Expand Up @@ -207,7 +208,6 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
state_.unique_token_lens_vec.insert(state_.unique_token_lens_vec.end(),
state.unique_token_lens_vec.begin(),
state.unique_token_lens_vec.end());
state_.empty_kv_cache = state_.empty_kv_cache && state.empty_kv_cache;
state_.max_seq_len = std::max(state_.max_seq_len, state.max_seq_len);
state_.q_max_seq_len = std::max(state_.q_max_seq_len, state.q_max_seq_len);
#if defined(USE_NPU)
Expand Down Expand Up @@ -282,7 +282,6 @@ void BatchInputBuilder::process_single_sequence(
<< allowed_max_tokens_[seq_index];

// Update state
state.empty_kv_cache = state.empty_kv_cache && (n_kv_cache_tokens == 0);
state.max_seq_len = std::max(state.max_seq_len, seq_len);
state.q_max_seq_len = std::max(state.q_max_seq_len, q_seq_len);
#if defined(USE_NPU)
Expand Down Expand Up @@ -496,12 +495,7 @@ void BatchInputBuilder::padding_decode_batch_size(
uint32_t min_decoding_batch_size) {
if (num_sequences_ < min_decoding_batch_size) {
const uint32_t n_tokens = state_.flatten_tokens_vec.size();
// kv_cache is not empty in decoding phase
const bool in_decoding_phase = !state_.empty_kv_cache;
const bool same_num_decoding_tokens =
state_.q_max_seq_len == num_decoding_tokens &&
n_tokens == num_sequences_ * num_decoding_tokens;
if (in_decoding_phase && same_num_decoding_tokens) {
if (state_.batch_forward_type.is_decode()) {
// add padding tokens to the batch
for (int32_t i = num_sequences_; i < min_decoding_batch_size; ++i) {
for (int32_t k = 0; k < num_decoding_tokens; ++k) {
Expand Down Expand Up @@ -547,7 +541,7 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
}

auto& input_params = forward_input.input_params;
input_params.empty_kv_cache = state_.empty_kv_cache;
input_params.batch_forward_type = state_.batch_forward_type;
input_params.num_sequences = state_.block_tables_vec.size();
input_params.kv_max_seq_len = state_.max_seq_len;
input_params.q_max_seq_len = state_.q_max_seq_len;
Expand All @@ -557,8 +551,6 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
input_params.q_seq_lens_vec = std::move(state_.q_seq_lens);
input_params.new_cache_slots =
torch::tensor(state_.new_token_slot_ids, torch::kInt);
input_params.decode_seq_range =
util::find_ones_indices(input_params.q_seq_lens_vec);

// Setup multimodal data
input_params.mm_data = MMData::batch(mm_data_vec_);
Expand Down Expand Up @@ -621,14 +613,13 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
std::move(state_.unique_token_counts_vec);
raw_forward_input.unique_token_lens_vec =
std::move(state_.unique_token_lens_vec);
raw_forward_input.empty_kv_cache = state_.empty_kv_cache;
// raw_forward_input.global_empty_kv_cache = ;
raw_forward_input.max_seq_len = state_.max_seq_len;
raw_forward_input.q_max_seq_len = state_.q_max_seq_len;
raw_forward_input.seq_lens = std::move(state_.seq_lens);
raw_forward_input.q_seq_lens = std::move(state_.q_seq_lens);
raw_forward_input.new_token_slot_ids = std::move(state_.new_token_slot_ids);
raw_forward_input.block_tables_vec = std::move(state_.block_tables_vec);
raw_forward_input.batch_forward_type = std::move(state_.batch_forward_type);
raw_forward_input.num_sequences = num_sequences_;
// raw_forward_input.dp_global_token_nums = ;
raw_forward_input.transfer_kv_infos = std::move(state_.transfer_kv_infos);
Expand Down Expand Up @@ -726,4 +717,69 @@ void BatchInputBuilder::process_swap_block_infos(
swap_cache_block_infos_->end());
}
}

void BatchInputBuilder::process_batch_forward_type() {
CHECK_EQ(state_.seq_lens.size(), state_.q_seq_lens.size())
<< "seq_lens size must be equal to q_seq_lens size";

if (state_.q_max_seq_len == 1) {
state_.batch_forward_type = BatchForwardType::DECODE;
return;
}

bool empty_kv_cache = true;
bool all_decode = true;
bool all_prefill = true;

#if defined(USE_NPU)
if (state_.seq_lens.size() == 0) {
state_.batch_forward_type = BatchForwardType::IDLE;
return;
}
for (size_t i = 0; i < state_.seq_lens.size(); ++i) {
auto q_len = state_.q_seq_lens[i];
auto kv_len = state_.seq_lens[i];
auto cache_len = kv_len - q_len;
if (cache_len > 0) {
empty_kv_cache = false;
}
if (q_len > 1) {
all_decode = false;
}
if (q_len == 1) {
all_prefill = false;
}
}
#elif defined(USE_MLU)
if (state_.seq_lens.size() == 1) {
state_.batch_forward_type = BatchForwardType::IDLE;
return;
}
for (size_t i = 1; i < state_.seq_lens.size(); ++i) {
auto q_len = state_.q_seq_lens[i] - state_.q_seq_lens[i - 1];
auto kv_len = state_.seq_lens[i] - state_.seq_lens[i - 1];
auto cache_len = kv_len - q_len;
if (cache_len > 0) {
empty_kv_cache = false;
}
if (q_len > 1) {
all_decode = false;
}
if (q_len == 1) {
all_prefill = false;
}
}
#endif
if (empty_kv_cache) {
state_.batch_forward_type = BatchForwardType::PREFILL;
} else {
if (all_prefill) {
state_.batch_forward_type = BatchForwardType::CHUNKED_PREFILL;
} else if (all_decode) {
state_.batch_forward_type = BatchForwardType::DECODE;
} else {
state_.batch_forward_type = BatchForwardType::MIXED;
}
}
}
} // namespace xllm
4 changes: 3 additions & 1 deletion xllm/core/framework/batch/batch_input_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class BatchInputBuilder {

void process_swap_block_infos(RawForwardInput& raw_forward_input);

void process_batch_forward_type();

// State management
struct BuilderState {
// Token and position data
Expand All @@ -80,7 +82,7 @@ class BatchInputBuilder {
std::vector<int32_t> unique_token_lens_vec;

// Sequence metadata
bool empty_kv_cache = true;
BatchForwardType batch_forward_type;
uint32_t max_seq_len = 0;
uint32_t q_max_seq_len = 0;
#if defined(USE_NPU)
Expand Down
2 changes: 1 addition & 1 deletion xllm/core/framework/batch/batch_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ TEST(BatchTest, Basic) {

// check the input parameters
const ModelInputParams& input_params = forward_input.input_params;
EXPECT_FALSE(input_params.empty_kv_cache);
EXPECT_TRUE(input_params.batch_forward_type.is_mixed());
EXPECT_EQ(input_params.num_sequences, 4);
EXPECT_EQ(input_params.q_max_seq_len, 9);
EXPECT_EQ(input_params.kv_max_seq_len, 16);
Expand Down
1 change: 1 addition & 0 deletions xllm/core/framework/model/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ cc_library(
embedding_lm.h
model_args.h
npu_dp_ep_padding.h
batch_forward_type.h
model_input_params.h
SRCS
npu_dp_ep_padding.cpp
Expand Down
81 changes: 81 additions & 0 deletions xllm/core/framework/model/batch_forward_type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
Copyright 2024 The ScaleLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once

namespace xllm {

class BatchForwardType {
public:
enum Value : int32_t {
// Prefill without using kv cache.
PREFILL = 0,
// Chunked prefill using kv cache.
// No decode sequence in this type.
CHUNKED_PREFILL = 1,
// Decode one token.
// No prefill sequence in this type.
DECODE = 2,
// Mixed prefill and decode in one batch when doing chunked prefill.
MIXED = 3,
// No sequence to forward.
IDLE = 4,
};

BatchForwardType() : value_(IDLE) {}

BatchForwardType(int32_t v) : value_(static_cast<Value>(v)) {}

constexpr BatchForwardType(Value v) : value_(v) {}

BatchForwardType& operator=(Value v) {
value_ = v;
return *this;
}

int32_t value() const { return value_; }

bool is_prefill() const { return (value_ == PREFILL); }

bool is_chunked_prefill() const { return (value_ == CHUNKED_PREFILL); }

bool is_decode() const { return (value_ == DECODE); }

bool is_mixed() const { return (value_ == MIXED); }

bool is_idle() const { return (value_ == IDLE); }

const char* to_string() const {
switch (value_) {
case PREFILL:
return "PREFILL";
case CHUNKED_PREFILL:
return "CHUNKED_PREFILL";
case DECODE:
return "DECODE";
case MIXED:
return "MIXED";
case IDLE:
return "IDLE";
default:
return "UNKNOWN";
}
}

private:
Value value_;
};
} // namespace xllm
29 changes: 8 additions & 21 deletions xllm/core/framework/model/model_input_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#if defined(USE_NPU)
#include "platform/npu/npu_layer_synchronizer.h"
#endif
#include "framework/model/batch_forward_type.h"
#include "framework/request/mm_data.h"
#include "npu_dp_ep_padding.h"
#include "util/tensor_helper.h"
Expand Down Expand Up @@ -50,8 +51,7 @@ struct CacheBlockInfo {
struct ModelInputParams {
ModelInputParams to(const torch::Device& device) const {
ModelInputParams params;
params.empty_kv_cache = empty_kv_cache;
params.global_empty_kv_cache = global_empty_kv_cache;
params.batch_forward_type = batch_forward_type;
params.num_sequences = num_sequences;
params.kv_max_seq_len = kv_max_seq_len;
params.q_max_seq_len = q_max_seq_len;
Expand All @@ -63,7 +63,6 @@ struct ModelInputParams {
params.block_tables = safe_to(block_tables, device, true);
params.kv_seq_lens_vec = kv_seq_lens_vec;
params.q_seq_lens_vec = q_seq_lens_vec;
params.decode_seq_range = decode_seq_range;

params.input_embedding = safe_to(input_embedding, device);

Expand Down Expand Up @@ -98,24 +97,22 @@ struct ModelInputParams {
}

void print() const {
LOG(INFO) << "ModelInputParams: empty_kv_cache is " << empty_kv_cache
<< " , global_empty_kv_cache is " << global_empty_kv_cache
<< " , num_sequences is " << num_sequences
<< " , kv_max_seq_len is " << kv_max_seq_len
LOG(INFO) << "ModelInputParams: batch_forward_type is "
<< batch_forward_type.to_string() << " , num_sequences is "
<< num_sequences << " , kv_max_seq_len is " << kv_max_seq_len
<< " , q_max_seq_len is " << q_max_seq_len
<< " , prefill_seq_len is " << prefill_seq_len;
LOG(INFO) << "ModelInputParams: kv_seq_lens_vec is " << kv_seq_lens_vec;
LOG(INFO) << "ModelInputParams: q_seq_lens_vec is " << q_seq_lens_vec;
LOG(INFO) << "ModelInputParams: decode_seq_range is " << decode_seq_range;
print_tensor(kv_seq_lens, "ModelInputParams: kv_seq_lens", 4);
print_tensor(q_seq_lens, "ModelInputParams: q_seq_lens", 4);
print_tensor(new_cache_slots, "ModelInputParams: new_cache_slots", 4);
print_tensor(block_tables, "ModelInputParams: block_tables", 4);
LOG(INFO) << "ModelInputParams: dp_global_token_nums is "
<< dp_global_token_nums;
}
// whether the kv-cache is empty for all sequences.
bool empty_kv_cache = true;
// forward type of the batch, used by worker/kernel.
BatchForwardType batch_forward_type;

// total number of sequences in the batch
int32_t num_sequences = 0;
Expand All @@ -124,15 +121,7 @@ struct ModelInputParams {
torch::Tensor kv_seq_lens;
std::vector<int> kv_seq_lens_vec;
std::vector<int> q_seq_lens_vec;
// Range of decode sequence indices in the batch [start, end].
// Decode sequences are identified by q_seq_lens == 1,
// prefill sequences by q_seq_lens > 1 .
// Used to determine whether to use prefill_node_ or
// decode_node_ in NPU layers
// Values: {-1, -1} if no decode requests (all prefill),
// {0, batch_size-1} if all decode requests,
// {start_idx, end_idx} if mixed prefill/decode requests
std::pair<int, int> decode_seq_range;

// max length for qkv.
int32_t kv_max_seq_len = 0;
int32_t q_max_seq_len = 0;
Expand All @@ -151,8 +140,6 @@ struct ModelInputParams {

// num tokens of all workers,mainly used for dp case
std::vector<int32_t> dp_global_token_nums;
// whether the kv-cache is empty for all sequences,mainly used for dp case
bool global_empty_kv_cache = true;

// num of prefill sequence in chunked prefill case
uint32_t prefill_seq_len = 0;
Expand Down
4 changes: 2 additions & 2 deletions xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1520,8 +1520,8 @@ torch::Tensor NpuDeepseekV2DecoderLayerImpl::forward(
int node_id) {
atb::Status st;
// all micro batches are in same prefill/decode stage,
// so, to judge empty_kv_cache, use input_params[0] here
if (input_params[0].global_empty_kv_cache) {
// deepseek dont support chunked prefill, so only check is_prefill.
if (input_params[0].batch_forward_type.is_prefill()) {
build_node_variant_pack(prefill_node_,
x,
cos_pos,
Expand Down
3 changes: 1 addition & 2 deletions xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1085,8 +1085,7 @@ torch::Tensor Glm4MoeDecoderImpl::forward(
std::vector<std::atomic<bool>*> event_flag,
int node_id) {
atb::Status st;
if (input_params.decode_seq_range.second !=
input_params.q_seq_lens.size(0) - 1) {
if (!input_params.batch_forward_type.is_decode()) {
build_node_variant_pack(prefill_node_,
x,
cos_pos,
Expand Down
3 changes: 1 addition & 2 deletions xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,7 @@ torch::Tensor NpuLlamaDecoderLayerImpl::forward(torch::Tensor& x,
int node_id) {
atb::Status st;

if (input_params.decode_seq_range.second !=
input_params.q_seq_lens.size(0) - 1) {
if (!input_params.batch_forward_type.is_decode()) {
build_node_variant_pack(prefill_node_,
x,
cos_pos,
Expand Down
Loading