@@ -27,6 +27,7 @@ limitations under the License.
2727#include " core/common/metrics.h"
2828#include " core/runtime/dit_master.h"
2929#include " core/runtime/llm_master.h"
30+ #include " core/runtime/rec_master.h"
3031#include " core/runtime/vlm_master.h"
3132#include " core/util/closure_guard.h"
3233#include " embedding.pb.h"
@@ -62,6 +63,9 @@ APIService::APIService(Master* master,
6263 image_generation_service_impl_ =
6364 std::make_unique<ImageGenerationServiceImpl>(
6465 dynamic_cast <DiTMaster*>(master), model_names);
66+ } else if (FLAGS_backend == " rec" ) {
67+ rec_completion_service_impl_ = std::make_unique<RecCompletionServiceImpl>(
68+ dynamic_cast <RecMaster*>(master), model_names);
6569 }
6670 models_service_impl_ =
6771 ServiceImplFactory<ModelsServiceImpl>::create_service_impl (
@@ -72,13 +76,6 @@ void APIService::Completions(::google::protobuf::RpcController* controller,
7276 const proto::CompletionRequest* request,
7377 proto::CompletionResponse* response,
7478 ::google::protobuf::Closure* done) {
75- // TODO with xllm-service
76- }
77-
78- void APIService::CompletionsHttp (::google::protobuf::RpcController* controller,
79- const proto::HttpRequest* request,
80- proto::HttpResponse* response,
81- ::google::protobuf::Closure* done) {
8279 xllm::ClosureGuard done_guard (
8380 done,
8481 std::bind (request_in_metric, nullptr ),
@@ -87,47 +84,38 @@ void APIService::CompletionsHttp(::google::protobuf::RpcController* controller,
8784 LOG (ERROR) << " brpc request | respose | controller is null" ;
8885 return ;
8986 }
90-
91- auto arena = response->GetArena ();
92- auto req_pb =
93- google::protobuf::Arena::CreateMessage<proto::CompletionRequest>(arena);
94- auto resp_pb =
95- google::protobuf::Arena::CreateMessage<proto::CompletionResponse>(arena);
96-
9787 auto ctrl = reinterpret_cast <brpc::Controller*>(controller);
98- std::string error;
99- json2pb::Json2PbOptions options;
100- butil::IOBuf& buf = ctrl->request_attachment ();
101- butil::IOBufAsZeroCopyInputStream iobuf_stream (buf);
102- auto st = json2pb::JsonToProtoMessage (&iobuf_stream, req_pb, options, &error);
103- if (!st) {
104- ctrl->SetFailed (error);
105- LOG (ERROR) << " parse json to proto failed: " << error;
106- return ;
107- }
10888
109- std::shared_ptr<Call> call = std::make_shared<CompletionCall>(
110- ctrl, done_guard.release (), req_pb, resp_pb);
111- completion_service_impl_->process_async (call);
112- }
113-
114- void APIService::ChatCompletions (::google::protobuf::RpcController* controller,
115- const proto::ChatRequest* request,
116- proto::ChatResponse* response,
117- ::google::protobuf::Closure* done) {
118- // TODO with xllm-service
89+ if (FLAGS_backend == " llm" ) {
90+ CHECK (completion_service_impl_) << " completion service is invalid." ;
91+ std::shared_ptr<Call> call = std::make_shared<CompletionCall>(
92+ ctrl,
93+ done_guard.release (),
94+ const_cast <proto::CompletionRequest*>(request),
95+ response);
96+ completion_service_impl_->process_async (call);
97+ } else if (FLAGS_backend == " rec" ) {
98+ CHECK (rec_completion_service_impl_)
99+ << " rec completion service is invalid." ;
100+ std::shared_ptr<Call> call = std::make_shared<CompletionCall>(
101+ ctrl,
102+ done_guard.release (),
103+ const_cast <proto::CompletionRequest*>(request),
104+ response);
105+ rec_completion_service_impl_->process_async (call);
106+ }
119107}
120108
121109namespace {
122- template <typename ChatCall , typename Service>
123- void ChatCompletionsImpl (std::unique_ptr<Service>& service,
124- xllm::ClosureGuard& guard,
125- ::google::protobuf::Arena* arena,
126- brpc::Controller* ctrl) {
110+ template <typename Call , typename Service>
111+ void CommonCompletionsImpl (std::unique_ptr<Service>& service,
112+ xllm::ClosureGuard& guard,
113+ ::google::protobuf::Arena* arena,
114+ brpc::Controller* ctrl) {
127115 auto req_pb =
128- google::protobuf::Arena::CreateMessage<typename ChatCall ::ReqType>(arena);
116+ google::protobuf::Arena::CreateMessage<typename Call ::ReqType>(arena);
129117 auto resp_pb =
130- google::protobuf::Arena::CreateMessage<typename ChatCall ::ResType>(arena);
118+ google::protobuf::Arena::CreateMessage<typename Call ::ResType>(arena);
131119
132120 std::string error;
133121 json2pb::Json2PbOptions options;
@@ -140,12 +128,46 @@ void ChatCompletionsImpl(std::unique_ptr<Service>& service,
140128 return ;
141129 }
142130
143- auto call =
144- std::make_shared<ChatCall>(ctrl, guard.release (), req_pb, resp_pb);
131+ auto call = std::make_shared<Call>(ctrl, guard.release (), req_pb, resp_pb);
145132 service->process_async (call);
146133}
147134} // namespace
148135
136+ void APIService::CompletionsHttp (::google::protobuf::RpcController* controller,
137+ const proto::HttpRequest* request,
138+ proto::HttpResponse* response,
139+ ::google::protobuf::Closure* done) {
140+ xllm::ClosureGuard done_guard (
141+ done,
142+ std::bind (request_in_metric, nullptr ),
143+ std::bind (request_out_metric, (void *)controller));
144+ if (!request || !response || !controller) {
145+ LOG (ERROR) << " brpc request | respose | controller is null" ;
146+ return ;
147+ }
148+
149+ auto arena = response->GetArena ();
150+ auto ctrl = reinterpret_cast <brpc::Controller*>(controller);
151+
152+ if (FLAGS_backend == " llm" ) {
153+ CHECK (completion_service_impl_) << " completion service is invalid." ;
154+ CommonCompletionsImpl<CompletionCall, CompletionServiceImpl>(
155+ completion_service_impl_, done_guard, arena, ctrl);
156+ } else if (FLAGS_backend == " rec" ) {
157+ CHECK (rec_completion_service_impl_)
158+ << " rec completion service is invalid." ;
159+ CommonCompletionsImpl<CompletionCall, RecCompletionServiceImpl>(
160+ rec_completion_service_impl_, done_guard, arena, ctrl);
161+ }
162+ }
163+
164+ void APIService::ChatCompletions (::google::protobuf::RpcController* controller,
165+ const proto::ChatRequest* request,
166+ proto::ChatResponse* response,
167+ ::google::protobuf::Closure* done) {
168+ // TODO with xllm-service
169+ }
170+
149171void APIService::ChatCompletionsHttp (
150172 ::google::protobuf::RpcController* controller,
151173 const proto::HttpRequest* request,
@@ -165,11 +187,11 @@ void APIService::ChatCompletionsHttp(
165187
166188 if (FLAGS_backend == " llm" ) {
167189 CHECK (chat_service_impl_) << " chat service is invalid." ;
168- ChatCompletionsImpl <ChatCall, ChatServiceImpl>(
190+ CommonCompletionsImpl <ChatCall, ChatServiceImpl>(
169191 chat_service_impl_, done_guard, arena, ctrl);
170192 } else if (FLAGS_backend == " vlm" ) {
171193 CHECK (mm_chat_service_impl_) << " mm chat service is invalid." ;
172- ChatCompletionsImpl <MMChatCall, MMChatServiceImpl>(
194+ CommonCompletionsImpl <MMChatCall, MMChatServiceImpl>(
173195 mm_chat_service_impl_, done_guard, arena, ctrl);
174196 }
175197}
0 commit comments