diff --git a/CMakeLists.txt b/CMakeLists.txt index c7765ee7..3e3e37a0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,7 +28,7 @@ if(USE_NPU) if(DEVICE_TYPE STREQUAL "USE_A3") message("downloading a3 arm xllm kernels") file(DOWNLOAD - "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.6.0/xllm_kernels-1.3.1-Linux.a3.arm.rpm" + "https://9n-online-service.s3-internal.cn-north-1.jdcloud-oss.com/9n-xllm-atb/xllm_kernels-1.3.0-Linux.rpm" "${CMAKE_BINARY_DIR}/xllm_kernels.rpm" ) else() diff --git a/xllm/api_service/api_service.cpp b/xllm/api_service/api_service.cpp index 788d0991..5e8502e9 100644 --- a/xllm/api_service/api_service.cpp +++ b/xllm/api_service/api_service.cpp @@ -27,6 +27,7 @@ limitations under the License. #include "core/common/metrics.h" #include "core/runtime/dit_master.h" #include "core/runtime/llm_master.h" +#include "core/runtime/rec_master.h" #include "core/runtime/vlm_master.h" #include "core/util/closure_guard.h" #include "embedding.pb.h" @@ -62,6 +63,9 @@ APIService::APIService(Master* master, image_generation_service_impl_ = std::make_unique( dynamic_cast(master), model_names); + } else if (FLAGS_backend == "rec") { + rec_completion_service_impl_ = std::make_unique( + dynamic_cast(master), model_names); } models_service_impl_ = ServiceImplFactory::create_service_impl( @@ -72,13 +76,6 @@ void APIService::Completions(::google::protobuf::RpcController* controller, const proto::CompletionRequest* request, proto::CompletionResponse* response, ::google::protobuf::Closure* done) { - // TODO with xllm-service -} - -void APIService::CompletionsHttp(::google::protobuf::RpcController* controller, - const proto::HttpRequest* request, - proto::HttpResponse* response, - ::google::protobuf::Closure* done) { xllm::ClosureGuard done_guard( done, std::bind(request_in_metric, nullptr), @@ -87,47 +84,38 @@ void APIService::CompletionsHttp(::google::protobuf::RpcController* controller, LOG(ERROR) << "brpc request | respose | controller is null"; return; } - - auto arena = response->GetArena(); - auto req_pb = - google::protobuf::Arena::CreateMessage(arena); - auto resp_pb = - google::protobuf::Arena::CreateMessage(arena); - auto ctrl = reinterpret_cast(controller); - std::string error; - json2pb::Json2PbOptions options; - butil::IOBuf& buf = ctrl->request_attachment(); - butil::IOBufAsZeroCopyInputStream iobuf_stream(buf); - auto st = json2pb::JsonToProtoMessage(&iobuf_stream, req_pb, options, &error); - if (!st) { - ctrl->SetFailed(error); - LOG(ERROR) << "parse json to proto failed: " << error; - return; - } - std::shared_ptr call = std::make_shared( - ctrl, done_guard.release(), req_pb, resp_pb); - completion_service_impl_->process_async(call); -} - -void APIService::ChatCompletions(::google::protobuf::RpcController* controller, - const proto::ChatRequest* request, - proto::ChatResponse* response, - ::google::protobuf::Closure* done) { - // TODO with xllm-service + if (FLAGS_backend == "llm") { + CHECK(completion_service_impl_) << " completion service is invalid."; + std::shared_ptr call = std::make_shared( + ctrl, + done_guard.release(), + const_cast(request), + response); + completion_service_impl_->process_async(call); + } else if (FLAGS_backend == "rec") { + CHECK(rec_completion_service_impl_) + << " rec completion service is invalid."; + std::shared_ptr call = std::make_shared( + ctrl, + done_guard.release(), + const_cast(request), + response); + rec_completion_service_impl_->process_async(call); + } } namespace { -template -void ChatCompletionsImpl(std::unique_ptr& service, - xllm::ClosureGuard& guard, - ::google::protobuf::Arena* arena, - brpc::Controller* ctrl) { +template +void CommonCompletionsImpl(std::unique_ptr& service, + xllm::ClosureGuard& guard, + ::google::protobuf::Arena* arena, + brpc::Controller* ctrl) { auto req_pb = - google::protobuf::Arena::CreateMessage(arena); + google::protobuf::Arena::CreateMessage(arena); auto resp_pb = - google::protobuf::Arena::CreateMessage(arena); + google::protobuf::Arena::CreateMessage(arena); std::string error; json2pb::Json2PbOptions options; @@ -140,12 +128,46 @@ void ChatCompletionsImpl(std::unique_ptr& service, return; } - auto call = - std::make_shared(ctrl, guard.release(), req_pb, resp_pb); + auto call = std::make_shared(ctrl, guard.release(), req_pb, resp_pb); service->process_async(call); } } // namespace +void APIService::CompletionsHttp(::google::protobuf::RpcController* controller, + const proto::HttpRequest* request, + proto::HttpResponse* response, + ::google::protobuf::Closure* done) { + xllm::ClosureGuard done_guard( + done, + std::bind(request_in_metric, nullptr), + std::bind(request_out_metric, (void*)controller)); + if (!request || !response || !controller) { + LOG(ERROR) << "brpc request | respose | controller is null"; + return; + } + + auto arena = response->GetArena(); + auto ctrl = reinterpret_cast(controller); + + if (FLAGS_backend == "llm") { + CHECK(completion_service_impl_) << " completion service is invalid."; + CommonCompletionsImpl( + completion_service_impl_, done_guard, arena, ctrl); + } else if (FLAGS_backend == "rec") { + CHECK(rec_completion_service_impl_) + << " rec completion service is invalid."; + CommonCompletionsImpl( + rec_completion_service_impl_, done_guard, arena, ctrl); + } +} + +void APIService::ChatCompletions(::google::protobuf::RpcController* controller, + const proto::ChatRequest* request, + proto::ChatResponse* response, + ::google::protobuf::Closure* done) { + // TODO with xllm-service +} + void APIService::ChatCompletionsHttp( ::google::protobuf::RpcController* controller, const proto::HttpRequest* request, @@ -165,11 +187,11 @@ void APIService::ChatCompletionsHttp( if (FLAGS_backend == "llm") { CHECK(chat_service_impl_) << " chat service is invalid."; - ChatCompletionsImpl( + CommonCompletionsImpl( chat_service_impl_, done_guard, arena, ctrl); } else if (FLAGS_backend == "vlm") { CHECK(mm_chat_service_impl_) << " mm chat service is invalid."; - ChatCompletionsImpl( + CommonCompletionsImpl( mm_chat_service_impl_, done_guard, arena, ctrl); } } diff --git a/xllm/api_service/api_service.h b/xllm/api_service/api_service.h index 0911fb05..5b92e497 100644 --- a/xllm/api_service/api_service.h +++ b/xllm/api_service/api_service.h @@ -122,6 +122,7 @@ class APIService : public proto::XllmAPIService { std::unique_ptr models_service_impl_; std::unique_ptr image_generation_service_impl_; std::unique_ptr rerank_service_impl_; + std::unique_ptr rec_completion_service_impl_; }; } // namespace xllm diff --git a/xllm/api_service/completion_service_impl.cpp b/xllm/api_service/completion_service_impl.cpp index 10d5e207..0ceda926 100644 --- a/xllm/api_service/completion_service_impl.cpp +++ b/xllm/api_service/completion_service_impl.cpp @@ -26,8 +26,10 @@ limitations under the License. #include "common/instance_name.h" #include "completion.pb.h" +#include "core/framework/request/mm_data.h" #include "core/framework/request/request_output.h" #include "core/runtime/llm_master.h" +#include "core/runtime/rec_master.h" #include "core/util/utils.h" #define likely(x) __builtin_expect(!!(x), 1) @@ -126,6 +128,7 @@ bool send_result_to_client_brpc(std::shared_ptr call, response.set_created(created_time); response.set_model(model); + // add choices into response response.mutable_choices()->Reserve(req_output.outputs.size()); for (const auto& output : req_output.outputs) { auto* choice = response.add_choices(); @@ -137,6 +140,7 @@ bool send_result_to_client_brpc(std::shared_ptr call, } } + // add usage statistics if (req_output.usage.has_value()) { const auto& usage = req_output.usage.value(); auto* proto_usage = response.mutable_usage(); @@ -147,35 +151,68 @@ bool send_result_to_client_brpc(std::shared_ptr call, proto_usage->set_total_tokens(static_cast(usage.num_total_tokens)); } - return call->write_and_finish(response); -} + if (FLAGS_backend == "rec") { + auto output_tensor = response.mutable_output_tensors()->Add(); + output_tensor->set_name("omnirec_result"); + // TODO: replace true with flags after converter merge + if (true) { + output_tensor->set_datatype(proto::DataType::INT64); + output_tensor->mutable_shape()->Add(req_output.outputs.size()); + output_tensor->mutable_shape()->Add(1); // Single item per output -} // namespace + auto context = output_tensor->mutable_contents(); + for (int i = 0; i < req_output.outputs.size(); ++i) { + if (req_output.outputs[i].item_ids.has_value()) { + context->mutable_int64_contents()->Add( + req_output.outputs[i].item_ids.value()); + } + } + } else { + output_tensor->set_datatype(proto::DataType::INT32); -CompletionServiceImpl::CompletionServiceImpl( - LLMMaster* master, - const std::vector& models) - : APIServiceImpl(models), master_(master) { - CHECK(master_ != nullptr); + output_tensor->mutable_shape()->Add(req_output.outputs.size()); + output_tensor->mutable_shape()->Add( + req_output.outputs[0].token_ids.size()); + + auto context = output_tensor->mutable_contents(); + for (int i = 0; i < req_output.outputs.size(); ++i) { + // LOG(INFO) << req_output.outputs[i].token_ids; + context->mutable_int_contents()->Add( + req_output.outputs[i].token_ids.begin(), + req_output.outputs[i].token_ids.end()); + } + } + } + + return call->write_and_finish(response); } -// complete_async for brpc -void CompletionServiceImpl::process_async_impl( - std::shared_ptr call) { +// Type alias for the return type of process_completion_request_params +using ProcessCompletionResult = + std::optional>, + bool, + std::string>>; +// Common function to process request parameters and validation +ProcessCompletionResult process_completion_request_params( + std::shared_ptr call, + const absl::flat_hash_set& models, + xllm::RateLimiter* rate_limiter) { const auto& rpc_request = call->request(); + // check if model is supported const auto& model = rpc_request.model(); - if (unlikely(!models_.contains(model))) { + if (unlikely(!models.contains(model))) { call->finish_with_error(StatusCode::UNKNOWN, "Model not supported"); - return; + return std::nullopt; } // Check if the request is being rate-limited. - if (unlikely(master_->get_rate_limiter()->is_limited())) { + if (unlikely(rate_limiter->is_limited())) { call->finish_with_error( StatusCode::RESOURCE_EXHAUSTED, "The number of concurrent requests has reached the limit."); - return; + return std::nullopt; } RequestParams request_params( @@ -195,44 +232,127 @@ void CompletionServiceImpl::process_async_impl( request_params.decode_address = rpc_request.routing().decode_name(); } + + return std::make_tuple(std::move(request_params), + std::move(prompt_tokens), + include_usage, + model); +} + +// Common callback function for handling request output +auto request_callback(std::shared_ptr call, + const std::string& model, + Master* master, + bool stream, + bool include_usage, + const std::string& request_id, + int64_t created_time) { + return [call, model, master, stream, include_usage, request_id, created_time]( + const RequestOutput& req_output) -> bool { + if (req_output.status.has_value()) { + const auto& status = req_output.status.value(); + if (!status.ok()) { + // Reduce the number of concurrent requests when a request is + // finished with error. + master->get_rate_limiter()->decrease_one_request(); + + return call->finish_with_error(status.code(), status.message()); + } + } + + // Reduce the number of concurrent requests when a request is finished + // or canceled. + if (req_output.finished || req_output.cancelled) { + master->get_rate_limiter()->decrease_one_request(); + } + + if (stream) { + return send_delta_to_client_brpc( + call, include_usage, request_id, created_time, model, req_output); + } + return send_result_to_client_brpc( + call, request_id, created_time, model, req_output); + }; +} + +} // namespace + +CompletionServiceImpl::CompletionServiceImpl( + LLMMaster* master, + const std::vector& models) + : APIServiceImpl(models), master_(master) { + CHECK(master_ != nullptr); +} + +// complete_async for brpc +void CompletionServiceImpl::process_async_impl( + std::shared_ptr call) { + auto result = process_completion_request_params( + call, models_, master_->get_rate_limiter()); + if (!result.has_value()) { + return; // Error already handled in process_completion_request_params + } + + auto [request_params, prompt_tokens, include_usage, model] = + std::move(result.value()); // schedule the request - master_->handle_request( - std::move(rpc_request.prompt()), - std::move(prompt_tokens), - std::move(request_params), - call.get(), - [call, - model, - master = master_, - stream = request_params.streaming, - include_usage = include_usage, - request_id = request_params.request_id, - created_time = absl::ToUnixSeconds(absl::Now())]( - const RequestOutput& req_output) -> bool { - if (req_output.status.has_value()) { - const auto& status = req_output.status.value(); - if (!status.ok()) { - // Reduce the number of concurrent requests when a request is - // finished with error. - master->get_rate_limiter()->decrease_one_request(); - - return call->finish_with_error(status.code(), status.message()); - } - } + master_->handle_request(std::move(call->request().prompt()), + std::move(prompt_tokens), + std::move(request_params), + call.get(), + request_callback(call, + model, + master_, + request_params.streaming, + include_usage, + request_params.request_id, + absl::ToUnixSeconds(absl::Now()))); +} - // Reduce the number of concurrent requests when a request is finished - // or canceled. - if (req_output.finished || req_output.cancelled) { - master->get_rate_limiter()->decrease_one_request(); - } +RecCompletionServiceImpl::RecCompletionServiceImpl( + RecMaster* master, + const std::vector& models) + : APIServiceImpl(models), master_(master) { + CHECK(master_ != nullptr); +} - if (stream) { - return send_delta_to_client_brpc( - call, include_usage, request_id, created_time, model, req_output); - } - return send_result_to_client_brpc( - call, request_id, created_time, model, req_output); - }); +void RecCompletionServiceImpl::process_async_impl( + std::shared_ptr call) { + auto result = process_completion_request_params( + call, models_, master_->get_rate_limiter()); + if (!result.has_value()) { + return; // Error already handled in process_completion_request_params + } + + auto [request_params, prompt_tokens, include_usage, model] = + std::move(result.value()); + const auto& rpc_request = call->request(); + std::optional mm_data = std::nullopt; + if (rpc_request.input_tensors_size()) { + // HISTOGRAM_OBSERVE(rec_input_first_dim, + // rpc_request.input_tensors(0).shape(0)); + + MMDict mm_dict; + for (int i = 0; i < rpc_request.input_tensors_size(); ++i) { + const auto& tensor = rpc_request.input_tensors(i); + mm_dict[tensor.name()] = + xllm::util::convert_rec_tensor_to_torch(tensor).to(torch::kBFloat16); + } + mm_data = std::move(MMData(MMType::EMBEDDING, mm_dict)); + } + + // schedule the request + master_->handle_request(std::move(rpc_request.prompt()), + std::move(prompt_tokens), + std::move(mm_data), + std::move(request_params), + request_callback(call, + model, + master_, + request_params.streaming, + include_usage, + request_params.request_id, + absl::ToUnixSeconds(absl::Now()))); } } // namespace xllm diff --git a/xllm/api_service/completion_service_impl.h b/xllm/api_service/completion_service_impl.h index 5fdc74f5..12b823a9 100644 --- a/xllm/api_service/completion_service_impl.h +++ b/xllm/api_service/completion_service_impl.h @@ -20,6 +20,7 @@ limitations under the License. #include "api_service_impl.h" #include "completion.pb.h" +#include "rec.pb.h" #include "stream_call.h" namespace xllm { @@ -41,4 +42,19 @@ class CompletionServiceImpl final : public APIServiceImpl { LLMMaster* master_ = nullptr; }; +class RecMaster; +// a class to handle completion requests +class RecCompletionServiceImpl final : public APIServiceImpl { + public: + RecCompletionServiceImpl(RecMaster* master, + const std::vector& models); + + // brpc call_data needs to use shared_ptr + void process_async_impl(std::shared_ptr call); + + private: + DISALLOW_COPY_AND_ASSIGN(RecCompletionServiceImpl); + RecMaster* master_ = nullptr; +}; + } // namespace xllm diff --git a/xllm/core/common/metrics.cpp b/xllm/core/common/metrics.cpp index 5f792f2b..3af87f54 100644 --- a/xllm/core/common/metrics.cpp +++ b/xllm/core/common/metrics.cpp @@ -180,6 +180,20 @@ DEFINE_COUNTER(proto_latency_seconds_o2proto, // engine metrics DEFINE_COUNTER(prepare_input_latency_seconds, "Latency of preparing input in seconds"); +DEFINE_COUNTER(prepare_input_latency_microseconds, + "Latency of preparing input in microseconds"); + +// rec engine metrics +DEFINE_COUNTER(rec_first_token_latency_microseconds, + "Latency of rec first token generation in microseconds"); +DEFINE_COUNTER(rec_second_token_latency_microseconds, + "Latency of rec second token generation in microseconds"); +DEFINE_COUNTER(rec_third_token_latency_microseconds, + "Latency of rec third token generation in microseconds"); +DEFINE_COUNTER(rec_sampling_latency_microseconds, + "Latency of rec sampling in microseconds"); +DEFINE_HISTOGRAM(expand_beam_latency_microseconds, + "Histogram of expand beam latency in microseconds"); // multi node metrics DEFINE_COUNTER(worker_service_latency_seconds, diff --git a/xllm/core/common/metrics.h b/xllm/core/common/metrics.h index 48663341..82c9f231 100644 --- a/xllm/core/common/metrics.h +++ b/xllm/core/common/metrics.h @@ -205,6 +205,14 @@ DECLARE_COUNTER(proto_latency_seconds_o2proto); // engine metrics DECLARE_COUNTER(prepare_input_latency_seconds); +// rec engine metrics +DECLARE_COUNTER(prepare_input_latency_microseconds); +DECLARE_COUNTER(rec_first_token_latency_microseconds); +DECLARE_COUNTER(rec_second_token_latency_microseconds); +DECLARE_COUNTER(rec_third_token_latency_microseconds); +DECLARE_COUNTER(rec_sampling_latency_microseconds); +DECLARE_HISTOGRAM(expand_beam_latency_microseconds); + // multi node metrics DECLARE_COUNTER(worker_service_latency_seconds); DECLARE_COUNTER(engine_latency_seconds); diff --git a/xllm/core/common/types.h b/xllm/core/common/types.h index faa1d11a..e5d39b08 100644 --- a/xllm/core/common/types.h +++ b/xllm/core/common/types.h @@ -31,6 +31,7 @@ class EngineType { SSM = 1, VLM = 2, DIT = 3, + REC = 4, INVALID = -1, }; @@ -44,6 +45,8 @@ class EngineType { value_ = VLM; } else if (str == "DIT") { value_ = DIT; + } else if (str == "REC") { + value_ = REC; } else { value_ = INVALID; } @@ -68,6 +71,8 @@ class EngineType { return "VLM"; } else if (this->value_ == DIT) { return "DIT"; + } else if (this->value_ == REC) { + return "REC"; } else { return "INVALID"; } diff --git a/xllm/core/framework/batch/CMakeLists.txt b/xllm/core/framework/batch/CMakeLists.txt index 94d20240..9676e906 100644 --- a/xllm/core/framework/batch/CMakeLists.txt +++ b/xllm/core/framework/batch/CMakeLists.txt @@ -10,12 +10,14 @@ cc_library( batch.h batch_factory.h batch_input_builder.h + rec_batch_input_builder.h mposition.h SRCS dit_batch.cpp batch.cpp batch_factory.cpp batch_input_builder.cpp + rec_batch_input_builder.cpp mposition.cpp beam_search.h DEPS diff --git a/xllm/core/framework/batch/batch.cpp b/xllm/core/framework/batch/batch.cpp index d2de8049..44f4990f 100644 --- a/xllm/core/framework/batch/batch.cpp +++ b/xllm/core/framework/batch/batch.cpp @@ -29,6 +29,7 @@ limitations under the License. #include "framework/model/model_input_params.h" #include "framework/request/sequence.h" #include "framework/sampling/sampling_params.h" +#include "rec_batch_input_builder.h" #include "runtime/params_utils.h" #include "util/slice.h" #include "util/tensor_helper.h" @@ -69,6 +70,10 @@ void Batch::add(const std::vector& sequences) { ForwardInput Batch::prepare_forward_input(uint32_t num_decoding_tokens, uint32_t min_decoding_batch_size, const ModelArgs& args) { + if (FLAGS_backend == "rec") { + return prepare_rec_forward_input( + num_decoding_tokens, min_decoding_batch_size, args); + } BatchInputBuilder builder(sequences_, allowed_max_tokens_, input_embeddings_vec_, @@ -81,6 +86,41 @@ ForwardInput Batch::prepare_forward_input(uint32_t num_decoding_tokens, min_decoding_batch_size); } +ForwardInput Batch::prepare_rec_forward_input(uint32_t num_decoding_tokens, + uint32_t min_decoding_batch_size, + const ModelArgs& args, + ThreadPool* thread_pool) { + // Convert SequencesGroup* to std::unique_ptr for + // compatibility + std::vector> sequence_groups_ptrs; + for (auto* group : sequence_groups_) { + // Note: This is a temporary workaround. In production, we should avoid this + // conversion and modify the interface to work with raw pointers directly. + sequence_groups_ptrs.emplace_back(std::unique_ptr(group)); + } + + RecBatchInputBuilder builder( + sequence_groups_ptrs, + allowed_max_tokens_, + input_embeddings_vec_, + mm_data_vec_, + copy_in_cache_block_infos_, + copy_out_cache_block_infos_, + swap_cache_block_infos_, + &args, + thread_pool); // Temporarily not using thread pool + + auto result = builder.build_rec_forward_input(num_decoding_tokens, + min_decoding_batch_size); + + // Release the unique_ptrs without deleting the objects + for (auto& ptr : sequence_groups_ptrs) { + ptr.release(); + } + + return result; +} + RawForwardInput Batch::prepare_forward_input(uint32_t start_idx, uint32_t end_idx, ThreadPool* thread_pool) { @@ -338,4 +378,11 @@ void Batch::process_beam_search_output(const RawForwardOutput& raw_output, update_for_sequence_group(sequence_group_id); } } + +void Batch::finish() { + // Finish all sequence groups + for (auto* sequence_group : sequence_groups_) { + sequence_group->finish(); + } +} } // namespace xllm diff --git a/xllm/core/framework/batch/batch.h b/xllm/core/framework/batch/batch.h index f862b305..45799220 100644 --- a/xllm/core/framework/batch/batch.h +++ b/xllm/core/framework/batch/batch.h @@ -111,6 +111,14 @@ class Batch { bool get_batch_prefill_status() const { return all_seqs_in_prefill_; } + void finish(); + + // prepare forward inputs for Rec model + ForwardInput prepare_rec_forward_input(uint32_t num_decoding_tokens, + uint32_t min_decoding_batch_size, + const ModelArgs& args, + ThreadPool* thread_pool = nullptr); + private: bool update_sequence_state(Sequence* seq, bool replace_fake_token); diff --git a/xllm/core/framework/batch/batch_factory.cpp b/xllm/core/framework/batch/batch_factory.cpp index 5dd9d428..97ba94de 100644 --- a/xllm/core/framework/batch/batch_factory.cpp +++ b/xllm/core/framework/batch/batch_factory.cpp @@ -106,4 +106,79 @@ std::vector BatchFactory::create_batches( return batches; } +std::vector BatchFactory::create_rec_batches( + const std::vector>& running_requests, + const std::vector& running_sequences, + const std::vector& running_sequences_budgets, + std::vector>* copy_in_cache_block_infos, + std::vector>* copy_out_cache_block_infos, + std::vector>* swap_cache_block_infos) { + size_t num_prompt_tokens = 0; + size_t num_generated_tokens = 0; + std::vector batches(dp_size_); + for (size_t i = 0; i < running_sequences.size(); ++i) { + auto* sequence = running_sequences[i]; + const size_t token_budget = running_sequences_budgets[i]; + + const size_t remaining_prompt_tokens = + sequence->num_prompt_tokens() > + sequence->kv_state().kv_cache_tokens_num() + ? sequence->num_prompt_tokens() - + sequence->kv_state().kv_cache_tokens_num() + : 0; + const size_t prompt_tokens = + std::min(remaining_prompt_tokens, token_budget); + const size_t generated_tokens = token_budget - prompt_tokens; + num_prompt_tokens += prompt_tokens; + num_generated_tokens += generated_tokens; + + // if dp enabled, each sequence is required to + // dispatch to the same rank in the whole lifetime + // batches[sequence->dp_rank()].add(sequence, token_budget); + if (!((sequence->stage() == SequenceStage::DECODE) && + (sequence->kv_state().kv_cache_tokens_num() > 0))) { + batches[sequence->dp_rank()].set_batch_prefill_status(true); + } + } + // for rec, only use seq_group to prepare_input. + for (const auto& request : running_requests) { + auto seq_group = request->sequence_group(); + int32_t dp_rank = seq_group->dp_rank(); + batches[dp_rank].add(seq_group); + } + + for (int i = 0; i < dp_size_; i++) { + if (!batches[i].empty()) { + if (copy_in_cache_block_infos != nullptr && + copy_in_cache_block_infos->size() == dp_size_) { + batches[i].set_copy_in_cache_block_infos( + &(copy_in_cache_block_infos->at(i))); + } + if (copy_out_cache_block_infos != nullptr && + copy_out_cache_block_infos->size() == dp_size_) { + batches[i].set_copy_out_cache_block_infos( + &(copy_out_cache_block_infos->at(i))); + } + if (swap_cache_block_infos != nullptr && + swap_cache_block_infos->size() == dp_size_) { + batches[i].set_swap_cache_block_infos(&(swap_cache_block_infos->at(i))); + } + } + } + + COUNTER_ADD(num_processing_tokens_total_prompt, num_prompt_tokens); + COUNTER_ADD(num_processing_tokens_total_generated, num_generated_tokens); + + if (running_sequences.size() > 0) { + HISTOGRAM_OBSERVE( + num_prompt_tokens_per_request, + static_cast(num_prompt_tokens / running_sequences.size())); + HISTOGRAM_OBSERVE( + num_generated_tokens_per_request, + static_cast(num_generated_tokens / running_sequences.size())); + } + + return batches; +} + } // namespace xllm diff --git a/xllm/core/framework/batch/batch_factory.h b/xllm/core/framework/batch/batch_factory.h index 44106771..10627400 100644 --- a/xllm/core/framework/batch/batch_factory.h +++ b/xllm/core/framework/batch/batch_factory.h @@ -41,6 +41,20 @@ class BatchFactory { std::vector>* swap_cache_block_infos = nullptr); + std::vector create_rec_batches( + const std::vector>& running_requests, + const std::vector& running_sequences, + const std::vector& running_sequences_budgets, + // for global kv cache copy block from host to device + std::vector>* copy_in_cache_block_infos = + nullptr, + // for global kv cache copy block from device to host + std::vector>* copy_out_cache_block_infos = + nullptr, + // for beam-search + std::vector>* swap_cache_block_infos = + nullptr); + private: BatchFactory(int32_t dp_size) : dp_size_(dp_size) {} ~BatchFactory() = default; diff --git a/xllm/core/framework/batch/batch_input_builder.h b/xllm/core/framework/batch/batch_input_builder.h index 9b76bfb1..c8171731 100644 --- a/xllm/core/framework/batch/batch_input_builder.h +++ b/xllm/core/framework/batch/batch_input_builder.h @@ -137,6 +137,7 @@ class BatchInputBuilder { uint32_t q_seq_len, BuilderState* state_ptr = nullptr); + protected: // Input data const std::vector& sequences_; const std::vector& allowed_max_tokens_; diff --git a/xllm/core/framework/batch/rec_batch_input_builder.cpp b/xllm/core/framework/batch/rec_batch_input_builder.cpp new file mode 100644 index 00000000..fda85054 --- /dev/null +++ b/xllm/core/framework/batch/rec_batch_input_builder.cpp @@ -0,0 +1,1000 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "rec_batch_input_builder.h" + +#include +#include +#include +#include +#include +#include + +#include "framework/model/model_args.h" +#include "framework/model/model_input_params.h" +#include "framework/request/sequence.h" +#include "framework/sampling/sampling_params.h" +#include "util/tensor_helper.h" +#include "util/threadpool.h" +#include "util/utils.h" + +namespace xllm { + +// Static member definition +RecBatchInputBuilder::HighPerformanceCache RecBatchInputBuilder::perf_cache_; + +RecBatchInputBuilder::RecBatchInputBuilder( + const std::vector>& sequence_groups, + const std::vector& allowed_max_tokens, + const std::vector& input_embeddings_vec, + const std::vector& mm_data_vec, + const std::vector* copy_in_cache_block_infos, + const std::vector* copy_out_cache_block_infos, + std::vector* swap_cache_block_infos, + const ModelArgs* args, + ThreadPool* thread_pool) + : BatchInputBuilder( // extract_sequences_from_groups(sequence_groups), + {}, + allowed_max_tokens, + input_embeddings_vec, + mm_data_vec, + copy_in_cache_block_infos, + copy_out_cache_block_infos, + swap_cache_block_infos, + args, + thread_pool), + sequence_groups_(sequence_groups) { + // Reset high performance cache + perf_cache_.memory_pool.reset(); +} + +std::vector RecBatchInputBuilder::extract_sequences_from_groups( + const std::vector>& sequence_groups) { + std::vector sequences; + for (const auto& group : sequence_groups) { + for (const auto& seq : group->sequences()) { + sequences.push_back(seq.get()); + } + } + return sequences; +} + +ForwardInput RecBatchInputBuilder::build_rec_forward_input( + uint32_t num_decoding_tokens, + uint32_t min_decoding_batch_size) { + // ========== Global constant cache ========== + static const std::vector FIXED_POSITIONS = {0}; + static const torch::Tensor FIXED_ENCODER_POSITIONS = + torch::tensor({0}, torch::kInt); + + // ========== Fast sequence information extraction ========== + const int32_t num_sequences = + !sequence_groups_.empty() + ? std::accumulate(sequence_groups_.begin(), + sequence_groups_.end(), + 0, + [](int sum, const auto& group) { + return sum + group->sequences().size(); + }) + : 0; + + if (UNLIKELY(num_sequences == 0)) { + return ForwardInput{}; + } + + // Get basic information of first sequence - optimize pointer access + Sequence* first_sequence = nullptr; + if (!sequence_groups_.empty() && !sequence_groups_[0]->sequences().empty()) { + first_sequence = sequence_groups_[0]->sequences()[0].get(); + } + + if (!first_sequence) { + return ForwardInput{}; + } + + const uint32_t seq_len = first_sequence->num_tokens(); + const uint32_t num_decoder_embeddings = + first_sequence->num_decoder_embeddings(); + const uint32_t n_prompt_tokens = first_sequence->num_prompt_tokens(); + const bool is_first_prefill = (first_sequence->num_generated_tokens() == 0); + // const uint64_t model_version = first_sequence->get_model_version(); + + // ========== High-performance encoder tokens construction ========== + auto buildEncoderTokensOptimized = [&]() -> const std::vector& { + auto& cache_data = perf_cache_.cache_data; + + // encoder doesn't use cache key, because encoder doesn't use encoder_tokens + // in non-first prefill scenarios, only uses encoder_seq_len + if (!is_first_prefill) { + return cache_data.encoder_tokens; + } + // Below is the reconstruction for first prefill stage + const auto& encoder_tokens = first_sequence->encoder_tokens(); + if (encoder_tokens.empty()) { + cache_data.encoder_tokens.clear(); + return cache_data.encoder_tokens; + } + + // Optimization: Use SIMD-friendly memory access patterns + cache_data.encoder_tokens.clear(); + cache_data.encoder_seq_lens.clear(); + + // Optimization for scenarios where sequences have different lengths across + // sequence groups Pre-calculate total token count to avoid multiple memory + // reallocations + int32_t total_tokens = 0; + for (const auto& group_ptr : sequence_groups_) { + if (!group_ptr->sequences().empty()) { + // Sequences within group have same length, only need to get first + // sequence's length + const int32_t group_encoder_seq_len = + group_ptr->sequences()[0]->encoder_seq_len(); + total_tokens += group_encoder_seq_len * group_ptr->sequences().size(); + } + } + + cache_data.encoder_tokens.reserve(total_tokens); + cache_data.encoder_seq_lens.resize(num_sequences); + cache_data.encoder_sparse_embeddings.clear(); + cache_data.encoder_sparse_embeddings.reserve(num_sequences); + cache_data.decoder_context_embeddings.clear(); + cache_data.decoder_context_embeddings.reserve(num_sequences); + + // Process by groups in batch + int32_t global_seq_idx = 0; + for (const auto& group_ptr : sequence_groups_) { + const auto& group = *group_ptr; + const int32_t group_size = group.sequences().size(); + + if (group_size == 0) continue; + + const int32_t group_encoder_seq_len = + group.sequences()[0]->encoder_seq_len(); + + // Batch set same values + std::fill_n(&cache_data.encoder_seq_lens[global_seq_idx], + group_size, + group_encoder_seq_len); + + // Batch copy tokens by sequence and collect sparse_embedding + for (const auto& sequence : group.sequences()) { + const auto& encoder_tokens = sequence->encoder_tokens(); + const int32_t* src_ptr = encoder_tokens.data(); + + // Use efficient batch insertion + cache_data.encoder_tokens.insert(cache_data.encoder_tokens.end(), + src_ptr, + src_ptr + group_encoder_seq_len); + // Collect sparse_embedding + auto mm_data = sequence->get_mm_data(); + auto sparse_embedding_optional = + mm_data.get(Sequence::ENCODER_SPARSE_EMBEDDING_NAME); + if (sparse_embedding_optional.has_value()) { + cache_data.encoder_sparse_embeddings.push_back( + sparse_embedding_optional.value()); + } + + auto decoder_context_embedding_optional = mm_data.get( + Sequence::DECODER_CONTEXT_EMBEDDING_NAME); + if (decoder_context_embedding_optional.has_value()) { + cache_data.decoder_context_embeddings.push_back( + decoder_context_embedding_optional.value()); + } + } + + global_seq_idx += group_size; + } + + return cache_data.encoder_tokens; + }; + + // ========== High-performance decoder data construction ========== + auto buildDecoderDataOptimized = [&]() { + // Pre-allocate all containers to avoid dynamic expansion + const size_t total_tokens = num_sequences * seq_len; + std::vector flatten_tokens_vec; + flatten_tokens_vec.reserve(total_tokens); + std::vector sampling_params; + sampling_params.reserve(num_sequences); + std::vector selected_token_idxes; + selected_token_idxes.reserve(num_sequences); + std::vector sample_idxes; + sample_idxes.reserve(num_sequences); + std::vector> generated_tokens; + generated_tokens.reserve(num_sequences); + + // Multi-threading optimization: Use parallel processing when sequence count + // exceeds threshold and thread pool is available + const int32_t THREADPOOL_THRESHOLD = 16; + ThreadPool* threadpool = thread_pool_; + if (num_sequences >= THREADPOOL_THRESHOLD && threadpool != nullptr) { + // Thread-safe result containers + std::vector> thread_flatten_tokens(num_sequences); + std::vector thread_sampling_params( + num_sequences); + std::vector thread_selected_token_idxes(num_sequences); + std::vector thread_sample_idxes(num_sequences); + std::vector> thread_generated_tokens(num_sequences); + + // Calculate thread allocation + const size_t num_threads = + std::min(static_cast(num_sequences), static_cast(16)); + const size_t sequences_per_thread = + (num_sequences + num_threads - 1) / num_threads; + + std::vector> futures; + std::vector>> promises; + futures.reserve(num_threads); + promises.reserve(num_threads); + + // Parallel processing function + auto process_sequences_range = [&](size_t start_idx, size_t end_idx) { + for (size_t i = start_idx; + i < end_idx && i < static_cast(num_sequences); + ++i) { + const Sequence* sequence = nullptr; + // Get sequence from sequence_groups + size_t seq_idx = 0; + for (const auto& group : sequence_groups_) { + if (seq_idx + group->sequences().size() > i) { + sequence = group->sequences()[i - seq_idx].get(); + break; + } + seq_idx += group->sequences().size(); + } + + if (!sequence) continue; + + const auto& token_ids = sequence->tokens(); + + // Build generated tokens + auto& cur_generated_tokens = thread_generated_tokens[i]; + cur_generated_tokens.reserve(seq_len - n_prompt_tokens); + for (uint32_t j = n_prompt_tokens; j < seq_len; ++j) { + cur_generated_tokens.push_back(token_ids[j]); + } + + // Build flatten tokens + auto& cur_flatten_tokens = thread_flatten_tokens[i]; + cur_flatten_tokens.reserve(seq_len); + cur_flatten_tokens.insert(cur_flatten_tokens.end(), + token_ids.begin(), + token_ids.begin() + seq_len); + + // Set sampling parameters + thread_sampling_params[i] = sequence->sampling_param(); + thread_sample_idxes[i] = static_cast(i); + } + }; + + // Launch parallel tasks + for (size_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) { + size_t start_idx = thread_idx * sequences_per_thread; + size_t end_idx = std::min(start_idx + sequences_per_thread, + static_cast(num_sequences)); + + if (start_idx >= static_cast(num_sequences)) break; + + auto promise = std::make_shared>(); + futures.push_back(promise->get_future()); + promises.push_back(promise); + + threadpool->schedule( + [process_sequences_range, start_idx, end_idx, promise]() mutable { + try { + process_sequences_range(start_idx, end_idx); + promise->set_value(); + } catch (...) { + promise->set_exception(std::current_exception()); + } + }); + } + + // Wait for all tasks to complete + for (auto& future : futures) { + future.get(); + } + + // Merge results + size_t start_idx = 0; + size_t total_tokens = seq_len + num_decoder_embeddings; + for (int32_t i = 0; i < num_sequences; ++i) { + flatten_tokens_vec.insert(flatten_tokens_vec.end(), + thread_flatten_tokens[i].begin(), + thread_flatten_tokens[i].end()); + selected_token_idxes.push_back( + static_cast(start_idx + total_tokens - 1)); + start_idx += total_tokens; + sampling_params.push_back(thread_sampling_params[i]); + sample_idxes.push_back(thread_sample_idxes[i]); + generated_tokens.push_back(std::move(thread_generated_tokens[i])); + } + } else { + // Original single-thread processing logic + size_t start_idx = 0; + size_t total_tokens = seq_len + num_decoder_embeddings; + size_t seq_idx = 0; + for (const auto& group : sequence_groups_) { + for (const auto& sequence : group->sequences()) { + const auto& token_ids = sequence->tokens(); + + // Optimize generated tokens construction + auto& cur_generated_tokens = generated_tokens.emplace_back(); + cur_generated_tokens.reserve(seq_len - n_prompt_tokens); + for (uint32_t j = n_prompt_tokens; j < seq_len; ++j) { + cur_generated_tokens.push_back(token_ids[j]); + } + // Optimize token processing - batch operations + flatten_tokens_vec.insert(flatten_tokens_vec.end(), + token_ids.begin(), + token_ids.begin() + seq_len); + + // Simplify sampling parameter processing + selected_token_idxes.push_back( + static_cast(start_idx + total_tokens - 1)); + start_idx += total_tokens; + sampling_params.push_back(sequence->sampling_param()); + sample_idxes.push_back(seq_idx); + seq_idx++; + } + } + } + + return std::make_tuple(std::move(flatten_tokens_vec), + std::move(sampling_params), + std::move(selected_token_idxes), + std::move(sample_idxes), + std::move(generated_tokens)); + }; + + // ========== Comprehensive parallel execution of optimized data construction + // ========== Use thread pool to execute all independent data construction + // tasks in parallel + std::future&> encoder_future; + std::future, + std::vector, + std::vector, + std::vector, + std::vector>>> + decoder_future; + + // Declare variables to store results + const std::vector* encoder_tokens_ptr = nullptr; + std::vector flatten_tokens_vec; + std::vector sampling_params; + std::vector selected_token_idxes; + std::vector sample_idxes; + std::vector> generated_tokens; + + if (thread_pool_ && num_sequences >= 8) { + // Use ThreadPool's schedule method to execute independent tasks in parallel + // buildDecoderDataOptimized handles multi-threading internally, no external + // parallel calls + + // Task 1: buildEncoderTokensOptimized + std::promise*> encoder_promise; + auto encoder_future = encoder_promise.get_future(); + thread_pool_->schedule([&, promise = std::move(encoder_promise)]() mutable { + const auto& result = buildEncoderTokensOptimized(); + promise.set_value(&result); + }); + // Wait for encoder to complete + encoder_tokens_ptr = encoder_future.get(); + // Task 2: buildDecoderDataOptimized executes directly, handles + // multi-threading internally + std::tie(flatten_tokens_vec, + sampling_params, + selected_token_idxes, + sample_idxes, + generated_tokens) = buildDecoderDataOptimized(); + } else { + // Single-thread execution (original logic) + encoder_tokens_ptr = &buildEncoderTokensOptimized(); + std::tie(flatten_tokens_vec, + sampling_params, + selected_token_idxes, + sample_idxes, + generated_tokens) = buildDecoderDataOptimized(); + } + + const auto& encoder_tokens = *encoder_tokens_ptr; + + // ========== High-performance ForwardInput construction ========== + ForwardInput forward_input; + auto& input_params = forward_input.input_params; + auto& cache_data = perf_cache_.cache_data; + + // Initialize key fields for asynchronous tasks + const int64_t bs = sequence_groups_.size(); + const int64_t group_width = + sequence_groups_.empty() ? 1 : sequence_groups_[0]->sequences().size(); + + std::vector> decoder_embedding_futures; + torch::Tensor result_embedding; + + // ========== Parallel tensor construction tasks ========== + if (thread_pool_ && num_sequences >= 4) { + // Only use parallelization for time-consuming tasks (token_ids and + // encoder_token_ids) + std::promise token_ids_promise; + std::promise encoder_token_ids_promise; + + auto token_ids_future = token_ids_promise.get_future(); + auto encoder_token_ids_future = encoder_token_ids_promise.get_future(); + + // Task 1: Build token_ids tensor - + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + thread_pool_->schedule([&flatten_tokens_vec, + promise = std::move(token_ids_promise)]() mutable { + try { + // Optimization: Pre-allocate memory and use std::memcpy to avoid clone + // operations + auto tensor = + torch::empty({static_cast(flatten_tokens_vec.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(tensor.data_ptr(), + flatten_tokens_vec.data(), + flatten_tokens_vec.size() * sizeof(int)); + promise.set_value(std::move(tensor)); + } catch (...) { + promise.set_exception(std::current_exception()); + } + }); + + // Task 2: Build encoder_token_ids tensor (if needed) - + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + thread_pool_->schedule( + [&encoder_tokens, + promise = std::move(encoder_token_ids_promise)]() mutable { + try { + torch::Tensor tensor; + if (!encoder_tokens.empty()) { + // Optimization: Pre-allocate memory and use std::memcpy to avoid + // clone operations + tensor = + torch::empty({static_cast(encoder_tokens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(tensor.data_ptr(), + encoder_tokens.data(), + encoder_tokens.size() * sizeof(int)); + } + promise.set_value(std::move(tensor)); + } catch (...) { + promise.set_exception(std::current_exception()); + } + }); + + if (!perf_cache_.cache_data.decoder_context_embeddings.empty()) { + // Task 3: Synchronously process decoder_embedding, inner group dimension + // parallelization optimization + + // Optimization: Directly get shape information from first embedding to + // avoid torch::cat + auto first_embedding = + perf_cache_.cache_data.decoder_context_embeddings[0]; + auto original_shape = first_embedding.sizes(); + int64_t context_len = original_shape[0]; + int64_t hidden_size = original_shape[1]; + + // Create tensor on pinned memory + auto options = torch::TensorOptions() + .dtype(first_embedding.dtype()) + .device(first_embedding.device()) + .pinned_memory(true) + .memory_format(torch::MemoryFormat::Contiguous); + + // Calculate total sequence length, pre-allocate context_len + seq_len + int64_t total_seq_len = context_len + seq_len; + + auto combined_embedding = + torch::empty({bs, group_width, total_seq_len, hidden_size}, options); + + // High-performance optimization: group dimension segmented + // parallelization + void* dst_data = combined_embedding.data_ptr(); + + // Get element size (supports float, bfloat16 and other types) + const size_t element_size = first_embedding.element_size(); + const size_t context_size = context_len * hidden_size * element_size; + const size_t group_stride = total_seq_len * hidden_size * element_size; + const size_t batch_stride = + group_width * total_seq_len * hidden_size * element_size; + + // Parallelization strategy: segment by group dimension, consistent with + // thread calculations elsewhere + const size_t num_threads = + std::min(static_cast(group_width), static_cast(16)); + const size_t groups_per_thread = + (group_width + num_threads - 1) / num_threads; + + for (size_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) { + size_t start_group = thread_idx * groups_per_thread; + size_t end_group = std::min(start_group + groups_per_thread, + static_cast(group_width)); + + if (start_group >= static_cast(group_width)) break; + + std::promise promise; + decoder_embedding_futures.push_back(promise.get_future()); + + thread_pool_->schedule( + [start_group, + end_group, + bs, + dst_data, + context_len, + hidden_size, + element_size, + batch_stride, + group_stride, + context_size, + embeddings = perf_cache_.cache_data.decoder_context_embeddings, + dst_tensor = combined_embedding, + promise = std::move(promise)]() mutable { + // Copy context_embedding for specified group range of each batch + for (int64_t b = 0; b < bs; ++b) { + // Optimization: Access corresponding batch embedding directly + // through index + const void* batch_src = embeddings[b].data_ptr(); + auto* batch_dst = + static_cast(dst_data) + b * batch_stride; + + for (size_t g = start_group; g < end_group; ++g) { + std::memcpy( + batch_dst + g * group_stride, batch_src, context_size); + } + } + promise.set_value(); + }); + } + + result_embedding = combined_embedding; + } + + // Task 4: Build sequence length vector - changed to serial execution (very + // time-consuming, ~0.001785ms) + std::vector cu_seq_lens, q_cu_seq_lens; +#ifdef USE_ASCEND + cu_seq_lens.assign(num_sequences, seq_len + num_decoder_embeddings); + q_cu_seq_lens.assign(num_sequences, seq_len + num_decoder_embeddings); +#else + cu_seq_lens.reserve(num_sequences + 1); + q_cu_seq_lens.reserve(num_sequences + 1); + cu_seq_lens.push_back(0); + q_cu_seq_lens.push_back(0); + + for (int32_t i = 0; i < num_sequences; ++i) { + cu_seq_lens.push_back(cu_seq_lens.back() + seq_len + + num_decoder_embeddings); + q_cu_seq_lens.push_back(q_cu_seq_lens.back() + seq_len + + num_decoder_embeddings); + } +#endif + + // Task 5: Build encoder_seq_lens_tensor - changed to serial execution (less + // time-consuming) + torch::Tensor encoder_seq_lens_tensor; + if (!cache_data.encoder_seq_lens.empty()) { + // Optimization: Pre-allocate memory and use std::memcpy to avoid clone + // operations + encoder_seq_lens_tensor = torch::empty( + {static_cast(cache_data.encoder_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(encoder_seq_lens_tensor.data_ptr(), + cache_data.encoder_seq_lens.data(), + cache_data.encoder_seq_lens.size() * sizeof(int)); + } + + // Set basic parameters simultaneously (not dependent on asynchronous tasks) + input_params.num_sequences = num_sequences; + input_params.empty_kv_cache = true; + input_params.global_empty_kv_cache = true; + input_params.kv_max_seq_len = seq_len + num_decoder_embeddings; + input_params.q_max_seq_len = seq_len + num_decoder_embeddings; + forward_input.positions = perf_cache_.fixed_positions_tensor; + if (!encoder_tokens.empty()) { + forward_input.positions = perf_cache_.fixed_encoder_positions_tensor; + } + + // Wait and collect results + forward_input.token_ids = token_ids_future.get(); + auto encoder_token_ids = encoder_token_ids_future.get(); + + // seq_lens has been changed to serial execution, use the constructed + // variable directly + + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.kv_seq_lens = + torch::empty({static_cast(cu_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.kv_seq_lens.data_ptr(), + cu_seq_lens.data(), + cu_seq_lens.size() * sizeof(int)); + + input_params.q_seq_lens = + torch::empty({static_cast(q_cu_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.q_seq_lens.data_ptr(), + q_cu_seq_lens.data(), + q_cu_seq_lens.size() * sizeof(int)); + input_params.kv_seq_lens_vec = std::move(cu_seq_lens); + input_params.q_seq_lens_vec = std::move(q_cu_seq_lens); + + // encoder_seq_lens_tensor has been changed to serial execution, use the + // constructed variable directly + if (encoder_seq_lens_tensor.defined()) { + // Set RecModelInputParams encoder data + if (!input_params.rec_params.has_value()) { + input_params.rec_params = RecModelInputParams{}; + } + input_params.rec_params->encoder_seq_lens_tensor = + std::move(encoder_seq_lens_tensor); + input_params.rec_params->encoder_seq_lens = cache_data.encoder_seq_lens; + } + } else { + // Single-threaded execution (original logic) + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + forward_input.token_ids = + torch::empty({static_cast(flatten_tokens_vec.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(forward_input.token_ids.data_ptr(), + flatten_tokens_vec.data(), + flatten_tokens_vec.size() * sizeof(int)); + forward_input.positions = perf_cache_.fixed_positions_tensor; + + if (!encoder_tokens.empty()) { + // Set RecModelInputParams encoder data + if (!input_params.rec_params.has_value()) { + input_params.rec_params = RecModelInputParams{}; + } + + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.rec_params->encoder_token_ids = + torch::empty({static_cast(encoder_tokens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.rec_params->encoder_token_ids.data_ptr(), + encoder_tokens.data(), + encoder_tokens.size() * sizeof(int)); + input_params.rec_params->encoder_positions = + perf_cache_.fixed_encoder_positions_tensor; + } + + // Pre-allocate and batch fill + std::vector cu_seq_lens, q_cu_seq_lens; +#ifdef USE_ASCEND + cu_seq_lens.assign(num_sequences, seq_len + num_decoder_embeddings); + q_cu_seq_lens.assign(num_sequences, seq_len + num_decoder_embeddings); +#else + cu_seq_lens.reserve(num_sequences + 1); + q_cu_seq_lens.reserve(num_sequences + 1); + cu_seq_lens.push_back(0); + q_cu_seq_lens.push_back(0); + + for (int32_t i = 0; i < num_sequences; ++i) { + cu_seq_lens.push_back(cu_seq_lens.back() + seq_len + + num_decoder_embeddings); + q_cu_seq_lens.push_back(q_cu_seq_lens.back() + seq_len + + num_decoder_embeddings); + } +#endif + + input_params.num_sequences = num_sequences; + input_params.empty_kv_cache = true; + input_params.global_empty_kv_cache = true; + input_params.kv_max_seq_len = seq_len + num_decoder_embeddings; + input_params.q_max_seq_len = seq_len + num_decoder_embeddings; + + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.kv_seq_lens = + torch::empty({static_cast(cu_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.kv_seq_lens.data_ptr(), + cu_seq_lens.data(), + cu_seq_lens.size() * sizeof(int)); + + input_params.q_seq_lens = + torch::empty({static_cast(q_cu_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.q_seq_lens.data_ptr(), + q_cu_seq_lens.data(), + q_cu_seq_lens.size() * sizeof(int)); + + input_params.kv_seq_lens_vec = std::move(cu_seq_lens); + input_params.q_seq_lens_vec = std::move(q_cu_seq_lens); + + if (!cache_data.encoder_seq_lens.empty()) { + // Set RecModelInputParams encoder data + if (!input_params.rec_params.has_value()) { + input_params.rec_params = RecModelInputParams{}; + } + + input_params.rec_params->encoder_seq_lens = cache_data.encoder_seq_lens; + + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.rec_params->encoder_seq_lens_tensor = torch::empty( + {static_cast(cache_data.encoder_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy( + input_params.rec_params->encoder_seq_lens_tensor.data_ptr(), + cache_data.encoder_seq_lens.data(), + cache_data.encoder_seq_lens.size() * sizeof(int)); + } + } + + // ========== Parallel processing of independent code blocks ========== + if (thread_pool_ && num_sequences >= 4) { + // Define promise/future for parallel tasks + std::promise block_tables_promise; + auto block_tables_future = block_tables_promise.get_future(); + + // Task 1: Empty block tables processing - use thread pool (relatively + // time-consuming) + thread_pool_->schedule([&input_params, + num_sequences, + &perf_cache_, + &block_tables_promise]() mutable { + try { + std::vector> empty_block_tables(num_sequences); + util::pad_2d_vector(empty_block_tables, 0); + // Optimization: Use create_2d_tensor_optimized, has special + // optimization for all-zero matrices + input_params.block_tables = + create_2d_tensor(empty_block_tables, torch::kInt); + + std::vector paged_kv_indptr(num_sequences + 1, 0); + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.new_cache_slots = + torch::empty({static_cast(paged_kv_indptr.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.new_cache_slots.data_ptr(), + paged_kv_indptr.data(), + paged_kv_indptr.size() * sizeof(int)); + + block_tables_promise.set_value(); + } catch (...) { + block_tables_promise.set_exception(std::current_exception()); + } + }); + + // Optimization: Merge small tasks into sequential execution to reduce + // thread switching overhead Cross-attention parameter construction - use + // placeholder + if (!input_params.rec_params.has_value()) { + input_params.rec_params = RecModelInputParams{}; + } + input_params.rec_params->cross_attn_kv_cu_seq_lens = + torch::zeros({1}, torch::kInt); + input_params.rec_params->cross_attn_kv_cu_seq_lens_vec = {0}; + input_params.rec_params->cross_attn_block_tables = + torch::zeros({1, 1}, torch::kInt); + + // Sampling parameter processing + if (!selected_token_idxes.empty()) { + forward_input.sampling_params.init(sampling_params, + selected_token_idxes, + sample_idxes, + std::vector>{}, + std::vector>{}, + std::vector{}); + } + + // First prefill processing - use placeholder + if (is_first_prefill) { + // Use placeholder instead of complex cross_attn_new_cache_slots + // construction + input_params.rec_params->cross_attn_new_cache_slots = + torch::zeros({1}, torch::kInt); + } + + // Wait for parallel tasks to complete (only block_tables uses thread pool) + block_tables_future.wait(); + } else { + // ========== Non-parallel case: sequential processing ========== + // Optimize empty block tables processing + std::vector> empty_block_tables(num_sequences); + util::pad_2d_vector(empty_block_tables, 0); + // Optimization: Use create_2d_tensor_optimized, has special optimization + // for all-zero matrices + input_params.block_tables = + create_2d_tensor(empty_block_tables, torch::kInt); + + std::vector paged_kv_indptr(num_sequences + 1, 0); + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.new_cache_slots = + torch::empty({static_cast(paged_kv_indptr.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.new_cache_slots.data_ptr(), + paged_kv_indptr.data(), + paged_kv_indptr.size() * sizeof(int)); + + // ========== Cross-attention parameter construction (using placeholder) + // ========== Use placeholder tensor instead of actual data + if (!input_params.rec_params.has_value()) { + input_params.rec_params = RecModelInputParams{}; + } + input_params.rec_params->cross_attn_kv_cu_seq_lens = + torch::zeros({1}, torch::kInt); + input_params.rec_params->cross_attn_kv_cu_seq_lens_vec = {0}; + + // Use placeholder tensor instead of actual data + input_params.rec_params->cross_attn_block_tables = + torch::zeros({1, 1}, torch::kInt); + + // ========== Optimize sampling parameter processing ========== + if (!selected_token_idxes.empty()) { + forward_input.sampling_params.init(sampling_params, + selected_token_idxes, + sample_idxes, + std::vector>{}, + std::vector>{}, + std::vector{}); + } + + // ========== First prefill processing (using placeholder) ========== + if (is_first_prefill) { + // Use placeholder tensor instead of actual data + input_params.rec_params->cross_attn_new_cache_slots = + torch::zeros({1}, torch::kInt); + } + } + + // ========== Common parameter settings ========== + // Batch set other parameters + input_params.embedding_ids.assign(num_sequences, 0); + +#ifdef USE_ASCEND + auto prefill_indices = util::find_ones_indices(input_params.q_seq_lens_vec); + input_params.decode_seq_range = + std::make_pair(0, static_cast(flatten_tokens_vec.size())); +#else + input_params.decode_seq_range = { + 0, static_cast(flatten_tokens_vec.size())}; +#endif + + // Rec model parameters + if (!input_params.rec_params.has_value()) { + input_params.rec_params = RecModelInputParams{}; + } + + input_params.rec_params->rec_stage = RecModelInputParams::RecStage::PREFILL; + input_params.rec_params->is_hybrid_mode = false; + input_params.rec_params->has_encoder_output = true; + input_params.rec_params->is_first_prefill = is_first_prefill; + input_params.rec_params->bs = bs; + input_params.rec_params->group_width = group_width; + input_params.rec_params->seq_len = seq_len; + input_params.rec_params->encoder_max_seq_len = + cache_data.encoder_seq_lens.empty() + ? 0 + : *std::max_element(cache_data.encoder_seq_lens.begin(), + cache_data.encoder_seq_lens.end()); + + input_params.rec_params->generated_tokens = std::move(generated_tokens); + + // Process sparse_embedding: Efficiently concatenate from cache_data + if (!perf_cache_.cache_data.encoder_sparse_embeddings.empty()) { + // Use torch::cat for efficient concatenation, concatenate along dim=0 + input_params.rec_params->encoder_sparse_embedding = + torch::cat(perf_cache_.cache_data.encoder_sparse_embeddings, /*dim=*/0); + } + + if (!perf_cache_.cache_data.decoder_context_embeddings.empty()) { + // Get group_width + int64_t group_width = input_params.rec_params->group_width; + if (group_width == 1 && seq_len == 0) { + // Optimization: When bs==1, directly use the first embedding to avoid + // unnecessary torch::cat + if (bs == 1) { + input_params.rec_params->decoder_context_embedding = + perf_cache_.cache_data.decoder_context_embeddings[0]; + } else { + // Use torch::cat for efficient concatenation, concatenate along dim=0 + auto original_context_embedding = torch::cat( + perf_cache_.cache_data.decoder_context_embeddings, /*dim=*/0); + input_params.rec_params->decoder_context_embedding = + original_context_embedding; + } + } else if (group_width == 1 && seq_len > 0) { + // Handle the scenario where group_width==1 and seq_len>0 + // Get information from the first embedding + const auto& first_embedding = + perf_cache_.cache_data.decoder_context_embeddings[0]; + auto original_shape = first_embedding.sizes(); + int64_t context_len = original_shape[0]; + int64_t hidden_size = original_shape[1]; + int64_t total_seq_len = context_len + seq_len; + + // Allocate a tensor of shape {bs, 1, total_seq_len, hidden_size}, + // optimized with pinned memory + auto options = torch::TensorOptions() + .dtype(first_embedding.dtype()) + .device(first_embedding.device()) + .pinned_memory(true) + .memory_format(torch::MemoryFormat::Contiguous); + auto combined_embedding = + torch::empty({bs, 1, total_seq_len, hidden_size}, options); + + // Single-threaded copy of context_len portion of data + void* dst_data = combined_embedding.data_ptr(); + const size_t element_size = first_embedding.element_size(); + const size_t context_size = context_len * hidden_size * element_size; + const size_t batch_stride = total_seq_len * hidden_size * element_size; + + // Copy context_embedding for each batch + for (int64_t b = 0; b < bs; ++b) { + const void* batch_src = + perf_cache_.cache_data.decoder_context_embeddings[b].data_ptr(); + auto* batch_dst = static_cast(dst_data) + b * batch_stride; + std::memcpy(batch_dst, batch_src, context_size); + } + input_params.rec_params->decoder_context_embedding = combined_embedding; + } else { + for (auto& future : decoder_embedding_futures) { + future.get(); + } + input_params.rec_params->decoder_context_embedding = + std::move(result_embedding); + } + } + + return forward_input; +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/batch/rec_batch_input_builder.h b/xllm/core/framework/batch/rec_batch_input_builder.h new file mode 100644 index 00000000..5bcf1347 --- /dev/null +++ b/xllm/core/framework/batch/rec_batch_input_builder.h @@ -0,0 +1,134 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include +#include + +#include "batch_input_builder.h" +#include "framework/model/model_args.h" +#include "framework/model/model_input_params.h" +#include "framework/request/mm_data.h" +#include "framework/request/sequence.h" +#include "framework/request/sequences_group.h" +#include "runtime/forward_params.h" +#include "util/threadpool.h" + +namespace xllm { + +class RecBatchInputBuilder : public BatchInputBuilder { + public: + explicit RecBatchInputBuilder( + const std::vector>& sequence_groups, + const std::vector& allowed_max_tokens, + const std::vector& input_embeddings_vec, + const std::vector& mm_data_vec, + const std::vector* copy_in_cache_block_infos, + const std::vector* copy_out_cache_block_infos, + std::vector* swap_cache_block_infos, + const ModelArgs* args, + ThreadPool* thread_pool = nullptr); + + protected: + // Provide protected access methods for subclasses - modified to access + // parent's protected members + const std::vector>& get_sequence_groups() + const { + return sequence_groups_; + } + const std::vector& get_allowed_max_tokens() const { + return allowed_max_tokens_; + } + const std::vector& get_input_embeddings_vec() const { + return input_embeddings_vec_; + } + const std::vector& get_mm_data_vec() const { return mm_data_vec_; } + const std::vector* get_copy_in_cache_block_infos() const { + return copy_in_cache_block_infos_; + } + const std::vector* get_copy_out_cache_block_infos() const { + return copy_out_cache_block_infos_; + } + std::vector* get_swap_cache_block_infos() const { + return swap_cache_block_infos_; + } + const ModelArgs* get_args() const { return args_; } + ThreadPool* get_thread_pool() const { return thread_pool_; } + + public: + // Main public interface + ForwardInput build_rec_forward_input(uint32_t num_decoding_tokens, + uint32_t min_decoding_batch_size); + + private: + // Helper method to extract sequences from groups + static std::vector extract_sequences_from_groups( + const std::vector>& sequence_groups); + + // Member variables - only keep sequence_groups_, others inherited from parent + // class + const std::vector>& sequence_groups_; + + // High performance cache system + struct HighPerformanceCache { + // Memory pool - avoid frequent allocation/deallocation + struct MemoryPool { + std::vector> int32_pools; + size_t pool_index = 0; + + std::vector& getInt32Vector(size_t reserve_size = 0) { + if (pool_index >= int32_pools.size()) { + int32_pools.emplace_back(); + } + auto& vec = int32_pools[pool_index++]; + vec.clear(); + if (reserve_size > 0) vec.reserve(reserve_size); + return vec; + } + + void reset() { pool_index = 0; } + }; + + // Cache data structure + struct CacheData { + std::vector encoder_tokens; + std::vector encoder_seq_lens; + std::vector encoder_sparse_embeddings; + std::vector decoder_context_embeddings; + }; + + // Pre-created constant tensors + torch::Tensor fixed_positions_tensor; + torch::Tensor fixed_encoder_positions_tensor; + torch::Tensor empty_tensor; + + MemoryPool memory_pool; + CacheData cache_data; + + HighPerformanceCache() { + // Pre-create commonly used tensors to avoid repeated creation + fixed_positions_tensor = torch::tensor({0}, torch::kInt); + fixed_encoder_positions_tensor = torch::tensor({0}, torch::kInt); + empty_tensor = torch::tensor(std::vector{}, torch::kInt); + } + }; + + static HighPerformanceCache perf_cache_; +}; + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index aaaae36d..f4926eda 100644 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -26,6 +26,137 @@ limitations under the License. #include "util/tensor_helper.h" namespace xllm { + +// Rec model specific input parameters +struct RecModelInputParams { + // Rec model specific parameters + + enum class RecStage { + PREFILL, // Prefill stage + DECODE // Decode stage + }; + + RecStage rec_stage = RecStage::PREFILL; + bool is_hybrid_mode = false; + // Flag to distinguish encoder vs decoder forward calls + bool is_encoder_forward = false; + // For Rec decoder cross-attention + bool has_encoder_output = false; + // Length of encoder output sequence for each sequence + std::vector encoder_seq_lens; + // Pre-constructed tensor for encoder_seq_lens + torch::Tensor encoder_seq_lens_tensor; + // max encoder seq len + int32_t encoder_max_seq_len = 0; + + // Additional parameters needed by rec_batch_input_builder + bool is_first_prefill = true; + int32_t bs = 0; // batch size + int32_t group_width = 0; + int32_t seq_len = 0; + std::vector> generated_tokens; + torch::Tensor encoder_sparse_embedding; + torch::Tensor decoder_context_embedding; + + // Separate KV cache parameters for different attention types + // For Rec decoder: self_attn uses growing cache, cross_attn uses fixed cache + torch::Tensor cross_attn_kv_cu_seq_lens; // KV lengths for cross-attention + torch::Tensor cross_attn_new_cache_slots; // Cache slots for cross-attention + torch::Tensor cross_attn_block_tables; // Block tables for cross-attention + std::vector cross_attn_kv_cu_seq_lens_vec; + + torch::Tensor encoder_token_ids; + // Rec encoder positions + torch::Tensor encoder_positions; + + RecModelInputParams to(const c10::Device& device) const { + RecModelInputParams result = *this; + + // Move tensors to the specified device + if (encoder_seq_lens_tensor.defined()) { + result.encoder_seq_lens_tensor = encoder_seq_lens_tensor.to(device); + } + + if (encoder_sparse_embedding.defined()) { + result.encoder_sparse_embedding = encoder_sparse_embedding.to(device); + } + + if (decoder_context_embedding.defined()) { + result.decoder_context_embedding = decoder_context_embedding.to(device); + } + + if (cross_attn_kv_cu_seq_lens.defined()) { + result.cross_attn_kv_cu_seq_lens = cross_attn_kv_cu_seq_lens.to(device); + } + + if (cross_attn_new_cache_slots.defined()) { + result.cross_attn_new_cache_slots = cross_attn_new_cache_slots.to(device); + } + + if (cross_attn_block_tables.defined()) { + result.cross_attn_block_tables = cross_attn_block_tables.to(device); + } + + if (encoder_token_ids.defined()) { + result.encoder_token_ids = encoder_token_ids.to(device); + } + + if (encoder_positions.defined()) { + result.encoder_positions = encoder_positions.to(device); + } + + return result; + } + + void print() const { + LOG(INFO) << "RecModelInputParams:" + << " rec_stage: " + << (rec_stage == RecStage::PREFILL ? "PREFILL" : "DECODE") + << " is_hybrid_mode: " << is_hybrid_mode + << " is_encoder_forward: " << is_encoder_forward + << " has_encoder_output: " << has_encoder_output + << " encoder_max_seq_len: " << encoder_max_seq_len + << " is_first_prefill: " << is_first_prefill << " bs: " << bs + << " group_width: " << group_width << " seq_len: " << seq_len + << " encoder_seq_lens size: " << encoder_seq_lens.size() + << " cross_attn_kv_cu_seq_lens_vec size: " + << cross_attn_kv_cu_seq_lens_vec.size() + << " generated_tokens size: " << generated_tokens.size(); + + // Print tensor shapes if defined + if (encoder_seq_lens_tensor.defined()) { + LOG(INFO) << " encoder_seq_lens_tensor shape: " + << encoder_seq_lens_tensor.sizes(); + } + if (encoder_sparse_embedding.defined()) { + LOG(INFO) << " encoder_sparse_embedding shape: " + << encoder_sparse_embedding.sizes(); + } + if (decoder_context_embedding.defined()) { + LOG(INFO) << " decoder_context_embedding shape: " + << decoder_context_embedding.sizes(); + } + if (cross_attn_kv_cu_seq_lens.defined()) { + LOG(INFO) << " cross_attn_kv_cu_seq_lens shape: " + << cross_attn_kv_cu_seq_lens.sizes(); + } + if (cross_attn_new_cache_slots.defined()) { + LOG(INFO) << " cross_attn_new_cache_slots shape: " + << cross_attn_new_cache_slots.sizes(); + } + if (cross_attn_block_tables.defined()) { + LOG(INFO) << " cross_attn_block_tables shape: " + << cross_attn_block_tables.sizes(); + } + if (encoder_token_ids.defined()) { + LOG(INFO) << " encoder_token_ids shape: " << encoder_token_ids.sizes(); + } + if (encoder_positions.defined()) { + LOG(INFO) << " encoder_positions shape: " << encoder_positions.sizes(); + } + } +}; + struct CacheBlockInfo { int32_t device_block_id = 0; int32_t host_block_id = 0; @@ -94,10 +225,14 @@ struct ModelInputParams { // Copy graph_buffer to device params.graph_buffer = safe_to(graph_buffer, device, true); + // Copy optional Rec parameters if present + if (rec_params.has_value()) { + params.rec_params = rec_params->to(device); + } return params; } - void print() const { + virtual void print() const { LOG(INFO) << "ModelInputParams: empty_kv_cache is " << empty_kv_cache << " , global_empty_kv_cache is " << global_empty_kv_cache << " , num_sequences is " << num_sequences @@ -113,6 +248,10 @@ struct ModelInputParams { print_tensor(block_tables, "ModelInputParams: block_tables", 4); LOG(INFO) << "ModelInputParams: dp_global_token_nums is " << dp_global_token_nums; + if (rec_params.has_value()) { + LOG(INFO) << "ModelInputParams: has rec_params"; + rec_params->print(); + } } // whether the kv-cache is empty for all sequences. bool empty_kv_cache = true; @@ -193,6 +332,12 @@ struct ModelInputParams { // Graph execution buffer for temporary tensor storage // Used by ACL Graph Executor to avoid repeated memory allocation torch::Tensor graph_buffer; + + // Optional Rec model specific parameters + std::optional rec_params; + + // Helper function to check if this is a Rec model + bool is_rec_model() const { return rec_params.has_value(); } }; } // namespace xllm diff --git a/xllm/core/framework/prefix_cache/prefix_cache.cpp b/xllm/core/framework/prefix_cache/prefix_cache.cpp index fac8ccfb..3e73c5d4 100644 --- a/xllm/core/framework/prefix_cache/prefix_cache.cpp +++ b/xllm/core/framework/prefix_cache/prefix_cache.cpp @@ -15,7 +15,6 @@ limitations under the License. #include "prefix_cache.h" -#include #include #include #include @@ -25,36 +24,10 @@ limitations under the License. #include "common/global_flags.h" #include "common/metrics.h" +#include "util/hash_util.h" namespace xllm { -void murmur_hash3(const uint8_t* pre_hash_value, - const Slice& token_ids, - uint8_t* hash_value) { - if (pre_hash_value == nullptr) { - MurmurHash3_x64_128(reinterpret_cast(token_ids.data()), - sizeof(int32_t) * token_ids.size(), - FLAGS_murmur_hash3_seed, - hash_value); - } else { - uint8_t key[1024]; - - int32_t data_len = - sizeof(int32_t) * token_ids.size() + MURMUR_HASH3_VALUE_LEN; - CHECK_GT(sizeof(key), data_len) << "key size is too small"; - - memcpy(key, pre_hash_value, MURMUR_HASH3_VALUE_LEN); - memcpy(key + MURMUR_HASH3_VALUE_LEN, - reinterpret_cast(token_ids.data()), - sizeof(int32_t) * token_ids.size()); - - MurmurHash3_x64_128(reinterpret_cast(key), - data_len, - FLAGS_murmur_hash3_seed, - hash_value); - } -} - std::vector PrefixCache::match( const Slice& token_ids, const Slice& existed_shared_blocks) { diff --git a/xllm/core/framework/prefix_cache/prefix_cache.h b/xllm/core/framework/prefix_cache/prefix_cache.h index fc778419..9db26b83 100644 --- a/xllm/core/framework/prefix_cache/prefix_cache.h +++ b/xllm/core/framework/prefix_cache/prefix_cache.h @@ -39,10 +39,6 @@ inline size_t round_down(size_t n, size_t multiple) { return (n / multiple) * multiple; } -void murmur_hash3(const uint8_t* pre_hash_value, - const Slice& token_ids, - uint8_t* hash_value); - class PrefixCache { public: PrefixCache(const PrefixCache&) = delete; diff --git a/xllm/core/framework/prefix_cache/prefix_cache_test.cpp b/xllm/core/framework/prefix_cache/prefix_cache_test.cpp index d0b0ca7e..a9fa1466 100644 --- a/xllm/core/framework/prefix_cache/prefix_cache_test.cpp +++ b/xllm/core/framework/prefix_cache/prefix_cache_test.cpp @@ -7,6 +7,7 @@ #include #include "framework/block/block_manager_impl.h" +#include "util/hash_util.h" namespace xllm { diff --git a/xllm/core/framework/request/request.cpp b/xllm/core/framework/request/request.cpp index 84f5a71e..836d4ce7 100644 --- a/xllm/core/framework/request/request.cpp +++ b/xllm/core/framework/request/request.cpp @@ -56,6 +56,8 @@ void Request::create_sequences_group() { sequence_params.best_of = state_.best_of; sequence_params.streaming = state_.stream; sequence_params.enable_schedule_overlap = state_.enable_schedule_overlap; + sequence_params.is_rec_model = state_.is_rec_model; + sequence_params.bos_token_id = state_.bos_token_id; sequence_params.sampling_param = &(state_.sampling_param); sequence_params.stopping_checker = &(state_.stopping_checker); sequences_group_ = std::make_unique(state_.prompt, diff --git a/xllm/core/framework/request/request_output.h b/xllm/core/framework/request/request_output.h index c4781ac3..2527bc88 100644 --- a/xllm/core/framework/request/request_output.h +++ b/xllm/core/framework/request/request_output.h @@ -66,6 +66,9 @@ struct SequenceOutput { // the token ids of the generated text. std::vector token_ids; + // item_id for rec. + std::optional item_ids; + // the reason the sequence finished. std::optional finish_reason; diff --git a/xllm/core/framework/request/request_state.h b/xllm/core/framework/request/request_state.h index 5ff04322..2dc52032 100644 --- a/xllm/core/framework/request/request_state.h +++ b/xllm/core/framework/request/request_state.h @@ -137,6 +137,12 @@ struct RequestState final { bool enable_schedule_overlap = false; + // rec model specific flag + bool is_rec_model = false; + + // The bos token id of the model. + int32_t bos_token_id = 0; + // The thread id of the thread pool in the response handler to ensure that // stream responses for the same request are executed sequentially during // multi-threaded stream processing. diff --git a/xllm/core/framework/request/sequence.cpp b/xllm/core/framework/request/sequence.cpp index 9a7d2ce9..5fe110b7 100644 --- a/xllm/core/framework/request/sequence.cpp +++ b/xllm/core/framework/request/sequence.cpp @@ -34,6 +34,15 @@ limitations under the License. namespace xllm { +// Number of decoder BOS tokens to add for rec models +static constexpr size_t kDecoderBosTokenCount = 1; +static constexpr size_t kDecoderMaxTokenCount = 4; + +// rec model specific: static constants for embedding names +const std::string Sequence::ENCODER_SPARSE_EMBEDDING_NAME = "sparse_embedding"; +const std::string Sequence::DECODER_CONTEXT_EMBEDDING_NAME = + "decoder_context_embedding"; + Sequence::Sequence(size_t index, const std::vector& prompt_token_ids, torch::Tensor input_embedding, @@ -44,25 +53,70 @@ Sequence::Sequence(size_t index, mm_data_(mm_data), latest_generate_time_(absl::Now()), sequence_params_(seq_params), - decoder_(std::move(decoder)) { - CHECK(!prompt_token_ids.empty()) << "empty prompt token ids"; - auto capacity = sequence_params_.seq_capacity; - CHECK_GT(capacity, prompt_token_ids.size()) << "capacity too small"; - - num_prompt_tokens_ = prompt_token_ids.size(); - volatile_num_prompt_tokens_ = num_prompt_tokens_; - tokens_.resize(capacity); - - // init logprob state - logprob_state_ = std::make_unique(num_prompt_tokens_, capacity); - - // add the prompt tokens - for (const auto token_id : prompt_token_ids) { - tokens_[num_tokens_++] = token_id; - token_to_count_map_[token_id]++; + decoder_(std::move(decoder)), + is_rec_model_(seq_params.is_rec_model) { + // rec model specific: handle encoder tokens and decoder embeddings + if (is_rec_model_) { + // For rec model, treat prompt_token_ids as encoder_tokens + if (prompt_token_ids.size() > 0) { + encoder_tokens_.resize(prompt_token_ids.size()); + for (size_t i = 0; i < prompt_token_ids.size(); ++i) { + encoder_tokens_[i] = prompt_token_ids[i]; + } + num_encoder_tokens_ = prompt_token_ids.size(); + } else { + // If no prompt tokens, check for encoder sparse embedding in mm_data + auto encoder_sparse_embedding = + mm_data_.get(ENCODER_SPARSE_EMBEDDING_NAME); + CHECK(encoder_sparse_embedding.has_value()) + << "encoder sparse embedding not found in mm_data"; + num_encoder_tokens_ = encoder_sparse_embedding.value().size(0); + } + + // Check if decoder context embedding exists in mm_data + auto decoder_context_embedding = + mm_data_.get(DECODER_CONTEXT_EMBEDDING_NAME); + auto capacity = kDecoderMaxTokenCount; + if (decoder_context_embedding.has_value()) { + // Use context embedding replacing bos + prompt + num_prompt_tokens_ = 0; + num_decoder_embeddings_ = decoder_context_embedding.value().size(0); + capacity = num_decoder_embeddings_ + capacity - kDecoderBosTokenCount; + } else { + // Only BOS token for decoder + num_prompt_tokens_ = kDecoderBosTokenCount; // kDecoderBosTokenCount + } + tokens_.resize(capacity); + for (size_t i = 0; i < num_prompt_tokens_; ++i) { + tokens_[num_tokens_++] = sequence_params_.bos_token_id; + token_to_count_map_[sequence_params_.bos_token_id]++; + } + + volatile_num_prompt_tokens_ = num_prompt_tokens_; + input_embedding_ = input_embedding; + cur_generated_token_idx_ = num_prompt_tokens_; + // init logprob state + logprob_state_ = + std::make_unique(num_prompt_tokens_, capacity); + } else { + CHECK(!prompt_token_ids.empty()) << "empty prompt token ids"; + auto capacity = sequence_params_.seq_capacity; + CHECK_GT(capacity, prompt_token_ids.size()) << "capacity too small"; + num_prompt_tokens_ = prompt_token_ids.size(); + tokens_.resize(capacity); + // add the prompt tokens + for (const auto token_id : prompt_token_ids) { + tokens_[num_tokens_++] = token_id; + token_to_count_map_[token_id]++; + } + + volatile_num_prompt_tokens_ = num_prompt_tokens_; + input_embedding_ = input_embedding; + cur_generated_token_idx_ = num_prompt_tokens_; + // init logprob state + logprob_state_ = + std::make_unique(num_prompt_tokens_, capacity); } - input_embedding_ = input_embedding; - cur_generated_token_idx_ = num_prompt_tokens_; } Sequence::Sequence(const Sequence& other) @@ -84,6 +138,10 @@ Sequence::Sequence(const Sequence& other) num_tokens_(other.num_tokens_), token_to_count_map_(other.token_to_count_map_), num_prompt_tokens_(other.num_prompt_tokens_), + num_encoder_tokens_(other.num_encoder_tokens_), + num_decoder_embeddings_(other.num_decoder_embeddings_), + encoder_tokens_(other.encoder_tokens_), + is_rec_model_(other.is_rec_model_), volatile_num_prompt_tokens_(other.volatile_num_prompt_tokens_), embedding_id_(other.embedding_id_), finished_(other.finished_), @@ -249,6 +307,15 @@ std::optional Sequence::generate_streaming_output( AUTO_COUNTER(detokenization_latency_seconds_stream); const auto ids = Slice(tokens_, size); + // For rec model, return token_ids directly without decode + if (is_rec_model()) { + const size_t start = num_prompt_tokens_; + SequenceOutput output; + output.index = index_; + output.token_ids = ids.slice(start, size); + return output; + } + // record the start index of token ids const size_t start = decoder_.output_offset(); auto delta = decoder_.decode(ids, tokenizer); @@ -456,4 +523,12 @@ Slice Sequence::get_generated_tokens() const { return {tokens_.data(), 0}; } +void Sequence::finish() { + finished_ = true; + finish_status_invalidated_ = false; + if (finish_reason_ == FinishReason::NONE) { + finish_reason_ = FinishReason::STOP; + } +} + } // namespace xllm diff --git a/xllm/core/framework/request/sequence.h b/xllm/core/framework/request/sequence.h index 846c037b..753f8358 100644 --- a/xllm/core/framework/request/sequence.h +++ b/xllm/core/framework/request/sequence.h @@ -65,6 +65,9 @@ struct SequenceParams { // enable_schedule_overlap or not. default = false. bool enable_schedule_overlap = false; + // whether this is a rec model. default = false. + bool is_rec_model = false; + int32_t bos_token_id = 0; // sampling params // reference from request RequestSamplingParam* sampling_param; // not owned @@ -192,6 +195,9 @@ class Sequence final { void close() { closed_ = true; } bool is_closed() const { return closed_; } + // finish the sequence by setting finished status and reason + void finish(); + // time between two tokens int64_t tbt(const absl::Time& now); // set sequence ttft @@ -265,6 +271,22 @@ class Sequence final { // get sequence id int32_t seq_id() const { return seq_id_; } + // rec model specific: get encoder tokens + const std::vector& encoder_tokens() const { return encoder_tokens_; } + + // rec model specific: get encoder sequence length + size_t encoder_seq_len() const { return num_encoder_tokens_; } + + // rec model specific: get number of decoder embeddings + size_t num_decoder_embeddings() const { return num_decoder_embeddings_; } + + // rec model specific: check if this is a rec model + bool is_rec_model() const { return is_rec_model_; } + + // rec model specific: static constants for embedding names + static const std::string ENCODER_SPARSE_EMBEDDING_NAME; + static const std::string DECODER_CONTEXT_EMBEDDING_NAME; + private: // the index of the sequence in the request size_t index_ = 0; @@ -312,6 +334,18 @@ class Sequence final { // the length of the prompt tokens size_t num_prompt_tokens_ = 0; + // rec model specific: number of encoder tokens + size_t num_encoder_tokens_ = 0; + + // rec model specific: number of decoder embeddings + size_t num_decoder_embeddings_ = 0; + + // rec model specific: encoder tokens storage + std::vector encoder_tokens_; + + // rec model specific: whether this is a rec model + bool is_rec_model_ = false; + // NOTE: MUST FIXME Later // record all tokens num in last turn when the request is // interrupted due to the lack of kv cache capacity. diff --git a/xllm/core/framework/request/sequences_group.cpp b/xllm/core/framework/request/sequences_group.cpp index 7bbce9af..d7c232f5 100644 --- a/xllm/core/framework/request/sequences_group.cpp +++ b/xllm/core/framework/request/sequences_group.cpp @@ -174,7 +174,7 @@ void SequencesGroup::process_beam_search() { if (!check_beam_search()) { return; } - + Timer timer; size_t beam_width = sequence_params_.sampling_param->beam_width; size_t seq_size = sequences_.size(); size_t topk = sequence_params_.sampling_param->top_logprobs; @@ -290,6 +290,14 @@ void SequencesGroup::process_beam_search() { CHECK_EQ(sequences_.size(), beam_width); update_for_sequence(0, beam_width); + HISTOGRAM_OBSERVE(expand_beam_latency_microseconds, + timer.elapsed_microseconds()); +} + +void SequencesGroup::finish() { + for (auto& sequence : sequences_) { + sequence->finish(); + } } } // namespace xllm diff --git a/xllm/core/framework/request/sequences_group.h b/xllm/core/framework/request/sequences_group.h index 1ed5ceca..5d0c9174 100644 --- a/xllm/core/framework/request/sequences_group.h +++ b/xllm/core/framework/request/sequences_group.h @@ -24,6 +24,7 @@ limitations under the License. #include #include "common.pb.h" +#include "common/metrics.h" #include "core/framework/sampling/sampling_params.h" #include "mm_data.h" #include "sequence.h" @@ -55,11 +56,17 @@ class SequencesGroup { } std::vector>& sequences() { return sequences_; } + const std::vector>& sequences() const { + return sequences_; + } int32_t dp_rank() { return sequences_[0]->dp_rank(); } bool is_prefill_stage() const { return sequences_[0]->is_prefill_stage(); } + // finish all sequences in the group + void finish(); + private: void add(); diff --git a/xllm/core/framework/sampling/CMakeLists.txt b/xllm/core/framework/sampling/CMakeLists.txt index a3cbe5a4..53070940 100644 --- a/xllm/core/framework/sampling/CMakeLists.txt +++ b/xllm/core/framework/sampling/CMakeLists.txt @@ -10,16 +10,19 @@ cc_library( rejection_sampler.h sampler.h beam_searcher.h + valid_path_filter.h SRCS sampling_params.cpp logits_utils.cpp rejection_sampler.cpp sampler.cpp beam_searcher.cpp + valid_path_filter.cpp DEPS glog::glog torch :kernels + :util $<$:xllm_ops> ) @@ -30,12 +33,14 @@ cc_test( rejection_sampler_test.cpp rejection_sampler.cpp sampling_params_test.cpp + valid_path_filter_test.cpp DEPS absl::strings GTest::gtest_main :flags :sampler glog::glog + torch ) target_link_libraries(sampler_test PRIVATE brpc OpenSSL::SSL OpenSSL::Crypto leveldb::leveldb ZLIB::ZLIB protobuf::libprotobuf) target_link_libraries(sampler_test diff --git a/xllm/core/framework/sampling/valid_path_filter.cpp b/xllm/core/framework/sampling/valid_path_filter.cpp new file mode 100644 index 00000000..7784ea0e --- /dev/null +++ b/xllm/core/framework/sampling/valid_path_filter.cpp @@ -0,0 +1,269 @@ +#include "valid_path_filter.h" + +#include + +#include +#include +#include +#include +#include + +#include "util/env_var.h" +#include "util/hash_util.h" +#include "util/slice.h" +#include "util/tensor_helper.h" +#include "util/timer.h" + +namespace xllm { + +namespace { + +void parse_valid_path_filter_file( + std::vector>& tokens_list, + const std::string& valid_path_filter_file) { + if (valid_path_filter_file.empty()) { + LOG(WARNING) << "Get empty vaild path filter file: " + << valid_path_filter_file; + return; + } + if (!std::filesystem::exists(valid_path_filter_file)) { + LOG(ERROR) << "Failed to find vaild path filter file: " + << valid_path_filter_file; + return; + } + std::ifstream ifs(valid_path_filter_file, std::ios::binary | std::ios::ate); + if (!ifs.is_open()) { + LOG(ERROR) << "Failed to load vaild path filter file: " + << valid_path_filter_file; + return; + } + + const size_t file_size = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + + const int elements_per_line = 3; + const size_t elements_size = elements_per_line * sizeof(int32_t); + const size_t line_size = sizeof(int64_t) + elements_size; + const size_t estimated_lines = (file_size + line_size - 1) / line_size; + + tokens_list.reserve(estimated_lines); + + int64_t item_id; + std::vector buffer(elements_per_line); + while (ifs.read(reinterpret_cast(&item_id), sizeof(int64_t)) && + ifs.read(reinterpret_cast(buffer.data()), elements_size)) { + tokens_list.emplace_back(buffer.begin(), buffer.end()); + } + LOG(INFO) << "ValidPathFilter parse tokens list size:" << tokens_list.size(); + + if (ifs.gcount() != 0 && ifs.gcount() != line_size) { + LOG(ERROR) << "Possibly containing incomplete lines : " + << valid_path_filter_file; + return; + } +} +} // namespace + +float ValidPathFilter::pre_mask_factor_ = -10000.0f; + +ValidPathFilter::ValidPathFilter(const std::string valid_path_filter_file, + const int32_t vocab_size, + torch::ScalarType dtype, + torch::Device device) + : vocab_size_(vocab_size), dtype_(dtype), device_(device) { + std::vector> tokens_list; + Timer timer; + parse_valid_path_filter_file(tokens_list, valid_path_filter_file); + init_cached_mask(tokens_list, vocab_size); + LOG(INFO) << " ValidPathFilter generate " << cached_sparse_mask_.size() + << " key for " << tokens_list.size() << " items which took " + << timer.elapsed_seconds() << " secs."; +} + +ValidPathFilter::ValidPathFilter( + const std::vector>& tokens_list, + const int32_t vocab_size, + torch::ScalarType dtype, + torch::Device device) + : vocab_size_(vocab_size), dtype_(dtype), device_(device) { + init_cached_mask(tokens_list, vocab_size); +} + +void ValidPathFilter::init_cached_mask( + const std::vector>& tokens_list, + const int32_t vocab_size) { + size_t total_num = tokens_list.size(); + if (total_num > 0) { + init_cached_tokens_ = true; + } + + // init extra thread pool + thread_num_ = util::get_int_env(util::EXTRA_THREAD_NUM, 16); + extra_threadpool_ = std::make_unique(thread_num_); + + // generate mask + torch::TensorOptions options = torch::dtype(dtype_).device(device_); + first_token_mask_ = torch::full({vocab_size}, pre_mask_factor_, dtype_); + empty_place_holder_ = torch::full({vocab_size}, 0.0f, options); + + cached_sparse_mask_.reserve(total_num); + for (size_t t_idx = 0; t_idx < total_num; t_idx++) { + Slice tokens_slice(tokens_list[t_idx]); + CHECK_EQ(tokens_slice.size(), 3); + + // handle first token + first_token_mask_[tokens_slice[0]] = 0; + + // handle extra token + for (int i = 1; i < tokens_slice.size(); i++) { + Murmur3Key murmur3_key; + Slice sub_slice(tokens_slice.data(), i); + murmur_hash3(nullptr, sub_slice, murmur3_key.data); + auto iter = cached_sparse_mask_.find(murmur3_key); + if (iter != cached_sparse_mask_.end()) { + iter->second.push_back(tokens_slice[i]); + } else { + std::vector false_indices = {tokens_slice[i]}; + cached_sparse_mask_.emplace(std::make_pair(murmur3_key, false_indices)); + } + } + } + + // Remove duplicates and sort for better performance + // Sort false indices in sparse masks for better performance + for (auto& pair : cached_sparse_mask_) { + std::sort(pair.second.begin(), pair.second.end()); + pair.second.erase(std::unique(pair.second.begin(), pair.second.end()), + pair.second.end()); + } + // first_token_mask_ = safe_to(first_token_mask_, device_, true); + LOG(INFO) << " ValidPathFilter third sparse storage: " + << cached_sparse_mask_.size(); +} + +torch::Tensor ValidPathFilter::forward( + const std::vector>& tokens_list) { + if (!init_cached_tokens_ || tokens_list.size() == 0) { + return torch::Tensor(); + } + + size_t token_size = tokens_list[0].size(); + + // prepare mask for first token + if (token_size == 0) { + size_t total_nums = tokens_list.size(); + auto mask = first_token_mask_.unsqueeze(0); + return mask.repeat({total_nums, 1}); + } + return forward_sparse_mask(tokens_list); +} + +torch::Tensor ValidPathFilter::forward_sparse_mask( + const std::vector>& tokens_list) { + Timer timer; + size_t total_nums = tokens_list.size(); + torch::TensorOptions options = torch::dtype(dtype_).device(device_); + auto mask = torch::full({total_nums, vocab_size_}, pre_mask_factor_, options); + + // Global batch collection for sparse storage optimization + std::vector global_batch_token_indices; + std::vector global_batch_vocab_indices; + std::mutex batch_mutex; // Protect global batch vectors in multi-threading + + // Pre-allocate space: assume max 8192 false indices per token + global_batch_token_indices.reserve(8192 * total_nums); + global_batch_vocab_indices.reserve(8192 * total_nums); + + auto update_mask = [&](size_t start_idx, size_t end_idx) { + // Local collection for this thread + std::vector local_token_indices; + std::vector local_vocab_indices; + local_token_indices.reserve(8192 * (end_idx - start_idx)); + local_vocab_indices.reserve(8192 * (end_idx - start_idx)); + + for (size_t token_idx = start_idx; token_idx < end_idx; ++token_idx) { + auto& tokens = tokens_list[token_idx]; + if (tokens.size() == 0) { + mask[token_idx] = first_token_mask_.to(device_); + } else { + Slice tokens_slice(tokens); + Murmur3Key murmur3_key; + murmur_hash3(nullptr, tokens_slice, murmur3_key.data); + + auto iter = cached_sparse_mask_.find(murmur3_key); + if (iter != cached_sparse_mask_.end()) { + // Collect indices locally first + for (int32_t vocab_idx : iter->second) { + local_token_indices.push_back(static_cast(token_idx)); + local_vocab_indices.push_back(static_cast(vocab_idx)); + } + } else { + mask[token_idx] = empty_place_holder_; + LOG(ERROR) << "Failed to generate mask for " << tokens; + } + } + } + + // Merge local results to global batch (thread-safe) + if (!local_token_indices.empty()) { + std::lock_guard lock(batch_mutex); + global_batch_token_indices.insert(global_batch_token_indices.end(), + local_token_indices.begin(), + local_token_indices.end()); + global_batch_vocab_indices.insert(global_batch_vocab_indices.end(), + local_vocab_indices.begin(), + local_vocab_indices.end()); + } + }; + + if (use_threadpool_for_beam_expansion_) { + // 分段处理优化:每个线程处理多个mask + const size_t batch_size = + std::max(1UL, (total_nums + thread_num_ - 1) / thread_num_); + const size_t num_batches = (total_nums + batch_size - 1) / batch_size; + + std::vector>> promises; + std::vector> futures; + promises.reserve(num_batches); + futures.reserve(num_batches); + + for (size_t batch_idx = 0; batch_idx < num_batches; ++batch_idx) { + auto promise = std::make_shared>(); + futures.push_back(promise->get_future()); + promises.push_back(promise); + + size_t start_idx = batch_idx * batch_size; + size_t end_idx = std::min(start_idx + batch_size, total_nums); + + extra_threadpool_->schedule( + [update_mask, start_idx, end_idx, promise]() mutable { + update_mask(start_idx, end_idx); + promise->set_value(); + }); + } + + for (auto& future : futures) { + future.get(); + } + } else { + update_mask(0, total_nums); + } + + // Global batch tensor operation after all threads complete + if (!global_batch_token_indices.empty()) { + auto token_indices = + torch::tensor(global_batch_token_indices, torch::kInt64); + auto vocab_indices = + torch::tensor(global_batch_vocab_indices, torch::kInt64); + torch::TensorOptions device_options = + torch::dtype(torch::kInt64).device(device_); + token_indices = safe_to(token_indices, device_options, true); + vocab_indices = safe_to(vocab_indices, device_options, true); + mask.index_put_({token_indices, vocab_indices}, 0.0f); + // auto indices = torch::stack({token_indices, vocab_indices}, 1); + // return indices; + } + + return mask; +} +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/sampling/valid_path_filter.h b/xllm/core/framework/sampling/valid_path_filter.h new file mode 100644 index 00000000..c3eda7e1 --- /dev/null +++ b/xllm/core/framework/sampling/valid_path_filter.h @@ -0,0 +1,65 @@ +#pragma once +#include +#include +#include + +#include "util/hash_util.h" +#include "util/threadpool.h" + +namespace xllm { + +class ValidPathFilter final { + public: + ValidPathFilter(const std::string valid_path_filter_file, + const int32_t vocab_size, + torch::ScalarType dtype, + torch::Device device); + ValidPathFilter(const std::vector>& tokens_list, + const int32_t vocab_size, + torch::ScalarType dtype, + torch::Device device); + + // operator() allows us to use the module as a function. + template + auto operator()(Args&&... args) const { + return this->forward(::std::forward(args)...); + } + + // output: [num_tokens, vocab_size] + torch::Tensor forward(const std::vector>& tokens_list); + + private: + void init_cached_mask(const std::vector>& tokens_list, + const int32_t vocab_size); + + // prepare mask using cached sparse mask + torch::Tensor forward_sparse_mask( + const std::vector>& tokens_list); + + // Sparse storage: map from key to indices of candidate tokens. + std::unordered_map, + FixedStringKeyHash, + FixedStringKeyEqual> + cached_sparse_mask_; + + torch::Tensor empty_place_holder_; + torch::Tensor first_token_mask_; + + bool init_cached_tokens_ = false; + + static float pre_mask_factor_; + + int32_t vocab_size_; + + torch::ScalarType dtype_ = torch::ScalarType::Undefined; + + torch::Device device_; + + int32_t thread_num_; + std::unique_ptr extra_threadpool_; + // 控制是否使用线程池进行beam expansion + bool use_threadpool_for_beam_expansion_ = true; +}; + +} // namespace xllm diff --git a/xllm/core/framework/sampling/valid_path_filter_test.cpp b/xllm/core/framework/sampling/valid_path_filter_test.cpp new file mode 100644 index 00000000..cb2e4f2a --- /dev/null +++ b/xllm/core/framework/sampling/valid_path_filter_test.cpp @@ -0,0 +1,167 @@ +#include "valid_path_filter.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace xllm { + +TEST(ValidPathFilterTest, Vector) { + // 基于实际使用场景的测试数据 + // tokens_list表示有效的token序列路径,每个序列长度为3 + std::vector> tokens_list = { + {1, 2, 3}, // 序列1: 1->2->3 + {1, 2, 4}, // 序列2: 1->2->4 + {1, 3, 5}, // 序列3: 1->3->5 + {2, 4, 6}, // 序列4: 2->4->6 + {3, 5, 7} // 序列5: 3->5->7 + }; + + torch::ScalarType dtype(torch::kFloat32); + torch::Device device(torch::kCPU); + int32_t vocab_size = 8; // 词汇表大小为8 (tokens 0-7) + + ValidPathFilter filter = + ValidPathFilter(tokens_list, vocab_size, dtype, device); + + // 测试不同的候选token序列 + std::vector> candidate_tokens = { + {1, 2}, // 前缀[1,2],应该允许token 3和4 + {1}, // 前缀[1],应该允许token 2和3 + {}, // 空前缀,应该允许第一个token 1,2,3 + {2, 4}, // 前缀[2,4],应该允许token 6 + {9, 9} // 无效前缀,应该全部被mask + }; + + const auto options = torch::dtype(dtype).device(device); + torch::Tensor mask = filter.forward(candidate_tokens); + + // 验证输出形状 + EXPECT_EQ(mask.sizes(), + torch::IntArrayRef({candidate_tokens.size(), vocab_size})); + + // 验证mask值 + // mask值为0表示允许,-10000表示禁止 + + // 对于前缀[1,2]:下一个token可以是3或4 + auto mask_1_2 = mask[0]; + EXPECT_EQ(mask_1_2[3].item(), 0.0f); // token 3允许 + EXPECT_EQ(mask_1_2[4].item(), 0.0f); // token 4允许 + EXPECT_EQ(mask_1_2[0].item(), -10000.0f); // token 0禁止 + EXPECT_EQ(mask_1_2[1].item(), -10000.0f); // token 1禁止 + EXPECT_EQ(mask_1_2[2].item(), -10000.0f); // token 2禁止 + + // 对于前缀[1]:下一个token可以是2或3 + auto mask_1 = mask[1]; + EXPECT_EQ(mask_1[2].item(), 0.0f); // token 2允许 + EXPECT_EQ(mask_1[3].item(), 0.0f); // token 3允许 + EXPECT_EQ(mask_1[0].item(), -10000.0f); // token 0禁止 + EXPECT_EQ(mask_1[1].item(), -10000.0f); // token 1禁止 + + // 对于空前缀[]:第一个token可以是1,2,3 + auto mask_empty = mask[2]; + EXPECT_EQ(mask_empty[1].item(), 0.0f); // token 1允许 + EXPECT_EQ(mask_empty[2].item(), 0.0f); // token 2允许 + EXPECT_EQ(mask_empty[3].item(), 0.0f); // token 3允许 + EXPECT_EQ(mask_empty[0].item(), -10000.0f); // token 0禁止 +} + +TEST(ValidPathFilterTest, File) { + // 创建测试数据文件 + std::vector> tokens_list = { + {1, 2, 3}, {1, 2, 4}, {1, 3, 5}, {2, 4, 6}, {3, 5, 7}}; + + const std::string rec_tokens_file = "./test_data.bin"; + + // 清理旧文件 + if (std::ifstream(rec_tokens_file)) { + std::remove(rec_tokens_file.c_str()); + } + + // 按照实现期望的格式写入文件:int64_t item_id + 3个int32_t + std::ofstream outfile(rec_tokens_file, std::ios::binary); + if (!outfile) { + LOG(ERROR) << "Failed to create test file: " << rec_tokens_file; + return; + } + + int64_t item_id = 0; + for (const auto& row : tokens_list) { + outfile.write(reinterpret_cast(&item_id), sizeof(int64_t)); + outfile.write(reinterpret_cast(row.data()), + row.size() * sizeof(int32_t)); + item_id++; + } + outfile.close(); + + torch::ScalarType dtype(torch::kFloat32); + torch::Device device(torch::kCPU); + int32_t vocab_size = 8; + + ValidPathFilter filter = + ValidPathFilter(rec_tokens_file, vocab_size, dtype, device); + + // 使用相同的测试用例 + std::vector> candidate_tokens = { + {1, 2}, // 前缀[1,2] + {1}, // 前缀[1] + {} // 空前缀 + }; + + const auto options = torch::dtype(dtype).device(device); + torch::Tensor mask = filter.forward(candidate_tokens); + + // 验证输出形状 + EXPECT_EQ(mask.sizes(), + torch::IntArrayRef({candidate_tokens.size(), vocab_size})); + + // 验证与Vector测试相同的结果 + // 对于前缀[1,2]:下一个token可以是3或4 + auto mask_1_2 = mask[0]; + EXPECT_EQ(mask_1_2[3].item(), 0.0f); + EXPECT_EQ(mask_1_2[4].item(), 0.0f); + + // 对于前缀[1]:下一个token可以是2或3 + auto mask_1 = mask[1]; + EXPECT_EQ(mask_1[2].item(), 0.0f); + EXPECT_EQ(mask_1[3].item(), 0.0f); + + // 对于空前缀[]:第一个token可以是1,2,3 + auto mask_empty = mask[2]; + EXPECT_EQ(mask_empty[1].item(), 0.0f); + EXPECT_EQ(mask_empty[2].item(), 0.0f); + EXPECT_EQ(mask_empty[3].item(), 0.0f); + + // 清理测试文件 + if (std::ifstream(rec_tokens_file)) { + std::remove(rec_tokens_file.c_str()); + } +} + +TEST(ValidPathFilterTest, EmptyInput) { + // 测试空输入的情况 + std::vector> tokens_list = {{1, 2, 3}}; + torch::ScalarType dtype(torch::kFloat32); + torch::Device device(torch::kCPU); + int32_t vocab_size = 5; + + ValidPathFilter filter = + ValidPathFilter(tokens_list, vocab_size, dtype, device); + + // 测试空的候选token列表 + std::vector> empty_candidates = {}; + torch::Tensor mask = filter.forward(empty_candidates); + + // 应该返回未定义的tensor + EXPECT_FALSE(mask.defined()); +} + +} // namespace xllm diff --git a/xllm/core/runtime/CMakeLists.txt b/xllm/core/runtime/CMakeLists.txt index 54b10152..ebb43dcd 100644 --- a/xllm/core/runtime/CMakeLists.txt +++ b/xllm/core/runtime/CMakeLists.txt @@ -22,10 +22,12 @@ cc_library( dit_worker.h embed_worker_impl.h embed_vlm_worker_impl.h + rec_worker_impl.h engine.h llm_engine.h vlm_engine.h dit_engine.h + rec_engine.h worker_client.h xservice_client.h speculative_engine.h @@ -40,12 +42,14 @@ cc_library( worker_impl.cpp llm_worker_impl.cpp vlm_worker_impl.cpp + rec_worker_impl.cpp dit_worker.cpp embed_worker_impl.cpp embed_vlm_worker_impl.cpp llm_engine.cpp vlm_engine.cpp dit_engine.cpp + rec_engine.cpp worker_client.cpp xservice_client.cpp params_utils.cpp @@ -88,11 +92,13 @@ cc_library( master.h vlm_master.h dit_master.h + rec_master.h SRCS llm_master.cpp master.cpp vlm_master.cpp dit_master.cpp + rec_master.cpp DEPS :common :distributed_runtime diff --git a/xllm/core/runtime/forward_params.h b/xllm/core/runtime/forward_params.h index dd4a3d8f..fbce4155 100644 --- a/xllm/core/runtime/forward_params.h +++ b/xllm/core/runtime/forward_params.h @@ -37,6 +37,7 @@ class WorkerType { DIT, // DIT ELM, // Embedding LM EVLM, // Embedding VLM + REC, // Rec }; constexpr WorkerType(Value v) : value_(v) {} @@ -51,6 +52,8 @@ class WorkerType { value_ = ELM; } else if (str == "EVLM") { value_ = EVLM; + } else if (str == "REC") { + value_ = REC; } else { value_ = INVALID; } @@ -77,6 +80,8 @@ class WorkerType { return "ELM"; } else if (this->value_ == EVLM) { return "EVLM"; + } else if (this->value_ == REC) { + return "REC"; } else { return "INVALID"; } diff --git a/xllm/core/runtime/master.cpp b/xllm/core/runtime/master.cpp index c4418e73..2e6b901d 100644 --- a/xllm/core/runtime/master.cpp +++ b/xllm/core/runtime/master.cpp @@ -34,6 +34,8 @@ limitations under the License. #include "runtime/dit_master.h" #include "runtime/llm_engine.h" #include "runtime/llm_master.h" +#include "runtime/rec_engine.h" +#include "runtime/rec_master.h" #include "runtime/speculative_engine.h" #include "runtime/vlm_engine.h" #include "runtime/vlm_master.h" @@ -207,6 +209,39 @@ Master::Master(const Options& options, EngineType type) : options_(options) { eng_options.device_ip(options_.device_ip().value()); } engine_ = std::make_unique(eng_options); + } else if (type == EngineType::REC) { + runtime::Options eng_options; + eng_options.model_path(options_.model_path()) + .devices(devices) + .block_size(options_.block_size()) + .max_cache_size(options_.max_cache_size()) + .max_memory_utilization(options_.max_memory_utilization()) + .enable_prefix_cache(options_.enable_prefix_cache()) + .task_type(options_.task_type()) + .enable_mla(options_.enable_mla()) + .master_node_addr(options_.master_node_addr()) + .nnodes(options_.nnodes()) + .node_rank(options_.node_rank()) + .dp_size(options_.dp_size()) + .ep_size(options_.ep_size()) + .enable_chunked_prefill(options_.enable_chunked_prefill()) + .max_seqs_per_batch(options_.max_seqs_per_batch()) + .max_tokens_per_chunk_for_prefill( + options_.max_tokens_per_chunk_for_prefill()) + .instance_role(options_.instance_role()) + .kv_cache_transfer_mode(options_.kv_cache_transfer_mode()) + .transfer_listen_port(options_.transfer_listen_port()) + .enable_disagg_pd(options_.enable_disagg_pd()) + .enable_service_routing(options_.enable_service_routing()) + .enable_schedule_overlap(options_.enable_schedule_overlap()) + .enable_cache_upload(options_.enable_cache_upload()) + .host_blocks_factor(options_.host_blocks_factor()) + .enable_kvcache_store(options_.enable_kvcache_store()) + .store_protocol(options_.store_protocol()) + .store_master_server_entry(options_.store_master_server_entry()) + .store_metadata_connstring(options_.store_metadata_connstring()) + .enable_continuous_kvcache(options_.enable_continuous_kvcache()); + engine_ = std::make_unique(eng_options); } else { LOG(WARNING) << "Not supported llm engine type: " << static_cast(type); @@ -222,6 +257,8 @@ std::unique_ptr create_master(const std::string& backend, } else if (backend == "dit") { LOG(INFO) << "creating dit master"; return std::make_unique(options); + } else if (backend == "rec") { + return std::make_unique(options); } else { LOG(FATAL) << "Failed to create master, backend is" << backend; return nullptr; diff --git a/xllm/core/runtime/rec_engine.cpp b/xllm/core/runtime/rec_engine.cpp new file mode 100644 index 00000000..469a66f0 --- /dev/null +++ b/xllm/core/runtime/rec_engine.cpp @@ -0,0 +1,340 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "rec_engine.h" + +#include + +#include +#include + +#include "common/metrics.h" +#include "framework/model/model_args.h" +#include "framework/model_loader.h" +#include "framework/parallel_state/parallel_state.h" +#include "util/pretty_print.h" +#include "util/timer.h" +#include "util/utils.h" +#include "worker.h" + +namespace xllm { + +RecEngine::RecEngine(const runtime::Options& options) : options_(options) { + const auto& devices = options_.devices(); + CHECK_GT(devices.size(), 0) << "At least one device is required"; + + CHECK(!devices[0].is_cpu()) << "CPU device is not supported"; + const auto device_type = devices[0].type(); + for (const auto device : devices) { + CHECK_EQ(device.type(), device_type) + << "All devices should be the same type"; + } + + // initialize process groups if there are multiple devices + if (devices.size() > 1) { + // create a process group for each device if there are multiple gpus + process_groups_ = parallel_state::create_npu_process_groups(devices); + } + + WorkerType worker_type = WorkerType::REC; + const int32_t world_size = static_cast(devices.size()); + for (size_t i = 0; i < devices.size(); ++i) { + const int32_t rank = static_cast(i); + ProcessGroup* pg = world_size > 1 ? process_groups_[i].get() : nullptr; + ParallelArgs parallel_args(rank, world_size, pg); + workers_.emplace_back(std::make_unique( + parallel_args, devices[i], options_, worker_type)); + } + + if (workers_.size() > 1) { + // test process group + std::vector> futures; + futures.reserve(workers_.size()); + for (auto& worker : workers_) { + futures.emplace_back(worker->process_group_test_async()); + } + // wait up to 4 seconds for all futures to complete + folly::collectAll(futures).within(std::chrono::seconds(4)).get(); + } +} + +bool RecEngine::init() { + if (!init_model()) { + LOG(ERROR) << "Failed to init model from: " << options_.model_path(); + return false; + } + + auto kv_cache_cap = estimate_kv_cache_capacity(); + + if (!allocate_kv_cache(kv_cache_cap)) { + LOG(ERROR) << "Failed to allocate kv cache"; + return false; + } + + return true; +} + +bool RecEngine::init_model() { + const std::string& model_path = options_.model_path(); + auto model_loader = ModelLoader::create(model_path); + LOG(INFO) << "Initializing model from: " << model_path; + + // RecEngine does not use tokenizer + tokenizer_ = nullptr; + + args_ = model_loader->model_args(); + quant_args_ = model_loader->quant_args(); + tokenizer_args_ = model_loader->tokenizer_args(); + + // compute the number of local kv heads and head dim + const int world_size = static_cast(workers_.size()); + const int64_t n_heads = args_.n_heads(); + const int64_t n_kv_heads = args_.n_kv_heads().value_or(n_heads); + n_local_kv_heads_ = std::max(1, n_kv_heads / world_size); + head_dim_ = args_.head_dim(); + dtype_ = xllm::util::parse_dtype(args_.dtype(), options_.devices()[0]); + + // key + value for all layers + LOG(INFO) << "Block info, block_size: " << options_.block_size() + << ", n_local_kv_heads: " << n_local_kv_heads_ + << ", head_dim: " << head_dim_ << ", n_layers: " << args_.n_layers() + << ", dtype: " << dtype_; + + // RecEngine does not use tokenizer, skip vocab_size check + + LOG(INFO) << "Initializing model with " << args_; + LOG(INFO) << "Initializing model with quant args: " << quant_args_; + LOG(INFO) << "Initializing model with tokenizer args: " << tokenizer_args_; + + // init model for each worker in parallel + // multiple workers, call async init + std::vector> futures; + futures.reserve(workers_.size()); + for (auto& worker : workers_) { + futures.push_back(worker->init_model_async(model_path)); + } + // wait for all futures to complete + auto results = folly::collectAll(futures).get(); + for (const auto& result : results) { + if (!result.value()) { + return false; + } + } + + return true; +} + +Engine::KVCacheCapacity RecEngine::estimate_kv_cache_capacity() { + const int64_t max_cache_size = options_.max_cache_size(); + const double max_memory_utilization = options_.max_memory_utilization(); + + const auto& device = workers_[0]->device(); + // call worker to profile memory usage + std::vector>> futures; + futures.reserve(workers_.size()); + for (auto& worker : workers_) { + futures.push_back(worker->estimate_kv_cache_capacity_async()); + } + + // pick smallest available memory from all devices + int64_t cache_size_in_bytes = std::numeric_limits::max(); + // wait for all futures to complete + auto results = folly::collectAll(futures).get(); + for (size_t i = 0; i < results.size(); ++i) { + const auto device = workers_[i]->device(); + if (!results[i].hasValue()) { + LOG(ERROR) << "Failed to profile memory usage for device: " << device; + continue; + } + auto [available_memory, total_memory] = results[i].value(); + LOG(INFO) << device + << ": available memory: " << readable_size(available_memory) + << ", total memory: " << readable_size(total_memory) + << ", Using max_memory_utilization: " << max_memory_utilization + << ", max_cache_size: " << readable_size(max_cache_size); + // apply memory cap from config if it is set + if (max_memory_utilization < 1.0) { + const int64_t buffer_memory = + total_memory * (1.0 - max_memory_utilization); + available_memory -= buffer_memory; + } + if (max_cache_size > 0) { + available_memory = std::min(available_memory, max_cache_size); + } + cache_size_in_bytes = std::min(cache_size_in_bytes, available_memory); + } + + KVCacheCapacity kv_cache_cap; + kv_cache_cap.cache_size_in_bytes = std::max(cache_size_in_bytes, int64_t(0)); + CHECK_GT(kv_cache_cap.cache_size_in_bytes, 0) + << "Available kv cache size must be greater than 0"; + + // compute kv cache slot size + const auto dtype_size = torch::scalarTypeToTypeMeta(dtype_).itemsize(); + // key + value for all layers + const int64_t slot_size = + 2 * n_local_kv_heads_ * head_dim_ * args_.n_layers() * dtype_size; + kv_cache_cap.slot_size = slot_size; + + // compute kv blocks num + const int32_t block_size = options_.block_size(); + const int64_t block_size_in_bytes = block_size * slot_size; + kv_cache_cap.n_blocks = cache_size_in_bytes / block_size_in_bytes; + CHECK_GT(kv_cache_cap.n_blocks, 0) << "no n_blocks for kv cache"; + + return kv_cache_cap; +} + +bool RecEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) { + LOG(INFO) << "kv cache capacity: " + << "bytes: " << kv_cache_cap.cache_size_in_bytes + << ", blocks: " << kv_cache_cap.n_blocks + << ", slot_size: " << kv_cache_cap.slot_size; + + const int32_t block_size = options_.block_size(); + + // init kv cache for each worker + std::vector> kv_cache_shape; + kv_cache_shape.reserve(2); + kv_cache_shape.emplace_back(std::vector{ + kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_}); + kv_cache_shape.emplace_back(std::vector{ + kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_}); + + LOG(INFO) << "Initializing k cache with shape: [" << kv_cache_shape[0] << "]"; + LOG(INFO) << "Initializing v cache with shape: [" << kv_cache_shape[1] << "]"; + + // initialize block manager + BlockManagerPool::Options options; + options.num_blocks(kv_cache_cap.n_blocks) + .host_num_blocks(kv_cache_cap.n_blocks) + .block_size(block_size) + .enable_prefix_cache(options_.enable_prefix_cache()) + .enable_disagg_pd(options_.enable_disagg_pd()) + .enable_cache_upload(options_.enable_cache_upload()); + kv_cache_manager_ = std::make_unique(options); + + // init kv cache for each worker in parallel + std::vector> futures; + futures.reserve(workers_.size()); + for (auto& worker : workers_) { + futures.push_back(worker->allocate_kv_cache_async(kv_cache_shape)); + } + // wait for all futures to complete + auto results = folly::collectAll(futures).get(); + for (const auto& result : results) { + if (!result.value()) { + return false; + } + } + return true; +} + +// RecEngine executes model three times: prefill + 2 decode steps +ForwardOutput RecEngine::step(std::vector& batches) { + if (workers_.empty()) { + // empty worker, return + return {}; + } + + Timer timer; + auto forward_inputs = workers_[0]->prepare_inputs(batches[0]); + COUNTER_ADD(prepare_input_latency_microseconds, timer.elapsed_microseconds()); + + if (!forward_inputs.token_ids.defined()) { + // empty input, just return + return {}; + } + + timer.reset(); + // Prefill step: Run the first model execution + const auto& prefill_output = get_model_output(forward_inputs); + COUNTER_ADD(rec_first_token_latency_microseconds, + timer.elapsed_microseconds()); + + timer.reset(); + batches[0].process_sample_output(prefill_output.sample_output, false); + COUNTER_ADD(rec_sampling_latency_microseconds, timer.elapsed_microseconds()); + + // Decode steps: Run the model 2 more times for decoding + ForwardOutput decode_output; + + for (int i = 0; i < 2; ++i) { + timer.reset(); + forward_inputs = workers_[0]->prepare_inputs(batches[0]); + COUNTER_ADD(prepare_input_latency_microseconds, + timer.elapsed_microseconds()); + + timer.reset(); + decode_output = get_model_output(forward_inputs); + if (i == 0) { + COUNTER_ADD(rec_second_token_latency_microseconds, + timer.elapsed_microseconds()); + } else if (i == 1) { + COUNTER_ADD(rec_third_token_latency_microseconds, + timer.elapsed_microseconds()); + } + + timer.reset(); + batches[0].process_sample_output(decode_output.sample_output, false); + COUNTER_ADD(rec_sampling_latency_microseconds, + timer.elapsed_microseconds()); + } + + batches[0].finish(); + + // Return the final model output + return decode_output; +} + +void RecEngine::update_last_step_result(std::vector& batch) {} + +std::vector RecEngine::get_active_activation_memory() const { + // call worker to get active activation memory + std::vector> futures; + futures.reserve(workers_.size()); + for (auto& worker : workers_) { + futures.push_back(worker->get_active_activation_memory_async()); + } + + // wait for all futures to complete + auto results = folly::collectAll(futures).get(); + std::vector active_activation_memories; + active_activation_memories.reserve(workers_.size()); + for (auto& result : results) { + active_activation_memories.push_back(result.value()); + } + return active_activation_memories; +} + +ForwardOutput RecEngine::get_model_output(const ForwardInput& model_inputs) { + std::vector>> futures; + futures.reserve(workers_.size()); + // TODO to adapt multi stream parallel later + BatchedForwardInputs batched_fwd_inputs; + batched_fwd_inputs.micro_inputs = {model_inputs}; + for (auto& worker : workers_) { + futures.emplace_back(worker->step_async(batched_fwd_inputs)); + } + // wait for the all future to complete + auto results = folly::collectAll(futures).get(); + // return the result from the driver + auto forward_output = results.front().value(); + + DCHECK(forward_output.has_value()) << "Failed to execute model"; + return forward_output.value(); +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/runtime/rec_engine.h b/xllm/core/runtime/rec_engine.h new file mode 100644 index 00000000..b9d586cc --- /dev/null +++ b/xllm/core/runtime/rec_engine.h @@ -0,0 +1,83 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include + +#include "common/macros.h" +#include "engine.h" +#include "framework/batch/batch.h" +#include "framework/block/block_manager_pool.h" +#include "framework/quant_args.h" +#include "framework/tokenizer/tokenizer.h" +#include "framework/tokenizer/tokenizer_args.h" +#include "worker.h" + +namespace xllm { + +class RecEngine : public Engine { + public: + // create an engine with the given devices + RecEngine(const runtime::Options& options); + + virtual ~RecEngine() = default; + + ForwardOutput step(std::vector& batch) override; + + const runtime::Options& options() const { return options_; } + + bool init() override; + + void update_last_step_result(std::vector& batch) override; + + // return the active activation memory + std::vector get_active_activation_memory() const override; + + // Override tokenizer to return nullptr for rec models + const Tokenizer* tokenizer() const override { return nullptr; } + + private: + bool init_model(); + Engine::KVCacheCapacity estimate_kv_cache_capacity(); + bool allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap); + + // Helper methods for rec-specific execution + ForwardOutput get_model_output(const ForwardInput& model_inputs); + + private: + // options + runtime::Options options_; + + // dtype + torch::ScalarType dtype_; + + // quantization args + QuantArgs quant_args_; + + // a list of process groups, with each process group handling a single device + std::vector> process_groups_; + + // a list of workers, with each worker handling a partial of model + std::vector> workers_; + + // config for kv cache + int64_t n_local_kv_heads_ = 0; + int64_t head_dim_ = 0; +}; + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/runtime/rec_master.cpp b/xllm/core/runtime/rec_master.cpp new file mode 100644 index 00000000..f23792e5 --- /dev/null +++ b/xllm/core/runtime/rec_master.cpp @@ -0,0 +1,267 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "rec_master.h" + +#include +#include +#include + +#include "absl/time/time.h" +#include "models/model_registry.h" +#include "runtime/rec_engine.h" +#include "runtime/xservice_client.h" +#include "scheduler/scheduler_factory.h" +#include "util/scope_guard.h" +#include "util/threadpool.h" +#include "util/utils.h" + +namespace xllm { + +RecMaster::RecMaster(const Options& options) + : Master(options, EngineType::LLM) { + // Initialize with Rec engine type + // The rest of the initialization follows the same pattern as LLMMaster + CHECK(engine_->init()); + + model_args_ = engine_->model_args(); + + bool enable_decode_response_to_service = false; + if (options_.enable_service_routing()) { + XServiceClient* xservice_client = XServiceClient::get_instance(); + if (!xservice_client->init(options_.etcd_addr().value_or(""), + options_.xservice_addr().value_or(""), + options_.instance_name().value_or(""), + engine_->block_manager_pool())) { + LOG(FATAL) << "XServiceClient init fail!"; + return; + } + auto service_config = xservice_client->get_config(); + enable_decode_response_to_service = + service_config.enable_decode_response_to_service; + } + + ContinuousScheduler::Options scheduler_options; + scheduler_options.max_tokens_per_batch(options_.max_tokens_per_batch()) + .max_seqs_per_batch(options_.max_seqs_per_batch()) + .max_tokens_per_chunk_for_prefill( + options_.max_tokens_per_chunk_for_prefill()) + .num_speculative_tokens(options_.num_speculative_tokens()) + .dp_size(options_.dp_size()) + .enable_disagg_pd(options_.enable_disagg_pd()) + .enable_schedule_overlap(options_.enable_schedule_overlap()) + .enable_chunked_prefill(options_.enable_chunked_prefill()) + .instance_role(options_.instance_role()) + .kv_cache_transfer_mode(options_.kv_cache_transfer_mode()) + .enable_service_routing(options_.enable_service_routing()) + .enable_decode_response_to_service(enable_decode_response_to_service); + scheduler_ = create_fixsteps_scheduler(engine_.get(), scheduler_options); + + // OmniRec model does not have a tokenizer + chat_template_ = nullptr; + tokenizer_ = nullptr; + threadpool_ = std::make_unique(options_.num_handling_threads()); +} + +void RecMaster::run() { + const bool already_running = running_.load(std::memory_order_relaxed); + if (already_running) { + LOG(WARNING) << "RecMaster is already running."; + return; + } + running_.store(true, std::memory_order_relaxed); + loop_thread_ = std::thread([this]() { + const auto timeout = absl::Milliseconds(5); + while (!stopped_.load(std::memory_order_relaxed)) { + // move scheduler forward + scheduler_->step(timeout); + } + running_.store(false, std::memory_order_relaxed); + }); + + // Engine run method is not available, remove this call +} + +RecMaster::~RecMaster() { + // set stop flag + stopped_.store(true, std::memory_order_relaxed); + // wait for the loop thread to finish + if (loop_thread_.joinable()) { + loop_thread_.join(); + } +} + +void RecMaster::handle_request(std::string prompt, + std::optional> prompt_tokens, + std::optional mm_data, + RequestParams sp, + OutputCallback callback) { + // add one pending request + scheduler_->incr_pending_requests(1); + auto cb = [callback = std::move(callback), + scheduler = scheduler_.get()](const RequestOutput& output) { + output.log_request_status(); + return callback(output); + }; + // add into the queue + threadpool_->schedule([this, + prompt = std::move(prompt), + prompt_tokens = std::move(prompt_tokens), + mm_data = std::move(mm_data), + sp = std::move(sp), + callback = std::move(cb)]() mutable { + AUTO_COUNTER(request_handling_latency_seconds_completion); + + // remove the pending request after scheduling + SCOPE_GUARD([this] { scheduler_->decr_pending_requests(); }); + + Timer timer; + // verify the prompt + if (!sp.verify_params(callback)) { + return; + } + + auto request = generate_request(std::move(prompt), + std::move(prompt_tokens), + std::move(mm_data), + sp, + callback); + if (!request) { + return; + } + + if (!scheduler_->add_request(request)) { + CALLBACK_WITH_ERROR(StatusCode::RESOURCE_EXHAUSTED, + "No available resources to schedule request"); + } + }); +} + +std::shared_ptr RecMaster::generate_request( + std::string prompt, + std::optional> prompt_tokens, + std::optional mm_data, + RequestParams sp, + OutputCallback callback) { + // For Rec model, prompt is expected to be empty and prompt_tokens should + // contain the actual data Skip prompt empty check as mentioned in + // requirements + + Timer timer; + std::vector local_prompt_tokens; + + if (prompt_tokens.has_value()) { + local_prompt_tokens = std::move(prompt_tokens.value()); + LOG(INFO) + << "[Rec DEBUG] generate_request - received prompt_tokens.size(): " + << local_prompt_tokens.size() + << ", prompt.length(): " << prompt.length(); + } else if (!mm_data.has_value()) { + // sparse LLM + LOG(ERROR) << "Rec model requires prompt_tokens/embedding to be provided"; + CALLBACK_WITH_ERROR( + StatusCode::INVALID_ARGUMENT, + "Rec model requires prompt_tokens/embedding to be provided"); + return nullptr; + } + + COUNTER_ADD(tokenization_latency_seconds, timer.elapsed_seconds()); + + int32_t max_context_len = model_args_.max_position_embeddings(); + if (!options_.enable_chunked_prefill()) { + max_context_len = + std::min(max_context_len, options_.max_tokens_per_batch()); + } + if (local_prompt_tokens.size() >= max_context_len) { + LOG(ERROR) << "Prompt is too long: " << local_prompt_tokens.size(); + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, "Prompt is too long"); + return nullptr; + } + + uint32_t max_tokens = sp.max_tokens; + if (max_tokens == 0) { + const uint32_t kDefaultMaxTokens = 5120; + max_tokens = kDefaultMaxTokens; + } + + // allocate enough capacity for prompt tokens, max tokens, and speculative + // tokens + size_t capacity = local_prompt_tokens.size() + max_tokens + + options_.num_speculative_tokens() + /*bonus_token*/ 1; + if (options_.enable_schedule_overlap()) { + capacity += options_.num_speculative_tokens() + 1; + } + const size_t best_of = sp.best_of.value_or(sp.n); + + RequestSamplingParam sampling_param; + sampling_param.frequency_penalty = sp.frequency_penalty; + sampling_param.presence_penalty = sp.presence_penalty; + sampling_param.repetition_penalty = sp.repetition_penalty; + sampling_param.temperature = sp.temperature; + sampling_param.top_p = sp.top_p; + sampling_param.top_k = sp.top_k; + sampling_param.logprobs = sp.logprobs; + sampling_param.top_logprobs = sp.top_logprobs; + sampling_param.is_embeddings = sp.is_embeddings; + sampling_param.beam_width = sp.beam_width; + if (best_of > sp.n) { + // enable logprobs for best_of to generate sequence logprob + sampling_param.logprobs = true; + } + // sampling_param.do_sample = sp.do_sample; + + bool stream = sp.streaming; + // results cannot be streamed when best_of != n + if (best_of != sp.n) { + stream = false; + } + // std::unordered_set stop_tokens; + // std::vector> stop_sequences; + // StoppingChecker stopping_checker( + // max_tokens, + // max_context_len - options_.num_speculative_tokens(), + // , + // model_args_.eos_token_id(), + // sp.ignore_eos, + // std::move(stop_tokens), + // std::move(stop_sequences)); + StoppingChecker stopping_checker; + RequestState req_state(std::move(prompt), + std::move(local_prompt_tokens), + mm_data.value_or(MMData{}), + std::move(sampling_param), + std::move(stopping_checker), + capacity, + sp.n, + best_of, + sp.logprobs, + stream, + sp.echo, + sp.skip_special_tokens, + options_.enable_schedule_overlap(), + callback, + nullptr, + sp.decode_address); + req_state.is_rec_model = true; + req_state.bos_token_id = model_args_.bos_token_id(); + auto request = std::make_shared(sp.request_id, + sp.x_request_id, + sp.x_request_time, + std::move(req_state), + sp.service_request_id); + return request; +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/runtime/rec_master.h b/xllm/core/runtime/rec_master.h new file mode 100644 index 00000000..60d20c42 --- /dev/null +++ b/xllm/core/runtime/rec_master.h @@ -0,0 +1,71 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include "framework/chat_template/jinja_chat_template.h" +#include "framework/model/model_args.h" +#include "runtime/master.h" +#include "runtime/rec_engine.h" +#include "scheduler/continuous_scheduler.h" +#include "scheduler/fixsteps_scheduler.h" +#include "util/threadpool.h" + +namespace xllm { + +class RecMaster : public Master { + public: + explicit RecMaster(const Options& options); + ~RecMaster(); + + // handle a request, the engine will execute the request asynchronously + // completion/encode + void handle_request(std::string prompt, + std::optional> prompt_tokens, + std::optional mm_data, + RequestParams sp, + OutputCallback callback); + + // start the handling loop + void run() override; + + private: + std::shared_ptr generate_request( + std::string prompt, + std::optional> prompt_tokens, + std::optional mm_data, + RequestParams sp, + OutputCallback callback); + + std::unique_ptr scheduler_; + // model args + ModelArgs model_args_; + std::unique_ptr threadpool_; + std::unique_ptr tokenizer_; + // chat template instance + std::unique_ptr chat_template_; + // thread for moving forward the scheduler + std::thread loop_thread_; + // flag to stop the loop + std::atomic stopped_{false}; + + // flag to indicate if the handler is running + std::atomic running_{false}; +}; + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/runtime/rec_worker_impl.cpp b/xllm/core/runtime/rec_worker_impl.cpp new file mode 100644 index 00000000..35f3245e --- /dev/null +++ b/xllm/core/runtime/rec_worker_impl.cpp @@ -0,0 +1,347 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "rec_worker_impl.h" + +#include + +#include +#include +#include + +#include "butil/file_util.h" +#include "butil/files/dir_reader_linux.h" +#include "butil/files/file_path.h" +#include "butil/strings/string_util.h" +#include "common/metrics.h" +#include "models/model_registry.h" +#include "util/env_var.h" +#include "util/utils.h" + +namespace xllm { + +RecWorkerImpl::RecWorkerImpl(const ParallelArgs& parallel_args, + const torch::Device& device, + const runtime::Options& options) + : WorkerImpl(parallel_args, device, options) { + // Initialize filter mask stream for H2D operations + filter_mask_stream_ = device_.get_stream_from_pool(); + + // Initialize thread pool for async operations using environment variable + int thread_num = util::get_int_env(util::EXTRA_THREAD_NUM, 16); + thread_pool_ = std::make_shared(thread_num); +} + +bool RecWorkerImpl::init_model(const std::string& model_weights_path) { + auto model_loader = ModelLoader::create(model_weights_path); + + auto args = model_loader->model_args(); + auto quant_args = model_loader->quant_args(); + torch::ScalarType dtype = util::parse_dtype(args.dtype(), device_); + + if (options_.enable_speculative_decode() && FLAGS_enable_atb_spec_kernel) { + args.num_speculative_tokens(options_.num_speculative_tokens()); + } + + // create model context + dtype_ = dtype; + auto tensor_options = torch::dtype(dtype_).device(device_); + context_ = ModelContext(parallel_args_, args, quant_args, tensor_options); + + // init model, create model executor + bool status = this->init_model(context_); + if (!status) { + return false; + } + + this->load_model(std::move(model_loader)); + + status_ = Status::LOADED; + // TODO: replace path with flags after filter merge + butil::FilePath filter_bin_path = + butil::FilePath(model_weights_path).Append("replace me when merge"); + valid_path_filter_ = std::make_unique( + filter_bin_path.value(), args.vocab_size(), dtype_, device_); + + return true; +} + +bool RecWorkerImpl::init_model(ModelContext& context) { + CHECK(model_ == nullptr) << "Model is already initialized."; + device_.set_device(); + + // Try to create a causal LM model (Rec models are typically based on + // CausalLM) + model_ = create_llm_model(context); + + // Check if model creation was successful + CHECK(model_ != nullptr) << "Failed to create Rec model."; + model_executor_ = std::make_unique( + model_.get(), context.get_model_args(), device_, options_); + + if (FLAGS_enable_beam_search_kernel) { + beam_searcher_ = std::make_unique(); + } + return true; +} + +std::optional RecWorkerImpl::step( + const BatchedForwardInputs& inputs) { + device_.set_device(); + + // Timer for performance monitoring + auto start_time = std::chrono::high_resolution_clock::now(); + + std::vector flatten_tokens_micro_batches; + std::vector flatten_positions_micro_batches; + std::vector input_params_micro_batches; + auto& concated_sampling_params = inputs.concated_sampling_params; + + for (auto i = 0; i < inputs.micro_inputs.size(); ++i) { + flatten_tokens_micro_batches.push_back( + std::move(inputs.micro_inputs[i].token_ids)); + flatten_positions_micro_batches.push_back( + std::move(inputs.micro_inputs[i].positions)); + input_params_micro_batches.push_back( + std::move(inputs.micro_inputs[i].input_params)); + } + + // Start async filter mask preparation early for overlap (if beam search is + // enabled) + std::future filter_mask_future; + + if (!input_params_micro_batches.empty() && + input_params_micro_batches[0].is_rec_model() && + input_params_micro_batches[0].rec_params.has_value()) { + auto& rec_params = input_params_micro_batches[0].rec_params.value(); + if (!rec_params.generated_tokens.empty()) { + filter_mask_future = + prepare_filter_mask_async(rec_params.generated_tokens); + } + } + + // Check if we have encoder inputs (rec model with encoder/decoder) + torch::Tensor hidden_states; + bool has_encoder_inputs = false; + + // Check if this is a rec model with encoder inputs + if (!input_params_micro_batches.empty() && + input_params_micro_batches[0].is_rec_model() && + input_params_micro_batches[0].rec_params.has_value()) { + auto& rec_params = input_params_micro_batches[0].rec_params.value(); + + // Check for encoder inputs + if ((rec_params.encoder_token_ids.defined() && + rec_params.encoder_positions.defined()) || + rec_params.encoder_sparse_embedding.defined()) { + has_encoder_inputs = true; + + // Set hybrid mode if sparse embedding is defined + if (rec_params.encoder_sparse_embedding.defined()) { + input_params_micro_batches[0].rec_params->is_hybrid_mode = true; + } + } + } + + if (has_encoder_inputs) { + // Two-stage forward: encoder then decoder + auto& rec_params = input_params_micro_batches[0].rec_params.value(); + + if (rec_params.rec_stage == RecModelInputParams::RecStage::PREFILL) { + // Check if this is the first prefill or subsequent prefill + if (!rec_params.is_first_prefill) { + // Subsequent prefill: only run decoder + hidden_states = + model_executor_->forward(flatten_tokens_micro_batches, + flatten_positions_micro_batches, + kv_caches_, + input_params_micro_batches); + } else { + // First prefill: run encoder first, then decoder + + // 1. Run encoder forward + auto encoder_input_params = input_params_micro_batches; + encoder_input_params[0].rec_params->is_encoder_forward = true; + + std::vector encoder_tokens; + std::vector encoder_positions; + + if (rec_params.is_hybrid_mode && + rec_params.encoder_sparse_embedding.defined()) { + encoder_tokens.push_back(rec_params.encoder_sparse_embedding); + } else { + encoder_tokens.push_back(rec_params.encoder_token_ids); + } + encoder_positions.push_back(rec_params.encoder_positions); + + // Run encoder + hidden_states = model_executor_->forward(encoder_tokens, + encoder_positions, + kv_caches_, + encoder_input_params); + + // 2. Run decoder forward + encoder_input_params[0].rec_params->is_encoder_forward = false; + hidden_states = + model_executor_->forward(flatten_tokens_micro_batches, + flatten_positions_micro_batches, + kv_caches_, + encoder_input_params); + } + } else { + // Decode stage: only run decoder + hidden_states = model_executor_->forward(flatten_tokens_micro_batches, + flatten_positions_micro_batches, + kv_caches_, + input_params_micro_batches); + } + } else { + // Non-rec model or rec model without encoder: use standard forward + LOG(ERROR) << "RecWorkerImpl not supports decoder-only model"; + } + + torch::Tensor logits; + if (concated_sampling_params.selected_token_idxes.defined()) { + logits = model_->logits(hidden_states, + concated_sampling_params.selected_token_idxes); + } + + ForwardOutput output; + + if (!driver_) { + return std::nullopt; + } + + // Get filter mask result from async preparation if available + torch::Tensor filter_mask; + if (filter_mask_future.valid()) { + // Get the result from async preparation (this will block if not ready) + filter_mask = filter_mask_future.get(); + } + + // Driver prepare model output + if (concated_sampling_params.selected_token_idxes.defined()) { + auto sample_logits = + logits.index_select(/*dim=*/0, concated_sampling_params.sample_idxes); + + // Apply filter mask if available + if (filter_mask.defined()) { + // Ensure filter_mask has the same batch size as sample_logits + if (filter_mask.size(0) == sample_logits.size(0)) { + sample_logits = sample_logits + filter_mask; + } else { + // If dimensions don't match, select the appropriate rows from + // filter_mask + auto selected_filter_mask = filter_mask.index_select( + /*dim=*/0, concated_sampling_params.sample_idxes); + sample_logits = sample_logits + selected_filter_mask; + } + } + + auto sample_output = + sampler_->forward(sample_logits, concated_sampling_params); + output.logits = logits; + + // Set sample output to output + output.sample_output = sample_output; + + // Carry over the sampling params + output.do_sample = concated_sampling_params.do_sample; + output.logprobs = concated_sampling_params.logprobs; + output.max_top_logprobs = concated_sampling_params.max_top_logprobs; + } + + // Synchronize at the end like in llm_worker_impl + auto ret = device_.synchronize_default_stream(); + + // Record execution latency + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast( + end_time - start_time); + COUNTER_ADD(execution_latency_seconds_model, duration.count() / 1000000.0); + + return output; +} + +ForwardInput RecWorkerImpl::prepare_inputs(Batch& batch) { + // Use the rec-specific input preparation method + return batch.prepare_rec_forward_input(options_.num_decoding_tokens(), + 0, // min_decoding_batch_size + context_.get_model_args()); +} + +std::future RecWorkerImpl::prepare_filter_mask_async( + const std::vector>& generated_tokens) { + // Create promise/future pair for async result + auto promise = std::make_shared>(); + auto future = promise->get_future(); + + // Submit async task to thread pool + thread_pool_->schedule([this, generated_tokens, promise]() -> void { + try { + // Set stream guard for H2D operations + c10::StreamGuard streamGuard = filter_mask_stream_->set_stream_guard(); + + torch::Tensor cpu_mask; + + // Use ValidPathFilter if available, otherwise create placeholder mask + if (valid_path_filter_ && !generated_tokens.empty()) { + // Use ValidPathFilter to generate the actual filter mask + cpu_mask = valid_path_filter_->forward(generated_tokens); + + // If ValidPathFilter returns empty tensor, create placeholder + if (!cpu_mask.defined()) { + int batch_size = generated_tokens.size(); + int vocab_size = 32000; // Default vocab size + cpu_mask = torch::zeros({batch_size, vocab_size}, torch::kFloat32); + } + } else if (!generated_tokens.empty()) { + // Fallback: create placeholder mask when ValidPathFilter is not + // available + int batch_size = generated_tokens.size(); + int vocab_size = 32000; // Default vocab size + cpu_mask = torch::zeros({batch_size, vocab_size}, torch::kFloat32); + + // Apply some basic filtering logic (placeholder) + for (int i = 0; i < batch_size; ++i) { + // Set some tokens to -inf to filter them out + cpu_mask[i] + .slice(0, 0, 1000) + .fill_(-std::numeric_limits::infinity()); + } + } else { + // Return empty tensor if no generated tokens + promise->set_value(torch::Tensor()); + return; + } + + // Copy to device using the dedicated H2D stream + torch::Tensor device_mask = cpu_mask.to(device_, /*non_blocking=*/true); + + // Synchronize the H2D stream to ensure copy is complete + filter_mask_stream_->synchronize(); + + // Set the result in the promise + promise->set_value(device_mask); + } catch (const std::exception& e) { + // Set exception in promise if something goes wrong + promise->set_exception(std::current_exception()); + } + }); + + return future; +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/runtime/rec_worker_impl.h b/xllm/core/runtime/rec_worker_impl.h new file mode 100644 index 00000000..9c5739fc --- /dev/null +++ b/xllm/core/runtime/rec_worker_impl.h @@ -0,0 +1,76 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include +#include + +#include "framework/batch/batch.h" +#include "framework/model_context.h" +#include "framework/sampling/valid_path_filter.h" +#include "platform/stream.h" +#include "runtime/forward_params.h" +#include "util/threadpool.h" +#include "worker_impl.h" + +namespace xllm { + +// Rec specific worker implementation +class RecWorkerImpl : public WorkerImpl { + public: + RecWorkerImpl(const ParallelArgs& parallel_args, + const torch::Device& device, + const runtime::Options& options); + + // Override init_model for Rec specific implementation + bool init_model(const std::string& model_weights_path) override; + + // Override init_model with ModelContext for Rec specific implementation + bool init_model(ModelContext& context) override; + + // Override step for Rec specific implementation + std::optional step( + const BatchedForwardInputs& inputs) override; + + // Override prepare_inputs for Rec specific implementation + ForwardInput prepare_inputs(Batch& batch) override; + + private: + // Helper method for filter mask preparation (placeholder for future + // implementation) + torch::Tensor prepare_filter_mask( + const std::vector>& generated_tokens); + + // Async filter mask preparation with overlap + std::future prepare_filter_mask_async( + const std::vector>& generated_tokens); + + // Stream for H2D memory copy operations + std::unique_ptr filter_mask_stream_; + + // ThreadPool for async operations + std::shared_ptr thread_pool_; + + // ValidPathFilter for beam search filtering + std::unique_ptr valid_path_filter_; + + // BeamSearcher for beam search functionality + std::unique_ptr beam_searcher_; +}; + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/runtime/worker.cpp b/xllm/core/runtime/worker.cpp index 3ab9b6e6..b641ebe8 100644 --- a/xllm/core/runtime/worker.cpp +++ b/xllm/core/runtime/worker.cpp @@ -32,6 +32,7 @@ limitations under the License. #include "runtime/embed_vlm_worker_impl.h" #include "runtime/embed_worker_impl.h" #include "runtime/llm_worker_impl.h" +#include "runtime/rec_worker_impl.h" #include "runtime/speculative_worker_impl.h" #include "runtime/vlm_worker_impl.h" #include "util/timer.h" @@ -51,6 +52,8 @@ Worker::Worker(const ParallelArgs& parallel_args, impl_ = new EmbedWorkerImpl(parallel_args, device, options); } else if (worker_type == WorkerType::EVLM) { impl_ = new EmbedVLMWorkerImpl(parallel_args, device, options); + } else if (worker_type == WorkerType::REC) { + impl_ = new RecWorkerImpl(parallel_args, device, options); } else { LOG(ERROR) << "Unknown worker type, please check logic"; } diff --git a/xllm/core/scheduler/CMakeLists.txt b/xllm/core/scheduler/CMakeLists.txt index d694b3b1..999c6ce1 100644 --- a/xllm/core/scheduler/CMakeLists.txt +++ b/xllm/core/scheduler/CMakeLists.txt @@ -17,6 +17,7 @@ cc_library( scheduler.h dit_scheduler.h prefill_only_scheduler.h + fixsteps_scheduler.h scheduler_factory.h decode_priority_queue.h perf_model.h @@ -27,6 +28,7 @@ cc_library( disagg_pd_scheduler.cpp pd_ooc_scheduler.cpp async_response_processor.cpp + fixsteps_scheduler.cpp dit_scheduler.cpp prefill_only_scheduler.cpp scheduler_factory.cpp diff --git a/xllm/core/scheduler/fixsteps_scheduler.cpp b/xllm/core/scheduler/fixsteps_scheduler.cpp new file mode 100644 index 00000000..ccd076ee --- /dev/null +++ b/xllm/core/scheduler/fixsteps_scheduler.cpp @@ -0,0 +1,309 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "fixsteps_scheduler.h" + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "common/metrics.h" +#include "framework/batch/batch.h" +#include "framework/batch/batch_factory.h" +#include "framework/request/request.h" +#include "framework/request/sequence.h" +#include "runtime/engine.h" + +namespace xllm { + +namespace { +constexpr size_t kRequestQueueSize = 100000; +} // namespace + +FixStepsScheduler::FixStepsScheduler(Engine* engine, const Options& options) + : ContinuousScheduler(engine, options) {} + +bool FixStepsScheduler::add_request(std::shared_ptr& request) { + CHECK(request != nullptr); + CHECK(!request->sequences().empty()); + + if (request_queue_.write(request)) { //.get() + // take over the ownership of the request + // request.release(); + return true; + } + // queue is full + return false; +} + +void FixStepsScheduler::handle_prefill_requests( + size_t& remaining_token_budget, + size_t& remaining_seq_budget, + std::vector>& finished_requests) { + // Handle new request prompt first. + // Include those requests that are preempted by others. + // + // schedule the prefill requests in the waiting priority queue until budgets + // are exhausted. + // When the KV Cache usage reaches the threshold, prefill requests will no + // longer be scheduled to avoid frequent preemption. + // + // NOTE: preempted requests will be pushed in waiting_priority_queue, + // they may contian many sequences, so we should check here. + bool budget_exhausted = false; + bool blocks_exhausted = false; + while (!waiting_priority_queue_.empty() && remaining_seq_budget > 0 && + remaining_token_budget > 0 && + kv_cache_manager_->kv_cache_utilization() < + FLAGS_prefill_scheduling_memory_usage_threshold) { + std::shared_ptr request(waiting_priority_queue_.top()); + if (request->finished() || request->cancelled()) { + // kv_cache_manager_->deallocate(request.get()); + // release the ownership of the request + finished_requests.emplace_back(request); + // remove the request from the priority queue + waiting_priority_queue_.pop(); + continue; + } + + const size_t num_sequences = request->sequences().size(); + if (!request->preempted()) { + CHECK(num_sequences == 1) + << "Waiting request should have only one sequence."; + } + + // TODO: FIXME later + // Optimization of the scheduling algorithm under multiple sequences + size_t allocated_tokens = 0; + size_t allocated_seqs = 0; + double allocated_estimate_latency = 0; + bool can_schedule = true; + std::vector prefill_sequences; + std::vector prefill_sequences_budget; + prefill_sequences.reserve(request->sequences().size()); + prefill_sequences_budget.reserve(request->sequences().size()); + for (auto& prefill_sequence : request->sequences()) { + if (prefill_sequence->finished()) { + continue; + } + + size_t num_tokens = prefill_sequence->num_need_compute_tokens(); + if (remaining_token_budget < allocated_tokens + num_tokens || + remaining_seq_budget < allocated_seqs + 1) { + can_schedule = false; + budget_exhausted = true; + break; + } + + prefill_sequences_budget.emplace_back(num_tokens); + prefill_sequences.emplace_back(prefill_sequence.get()); + allocated_tokens += num_tokens; + allocated_seqs += 1; + } + + if (!can_schedule) { + for (auto& seq : prefill_sequences) { + // release shared blocks + kv_cache_manager_->deallocate(seq); + } + break; + } + + if (prefill_sequences.empty()) { + continue; + } + + remaining_token_budget -= allocated_tokens; + remaining_seq_budget -= allocated_seqs; + waiting_priority_queue_.pop(); + running_requests_.emplace_back(request); + running_sequences_.insert(running_sequences_.end(), + prefill_sequences.begin(), + prefill_sequences.end()); + running_sequences_budgets_.insert(running_sequences_budgets_.end(), + prefill_sequences_budget.begin(), + prefill_sequences_budget.end()); + } + + if (running_sequences_.empty() && !waiting_priority_queue_.empty() && + running_queue_->empty()) { + LOG(ERROR) + << "Request prompt is too long, no enough budget/memory to schedule " + "a single sequence."; + // no enough memory to schedule single sequence, just finish the request + std::shared_ptr request(waiting_priority_queue_.top()); + waiting_priority_queue_.pop(); + // block_manager_->release_blocks_for(request.get()); + response_processor_->process_failed_request( + request, + {StatusCode::RESOURCE_EXHAUSTED, + "No enough budget to schedule single sequence."}); + } +} + +std::vector FixStepsScheduler::prepare_batch() { + Timer timer; + // propogate new requests to waiting_priority_queue_ + // Include those requests that are preempted by others. + std::shared_ptr request; + // read from request queue then push to waiting priority queue + while (request_queue_.read(request)) { + CHECK(request); + + // expand sequences to the target number if prefix cache is disabled. + if (!enable_prefix_cache_) { + // expand sequences to the target number + request->expand_sequences(false); + } + + if (request->sequences()[0]->kv_state().kv_cache_tokens_num() == 0) { + waiting_priority_queue_.push(request); + } else { + // request from prefill instance in disagge pd mode. + running_requests_.emplace_back(request); + } + } + + // handle finished/cancelled requests + std::vector> finished_requests; + for (auto it = running_requests_.rbegin(); it != running_requests_.rend(); + ++it) { + if (*it == nullptr) { + continue; + } + std::shared_ptr request = *it; + request->update_connection_status(); + if (request->finished() || request->cancelled()) { + // kv_cache_manager_->deallocate(request.get()); + // release the ownership of the request + finished_requests.emplace_back(request); + // finished request is set to nullptr + *it = nullptr; + } + } + + // clear previous batch + running_requests_.clear(); + running_sequences_.clear(); + running_sequences_budgets_.clear(); + + // remaining budget for the current batch + size_t remaining_token_budget = options_.max_tokens_per_batch(); + size_t remaining_seq_budget = std::max(options_.max_seqs_per_batch(), 1); + size_t num_preempted_requests = 0; + + handle_prefill_requests( + remaining_token_budget, remaining_seq_budget, finished_requests); + + // only forward once, no decode requests + // handle_decode_requests( + // remaining_token_budget, remaining_seq_budget, num_preempted_requests); + + if (!finished_requests.empty()) { + response_processor_->process_completed_requests(finished_requests); + } + + // update the batch + auto batches = BatchFactory::get_instance(options_.dp_size()) + ->create_rec_batches( + running_requests_, + running_sequences_, + running_sequences_budgets_, + kv_cache_manager_->get_copy_in_cache_block_infos(), + kv_cache_manager_->get_copy_out_cache_block_infos(), + kv_cache_manager_->get_swap_cache_block_infos()); + + // update metrics before returning + if (!batches[0].empty()) { + // only update the scheduling latency when there are requests to process + COUNTER_ADD(scheduling_latency_seconds, timer.elapsed_seconds()); + } + + GAUGE_SET(num_pending_requests, + pending_requests_.load(std::memory_order_relaxed)); + GAUGE_SET(num_running_requests, running_requests_.size()); + GAUGE_SET(num_waiting_requests, + waiting_priority_queue_.size() + running_queue_->size()); + + GAUGE_ADD(num_preempted_requests, num_preempted_requests); + + GAUGE_SET(num_running_sequences, running_sequences_.size()); + + GAUGE_SET(kv_cache_utilization_perc, + kv_cache_manager_->kv_cache_utilization()); + if (!FLAGS_enable_continuous_kvcache) { + GAUGE_SET(num_blocks_in_prefix_cache, + kv_cache_manager_->num_blocks_in_prefix_cache().size()); + GAUGE_SET(num_free_blocks, kv_cache_manager_->num_free_blocks().size()); + GAUGE_SET(num_used_blocks, kv_cache_manager_->num_used_blocks().size()); + } + return batches; +} + +std::vector FixStepsScheduler::schedule_request( + const absl::Duration& timeout) { + const auto deadline = absl::Now() + timeout; + std::vector batch; + while (true) { + batch = prepare_batch(); + bool all_empty = + std::all_of(batch.begin(), batch.end(), [](const Batch& one_batch) { + return one_batch.empty(); + }); + if (!all_empty) { + return batch; + } + const auto now = absl::Now(); + if (now > deadline) { + break; + } + // wait for new requests to arrive + constexpr uint64_t kStepSleepTimeMs = 1; + const auto time_to_sleep = + std::min(absl::Milliseconds(kStepSleepTimeMs), deadline - now); + absl::SleepFor(time_to_sleep); + } + // return an empty batch + return batch; +} + +// step the scheduler forward by one step +// may get blocked if there are no requests to process +void FixStepsScheduler::step(const absl::Duration& timeout) { + if (!options_.enable_schedule_overlap()) { + // get a new batch of requests + std::vector batch = schedule_request(timeout); + bool all_empty = + std::all_of(batch.begin(), batch.end(), [](const Batch& one_batch) { + return one_batch.empty(); + }); + if (all_empty) { + return; + } + engine_->step(batch); + kv_cache_manager_->reset_copy_content(); + } else { + LOG(ERROR) << "FixStepsScheduler::step() not supported with " + "enable_schedule_overlap"; + } +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/scheduler/fixsteps_scheduler.h b/xllm/core/scheduler/fixsteps_scheduler.h new file mode 100644 index 00000000..1fcfc3b3 --- /dev/null +++ b/xllm/core/scheduler/fixsteps_scheduler.h @@ -0,0 +1,62 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +#include "async_response_processor.h" +#include "common/macros.h" +#include "common/types.h" +#include "framework/batch/batch.h" +#include "framework/request/request.h" +#include "framework/request/sequence.h" +#include "runtime/xservice_client.h" +#include "scheduler.h" +#include "scheduler/continuous_scheduler.h" + +namespace xllm { +class Engine; + +class FixStepsScheduler final : public ContinuousScheduler { + public: + FixStepsScheduler(Engine* engine, const Options& options); + virtual ~FixStepsScheduler() = default; + + bool add_request(std::shared_ptr& request) override; + + // step the scheduler forward by one step + // may get blocked if there are no requests to process + void step(const absl::Duration& timeout) override; + + private: + std::vector schedule_request(const absl::Duration& timeout); + + // build a batch of requests from the priority queue + virtual std::vector prepare_batch(); + + void handle_prefill_requests( + size_t& remaining_token_budget, + size_t& remaining_seq_budget, + std::vector>& finished_requests); +}; + +} // namespace xllm diff --git a/xllm/core/scheduler/scheduler_factory.cpp b/xllm/core/scheduler/scheduler_factory.cpp index 8be5a8b8..de85bd13 100644 --- a/xllm/core/scheduler/scheduler_factory.cpp +++ b/xllm/core/scheduler/scheduler_factory.cpp @@ -19,6 +19,7 @@ limitations under the License. #include "scheduler/continuous_scheduler.h" #include "scheduler/disagg_pd_scheduler.h" #include "scheduler/dit_scheduler.h" +#include "scheduler/fixsteps_scheduler.h" #include "scheduler/pd_ooc_scheduler.h" #include "scheduler/prefill_only_scheduler.h" #include "scheduler/zero_eviction_scheduler.h" @@ -51,6 +52,12 @@ std::unique_ptr create_continuous_scheduler( return std::make_unique(engine, options); } +std::unique_ptr create_fixsteps_scheduler( + Engine* engine, + ContinuousScheduler::Options options) { + return std::make_unique(engine, options); +} + std::unique_ptr create_dit_scheduler( DiTEngine* engine, DiTScheduler::Options options) { diff --git a/xllm/core/scheduler/scheduler_factory.h b/xllm/core/scheduler/scheduler_factory.h index daf153ba..0fa452dd 100644 --- a/xllm/core/scheduler/scheduler_factory.h +++ b/xllm/core/scheduler/scheduler_factory.h @@ -18,6 +18,7 @@ limitations under the License. #include "runtime/xservice_client.h" #include "scheduler/continuous_scheduler.h" #include "scheduler/dit_scheduler.h" +#include "scheduler/fixsteps_scheduler.h" namespace xllm { @@ -25,6 +26,10 @@ std::unique_ptr create_continuous_scheduler( Engine* engine, ContinuousScheduler::Options options); +std::unique_ptr create_fixsteps_scheduler( + Engine* engine, + ContinuousScheduler::Options options); + std::unique_ptr create_dit_scheduler( DiTEngine* engine, DiTScheduler::Options options); diff --git a/xllm/core/util/CMakeLists.txt b/xllm/core/util/CMakeLists.txt index 3318822e..cd2fbd8d 100644 --- a/xllm/core/util/CMakeLists.txt +++ b/xllm/core/util/CMakeLists.txt @@ -31,6 +31,7 @@ cc_library( SRCS device_name_utils.cpp env_var.cpp + hash_util.cpp json_reader.cpp net.cpp pretty_print.cpp @@ -50,7 +51,9 @@ cc_library( Boost::serialization absl::synchronization ${Python_LIBRARIES} + proto::xllm_proto :platform + SMHasherSupport ) target_link_libraries(util PRIVATE OpenSSL::SSL OpenSSL::Crypto) add_dependencies(util brpc-static) @@ -70,8 +73,3 @@ cc_test( ) target_link_libraries(util_test PRIVATE brpc OpenSSL::SSL OpenSSL::Crypto) add_dependencies(util_test brpc-static) - - - - - diff --git a/xllm/core/util/env_var.cpp b/xllm/core/util/env_var.cpp index 6aac81ef..937fe17a 100644 --- a/xllm/core/util/env_var.cpp +++ b/xllm/core/util/env_var.cpp @@ -20,6 +20,9 @@ limitations under the License. namespace xllm { namespace util { +// Environment variable keys +const std::string EXTRA_THREAD_NUM = "EXTRA_THREAD_NUM"; + bool get_bool_env(const std::string& key, bool defaultValue) { const char* val = std::getenv(key.c_str()); if (val == nullptr) { @@ -30,5 +33,17 @@ bool get_bool_env(const std::string& key, bool defaultValue) { strVal == "True"); } +int get_int_env(const std::string& key, int defaultValue) { + const char* val = std::getenv(key.c_str()); + if (val == nullptr) { + return defaultValue; + } + try { + return std::stoi(val); + } catch (const std::exception&) { + return defaultValue; + } +} + } // namespace util } // namespace xllm diff --git a/xllm/core/util/env_var.h b/xllm/core/util/env_var.h index cbd69d67..9e524d36 100644 --- a/xllm/core/util/env_var.h +++ b/xllm/core/util/env_var.h @@ -20,7 +20,12 @@ limitations under the License. namespace xllm { namespace util { +// Environment variable keys +extern const std::string EXTRA_THREAD_NUM; + bool get_bool_env(const std::string& key, bool defaultValue); +int get_int_env(const std::string& key, int defaultValue); + } // namespace util } // namespace xllm diff --git a/xllm/core/util/hash_util.cpp b/xllm/core/util/hash_util.cpp new file mode 100644 index 00000000..f335f545 --- /dev/null +++ b/xllm/core/util/hash_util.cpp @@ -0,0 +1,55 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "hash_util.h" + +#include +#include + +#include "third_party/smhasher/src/MurmurHash3.h" + +// Use a constant seed instead of FLAGS to avoid circular dependency +constexpr uint32_t MURMUR_HASH3_SEED = 0; + +namespace xllm { + +void murmur_hash3(const uint8_t* pre_hash_value, + const Slice& token_ids, + uint8_t* hash_value) { + if (pre_hash_value == nullptr) { + MurmurHash3_x64_128(reinterpret_cast(token_ids.data()), + sizeof(int32_t) * token_ids.size(), + MURMUR_HASH3_SEED, + hash_value); + } else { + uint8_t key[1024]; + + int32_t data_len = + sizeof(int32_t) * token_ids.size() + MURMUR_HASH3_VALUE_LEN; + CHECK_GT(sizeof(key), data_len) << "key size is too small"; + + memcpy(key, pre_hash_value, MURMUR_HASH3_VALUE_LEN); + memcpy(key + MURMUR_HASH3_VALUE_LEN, + reinterpret_cast(token_ids.data()), + sizeof(int32_t) * token_ids.size()); + + MurmurHash3_x64_128(reinterpret_cast(key), + data_len, + MURMUR_HASH3_SEED, + hash_value); + } +} + +} // namespace xllm diff --git a/xllm/core/util/hash_util.h b/xllm/core/util/hash_util.h index 31393d5b..e4886ee4 100644 --- a/xllm/core/util/hash_util.h +++ b/xllm/core/util/hash_util.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "slice.h" + namespace xllm { constexpr uint32_t MURMUR_HASH3_VALUE_LEN = 16; @@ -62,4 +64,8 @@ struct FixedStringKeyEqual { } }; +void murmur_hash3(const uint8_t* pre_hash_value, + const Slice& token_ids, + uint8_t* hash_value); + } // namespace xllm diff --git a/xllm/core/util/tensor_helper.h b/xllm/core/util/tensor_helper.h index 714810c9..418d1cef 100644 --- a/xllm/core/util/tensor_helper.h +++ b/xllm/core/util/tensor_helper.h @@ -50,6 +50,58 @@ inline torch::Tensor create_2d_tensor(const std::vector >& vec, return tensor; }; +// 为空的2D vector提供特殊优化版本 +template +inline torch::Tensor create_2d_tensor_optimized( + const std::vector >& vec, + torch::ScalarType dtype) { + if (vec.empty()) { + return {}; + } + + const size_t n_rows = vec.size(); + const size_t n_cols = vec.empty() ? 0 : vec[0].size(); + + // 对于全零矩阵的特殊优化 + bool all_zero = true; + for (const auto& row : vec) { + for (const auto& val : row) { + if (val != T(0)) { + all_zero = false; + break; + } + } + if (!all_zero) break; + } + + if (all_zero) { + // 直接创建零tensor,更高效 + return torch::zeros( + {static_cast(n_rows), static_cast(n_cols)}, + torch::TensorOptions() + .dtype(dtype) + .device(torch::kCPU) + .pinned_memory(true)); + } + + // 否则使用优化的内存复制方式 + auto tensor = + torch::empty({static_cast(n_rows), static_cast(n_cols)}, + torch::TensorOptions() + .dtype(dtype) + .device(torch::kCPU) + .pinned_memory(true)); + + // 优化:使用批量内存复制替代逐行torch::tensor创建 + T* tensor_data = tensor.data_ptr(); + for (int64_t i = 0; i < n_rows; ++i) { + CHECK_EQ(vec[i].size(), n_cols); + // 直接复制内存,避免创建临时tensor + std::memcpy(tensor_data + i * n_cols, vec[i].data(), n_cols * sizeof(T)); + } + return tensor; +}; + inline torch::Tensor safe_to(const torch::Tensor& t, const torch::TensorOptions& options, bool non_blocking = false) { diff --git a/xllm/core/util/utils.cpp b/xllm/core/util/utils.cpp index 0182b687..d3c25164 100644 --- a/xllm/core/util/utils.cpp +++ b/xllm/core/util/utils.cpp @@ -148,5 +148,103 @@ std::vector cal_vec_split_index(uint32_t vec_size, return split_index; } +torch::Dtype convert_rec_type_to_torch(proto::DataType data_type) { + // Future extensions go here. + switch (data_type) { + case proto::DataType::FLOAT: + return torch::kFloat32; + + case proto::DataType::BFLOAT16: + return torch::kBFloat16; + + case proto::DataType::BOOL: + return torch::kBool; + + case proto::DataType::UINT8: + return torch::kUInt8; + + // case proto::DataType::UINT32: + // return torch::kUInt32; + + case proto::DataType::INT8: + return torch::kInt8; + + case proto::DataType::INT16: + return torch::kInt16; + + default: + throw std::runtime_error("Unsupported data type: " + + std::to_string(static_cast(data_type))); + } +} + +torch::Tensor convert_rec_tensor_to_torch( + const proto::InferInputTensor& input_tensor) { + std::vector shape; + shape.reserve(input_tensor.shape_size()); + for (int i = 0; i < input_tensor.shape_size(); ++i) { + shape.push_back(input_tensor.shape(i)); + } + + if (!input_tensor.has_contents()) { + throw std::runtime_error("Input tensor '" + input_tensor.name() + + "' has no contents"); + } + + const auto& contents = input_tensor.contents(); + torch::Dtype dtype = convert_rec_type_to_torch(input_tensor.data_type()); + + switch (dtype) { + case torch::kFloat32: { + // Directly use protobuf's float array + const auto& data = contents.fp32_contents(); + return torch::from_blob( + const_cast(data.data()), + shape, + torch::dtype(torch::kFloat32).requires_grad(false)) + .clone(); // Clone to ensure independent memory + } + // not support now. + // case torch::kFloat16: { + // // Need type conversion (protobuf usually stores float16 as uint16) + // const auto& data = contents.bytes_contents(); + // std::vector half_data; + // half_data.reserve(data.size()); + // for (auto val : data) { + // half_data.push_back(static_cast(val)); + // } + // return torch::tensor(half_data, torch::dtype(torch::kFloat16)) + // .view(shape); + // } + + case torch::kInt32: { + const auto& data = contents.int_contents(); + return torch::from_blob(const_cast(data.data()), + shape, + torch::dtype(torch::kInt32)) + .clone(); + } + + case torch::kInt64: { + const auto& data = contents.int64_contents(); + return torch::from_blob(const_cast(data.data()), + shape, + torch::dtype(torch::kInt64)) + .clone(); + } + + case torch::kBool: { + const auto& data = contents.bool_contents(); + return torch::tensor(std::vector(data.begin(), data.end()), + torch::dtype(torch::kBool)) + .view(shape); + } + + default: + throw std::runtime_error("Unhandled data type conversion for: " + + std::to_string(static_cast(dtype))); + } +} + } // namespace util } // namespace xllm diff --git a/xllm/core/util/utils.h b/xllm/core/util/utils.h index 51491972..3c95ca84 100644 --- a/xllm/core/util/utils.h +++ b/xllm/core/util/utils.h @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "rec.pb.h" #include "slice.h" namespace xllm { @@ -71,5 +72,8 @@ bool match_suffix(const Slice& data, const Slice& suffix); std::vector cal_vec_split_index(uint32_t vec_size, uint32_t part_num); +torch::Tensor convert_rec_tensor_to_torch( + const proto::InferInputTensor& input_tensor); + } // namespace util } // namespace xllm diff --git a/xllm/proto/CMakeLists.txt b/xllm/proto/CMakeLists.txt index 38be2b75..5014c409 100644 --- a/xllm/proto/CMakeLists.txt +++ b/xllm/proto/CMakeLists.txt @@ -6,6 +6,7 @@ proto_library( SRCS tensor.proto common.proto + rec.proto completion.proto chat.proto multimodal.proto diff --git a/xllm/proto/completion.proto b/xllm/proto/completion.proto index ccddd0e6..37b16738 100644 --- a/xllm/proto/completion.proto +++ b/xllm/proto/completion.proto @@ -4,6 +4,7 @@ option go_package = "jd.com/jd-infer/xllm;xllm"; package xllm.proto; import "common.proto"; +import "rec.proto"; // Next ID: 26 message CompletionRequest { @@ -95,6 +96,9 @@ message CompletionRequest { optional Priority priority = 28; optional int32 beam_width = 29; + + // tensor for rec embedding. + repeated InferInputTensor input_tensors = 30; } message LogProbs { @@ -142,5 +146,8 @@ message CompletionResponse { // usage statistics for the completion request. Usage usage = 6; + + // for rec output + repeated InferOutputTensor output_tensors = 7; } diff --git a/xllm/proto/rec.proto b/xllm/proto/rec.proto new file mode 100644 index 00000000..5504b865 --- /dev/null +++ b/xllm/proto/rec.proto @@ -0,0 +1,119 @@ +syntax = "proto3"; +option go_package = "jd.com/jd-infer/xllm;xllm"; +package xllm.proto; +import "common.proto"; + +option cc_enable_arenas = true; +option cc_generic_services = true; +enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + // Non-IEEE floating-point format based on papers + // FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433, + // 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf. + // Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear. + // The computation usually happens inside a block quantize / dequantize + // fused by the runtime. + FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf + FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero + FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients + FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, not inf, mostly used for gradients, no negative zero + // 4-bit integer data types + UINT4 = 21; // Unsigned integer in range [0, 15] + INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation + // 4-bit floating point data types + FLOAT4E2M1 = 23; + // E8M0 type used as the scale for microscaling (MX) formats: + // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + FLOAT8E8M0 = 24; + // Future extensions go here. +} +// The data contained in a tensor represented by the repeated type +// that matches the tensor's data type. Protobuf oneof is not used +// because oneofs cannot contain repeated fields. +message InferTensorContents +{ + // Representation for BOOL data type. The size must match what is + // expected by the tensor's shape. The contents must be the flattened, + // one-dimensional, row-major order of the tensor elements. + repeated bool bool_contents = 1; + // Representation for INT8, INT16, and INT32 data types. The size + // must match what is expected by the tensor's shape. The contents + // must be the flattened, one-dimensional, row-major order of the + // tensor elements. + repeated int32 int_contents = 2; + // Representation for INT64 data types. The size must match what + // is expected by the tensor's shape. The contents must be the + // flattened, one-dimensional, row-major order of the tensor elements. + repeated int64 int64_contents = 3; + // Representation for UINT8, UINT16, and UINT32 data types. The size + // must match what is expected by the tensor's shape. The contents + // must be the flattened, one-dimensional, row-major order of the + // tensor elements. + repeated uint32 uint_contents = 4; + // Representation for UINT64 data types. The size must match what + // is expected by the tensor's shape. The contents must be the + // flattened, one-dimensional, row-major order of the tensor elements. + repeated uint64 uint64_contents = 5; + // Representation for FP32 data type. The size must match what is + // expected by the tensor's shape. The contents must be the flattened, + // one-dimensional, row-major order of the tensor elements. + repeated float fp32_contents = 6; + // Representation for FP64 data type. The size must match what is + // expected by the tensor's shape. The contents must be the flattened, + // one-dimensional, row-major order of the tensor elements. + repeated double fp64_contents = 7; + // Representation for BYTES data type. The size must match what is + // expected by the tensor's shape. The contents must be the flattened, + // one-dimensional, row-major order of the tensor elements. + repeated bytes bytes_contents = 8; +} +// An input tensor for an inference request. +message InferInputTensor +{ + // The tensor name. + string name = 1; + // The tensor data type. + DataType data_type = 2; + // The tensor shape. + repeated int64 shape = 3; + // The tensor contents using a data-type format. This field must + // not be specified if "raw" tensor contents are being used for + // the inference request. + InferTensorContents contents = 4; +} +// An output tensor returned for an inference request. +message InferOutputTensor +{ + // The tensor name. + string name = 1; + // The tensor data type. + DataType datatype = 2; + // The tensor shape. + repeated int64 shape = 3; + // The tensor contents using a data-type format. This field must + // not be specified if "raw" tensor contents are being used for + // the inference response. + InferTensorContents contents = 4; +} \ No newline at end of file