Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ if(USE_NPU)
if(DEVICE_TYPE STREQUAL "USE_A3")
message("downloading a3 arm xllm kernels")
file(DOWNLOAD
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.6.0/xllm_kernels-1.3.1-Linux.a3.arm.rpm"
"https://9n-online-service.s3-internal.cn-north-1.jdcloud-oss.com/9n-xllm-atb/xllm_kernels-1.3.0-Linux.rpm"
"${CMAKE_BINARY_DIR}/xllm_kernels.rpm"
)
else()
Expand Down
112 changes: 67 additions & 45 deletions xllm/api_service/api_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "core/common/metrics.h"
#include "core/runtime/dit_master.h"
#include "core/runtime/llm_master.h"
#include "core/runtime/rec_master.h"
#include "core/runtime/vlm_master.h"
#include "core/util/closure_guard.h"
#include "embedding.pb.h"
Expand Down Expand Up @@ -62,6 +63,9 @@ APIService::APIService(Master* master,
image_generation_service_impl_ =
std::make_unique<ImageGenerationServiceImpl>(
dynamic_cast<DiTMaster*>(master), model_names);
} else if (FLAGS_backend == "rec") {
rec_completion_service_impl_ = std::make_unique<RecCompletionServiceImpl>(
dynamic_cast<RecMaster*>(master), model_names);
}
models_service_impl_ =
ServiceImplFactory<ModelsServiceImpl>::create_service_impl(
Expand All @@ -72,13 +76,6 @@ void APIService::Completions(::google::protobuf::RpcController* controller,
const proto::CompletionRequest* request,
proto::CompletionResponse* response,
::google::protobuf::Closure* done) {
// TODO with xllm-service
}

void APIService::CompletionsHttp(::google::protobuf::RpcController* controller,
const proto::HttpRequest* request,
proto::HttpResponse* response,
::google::protobuf::Closure* done) {
xllm::ClosureGuard done_guard(
done,
std::bind(request_in_metric, nullptr),
Expand All @@ -87,47 +84,38 @@ void APIService::CompletionsHttp(::google::protobuf::RpcController* controller,
LOG(ERROR) << "brpc request | respose | controller is null";
return;
}

auto arena = response->GetArena();
auto req_pb =
google::protobuf::Arena::CreateMessage<proto::CompletionRequest>(arena);
auto resp_pb =
google::protobuf::Arena::CreateMessage<proto::CompletionResponse>(arena);

auto ctrl = reinterpret_cast<brpc::Controller*>(controller);
std::string error;
json2pb::Json2PbOptions options;
butil::IOBuf& buf = ctrl->request_attachment();
butil::IOBufAsZeroCopyInputStream iobuf_stream(buf);
auto st = json2pb::JsonToProtoMessage(&iobuf_stream, req_pb, options, &error);
if (!st) {
ctrl->SetFailed(error);
LOG(ERROR) << "parse json to proto failed: " << error;
return;
}

std::shared_ptr<Call> call = std::make_shared<CompletionCall>(
ctrl, done_guard.release(), req_pb, resp_pb);
completion_service_impl_->process_async(call);
}

void APIService::ChatCompletions(::google::protobuf::RpcController* controller,
const proto::ChatRequest* request,
proto::ChatResponse* response,
::google::protobuf::Closure* done) {
// TODO with xllm-service
if (FLAGS_backend == "llm") {
CHECK(completion_service_impl_) << " completion service is invalid.";
std::shared_ptr<Call> call = std::make_shared<CompletionCall>(
ctrl,
done_guard.release(),
const_cast<proto::CompletionRequest*>(request),
response);
completion_service_impl_->process_async(call);
} else if (FLAGS_backend == "rec") {
CHECK(rec_completion_service_impl_)
<< " rec completion service is invalid.";
std::shared_ptr<Call> call = std::make_shared<CompletionCall>(
ctrl,
done_guard.release(),
const_cast<proto::CompletionRequest*>(request),
response);
rec_completion_service_impl_->process_async(call);
}
}

namespace {
template <typename ChatCall, typename Service>
void ChatCompletionsImpl(std::unique_ptr<Service>& service,
xllm::ClosureGuard& guard,
::google::protobuf::Arena* arena,
brpc::Controller* ctrl) {
template <typename Call, typename Service>
void CommonCompletionsImpl(std::unique_ptr<Service>& service,
xllm::ClosureGuard& guard,
::google::protobuf::Arena* arena,
brpc::Controller* ctrl) {
auto req_pb =
google::protobuf::Arena::CreateMessage<typename ChatCall::ReqType>(arena);
google::protobuf::Arena::CreateMessage<typename Call::ReqType>(arena);
auto resp_pb =
google::protobuf::Arena::CreateMessage<typename ChatCall::ResType>(arena);
google::protobuf::Arena::CreateMessage<typename Call::ResType>(arena);

std::string error;
json2pb::Json2PbOptions options;
Expand All @@ -140,12 +128,46 @@ void ChatCompletionsImpl(std::unique_ptr<Service>& service,
return;
}

auto call =
std::make_shared<ChatCall>(ctrl, guard.release(), req_pb, resp_pb);
auto call = std::make_shared<Call>(ctrl, guard.release(), req_pb, resp_pb);
service->process_async(call);
}
} // namespace

void APIService::CompletionsHttp(::google::protobuf::RpcController* controller,
const proto::HttpRequest* request,
proto::HttpResponse* response,
::google::protobuf::Closure* done) {
xllm::ClosureGuard done_guard(
done,
std::bind(request_in_metric, nullptr),
std::bind(request_out_metric, (void*)controller));
if (!request || !response || !controller) {
LOG(ERROR) << "brpc request | respose | controller is null";
return;
}

auto arena = response->GetArena();
auto ctrl = reinterpret_cast<brpc::Controller*>(controller);

if (FLAGS_backend == "llm") {
CHECK(completion_service_impl_) << " completion service is invalid.";
CommonCompletionsImpl<CompletionCall, CompletionServiceImpl>(
completion_service_impl_, done_guard, arena, ctrl);
} else if (FLAGS_backend == "rec") {
CHECK(rec_completion_service_impl_)
<< " rec completion service is invalid.";
CommonCompletionsImpl<CompletionCall, RecCompletionServiceImpl>(
rec_completion_service_impl_, done_guard, arena, ctrl);
}
}

void APIService::ChatCompletions(::google::protobuf::RpcController* controller,
const proto::ChatRequest* request,
proto::ChatResponse* response,
::google::protobuf::Closure* done) {
// TODO with xllm-service
}

void APIService::ChatCompletionsHttp(
::google::protobuf::RpcController* controller,
const proto::HttpRequest* request,
Expand All @@ -165,11 +187,11 @@ void APIService::ChatCompletionsHttp(

if (FLAGS_backend == "llm") {
CHECK(chat_service_impl_) << " chat service is invalid.";
ChatCompletionsImpl<ChatCall, ChatServiceImpl>(
CommonCompletionsImpl<ChatCall, ChatServiceImpl>(
chat_service_impl_, done_guard, arena, ctrl);
} else if (FLAGS_backend == "vlm") {
CHECK(mm_chat_service_impl_) << " mm chat service is invalid.";
ChatCompletionsImpl<MMChatCall, MMChatServiceImpl>(
CommonCompletionsImpl<MMChatCall, MMChatServiceImpl>(
mm_chat_service_impl_, done_guard, arena, ctrl);
}
}
Expand Down
1 change: 1 addition & 0 deletions xllm/api_service/api_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class APIService : public proto::XllmAPIService {
std::unique_ptr<ModelsServiceImpl> models_service_impl_;
std::unique_ptr<ImageGenerationServiceImpl> image_generation_service_impl_;
std::unique_ptr<RerankServiceImpl> rerank_service_impl_;
std::unique_ptr<RecCompletionServiceImpl> rec_completion_service_impl_;
};

} // namespace xllm
Loading