Skip to content

Commit 6162f01

Browse files
authored
bugfix: fix the hang issue of offline inference when enable_shm. (#312)
Signed-off-by: Tao Peng <[email protected]>
1 parent 90dfadb commit 6162f01

File tree

14 files changed

+74
-41
lines changed

14 files changed

+74
-41
lines changed

xllm/core/distributed_runtime/dist_manager.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ std::unique_ptr<CommChannel> create_channel(const std::string& worker_addrs,
100100
const runtime::Options& options) {
101101
std::unique_ptr<CommChannel> channel;
102102

103-
if (net::extract_ip(FLAGS_master_node_addr) ==
103+
if (net::extract_ip(options.master_node_addr().value_or("")) ==
104104
net::extract_ip(worker_addrs) &&
105105
options.enable_shm()) {
106106
// create shared memory manager for local rank
@@ -118,6 +118,7 @@ std::unique_ptr<CommChannel> create_channel(const std::string& worker_addrs,
118118

119119
return channel;
120120
}
121+
121122
} // namespace
122123

123124
void DistManager::setup_multi_node_workers(

xllm/core/distributed_runtime/shm_channel.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include "shm_channel.h"
1717

1818
#include "common/global_flags.h"
19+
#include "util/net.h"
1920

2021
namespace xllm {
2122

@@ -26,16 +27,18 @@ ShmChannel::ShmChannel(int dp_group,
2627
: enable_shm_(options.enable_shm()) {
2728
bool is_creator;
2829

30+
std::string name_prefix =
31+
"xllm_" + net::extract_port(options.master_node_addr().value_or(""));
2932
if (is_driver) {
3033
auto name = ForwardSharedMemoryManager::create_unique_name(
31-
dp_group, FORWARD_RAW_INPUT_TYPE, rank);
34+
name_prefix, dp_group, FORWARD_RAW_INPUT_TYPE, rank);
3235
input_shm_manager_ = std::make_unique<ForwardSharedMemoryManager>(
3336
name, PB_INPUT_SHM_SIZE, is_creator, FORWARD_RAW_INPUT_TYPE);
3437
LOG(INFO) << "Create input shared memory manager with name: " << name;
3538
}
3639

3740
auto name = ForwardSharedMemoryManager::create_unique_name(
38-
dp_group, FORWARD_RAW_OUTPUT_TYPE, rank);
41+
name_prefix, dp_group, FORWARD_RAW_OUTPUT_TYPE, rank);
3942
output_shm_manager_ = std::make_unique<ForwardSharedMemoryManager>(
4043
name, PB_OUTPUT_SHM_SIZE, is_creator, FORWARD_RAW_OUTPUT_TYPE);
4144
LOG(INFO) << "Create output shared memory manager with name: " << name;

xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,17 @@ SpawnWorkerServer::SpawnWorkerServer(const std::string& master_node_addr,
3939
int device_idx,
4040
int num_decoding_tokens,
4141
int block_size,
42-
bool enable_shm) {
42+
bool enable_shm,
43+
bool is_local) {
4344
// TODO: pass whole xllm::runtime::Options here from main process.
4445
xllm::runtime::Options runner_options;
4546
runner_options.block_size(block_size)
4647
.num_decoding_tokens(num_decoding_tokens)
4748
.enable_schedule_overlap(false)
4849
.enable_offline_inference(true)
4950
.master_node_addr(master_node_addr)
50-
.enable_shm(enable_shm);
51+
.enable_shm(enable_shm)
52+
.is_local(is_local);
5153
FLAGS_enable_schedule_overlap = false;
5254
FLAGS_master_node_addr = master_node_addr;
5355
FLAGS_block_size = block_size;

xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ class SpawnWorkerServer final {
2828
int device_idx,
2929
int num_decoding_tokens,
3030
int block_size,
31-
bool enable_shm);
31+
bool enable_shm,
32+
bool is_local);
3233

3334
~SpawnWorkerServer() = default;
3435

xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server_process.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@ limitations under the License.
2929
// @num_decoding_tokens
3030
// @block_size
3131
// @enable_shm
32+
// @is_local
3233
int main(int argc, char* argv[]) {
33-
if (argc < 8) {
34+
if (argc < 9) {
3435
LOG(ERROR)
35-
<< "Spwan worker process receive wrong args. Need 8 args, receive "
36+
<< "Spwan worker process receive wrong args. Need 9 args, receive "
3637
<< argc;
3738
return 1;
3839
}
@@ -52,15 +53,17 @@ int main(int argc, char* argv[]) {
5253
int num_decoding_tokens = atoi(argv[6]);
5354
int block_size = atoi(argv[7]);
5455
int enable_shm = atoi(argv[8]);
56+
int is_local = atoi(argv[9]);
5557

5658
LOG(INFO) << "Spwan worker: "
5759
<< "master_node_addr = " << master_node_addr
58-
<< ", local_rank = " << local_rank
60+
<< ", is_local = " << is_local << ", local_rank = " << local_rank
5961
<< ", world_size = " << world_size
6062
<< ", device_idx = " << device_idx
6163
<< ", num_decoding_tokens = " << num_decoding_tokens
6264
<< ", block_size = " << block_size
63-
<< ", enable_shm = " << (enable_shm > 0) << "\n";
65+
<< ", enable_shm = " << (enable_shm > 0)
66+
<< ", enable_shm = " << (is_local > 0) << "\n";
6467

6568
xllm::SpawnWorkerServer worker(master_node_addr,
6669
local_rank,
@@ -69,7 +72,8 @@ int main(int argc, char* argv[]) {
6972
device_idx,
7073
num_decoding_tokens,
7174
block_size,
72-
enable_shm > 0);
75+
enable_shm > 0,
76+
is_local > 0);
7377

7478
worker.run();
7579

xllm/core/distributed_runtime/worker_server.cpp

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ void WorkerServer::create_spawn_server(int local_rank,
141141
const char* block_size_ptr = block_size_str.c_str();
142142
auto enable_shm_str = std::to_string(options.enable_shm());
143143
const char* enable_shm_ptr = enable_shm_str.c_str();
144+
auto is_local_str = std::to_string(options.is_local());
145+
const char* is_local_ptr = is_local_str.c_str();
144146
std::string spawn_worker_bin_path =
145147
options.spawn_worker_path() + "/spawn_worker";
146148
LOG(INFO) << "Spawn worker path: " << spawn_worker_bin_path;
@@ -153,6 +155,7 @@ void WorkerServer::create_spawn_server(int local_rank,
153155
num_decoding_tokens_ptr,
154156
block_size_ptr,
155157
enable_shm_ptr,
158+
is_local_ptr,
156159
nullptr};
157160
pid_t pid;
158161
posix_spawn_file_actions_init(&file_actions_);
@@ -181,14 +184,16 @@ void WorkerServer::prepare_shm(
181184
int dp_local_tp_size = parallel_args.world_size() / parallel_args.dp_size();
182185
int dp_group = parallel_args.rank() / dp_local_tp_size;
183186

187+
std::string name_prefix =
188+
"xllm_" + net::extract_port(options.master_node_addr().value());
184189
string name = ForwardSharedMemoryManager::create_unique_name(
185-
dp_group, FORWARD_RAW_INPUT_TYPE, parallel_args.rank());
190+
name_prefix, dp_group, FORWARD_RAW_INPUT_TYPE, parallel_args.rank());
186191
input_shm_manager = std::make_unique<ForwardSharedMemoryManager>(
187192
name, PB_INPUT_SHM_SIZE, is_creator, FORWARD_RAW_INPUT_TYPE);
188193
LOG(INFO) << "Create input shared memory manager with name: " << name;
189194

190195
name = ForwardSharedMemoryManager::create_unique_name(
191-
dp_group, FORWARD_RAW_OUTPUT_TYPE, parallel_args.rank());
196+
name_prefix, dp_group, FORWARD_RAW_OUTPUT_TYPE, parallel_args.rank());
192197
output_shm_manager = std::make_unique<ForwardSharedMemoryManager>(
193198
name, PB_OUTPUT_SHM_SIZE, is_creator, FORWARD_RAW_OUTPUT_TYPE);
194199
LOG(INFO) << "Create output shared memory manager with name: " << name;
@@ -204,31 +209,34 @@ WorkerServer::WorkerServer(int local_worker_idx,
204209
WorkerType worker_type,
205210
bool use_spawn_worker) {
206211
if (worker_type == WorkerType::LLM || worker_type == WorkerType::ELM) {
212+
// TODO: Refactor these code later.
207213
if (use_spawn_worker) {
208214
// start worker in a spawn process(for offline inference worker.)
209215
create_spawn_server(
210216
local_worker_idx, master_node_addr, done, parallel_args, d, options);
211217
return;
212-
}
218+
} else {
219+
std::unique_ptr<ForwardSharedMemoryManager> input_shm_manager = nullptr;
220+
std::unique_ptr<ForwardSharedMemoryManager> output_shm_manager = nullptr;
221+
prepare_shm(
222+
parallel_args, options, input_shm_manager, output_shm_manager);
213223

214-
std::unique_ptr<ForwardSharedMemoryManager> input_shm_manager = nullptr;
215-
std::unique_ptr<ForwardSharedMemoryManager> output_shm_manager = nullptr;
216-
prepare_shm(parallel_args, options, input_shm_manager, output_shm_manager);
217-
// start worker in a thread.
218-
worker_thread_ =
219-
std::make_unique<std::thread>(&WorkerServer::create_server,
220-
this,
221-
std::cref(options),
222-
std::ref(done),
223-
std::cref(master_node_addr),
224-
std::cref(d),
225-
parallel_args.world_size(),
226-
parallel_args.rank(),
227-
parallel_args.dp_size(),
228-
local_worker_idx,
229-
parallel_args.ep_size(),
230-
std::move(input_shm_manager),
231-
std::move(output_shm_manager));
224+
// start worker in a thread.
225+
worker_thread_ =
226+
std::make_unique<std::thread>(&WorkerServer::create_server,
227+
this,
228+
std::cref(options),
229+
std::ref(done),
230+
std::cref(master_node_addr),
231+
std::cref(d),
232+
parallel_args.world_size(),
233+
parallel_args.rank(),
234+
parallel_args.dp_size(),
235+
local_worker_idx,
236+
parallel_args.ep_size(),
237+
std::move(input_shm_manager),
238+
std::move(output_shm_manager));
239+
}
232240
} else {
233241
// TODO: support other model type later.
234242
LOG(ERROR) << "Unsupported model type: " << worker_type;

xllm/core/runtime/dit_engine.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ DiTEngine::DiTEngine(const runtime::Options& options) : options_(options) {
4242
}
4343
const int32_t world_size = static_cast<int32_t>(devices.size());
4444

45+
CHECK(!options_.enable_shm()) << "Dit can not support enable_shm currently.";
46+
4547
// create workers
4648
for (size_t i = 0; i < devices.size(); ++i) {
4749
const int32_t rank = static_cast<int32_t>(i);

xllm/core/runtime/forward_shared_memory_manager.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -655,10 +655,12 @@ ForwardSharedMemoryManager::~ForwardSharedMemoryManager() = default;
655655

656656
/* The shared memory filename may have duplicates when using kill -9 xllm, but
657657
this doesn't affect usage.*/
658-
std::string ForwardSharedMemoryManager::create_unique_name(int dp_group,
659-
int forward_type,
660-
int rank) {
661-
std::string filename = "xllm_" + net::extract_port(FLAGS_master_node_addr);
658+
std::string ForwardSharedMemoryManager::create_unique_name(
659+
const std::string& prefix,
660+
int dp_group,
661+
int forward_type,
662+
int rank) {
663+
std::string filename = prefix;
662664
if (forward_type == FORWARD_PB_INPUT_TYPE ||
663665
forward_type == FORWARD_RAW_INPUT_TYPE) {
664666
filename += "_dpg_" + std::to_string(dp_group) + "_input";
@@ -997,4 +999,4 @@ void ForwardSharedMemoryManager::raw_output_read(RawForwardOutput& output) {
997999
void ForwardSharedMemoryManager::clear() {
9981000
std::memset(base_address(), 0, size());
9991001
}
1000-
} // namespace xllm
1002+
} // namespace xllm

xllm/core/runtime/forward_shared_memory_manager.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ class ForwardSharedMemoryManager : public SharedMemoryManager {
4747
bool& is_creator,
4848
ForwardType type);
4949
~ForwardSharedMemoryManager();
50-
static std::string create_unique_name(int dp_group,
50+
static std::string create_unique_name(const std::string& prefix,
51+
int dp_group,
5152
int forward_type,
5253
int rank);
5354

@@ -121,4 +122,4 @@ class ForwardSharedMemoryManager : public SharedMemoryManager {
121122
void* metadata_addr_ = nullptr;
122123
ControlMetadata* control_ptr_ = nullptr;
123124
};
124-
} // namespace xllm
125+
} // namespace xllm

xllm/core/runtime/vlm_engine.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ VLMEngine::VLMEngine(const runtime::Options& options) : options_(options) {
5050
process_groups_ = parallel_state::create_npu_process_groups(devices);
5151
}
5252

53+
CHECK(!options_.enable_shm()) << "VLM can not support enable_shm currently.";
54+
5355
WorkerType worker_type =
5456
(options_.task_type() == "generate") ? WorkerType::VLM : WorkerType::EVLM;
5557
const int32_t world_size = static_cast<int32_t>(devices.size());

0 commit comments

Comments
 (0)