Skip to content

Commit b3af2da

Browse files
Super Useryq33victor
authored andcommitted
refactor: integrate prepare_shm function into the worker_server class.
1 parent 884377f commit b3af2da

File tree

3 files changed

+56
-59
lines changed

3 files changed

+56
-59
lines changed

xllm/core/distributed_runtime/dist_manager.cpp

Lines changed: 9 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -117,27 +117,6 @@ std::unique_ptr<CommChannel> create_channel(const std::string& worker_addrs,
117117

118118
return channel;
119119
}
120-
121-
void prepare_shm(
122-
int dp_local_tp_size,
123-
int rank,
124-
std::unique_ptr<ForwardSharedMemoryManager>& input_shm_manager,
125-
std::unique_ptr<ForwardSharedMemoryManager>& output_shm_manager) {
126-
bool is_creator;
127-
int32_t dp_group = rank / dp_local_tp_size;
128-
129-
string name = ForwardSharedMemoryManager::create_unique_name(
130-
dp_group, FORWARD_RAW_INPUT_TYPE, rank);
131-
input_shm_manager = std::make_unique<ForwardSharedMemoryManager>(
132-
name, PB_INPUT_SHM_SIZE, is_creator, FORWARD_RAW_INPUT_TYPE);
133-
LOG(INFO) << "Create input shared memory manager with name: " << name;
134-
135-
name = ForwardSharedMemoryManager::create_unique_name(
136-
dp_group, FORWARD_RAW_OUTPUT_TYPE, rank);
137-
output_shm_manager = std::make_unique<ForwardSharedMemoryManager>(
138-
name, PB_OUTPUT_SHM_SIZE, is_creator, FORWARD_RAW_OUTPUT_TYPE);
139-
LOG(INFO) << "Create output shared memory manager with name: " << name;
140-
}
141120
} // namespace
142121

