Skip to content

Commit c8fba1a

Browse files
authored
refactor: optimize the 'set_device' function calling to avoid set device on each step. (#321)
Signed-off-by: Tao Peng <[email protected]>
1 parent f74f283 commit c8fba1a

File tree

7 files changed

+47
-35
lines changed

7 files changed

+47
-35
lines changed

xllm/core/distributed_runtime/worker_service.cpp

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ WorkerService::WorkerService(runtime::Options options,
4040
device_.set_device();
4141
device_.init_device_context();
4242
stream_ = device_.get_stream_from_pool();
43+
threadpool_ = std::make_unique<ThreadPool>(
44+
4, [this]() mutable { device_.set_device(); });
4345
}
4446

4547
WorkerService::WorkerService(runtime::Options options,
@@ -52,6 +54,8 @@ WorkerService::WorkerService(runtime::Options options,
5254
device_.set_device();
5355
device_.init_device_context();
5456
stream_ = device_.get_stream_from_pool();
57+
threadpool_ = std::make_unique<ThreadPool>(
58+
4, [this]() mutable { device_.set_device(); });
5559
}
5660

5761
WorkerService::~WorkerService() = default;
@@ -72,7 +76,6 @@ void WorkerService::step(BatchedForwardInputs& batched_fwd_inputs,
7276
torch::Tensor& src_seq_idxes,
7377
torch::Tensor& out_tokens,
7478
torch::Tensor& out_logprobs) {
75-
device_.set_device();
7679
// execute model
7780
auto future = worker_->step_async(batched_fwd_inputs);
7881

@@ -250,7 +253,7 @@ void WorkerService::InitModel(::google::protobuf::RpcController* controller,
250253
const proto::ModelPath* request,
251254
proto::Status* response,
252255
::google::protobuf::Closure* done) {
253-
threadpool_.schedule([this, controller, request, response, done]() mutable {
256+
threadpool_->schedule([this, controller, request, response, done]() mutable {
254257
brpc::ClosureGuard done_guard(done);
255258
auto model_weights_path = request->model_weights_path();
256259
auto init_future = worker_->init_model_async(model_weights_path);
@@ -270,7 +273,7 @@ void WorkerService::ProcessGroupTest(
270273
const proto::Empty* request,
271274
proto::Status* response,
272275
::google::protobuf::Closure* done) {
273-
threadpool_.schedule([this, controller, request, response, done]() mutable {
276+
threadpool_->schedule([this, controller, request, response, done]() mutable {
274277
brpc::ClosureGuard done_guard(done);
275278
auto future = worker_->process_group_test_async();
276279
std::move(future).get();
@@ -284,7 +287,7 @@ void WorkerService::ProfileDeviceMemory(
284287
const proto::Empty* request,
285288
proto::DeviceMemory* response,
286289
::google::protobuf::Closure* done) {
287-
threadpool_.schedule([this, controller, request, response, done]() mutable {
290+
threadpool_->schedule([this, controller, request, response, done]() mutable {
288291
brpc::ClosureGuard done_guard(done);
289292
auto future = worker_->estimate_kv_cache_capacity_async();
290293
std::tuple<int64_t, int64_t> result = std::move(future).get();
@@ -299,7 +302,7 @@ void WorkerService::AllocateKVCache(
299302
const proto::KVCacheShape* request,
300303
proto::Status* response,
301304
::google::protobuf::Closure* done) {
302-
threadpool_.schedule([this, controller, request, response, done]() mutable {
305+
threadpool_->schedule([this, controller, request, response, done]() mutable {
303306
brpc::ClosureGuard done_guard(done);
304307
std::vector<std::vector<int64_t>> kv_cache_shape;
305308
kv_cache_shape.reserve(2);
@@ -319,7 +322,7 @@ void WorkerService::AllocateContinuousKVCache(
319322
const proto::XTensorOptionsVec* request,
320323
proto::Status* response,
321324
::google::protobuf::Closure* done) {
322-
threadpool_.schedule([this, controller, request, response, done]() mutable {
325+
threadpool_->schedule([this, controller, request, response, done]() mutable {
323326
brpc::ClosureGuard done_guard(done);
324327
XTensor::Options key_options;
325328
XTensor::Options value_options;
@@ -350,7 +353,7 @@ void WorkerService::AllocateKVCacheWithTransfer(
350353
const proto::AllocateKVCacheWithTransferRequest* req,
351354
proto::Status* resp,
352355
::google::protobuf::Closure* done) {
353-
threadpool_.schedule([this, controller, req, resp, done]() mutable {
356+
threadpool_->schedule([this, controller, req, resp, done]() mutable {
354357
brpc::ClosureGuard done_guard(done);
355358
uint64_t kv_cache_size = req->kv_cache_size();
356359
std::vector<std::vector<int64_t>> kv_cache_shape;
@@ -373,7 +376,7 @@ void WorkerService::GetCacheInfo(::google::protobuf::RpcController* controller,
373376
const proto::Empty* req,
374377
proto::CacheInfo* resp,
375378
::google::protobuf::Closure* done) {
376-
threadpool_.schedule([this, controller, req, resp, done]() mutable {
379+
threadpool_->schedule([this, controller, req, resp, done]() mutable {
377380
brpc::ClosureGuard done_guard(done);
378381
uint64_t cluster_id;
379382
std::string addr;
@@ -392,7 +395,7 @@ void WorkerService::PullKVCache(::google::protobuf::RpcController* controller,
392395
const proto::PullKVCacheRequest* req,
393396
proto::Status* resp,
394397
::google::protobuf::Closure* done) {
395-
threadpool_.schedule([this, controller, req, resp, done]() mutable {
398+
threadpool_->schedule([this, controller, req, resp, done]() mutable {
396399
brpc::ClosureGuard done_guard(done);
397400
uint64_t src_cluster_id = req->cluster_id();
398401
std::string addr = req->addr();
@@ -433,7 +436,7 @@ void WorkerService::GetDeviceInfo(::google::protobuf::RpcController* controller,
433436
const proto::Empty* req,
434437
proto::DeviceInfo* resp,
435438
::google::protobuf::Closure* done) {
436-
threadpool_.schedule([this, controller, req, resp, done]() mutable {
439+
threadpool_->schedule([this, controller, req, resp, done]() mutable {
437440
brpc::ClosureGuard done_guard(done);
438441
std::string device_ip;
439442
uint16_t listen_port;
@@ -448,7 +451,7 @@ void WorkerService::LinkCluster(::google::protobuf::RpcController* controller,
448451
const proto::ClusterInfo* req,
449452
proto::Status* resp,
450453
::google::protobuf::Closure* done) {
451-
threadpool_.schedule([this, controller, req, resp, done]() mutable {
454+
threadpool_->schedule([this, controller, req, resp, done]() mutable {
452455
brpc::ClosureGuard done_guard(done);
453456
std::vector<uint64_t> cluster_ids(req->cluster_ids().begin(),
454457
req->cluster_ids().end());
@@ -467,7 +470,7 @@ void WorkerService::UnlinkCluster(::google::protobuf::RpcController* controller,
467470
const proto::ClusterInfo* req,
468471
proto::Status* resp,
469472
::google::protobuf::Closure* done) {
470-
threadpool_.schedule([this, controller, req, resp, done]() mutable {
473+
threadpool_->schedule([this, controller, req, resp, done]() mutable {
471474
brpc::ClosureGuard done_guard(done);
472475
std::vector<uint64_t> cluster_ids(req->cluster_ids().begin(),
473476
req->cluster_ids().end());
@@ -488,11 +491,11 @@ void WorkerService::ExecuteModel(
488491
const proto::BatchedForwardInputs* pb_batched_fwd_inputs,
489492
proto::ForwardOutput* pb_forward_output,
490493
::google::protobuf::Closure* done) {
491-
threadpool_.schedule([this,
492-
controller,
493-
pb_batched_fwd_inputs,
494-
pb_forward_output,
495-
done]() mutable {
494+
threadpool_->schedule([this,
495+
controller,
496+
pb_batched_fwd_inputs,
497+
pb_forward_output,
498+
done]() mutable {
496499
brpc::ClosureGuard done_guard(done);
497500
Timer timer;
498501
// convert proto::BatchedForwardInputs to BatchedForwardInputs
@@ -574,9 +577,8 @@ void WorkerService::GetLastStepResult(
574577
const proto::Empty* req,
575578
proto::ForwardOutput* pb_forward_output,
576579
::google::protobuf::Closure* done) {
577-
threadpool_.schedule(
580+
threadpool_->schedule(
578581
[this, controller, req, pb_forward_output, done]() mutable {
579-
device_.set_device();
580582
brpc::ClosureGuard done_guard(done);
581583

582584
auto future = worker_->get_last_step_result_async();
@@ -642,7 +644,7 @@ void WorkerService::GetActiveActivationMemory(
642644
const proto::Empty* req,
643645
proto::ActivationMemory* resp,
644646
::google::protobuf::Closure* done) {
645-
threadpool_.schedule([this, controller, req, resp, done]() mutable {
647+
threadpool_->schedule([this, controller, req, resp, done]() mutable {
646648
brpc::ClosureGuard done_guard(done);
647649
auto future = worker_->get_active_activation_memory_async();
648650
int64_t active_activation_memory = std::move(future).get();

xllm/core/distributed_runtime/worker_service.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class WorkerService : public proto::DistributeWorker {
149149

150150
std::unique_ptr<std::thread> polling_thread_;
151151

152-
ThreadPool threadpool_{4};
152+
std::unique_ptr<ThreadPool> threadpool_;
153153
};
154154

155155
} // namespace xllm

xllm/core/runtime/dit_worker.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,12 @@ namespace xllm {
4141
DiTWorker::DiTWorker(const ParallelArgs& parallel_args,
4242
const torch::Device& device,
4343
const runtime::Options& options)
44-
: device_(device), options_(options), parallel_args_(parallel_args) {}
44+
: device_(device), options_(options), parallel_args_(parallel_args) {
45+
device_.set_device();
46+
}
4547

4648
bool DiTWorker::init_model(const std::string& model_weights_path) {
4749
CHECK(dit_model_ == nullptr) << "Model is already initialized.";
48-
device_.set_device();
4950

5051
auto loader = std::make_unique<DiTModelLoader>(model_weights_path);
5152
dtype_ = util::parse_dtype(loader->get_torch_dtype(), device_);
@@ -80,7 +81,6 @@ bool DiTWorker::init_model(const std::string& model_weights_path) {
8081
}
8182

8283
std::optional<DiTForwardOutput> DiTWorker::step(const DiTForwardInput& inputs) {
83-
device_.set_device();
8484
Timer timer;
8585

8686
auto output = dit_model_executor_->forward(inputs.to(device_, dtype_));

xllm/core/runtime/llm_worker_impl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,12 @@ namespace xllm {
4141
LLMWorkerImpl::LLMWorkerImpl(const ParallelArgs& parallel_args,
4242
const torch::Device& device,
4343
const runtime::Options& options)
44-
: WorkerImpl(parallel_args, device, options) {}
44+
: WorkerImpl(parallel_args, device, options) {
45+
device_.set_device();
46+
}
4547

4648
bool LLMWorkerImpl::init_model(ModelContext& context) {
4749
CHECK(model_ == nullptr) << "Model is already initialized.";
48-
device_.set_device();
4950

5051
// Try to create a causal LM model
5152
model_ = create_llm_model(context);
@@ -67,7 +68,6 @@ bool LLMWorkerImpl::init_model(ModelContext& context) {
6768

6869
std::optional<ForwardOutput> LLMWorkerImpl::step(
6970
const BatchedForwardInputs& inputs) {
70-
device_.set_device();
7171
Timer timer;
7272
std::vector<torch::Tensor> flatten_tokens_micro_batches;
7373
std::vector<torch::Tensor> flatten_positions_micro_batches;

xllm/core/runtime/vlm_worker_impl.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,13 @@ namespace xllm {
3838
VLMWorkerImpl::VLMWorkerImpl(const ParallelArgs& parallel_args,
3939
const torch::Device& device,
4040
const runtime::Options& options)
41-
: WorkerImpl(parallel_args, device, options) {}
41+
: WorkerImpl(parallel_args, device, options) {
42+
device_.set_device();
43+
}
4244

4345
bool VLMWorkerImpl::init_model(ModelContext& context) {
4446
CHECK(model_ == nullptr) << "Model is already initialized.";
4547

46-
device_.set_device();
47-
4848
// initialize model
4949
context.set_image_embedding_mode(false);
5050
model_ = create_vlm_model(context);
@@ -56,7 +56,6 @@ bool VLMWorkerImpl::init_model(ModelContext& context) {
5656

5757
std::optional<ForwardOutput> VLMWorkerImpl::step(
5858
const BatchedForwardInputs& inputs) {
59-
device_.set_device();
6059
Timer timer;
6160
// TODO guojinrong, to adapt multi stream parallel later
6261
// all tensors should be on the same device as model

xllm/core/util/threadpool.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,15 @@ limitations under the License.
1818
#include <thread>
1919

2020
namespace xllm {
21-
ThreadPool::ThreadPool(size_t num_threads) : queues_(num_threads) {
21+
ThreadPool::ThreadPool(size_t num_threads) : ThreadPool(num_threads, nullptr) {}
22+
23+
ThreadPool::ThreadPool(size_t num_threads, Runnable init_func)
24+
: queues_(num_threads) {
2225
for (size_t i = 0; i < num_threads; ++i) {
23-
threads_.emplace_back([this, i]() { internal_loop(i); });
26+
threads_.emplace_back(
27+
[this, i, init_func = std::move(init_func)]() mutable {
28+
internal_loop(i, std::move(init_func));
29+
});
2430
}
2531
}
2632

@@ -60,7 +66,11 @@ void ThreadPool::schedule_with_tid(Runnable runnable, size_t tid) {
6066
queues_[tid].enqueue(std::move(runnable));
6167
}
6268

63-
void ThreadPool::internal_loop(size_t index) {
69+
void ThreadPool::internal_loop(size_t index, Runnable&& init_func) {
70+
if (init_func != nullptr) {
71+
init_func();
72+
}
73+
6474
while (true) {
6575
Runnable runnable;
6676
queues_[index].wait_dequeue(runnable);

xllm/core/util/threadpool.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class ThreadPool final {
4040
ThreadPool& operator=(ThreadPool&&) = delete;
4141

4242
explicit ThreadPool(size_t num_threads);
43+
explicit ThreadPool(size_t num_threads, Runnable init_func);
4344

4445
// schedule a runnable to be executed
4546
int32_t schedule(Runnable runnable);
@@ -55,7 +56,7 @@ class ThreadPool final {
5556
size_t size() { return threads_.size(); }
5657

5758
private:
58-
void internal_loop(size_t tid);
59+
void internal_loop(size_t tid, Runnable&& init_func);
5960

6061
std::vector<std::thread> threads_;
6162
std::vector<moodycamel::BlockingConcurrentQueue<Runnable>> queues_;

0 commit comments

Comments
 (0)