@@ -40,6 +40,8 @@ WorkerService::WorkerService(runtime::Options options,
4040 device_.set_device ();
4141 device_.init_device_context ();
4242 stream_ = device_.get_stream_from_pool ();
43+ threadpool_ = std::make_unique<ThreadPool>(
44+ 4 , [this ]() mutable { device_.set_device (); });
4345}
4446
4547WorkerService::WorkerService (runtime::Options options,
@@ -52,6 +54,8 @@ WorkerService::WorkerService(runtime::Options options,
5254 device_.set_device ();
5355 device_.init_device_context ();
5456 stream_ = device_.get_stream_from_pool ();
57+ threadpool_ = std::make_unique<ThreadPool>(
58+ 4 , [this ]() mutable { device_.set_device (); });
5559}
5660
5761WorkerService::~WorkerService () = default ;
@@ -72,7 +76,6 @@ void WorkerService::step(BatchedForwardInputs& batched_fwd_inputs,
7276 torch::Tensor& src_seq_idxes,
7377 torch::Tensor& out_tokens,
7478 torch::Tensor& out_logprobs) {
75- device_.set_device ();
7679 // execute model
7780 auto future = worker_->step_async (batched_fwd_inputs);
7881
@@ -250,7 +253,7 @@ void WorkerService::InitModel(::google::protobuf::RpcController* controller,
250253 const proto::ModelPath* request,
251254 proto::Status* response,
252255 ::google::protobuf::Closure* done) {
253- threadpool_. schedule ([this , controller, request, response, done]() mutable {
256+ threadpool_-> schedule ([this , controller, request, response, done]() mutable {
254257 brpc::ClosureGuard done_guard (done);
255258 auto model_weights_path = request->model_weights_path ();
256259 auto init_future = worker_->init_model_async (model_weights_path);
@@ -270,7 +273,7 @@ void WorkerService::ProcessGroupTest(
270273 const proto::Empty* request,
271274 proto::Status* response,
272275 ::google::protobuf::Closure* done) {
273- threadpool_. schedule ([this , controller, request, response, done]() mutable {
276+ threadpool_-> schedule ([this , controller, request, response, done]() mutable {
274277 brpc::ClosureGuard done_guard (done);
275278 auto future = worker_->process_group_test_async ();
276279 std::move (future).get ();
@@ -284,7 +287,7 @@ void WorkerService::ProfileDeviceMemory(
284287 const proto::Empty* request,
285288 proto::DeviceMemory* response,
286289 ::google::protobuf::Closure* done) {
287- threadpool_. schedule ([this , controller, request, response, done]() mutable {
290+ threadpool_-> schedule ([this , controller, request, response, done]() mutable {
288291 brpc::ClosureGuard done_guard (done);
289292 auto future = worker_->estimate_kv_cache_capacity_async ();
290293 std::tuple<int64_t , int64_t > result = std::move (future).get ();
@@ -299,7 +302,7 @@ void WorkerService::AllocateKVCache(
299302 const proto::KVCacheShape* request,
300303 proto::Status* response,
301304 ::google::protobuf::Closure* done) {
302- threadpool_. schedule ([this , controller, request, response, done]() mutable {
305+ threadpool_-> schedule ([this , controller, request, response, done]() mutable {
303306 brpc::ClosureGuard done_guard (done);
304307 std::vector<std::vector<int64_t >> kv_cache_shape;
305308 kv_cache_shape.reserve (2 );
@@ -319,7 +322,7 @@ void WorkerService::AllocateContinuousKVCache(
319322 const proto::XTensorOptionsVec* request,
320323 proto::Status* response,
321324 ::google::protobuf::Closure* done) {
322- threadpool_. schedule ([this , controller, request, response, done]() mutable {
325+ threadpool_-> schedule ([this , controller, request, response, done]() mutable {
323326 brpc::ClosureGuard done_guard (done);
324327 XTensor::Options key_options;
325328 XTensor::Options value_options;
@@ -350,7 +353,7 @@ void WorkerService::AllocateKVCacheWithTransfer(
350353 const proto::AllocateKVCacheWithTransferRequest* req,
351354 proto::Status* resp,
352355 ::google::protobuf::Closure* done) {
353- threadpool_. schedule ([this , controller, req, resp, done]() mutable {
356+ threadpool_-> schedule ([this , controller, req, resp, done]() mutable {
354357 brpc::ClosureGuard done_guard (done);
355358 uint64_t kv_cache_size = req->kv_cache_size ();
356359 std::vector<std::vector<int64_t >> kv_cache_shape;
@@ -373,7 +376,7 @@ void WorkerService::GetCacheInfo(::google::protobuf::RpcController* controller,
373376 const proto::Empty* req,
374377 proto::CacheInfo* resp,
375378 ::google::protobuf::Closure* done) {
376- threadpool_. schedule ([this , controller, req, resp, done]() mutable {
379+ threadpool_-> schedule ([this , controller, req, resp, done]() mutable {
377380 brpc::ClosureGuard done_guard (done);
378381 uint64_t cluster_id;
379382 std::string addr;
@@ -392,7 +395,7 @@ void WorkerService::PullKVCache(::google::protobuf::RpcController* controller,
392395 const proto::PullKVCacheRequest* req,
393396 proto::Status* resp,
394397 ::google::protobuf::Closure* done) {
395- threadpool_. schedule ([this , controller, req, resp, done]() mutable {
398+ threadpool_-> schedule ([this , controller, req, resp, done]() mutable {
396399 brpc::ClosureGuard done_guard (done);
397400 uint64_t src_cluster_id = req->cluster_id ();
398401 std::string addr = req->addr ();
@@ -433,7 +436,7 @@ void WorkerService::GetDeviceInfo(::google::protobuf::RpcController* controller,
433436 const proto::Empty* req,
434437 proto::DeviceInfo* resp,
435438 ::google::protobuf::Closure* done) {
436- threadpool_. schedule ([this , controller, req, resp, done]() mutable {
439+ threadpool_-> schedule ([this , controller, req, resp, done]() mutable {
437440 brpc::ClosureGuard done_guard (done);
438441 std::string device_ip;
439442 uint16_t listen_port;
@@ -448,7 +451,7 @@ void WorkerService::LinkCluster(::google::protobuf::RpcController* controller,
448451 const proto::ClusterInfo* req,
449452 proto::Status* resp,
450453 ::google::protobuf::Closure* done) {
451- threadpool_. schedule ([this , controller, req, resp, done]() mutable {
454+ threadpool_-> schedule ([this , controller, req, resp, done]() mutable {
452455 brpc::ClosureGuard done_guard (done);
453456 std::vector<uint64_t > cluster_ids (req->cluster_ids ().begin (),
454457 req->cluster_ids ().end ());
@@ -467,7 +470,7 @@ void WorkerService::UnlinkCluster(::google::protobuf::RpcController* controller,
467470 const proto::ClusterInfo* req,
468471 proto::Status* resp,
469472 ::google::protobuf::Closure* done) {
470- threadpool_. schedule ([this , controller, req, resp, done]() mutable {
473+ threadpool_-> schedule ([this , controller, req, resp, done]() mutable {
471474 brpc::ClosureGuard done_guard (done);
472475 std::vector<uint64_t > cluster_ids (req->cluster_ids ().begin (),
473476 req->cluster_ids ().end ());
@@ -488,11 +491,11 @@ void WorkerService::ExecuteModel(
488491 const proto::BatchedForwardInputs* pb_batched_fwd_inputs,
489492 proto::ForwardOutput* pb_forward_output,
490493 ::google::protobuf::Closure* done) {
491- threadpool_. schedule ([this ,
492- controller,
493- pb_batched_fwd_inputs,
494- pb_forward_output,
495- done]() mutable {
494+ threadpool_-> schedule ([this ,
495+ controller,
496+ pb_batched_fwd_inputs,
497+ pb_forward_output,
498+ done]() mutable {
496499 brpc::ClosureGuard done_guard (done);
497500 Timer timer;
498501 // convert proto::BatchedForwardInputs to BatchedForwardInputs
@@ -574,9 +577,8 @@ void WorkerService::GetLastStepResult(
574577 const proto::Empty* req,
575578 proto::ForwardOutput* pb_forward_output,
576579 ::google::protobuf::Closure* done) {
577- threadpool_. schedule (
580+ threadpool_-> schedule (
578581 [this , controller, req, pb_forward_output, done]() mutable {
579- device_.set_device ();
580582 brpc::ClosureGuard done_guard (done);
581583
582584 auto future = worker_->get_last_step_result_async ();
@@ -642,7 +644,7 @@ void WorkerService::GetActiveActivationMemory(
642644 const proto::Empty* req,
643645 proto::ActivationMemory* resp,
644646 ::google::protobuf::Closure* done) {
645- threadpool_. schedule ([this , controller, req, resp, done]() mutable {
647+ threadpool_-> schedule ([this , controller, req, resp, done]() mutable {
646648 brpc::ClosureGuard done_guard (done);
647649 auto future = worker_->get_active_activation_memory_async ();
648650 int64_t active_activation_memory = std::move (future).get ();
0 commit comments