143122
void DistManager::setup_multi_node_workers(
@@ -201,24 +180,15 @@ void DistManager::setup_multi_node_workers(
201180
bool use_spawn_worker = options.enable_offline_inference() && i > 0;
202181
ParallelArgs parallel_args(rank, world_size, dp_size, nullptr, ep_size);
203182

204-
std::unique_ptr<ForwardSharedMemoryManager> input_shm_manager = nullptr;
205-
std::unique_ptr<ForwardSharedMemoryManager> output_shm_manager = nullptr;
206-
if (options.is_local() && FLAGS_enable_shm) {
207-
prepare_shm(
208-
dp_local_tp_size, rank, input_shm_manager, output_shm_manager);
209-
}
210-
servers_.emplace_back(
211-
std::make_unique<WorkerServer>(i,
212-
master_node_addr,
213-
// done,
214-
dones[i],
215-
parallel_args,
216-
devices[i],
217-
worker_server_options,
218-
worker_type,
219-
use_spawn_worker,
220-
std::move(input_shm_manager),
221-
std::move(output_shm_manager)));
183+
servers_.emplace_back(std::make_unique<WorkerServer>(i,
184+
master_node_addr,
185+
// done,
186+
dones[i],
187+
parallel_args,
188+
devices[i],
189+
worker_server_options,
190+
worker_type,
191+
use_spawn_worker));
222192
}
223193

224194
// Master node need to wait all workers done

xllm/core/distributed_runtime/worker_server.cpp

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,17 +168,38 @@ void WorkerServer::create_spawn_server(int local_rank,
168168
done.store(true);
169169
}
170170

171-
WorkerServer::WorkerServer(
172-
int local_worker_idx,
173-
const std::string& master_node_addr,
174-
std::atomic<bool>& done,
171+
void WorkerServer::prepare_shm(
175172
const ParallelArgs& parallel_args,
176-
const torch::Device& d,
177173
const runtime::Options& options,
178-
WorkerType worker_type,
179-
bool use_spawn_worker,
180-
std::unique_ptr<ForwardSharedMemoryManager> input_shm_manager,
181-
std::unique_ptr<ForwardSharedMemoryManager> output_shm_manager) {
174+
std::unique_ptr<ForwardSharedMemoryManager>& input_shm_manager,
175+
std::unique_ptr<ForwardSharedMemoryManager>& output_shm_manager) {
176+
if (options.is_local() && FLAGS_enable_shm) {
177+
bool is_creator;
178+
int dp_local_tp_size = parallel_args.world_size() / parallel_args.dp_size();
179+
int dp_group = parallel_args.rank() / dp_local_tp_size;
180+
181+
string name = ForwardSharedMemoryManager::create_unique_name(
182+
dp_group, FORWARD_RAW_INPUT_TYPE, parallel_args.rank());
183+
input_shm_manager = std::make_unique<ForwardSharedMemoryManager>(
184+
name, PB_INPUT_SHM_SIZE, is_creator, FORWARD_RAW_INPUT_TYPE);
185+
LOG(INFO) << "Create input shared memory manager with name: " << name;
186+
187+
name = ForwardSharedMemoryManager::create_unique_name(
188+
dp_group, FORWARD_RAW_OUTPUT_TYPE, parallel_args.rank());
189+
output_shm_manager = std::make_unique<ForwardSharedMemoryManager>(
190+
name, PB_OUTPUT_SHM_SIZE, is_creator, FORWARD_RAW_OUTPUT_TYPE);
191+
LOG(INFO) << "Create output shared memory manager with name: " << name;
192+
}
193+
}
194+
195+
WorkerServer::WorkerServer(int local_worker_idx,
196+
const std::string& master_node_addr,
197+
std::atomic<bool>& done,
198+
const ParallelArgs& parallel_args,
199+
const torch::Device& d,
200+
const runtime::Options& options,
201+
WorkerType worker_type,
202+
bool use_spawn_worker) {
182203
if (worker_type == WorkerType::LLM || worker_type == WorkerType::ELM) {
183204
if (use_spawn_worker) {
184205
// start worker in a spawn process(for offline inference worker.)
@@ -187,6 +208,9 @@ WorkerServer::WorkerServer(
187208
return;
188209
}
189210

211+
std::unique_ptr<ForwardSharedMemoryManager> input_shm_manager = nullptr;
212+
std::unique_ptr<ForwardSharedMemoryManager> output_shm_manager = nullptr;
213+
prepare_shm(parallel_args, options, input_shm_manager, output_shm_manager);
190214
// start worker in a thread.
191215
worker_thread_ =
192216
std::make_unique<std::thread>(&WorkerServer::create_server,

xllm/core/distributed_runtime/worker_server.h

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,14 @@ namespace xllm {
3737

3838
class WorkerServer {
3939
public:
40-
WorkerServer(
41-
int local_worker_idx,
42-
const std::string& master_node_addr,
43-
std::atomic<bool>& done,
44-
const ParallelArgs& parallel_args,
45-
const torch::Device& d,
46-
const runtime::Options& options,
47-
WorkerType worker_type,
48-
bool use_spawn_worker = false,
49-
std::unique_ptr<ForwardSharedMemoryManager> input_shm_manager = nullptr,
50-
std::unique_ptr<ForwardSharedMemoryManager> output_shm_manager = nullptr);
40+
WorkerServer(int local_worker_idx,
41+
const std::string& master_node_addr,
42+
std::atomic<bool>& done,
43+
const ParallelArgs& parallel_args,
44+
const torch::Device& d,
45+
const runtime::Options& options,
46+
WorkerType worker_type,
47+
bool use_spawn_worker = false);
5148

5249
virtual ~WorkerServer();
5350

@@ -82,6 +79,12 @@ class WorkerServer {
8279
proto::AddressInfo& addr_info,
8380
proto::CommUniqueIdList& uids);
8481

82+
void prepare_shm(
83+
const ParallelArgs& parallel_args,
84+
const runtime::Options& options,
85+
std::unique_ptr<ForwardSharedMemoryManager>& input_shm_manager,
86+
std::unique_ptr<ForwardSharedMemoryManager>& output_shm_manager);
87+
8588
private:
8689
std::unique_ptr<std::thread> worker_thread_;
8790

0 commit comments

Comments
 (0)