Skip to content

Commit 0ed421e

Browse files
authored
perf: use async H2D copy for contiguous shm forward input on npu. (#1101)
1 parent 23f31d4 commit 0ed421e

File tree

2 files changed

+49
-8
lines changed

2 files changed

+49
-8
lines changed

xllm/core/runtime/forward_shared_memory_manager.cpp

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,15 @@ limitations under the License.
1616
#include <gflags/gflags.h>
1717

1818
#include <cstring>
19+
#include <mutex>
20+
#include <optional>
1921
#include <stdexcept>
2022

2123
#include "core/common/global_flags.h"
24+
#include "platform/stream.h"
25+
#if defined(USE_NPU)
26+
#include "platform/npu/device_capture_lock.h"
27+
#endif
2228
#include "core/util/net.h"
2329
#include "core/util/tensor_helper.h"
2430
#include "util/utils.h"
@@ -1029,16 +1035,36 @@ inline void read_mm_batch_data(const char*& buffer,
10291035
inline void deserialize_raw_forward_input(const char*& buffer,
10301036
const uint64_t buffer_size,
10311037
ForwardInput& forward_input,
1032-
const torch::Device& device) {
1038+
const torch::Device& device,
1039+
Stream* stream) {
10331040
const char* device_buffer = nullptr;
10341041
#if defined(USE_NPU)
1042+
std::optional<std::unique_lock<std::mutex>> capture_lock_guard;
1043+
torch::Tensor host_input_buffer;
10351044
if (FLAGS_use_contiguous_input_buffer) {
1036-
// h to d
1037-
auto host_input_buffer =
1038-
torch::from_blob(const_cast<char*>(buffer),
1039-
{static_cast<int64_t>(buffer_size)},
1040-
torch::dtype(torch::kUInt8));
1041-
forward_input.device_input_buffer = host_input_buffer.to(device);
1045+
host_input_buffer = torch::from_blob(const_cast<char*>(buffer),
1046+
{static_cast<int64_t>(buffer_size)},
1047+
torch::TensorOptions()
1048+
.dtype(torch::kUInt8)
1049+
.device(torch::kCPU)
1050+
.pinned_memory(true));
1051+
1052+
auto device_options =
1053+
torch::TensorOptions().dtype(torch::kUInt8).device(device);
1054+
1055+
if (stream != nullptr) {
1056+
auto& capture_lock =
1057+
::xllm::npu::DeviceCaptureLock::get_instance().get_lock(
1058+
device.index());
1059+
capture_lock_guard.emplace(capture_lock);
1060+
c10::StreamGuard stream_guard = stream->set_stream_guard();
1061+
forward_input.device_input_buffer =
1062+
safe_to(host_input_buffer, device_options, true);
1063+
} else {
1064+
forward_input.device_input_buffer =
1065+
safe_to(host_input_buffer, device_options);
1066+
}
1067+
10421068
device_buffer = (char*)forward_input.device_input_buffer.data_ptr();
10431069
}
10441070
#endif
@@ -1127,6 +1153,12 @@ inline void deserialize_raw_forward_input(const char*& buffer,
11271153
// root cause is identified and the error is resolved.
11281154
read_tensor(buffer, input_params.new_cache_slots);
11291155
read_tensor(buffer, input_params.block_tables);
1156+
1157+
#if defined(USE_NPU)
1158+
if (device_buffer != nullptr && stream != nullptr) {
1159+
stream->synchronize();
1160+
}
1161+
#endif
11301162
}
11311163

11321164
inline void serialize_raw_forward_input(const RawForwardInput& input,
@@ -1543,6 +1575,9 @@ ForwardSharedMemoryManager::ForwardSharedMemoryManager(const std::string& name,
15431575
: SharedMemoryManager(name, size, is_creator), forward_type_(type) {
15441576
control_ptr_ = static_cast<ControlMetadata*>(base_address());
15451577
metadata_addr_ = static_cast<char*>(base_address()) + sizeof(ControlMetadata);
1578+
if (FLAGS_use_contiguous_input_buffer) {
1579+
stream_ = std::make_unique<Stream>();
1580+
}
15461581
}
15471582

15481583
ForwardSharedMemoryManager::~ForwardSharedMemoryManager() = default;
@@ -1606,7 +1641,8 @@ void ForwardSharedMemoryManager::raw_input_read(ForwardInput& input,
16061641
static_cast<char*>(base_address()) + sizeof(ControlMetadata);
16071642
uint64_t total_size;
16081643
read_data(data_ptr, total_size);
1609-
deserialize_raw_forward_input(data_ptr, total_size, input, device);
1644+
deserialize_raw_forward_input(
1645+
data_ptr, total_size, input, device, stream_.get());
16101646

16111647
return;
16121648
}

xllm/core/runtime/forward_shared_memory_manager.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,16 @@ limitations under the License.
1515

1616
#include <stddef.h>
1717

18+
#include <memory>
19+
1820
#include "forward_params.h"
1921
#include "params_utils.h"
2022
#include "util/shared_memory_manager.h"
2123

2224
namespace xllm {
2325

26+
class Stream;
27+
2428
constexpr int64_t kNumWaitNanoseconds = 1000; // 1us
2529

2630
struct ControlMetadata {
@@ -120,5 +124,6 @@ class ForwardSharedMemoryManager : public SharedMemoryManager {
120124
uint64_t last_version_ = 0;
121125
void* metadata_addr_ = nullptr;
122126
ControlMetadata* control_ptr_ = nullptr;
127+
std::unique_ptr<Stream> stream_ = nullptr;
123128
};
124129
} // namespace xllm

0 commit comments

Comments
 (0)