Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions xllm/core/runtime/forward_shared_memory_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@ limitations under the License.
#include <gflags/gflags.h>

#include <cstring>
#include <mutex>
#include <optional>
#include <stdexcept>

#include "core/common/global_flags.h"
#include "platform/stream.h"
#if defined(USE_NPU)
#include "platform/npu/device_capture_lock.h"
#endif
#include "core/util/net.h"
#include "core/util/tensor_helper.h"
#include "util/utils.h"
Expand Down Expand Up @@ -1029,16 +1035,36 @@ inline void read_mm_batch_data(const char*& buffer,
inline void deserialize_raw_forward_input(const char*& buffer,
const uint64_t buffer_size,
ForwardInput& forward_input,
const torch::Device& device) {
const torch::Device& device,
Stream* stream) {
const char* device_buffer = nullptr;
#if defined(USE_NPU)
std::optional<std::unique_lock<std::mutex>> capture_lock_guard;
torch::Tensor host_input_buffer;
if (FLAGS_use_contiguous_input_buffer) {
// h to d
auto host_input_buffer =
torch::from_blob(const_cast<char*>(buffer),
{static_cast<int64_t>(buffer_size)},
torch::dtype(torch::kUInt8));
forward_input.device_input_buffer = host_input_buffer.to(device);
host_input_buffer = torch::from_blob(const_cast<char*>(buffer),
{static_cast<int64_t>(buffer_size)},
torch::TensorOptions()
.dtype(torch::kUInt8)
.device(torch::kCPU)
.pinned_memory(true));

auto device_options =
torch::TensorOptions().dtype(torch::kUInt8).device(device);

if (stream != nullptr) {
auto& capture_lock =
::xllm::npu::DeviceCaptureLock::get_instance().get_lock(
device.index());
capture_lock_guard.emplace(capture_lock);
c10::StreamGuard stream_guard = stream->set_stream_guard();
forward_input.device_input_buffer =
safe_to(host_input_buffer, device_options, true);
} else {
forward_input.device_input_buffer =
safe_to(host_input_buffer, device_options);
}

device_buffer = (char*)forward_input.device_input_buffer.data_ptr();
}
#endif
Expand Down Expand Up @@ -1127,6 +1153,12 @@ inline void deserialize_raw_forward_input(const char*& buffer,
// root cause is identified and the error is resolved.
read_tensor(buffer, input_params.new_cache_slots);
read_tensor(buffer, input_params.block_tables);

#if defined(USE_NPU)
if (device_buffer != nullptr && stream != nullptr) {
stream->synchronize();
}
#endif
}

inline void serialize_raw_forward_input(const RawForwardInput& input,
Expand Down Expand Up @@ -1543,6 +1575,9 @@ ForwardSharedMemoryManager::ForwardSharedMemoryManager(const std::string& name,
: SharedMemoryManager(name, size, is_creator), forward_type_(type) {
control_ptr_ = static_cast<ControlMetadata*>(base_address());
metadata_addr_ = static_cast<char*>(base_address()) + sizeof(ControlMetadata);
if (FLAGS_use_contiguous_input_buffer) {
stream_ = std::make_unique<Stream>();
}
}

ForwardSharedMemoryManager::~ForwardSharedMemoryManager() = default;
Expand Down Expand Up @@ -1606,7 +1641,8 @@ void ForwardSharedMemoryManager::raw_input_read(ForwardInput& input,
static_cast<char*>(base_address()) + sizeof(ControlMetadata);
uint64_t total_size;
read_data(data_ptr, total_size);
deserialize_raw_forward_input(data_ptr, total_size, input, device);
deserialize_raw_forward_input(
data_ptr, total_size, input, device, stream_.get());

return;
}
Expand Down
5 changes: 5 additions & 0 deletions xllm/core/runtime/forward_shared_memory_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@ limitations under the License.

#include <stddef.h>

#include <memory>

#include "forward_params.h"
#include "params_utils.h"
#include "util/shared_memory_manager.h"

namespace xllm {

class Stream;

constexpr int64_t kNumWaitNanoseconds = 1000; // 1us

struct ControlMetadata {
Expand Down Expand Up @@ -120,5 +124,6 @@ class ForwardSharedMemoryManager : public SharedMemoryManager {
uint64_t last_version_ = 0;
void* metadata_addr_ = nullptr;
ControlMetadata* control_ptr_ = nullptr;
std::unique_ptr<Stream> stream_ = nullptr;
};
} // namespace xllm
Loading