Skip to content

Commit 8376517

Browse files
committed
refactor: refine shared memory comm code to adapt the offline inference params.
Signed-off-by: Tao Peng <[email protected]>
1 parent a188831 commit 8376517

File tree

16 files changed

+79
-39
lines changed

16 files changed

+79
-39
lines changed

xllm/core/common/global_flags.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ DEFINE_string(kv_cache_transfer_mode,
274274
DEFINE_int32(transfer_listen_port, 26000, "The KVCacheTranfer listen port.");
275275

276276
DEFINE_bool(enable_shm,
277-
true,
277+
false,
278278
"Whether to enable shared memory for executing model.");
279279
// --- function call config ---
280280

xllm/core/common/options.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ class Options {
178178
// for offline inference: the path to spawn worker binary
179179
PROPERTY(std::string, spawn_worker_path) = "";
180180

181+
// use shared memory for inter-process communication in the single-machine
182+
// multi-GPU scenario.
183+
PROPERTY(bool, enable_shm) = false;
184+
181185
// whether the worker and master are on the same machine.
182186
PROPERTY(bool, is_local) = false;
183187
};

xllm/core/distributed_runtime/dist_manager.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,20 @@ void DistManager::setup_single_node_workers(const runtime::Options& options) {
9696
namespace {
9797
std::unique_ptr<CommChannel> create_channel(const std::string& worker_addrs,
9898
int r,
99-
int dp_local_tp_size) {
99+
int dp_local_tp_size,
100+
const runtime::Options& options) {
100101
std::unique_ptr<CommChannel> channel;
101102

102103
if (net::extract_ip(FLAGS_master_node_addr) ==
103104
net::extract_ip(worker_addrs) &&
104-
FLAGS_enable_shm) {
105+
options.enable_shm()) {
105106
// create shared memory manager for local rank
106107
bool is_driver = false;
107108
int dp_group = r / dp_local_tp_size;
108109
if (r % dp_local_tp_size == 0) {
109110
is_driver = true;
110111
}
111-
channel = std::make_unique<ShmChannel>(dp_group, r, is_driver);
112+
channel = std::make_unique<ShmChannel>(dp_group, r, is_driver, options);
112113
} else {
113114
channel = std::make_unique<CommChannel>();
114115
}
@@ -220,7 +221,8 @@ void DistManager::setup_multi_node_workers(
220221
<< r;
221222
return;
222223
}
223-
auto channel = create_channel(worker_addrs_map[r], r, dp_local_tp_size);
224+
auto channel =
225+
create_channel(worker_addrs_map[r], r, dp_local_tp_size, options);
224226
worker_clients_.emplace_back(
225227
std::make_unique<RemoteWorker>(r,
226228
worker_addrs_map[r],

xllm/core/distributed_runtime/shm_channel.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ limitations under the License.
1919

2020
namespace xllm {
2121

22-
ShmChannel::ShmChannel(int dp_group, int rank, bool is_driver) {
22+
ShmChannel::ShmChannel(int dp_group,
23+
int rank,
24+
bool is_driver,
25+
const runtime::Options& options)
26+
: enable_shm_(options.enable_shm()) {
2327
bool is_creator;
2428

2529
if (is_driver) {
@@ -45,7 +49,7 @@ bool ShmChannel::execute_model_with_shm(
4549
int use_shm_ret = input_shm_manager_->raw_input_write(inputs);
4650
if (use_shm_ret < 0) {
4751
// fallback
48-
FLAGS_enable_shm = false;
52+
enable_shm_ = false;
4953
LOG(ERROR)
5054
<< "RemoteWorker SharedMemoryManager write failed, fallback to brpc.";
5155
return false;
@@ -58,7 +62,7 @@ bool ShmChannel::execute_model_with_shm(
5862
void ShmChannel::execute_model_async(
5963
const std::vector<RawForwardInput>& inputs,
6064
folly::Promise<std::optional<RawForwardOutput>>& promise) {
61-
if (FLAGS_enable_shm) {
65+
if (enable_shm_) {
6266
// write to shared memory, then wait output.
6367
RawForwardOutput raw_output;
6468
bool shm_success = execute_model_with_shm(inputs, raw_output);
@@ -69,4 +73,4 @@ void ShmChannel::execute_model_async(
6973
}
7074
execute_model_with_brpc(inputs, promise);
7175
}
72-
} // namespace xllm
76+
} // namespace xllm

xllm/core/distributed_runtime/shm_channel.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@ limitations under the License.
1616
#pragma once
1717
#include "comm_channel.h"
1818
#include "runtime/forward_shared_memory_manager.h"
19+
#include "runtime/options.h"
1920

2021
namespace xllm {
2122

2223
class ShmChannel : public CommChannel {
2324
public:
24-
explicit ShmChannel(int dp_group, int rank, bool is_driver);
25+
explicit ShmChannel(int dp_group,
26+
int rank,
27+
bool is_driver,
28+
const runtime::Options& options);
2529
~ShmChannel() = default;
2630

2731
void execute_model_async(
@@ -31,8 +35,10 @@ class ShmChannel : public CommChannel {
3135
private:
3236
bool execute_model_with_shm(const std::vector<RawForwardInput>& inputs,
3337
RawForwardOutput& raw_output);
38+
39+
bool enable_shm_ = false;
3440
std::unique_ptr<ForwardSharedMemoryManager> input_shm_manager_ = nullptr;
3541
std::unique_ptr<ForwardSharedMemoryManager> output_shm_manager_ = nullptr;
3642
};
3743

38-
} // namespace xllm
44+
} // namespace xllm

xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,16 @@ SpawnWorkerServer::SpawnWorkerServer(const std::string& master_node_addr,
3838
int world_size,
3939
int device_idx,
4040
int num_decoding_tokens,
41-
int block_size) {
41+
int block_size,
42+
bool enable_shm) {
4243
// TODO: pass whole xllm::runtime::Options here from main process.
4344
xllm::runtime::Options runner_options;
4445
runner_options.block_size(block_size)
4546
.num_decoding_tokens(num_decoding_tokens)
4647
.enable_schedule_overlap(false)
4748
.enable_offline_inference(true)
48-
.master_node_addr(master_node_addr);
49+
.master_node_addr(master_node_addr)
50+
.enable_shm(enable_shm);
4951
FLAGS_enable_schedule_overlap = false;
5052
FLAGS_master_node_addr = master_node_addr;
5153
FLAGS_block_size = block_size;

xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class SpawnWorkerServer final {
2727
int world_size,
2828
int device_idx,
2929
int num_decoding_tokens,
30-
int block_size);
30+
int block_size,
31+
bool enable_shm);
3132

3233
~SpawnWorkerServer() = default;
3334

xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server_process.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ limitations under the License.
2828
// @device_idx
2929
// @num_decoding_tokens
3030
// @block_size
31+
// @enable_shm
3132
int main(int argc, char* argv[]) {
32-
if (argc < 7) {
33+
if (argc < 8) {
3334
LOG(ERROR)
34-
<< "Spwan worker process receive wrong args. Need 7 args, receive "
35+
<< "Spwan worker process receive wrong args. Need 8 args, receive "
3536
<< argc;
3637
return 1;
3738
}
@@ -50,22 +51,25 @@ int main(int argc, char* argv[]) {
5051
int device_idx = atoi(argv[5]);
5152
int num_decoding_tokens = atoi(argv[6]);
5253
int block_size = atoi(argv[7]);
54+
int enable_shm = atoi(argv[8]);
5355

5456
LOG(INFO) << "Spwan worker: "
5557
<< "master_node_addr = " << master_node_addr
5658
<< ", local_rank = " << local_rank
5759
<< ", world_size = " << world_size
5860
<< ", device_idx = " << device_idx
5961
<< ", num_decoding_tokens = " << num_decoding_tokens
60-
<< ", block_size = " << block_size << "\n";
62+
<< ", block_size = " << block_size
63+
<< ", enable_shm = " << (enable_shm > 0) << "\n";
6164

6265
xllm::SpawnWorkerServer worker(master_node_addr,
6366
local_rank,
6467
global_rank,
6568
world_size,
6669
device_idx,
6770
num_decoding_tokens,
68-
block_size);
71+
block_size,
72+
enable_shm > 0);
6973

7074
worker.run();
7175

xllm/core/distributed_runtime/worker_server.cpp

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ void WorkerServer::create_server(
110110
std::unique_ptr<Worker> worker =
111111
std::make_unique<Worker>(*parallel_args, device, options, worker_type);
112112
worker_service->set_worker(std::move(worker));
113-
if (FLAGS_enable_shm && input_shm_manager && output_shm_manager) {
113+
if (options.enable_shm() && input_shm_manager && output_shm_manager) {
114114
worker_service->create_polling_shm_thread(std::move(input_shm_manager),
115115
std::move(output_shm_manager));
116116
}
@@ -127,29 +127,32 @@ void WorkerServer::create_spawn_server(int local_rank,
127127
const ParallelArgs& parallel_args,
128128
const torch::Device& d,
129129
const runtime::Options& options) {
130-
auto local_rank_str0 = std::to_string(local_rank);
131-
const char* local_rank_str = local_rank_str0.c_str();
132-
auto global_rank_str0 = std::to_string(parallel_args.rank());
133-
const char* global_rank_str = global_rank_str0.c_str();
134-
auto world_size_str0 = std::to_string(parallel_args.world_size());
135-
const char* world_size_str = world_size_str0.c_str();
136-
auto device_idx_str0 = std::to_string(d.index());
137-
const char* device_idx_str = device_idx_str0.c_str();
138-
auto num_decoding_tokens_str0 = std::to_string(options.num_decoding_tokens());
139-
const char* num_decoding_tokens_str = num_decoding_tokens_str0.c_str();
140-
auto block_size_str0 = std::to_string(options.block_size());
141-
const char* block_size_str = block_size_str0.c_str();
130+
auto local_rank_str = std::to_string(local_rank);
131+
const char* local_rank_ptr = local_rank_str.c_str();
132+
auto global_rank_str = std::to_string(parallel_args.rank());
133+
const char* global_rank_ptr = global_rank_str.c_str();
134+
auto world_size_str = std::to_string(parallel_args.world_size());
135+
const char* world_size_ptr = world_size_str.c_str();
136+
auto device_idx_str = std::to_string(d.index());
137+
const char* device_idx_ptr = device_idx_str.c_str();
138+
auto num_decoding_tokens_str = std::to_string(options.num_decoding_tokens());
139+
const char* num_decoding_tokens_ptr = num_decoding_tokens_str.c_str();
140+
auto block_size_str = std::to_string(options.block_size());
141+
const char* block_size_ptr = block_size_str.c_str();
142+
auto enable_shm_str = std::to_string(options.enable_shm());
143+
const char* enable_shm_ptr = enable_shm_str.c_str();
142144
std::string spawn_worker_bin_path =
143145
options.spawn_worker_path() + "/spawn_worker";
144146
LOG(INFO) << "Spawn worker path: " << spawn_worker_bin_path;
145147
const char* argv[] = {spawn_worker_bin_path.c_str(),
146148
master_node_addr.c_str(),
147-
local_rank_str,
148-
global_rank_str,
149-
world_size_str,
150-
device_idx_str,
151-
num_decoding_tokens_str,
152-
block_size_str,
149+
local_rank_ptr,
150+
global_rank_ptr,
151+
world_size_ptr,
152+
device_idx_ptr,
153+
num_decoding_tokens_ptr,
154+
block_size_ptr,
155+
enable_shm_ptr,
153156
nullptr};
154157
pid_t pid;
155158
posix_spawn_file_actions_init(&file_actions_);
@@ -173,7 +176,7 @@ void WorkerServer::prepare_shm(
173176
const runtime::Options& options,
174177
std::unique_ptr<ForwardSharedMemoryManager>& input_shm_manager,
175178
std::unique_ptr<ForwardSharedMemoryManager>& output_shm_manager) {
176-
if (options.is_local() && FLAGS_enable_shm) {
179+
if (options.is_local() && options.enable_shm()) {
177180
bool is_creator;
178181
int dp_local_tp_size = parallel_args.world_size() / parallel_args.dp_size();
179182
int dp_group = parallel_args.rank() / dp_local_tp_size;

xllm/core/runtime/master.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ Master::Master(const Options& options, EngineType type) : options_(options) {
112112
.enable_schedule_overlap(options_.enable_schedule_overlap())
113113
.enable_offline_inference(options_.enable_offline_inference())
114114
.spawn_worker_path(options_.spawn_worker_path())
115+
.enable_shm(options_.enable_shm())
115116
.is_local(options_.is_local());
116117

117118
auto engine = std::make_unique<VLMEngine>(eng_options);
@@ -154,6 +155,7 @@ Master::Master(const Options& options, EngineType type) : options_(options) {
154155
.enable_cache_upload(options_.enable_cache_upload())
155156
.enable_offline_inference(options_.enable_offline_inference())
156157
.spawn_worker_path(options_.spawn_worker_path())
158+
.enable_shm(options_.enable_shm())
157159
.is_local(options_.is_local());
158160

159161
if (options_.device_ip().has_value()) {
@@ -201,6 +203,7 @@ Master::Master(const Options& options, EngineType type) : options_(options) {
201203
.enable_continuous_kvcache(options_.enable_continuous_kvcache())
202204
.enable_offline_inference(options_.enable_offline_inference())
203205
.spawn_worker_path(options_.spawn_worker_path())
206+
.enable_shm(options_.enable_shm())
204207
.is_local(options_.is_local());
205208

206209
if (options_.device_ip().has_value()) {

0 commit comments

Comments
 (0)