@@ -16,9 +16,15 @@ limitations under the License.
1616#include < gflags/gflags.h>
1717
1818#include < cstring>
19+ #include < mutex>
20+ #include < optional>
1921#include < stdexcept>
2022
2123#include " core/common/global_flags.h"
24+ #include " platform/stream.h"
25+ #if defined(USE_NPU)
26+ #include " platform/npu/device_capture_lock.h"
27+ #endif
2228#include " core/util/net.h"
2329#include " core/util/tensor_helper.h"
2430#include " util/utils.h"
@@ -1029,16 +1035,36 @@ inline void read_mm_batch_data(const char*& buffer,
10291035inline void deserialize_raw_forward_input (const char *& buffer,
10301036 const uint64_t buffer_size,
10311037 ForwardInput& forward_input,
1032- const torch::Device& device) {
1038+ const torch::Device& device,
1039+ Stream* stream) {
10331040 const char * device_buffer = nullptr ;
10341041#if defined(USE_NPU)
1042+ std::optional<std::unique_lock<std::mutex>> capture_lock_guard;
1043+ torch::Tensor host_input_buffer;
10351044 if (FLAGS_use_contiguous_input_buffer) {
1036- // h to d
1037- auto host_input_buffer =
1038- torch::from_blob (const_cast <char *>(buffer),
1039- {static_cast <int64_t >(buffer_size)},
1040- torch::dtype (torch::kUInt8 ));
1041- forward_input.device_input_buffer = host_input_buffer.to (device);
1045+ host_input_buffer = torch::from_blob (const_cast <char *>(buffer),
1046+ {static_cast <int64_t >(buffer_size)},
1047+ torch::TensorOptions ()
1048+ .dtype (torch::kUInt8 )
1049+ .device (torch::kCPU )
1050+ .pinned_memory (true ));
1051+
1052+ auto device_options =
1053+ torch::TensorOptions ().dtype (torch::kUInt8 ).device (device);
1054+
1055+ if (stream != nullptr ) {
1056+ auto & capture_lock =
1057+ ::xllm::npu::DeviceCaptureLock::get_instance ().get_lock(
1058+ device.index());
1059+ capture_lock_guard.emplace (capture_lock);
1060+ c10::StreamGuard stream_guard = stream->set_stream_guard ();
1061+ forward_input.device_input_buffer =
1062+ safe_to (host_input_buffer, device_options, true );
1063+ } else {
1064+ forward_input.device_input_buffer =
1065+ safe_to (host_input_buffer, device_options);
1066+ }
1067+
10421068 device_buffer = (char *)forward_input.device_input_buffer .data_ptr ();
10431069 }
10441070#endif
@@ -1127,6 +1153,12 @@ inline void deserialize_raw_forward_input(const char*& buffer,
11271153 // root cause is identified and the error is resolved.
11281154 read_tensor (buffer, input_params.new_cache_slots );
11291155 read_tensor (buffer, input_params.block_tables );
1156+
1157+ #if defined(USE_NPU)
1158+ if (device_buffer != nullptr && stream != nullptr ) {
1159+ stream->synchronize ();
1160+ }
1161+ #endif
11301162}
11311163
11321164inline void serialize_raw_forward_input (const RawForwardInput& input,
@@ -1543,6 +1575,9 @@ ForwardSharedMemoryManager::ForwardSharedMemoryManager(const std::string& name,
15431575 : SharedMemoryManager(name, size, is_creator), forward_type_(type) {
15441576 control_ptr_ = static_cast <ControlMetadata*>(base_address ());
15451577 metadata_addr_ = static_cast <char *>(base_address ()) + sizeof (ControlMetadata);
1578+ if (FLAGS_use_contiguous_input_buffer) {
1579+ stream_ = std::make_unique<Stream>();
1580+ }
15461581}
15471582
15481583ForwardSharedMemoryManager::~ForwardSharedMemoryManager () = default ;
@@ -1606,7 +1641,8 @@ void ForwardSharedMemoryManager::raw_input_read(ForwardInput& input,
16061641 static_cast <char *>(base_address ()) + sizeof (ControlMetadata);
16071642 uint64_t total_size;
16081643 read_data (data_ptr, total_size);
1609- deserialize_raw_forward_input (data_ptr, total_size, input, device);
1644+ deserialize_raw_forward_input (
1645+ data_ptr, total_size, input, device, stream_.get ());
16101646
16111647 return ;
16121648}
0 commit comments