@@ -168,17 +168,38 @@ void WorkerServer::create_spawn_server(int local_rank,
168168 done.store (true );
169169}
170170
171- WorkerServer::WorkerServer (
172- int local_worker_idx,
173- const std::string& master_node_addr,
174- std::atomic<bool >& done,
171+ void WorkerServer::prepare_shm (
175172 const ParallelArgs& parallel_args,
176- const torch::Device& d,
177173 const runtime::Options& options,
178- WorkerType worker_type,
179- bool use_spawn_worker,
180- std::unique_ptr<ForwardSharedMemoryManager> input_shm_manager,
181- std::unique_ptr<ForwardSharedMemoryManager> output_shm_manager) {
174+ std::unique_ptr<ForwardSharedMemoryManager>& input_shm_manager,
175+ std::unique_ptr<ForwardSharedMemoryManager>& output_shm_manager) {
176+ if (options.is_local () && FLAGS_enable_shm) {
177+ bool is_creator;
178+ int dp_local_tp_size = parallel_args.world_size () / parallel_args.dp_size ();
179+ int dp_group = parallel_args.rank () / dp_local_tp_size;
180+
181+ string name = ForwardSharedMemoryManager::create_unique_name (
182+ dp_group, FORWARD_RAW_INPUT_TYPE, parallel_args.rank ());
183+ input_shm_manager = std::make_unique<ForwardSharedMemoryManager>(
184+ name, PB_INPUT_SHM_SIZE, is_creator, FORWARD_RAW_INPUT_TYPE);
185+ LOG (INFO) << " Create input shared memory manager with name: " << name;
186+
187+ name = ForwardSharedMemoryManager::create_unique_name (
188+ dp_group, FORWARD_RAW_OUTPUT_TYPE, parallel_args.rank ());
189+ output_shm_manager = std::make_unique<ForwardSharedMemoryManager>(
190+ name, PB_OUTPUT_SHM_SIZE, is_creator, FORWARD_RAW_OUTPUT_TYPE);
191+ LOG (INFO) << " Create output shared memory manager with name: " << name;
192+ }
193+ }
194+
195+ WorkerServer::WorkerServer (int local_worker_idx,
196+ const std::string& master_node_addr,
197+ std::atomic<bool >& done,
198+ const ParallelArgs& parallel_args,
199+ const torch::Device& d,
200+ const runtime::Options& options,
201+ WorkerType worker_type,
202+ bool use_spawn_worker) {
182203 if (worker_type == WorkerType::LLM || worker_type == WorkerType::ELM) {
183204 if (use_spawn_worker) {
184205 // start worker in a spawn process(for offline inference worker.)
@@ -187,6 +208,9 @@ WorkerServer::WorkerServer(
187208 return ;
188209 }
189210
211+ std::unique_ptr<ForwardSharedMemoryManager> input_shm_manager = nullptr ;
212+ std::unique_ptr<ForwardSharedMemoryManager> output_shm_manager = nullptr ;
213+ prepare_shm (parallel_args, options, input_shm_manager, output_shm_manager);
190214 // start worker in a thread.
191215 worker_thread_ =
192216 std::make_unique<std::thread>(&WorkerServer::create_server,
0 commit comments