Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mooncake-integration/transfer_engine/transfer_engine_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,15 @@ int TransferEnginePy::freeManagedBuffer(uintptr_t buffer_addr, size_t length) {
return 0;
}

int TransferEnginePy::freeRemoteSegment(const char *target_hostname) {
std::lock_guard<std::mutex> guard(mutex_);
if (handle_map_.count(target_hostname)) {
engine_->closeSegment(handle_map_[target_hostname]);
handle_map_.erase(target_hostname);
}
return 0;
}

int TransferEnginePy::transferSyncWrite(const char *target_hostname,
uintptr_t buffer,
uintptr_t peer_buffer_address,
Expand Down Expand Up @@ -684,6 +693,7 @@ PYBIND11_MODULE(engine, m) {
.def("batch_unregister_memory",
&TransferEnginePy::batchUnregisterMemory)
.def("get_local_topology", &TransferEnginePy::getLocalTopology)
.def("free_remote_segment", &TransferEnginePy::freeRemoteSegment)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renaming it as "close_remote_segment" may be better.

.def("get_first_buffer_address",
&TransferEnginePy::getFirstBufferAddress);

Expand Down
2 changes: 2 additions & 0 deletions mooncake-integration/transfer_engine/transfer_engine_py.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ class TransferEnginePy {

int batchUnregisterMemory(std::vector<uintptr_t> buffer_addresses);

int markDead(const char *target_hostname);

std::string getLocalTopology();

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ class AscendDirectTransport : public Transport {

void processSliceList(const std::vector<Slice *> &slice_list);

int closeSegment(Transport::SegmentHandle handle) override;

private:
int InitAdxlEngine();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class HcclTransport : public Transport {
int unregisterLocalMemoryBatch(
const std::vector<void *> &addr_list) override;

int closeSegment(Transport::SegmentHandle handle) override;

private:
int allocateLocalSegmentID();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class CxlTransport : public Transport {

bool validateMemoryBounds(void *dest, void *src, size_t size);

int closeSegment(Transport::SegmentHandle handle) override;

private:
void *cxl_base_addr;
size_t cxl_dev_size;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class NvlinkTransport : public Transport {

const char* getName() const override { return "nvlink"; }

int closeSegment(Transport::SegmentHandle handle) override;

private:
std::atomic_bool running_;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class NVMeoFTransport : public Transport {
uint64_t target_start, TransferRequest::OpCode op,
TransferTask &task, const char *file_path);

int closeSegment(Transport::SegmentHandle handle) override;

private:
void startTransfer(Slice *slice);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class RdmaTransport : public Transport {
int unregisterLocalMemoryBatch(
const std::vector<void *> &addr_list) override;

int closeSegment(Transport::SegmentHandle handle) override;

// TRANSFER

Status submitTransfer(BatchID batch_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class TcpTransport : public Transport {
Status getTransferStatus(BatchID batch_id, size_t task_id,
TransferStatus &status) override;

int closeSegment(Transport::SegmentHandle handle) override;

private:
int install(std::string &local_server_name,
std::shared_ptr<TransferMetadata> meta,
Expand Down
5 changes: 5 additions & 0 deletions mooncake-transfer-engine/include/transport/transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@ class Transport {
virtual Status getTransferStatus(BatchID batch_id, size_t task_id,
TransferStatus &status) = 0;

/// @brief Close a segment handle.
/// @param handle The segment handle to close.
/// @return 0 on success, -1 on failure.
virtual int closeSegment(Transport::SegmentHandle handle) = 0;

std::shared_ptr<TransferMetadata> &meta() { return metadata_; }

struct BufferEntry {
Expand Down
8 changes: 8 additions & 0 deletions mooncake-transfer-engine/src/transfer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,14 @@ int TransferEngine::unregisterLocalMemoryBatch(
return 0;
}

int TransferEngine::closeSegment(Transport::SegmentHandle handle) {
for (auto &transport : multi_transports_->listTransports()) {
int ret = transport->closeSegment(handle);
if (ret < 0) return ret;
}
return 0;
}

#ifdef WITH_METRICS
// Helper function to convert string to lowercase for case-insensitive
// comparison
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,10 @@ void AscendDirectTransport::workerThread() {
LOG(INFO) << "AscendDirectTransport worker thread stopped";
}

int AscendDirectTransport::closeSegment(Transport::SegmentHandle handle) {
return 0;
}

void AscendDirectTransport::processSliceList(
const std::vector<Slice *> &slice_list) {
if (slice_list.empty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,10 @@ Status HcclTransport::getTransferStatus(BatchID batch_id, size_t task_id,
return Status::OK();
}

int HcclTransport::closeSegment(Transport::SegmentHandle handle) {
return 0;
}

int HcclTransport::registerLocalMemory(void *addr, size_t length,
const std::string &location,
bool remote_accessible,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,10 @@ int CxlTransport::unregisterLocalMemory(void *addr, bool update_metadata) {
return metadata_->removeLocalMemoryBuffer(addr, update_metadata);
}

int CxlTransport::closeSegment(Transport::SegmentHandle handle) {
return 0;
}

int CxlTransport::registerLocalMemoryBatch(
const std::vector<Transport::BufferEntry> &buffer_list,
const std::string &location) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,20 @@ int NvlinkTransport::unregisterLocalMemoryBatch(
return metadata_->updateLocalSegmentDesc();
}

int NvlinkTransport::closeSegment(Transport::SegmentHandle handle) {
// close all opened ipc handles for this SegmentHandle.
for (auto &entry : remap_entries_) {
if (entry.first.first == handle) {
cudaError_t err = cudaIpcCloseMemHandle(entry.second.shm_addr);
if (err != cudaSuccess) {
LOG(ERROR) << "NvlinkTransport: cudaIpcCloseMemHandle failed: "
<< cudaGetErrorString(err);
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add one line to remove this entry in remap_entries_ after cudaIpcCloseMemHandle.

}
return 0;
}

void *NvlinkTransport::allocatePinnedLocalMemory(size_t size) {
if (!supportFabricMem()) {
void *ptr = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ int NVMeoFTransport::unregisterLocalMemory(void *addr, bool update_metadata) {
return 0;
}

int NVMeoFTransport::closeSegment(Transport::SegmentHandle handle) {
return 0;
}

void NVMeoFTransport::addSliceToTask(void *source_addr, uint64_t slice_len,
uint64_t target_start,
TransferRequest::OpCode op,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,10 @@ Status RdmaTransport::getTransferStatus(BatchID batch_id, size_t task_id,
return Status::OK();
}

int RdmaTransport::closeSegment(Transport::SegmentHandle handle) {
return 0;
}

RdmaTransport::SegmentID RdmaTransport::getSegmentID(
const std::string &segment_name) {
return metadata_->getSegmentID(segment_name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,10 @@ int TcpTransport::allocateLocalSegmentID(int tcp_data_port) {
return 0;
}

int TcpTransport::closeSegment(Transport::SegmentHandle handle) {
return 0;
}

int TcpTransport::registerLocalMemory(void *addr, size_t length,
const std::string &location,
bool remote_accessible,
Expand Down
Loading