Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions mooncake-integration/transfer_engine/transfer_engine_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,14 +583,15 @@ int TransferEnginePy::transferCheckStatus(batch_id_t batch_id) {
}

int TransferEnginePy::batchRegisterMemory(std::vector<uintptr_t> buffer_addresses,
std::vector<size_t> capacities) {
std::vector<size_t> capacities,
const std::string &location) {
pybind11::gil_scoped_release release;
auto batch_size = buffer_addresses.size();
std::vector<BufferEntry> buffers;
for (int i = 0; i < batch_size; i ++ ) {
buffers.push_back(BufferEntry{(void *)buffer_addresses[i], capacities[i]});
}
return engine_->registerLocalMemoryBatch(buffers, kWildcardLocation);
return engine_->registerLocalMemoryBatch(buffers, location);
}

int TransferEnginePy::batchUnregisterMemory(std::vector<uintptr_t> buffer_addresses) {
Expand All @@ -603,9 +604,9 @@ int TransferEnginePy::batchUnregisterMemory(std::vector<uintptr_t> buffer_addres
return engine_->unregisterLocalMemoryBatch(buffers);
}

int TransferEnginePy::registerMemory(uintptr_t buffer_addr, size_t capacity) {
int TransferEnginePy::registerMemory(uintptr_t buffer_addr, size_t capacity, const std::string &location) {
char *buffer = reinterpret_cast<char *>(buffer_addr);
return engine_->registerLocalMemory(buffer, capacity);
return engine_->registerLocalMemory(buffer, capacity, location);
}

int TransferEnginePy::unregisterMemory(uintptr_t buffer_addr) {
Expand Down Expand Up @@ -656,9 +657,19 @@ PYBIND11_MODULE(engine, m) {
.def("write_bytes_to_buffer", &TransferEnginePy::writeBytesToBuffer)
.def("read_bytes_from_buffer",
&TransferEnginePy::readBytesFromBuffer)
.def("register_memory", &TransferEnginePy::registerMemory)
.def("register_memory",
&TransferEnginePy::registerMemory,
py::arg("buffer_addr"),
py::arg("capacity"),
py::arg("location") = kWildcardLocation
)
.def("unregister_memory", &TransferEnginePy::unregisterMemory)
.def("batch_register_memory", &TransferEnginePy::batchRegisterMemory)
.def("batch_register_memory",
&TransferEnginePy::batchRegisterMemory,
py::arg("buffer_addresses"),
py::arg("capacities"),
py::arg("location") = kWildcardLocation
)
.def("batch_unregister_memory", &TransferEnginePy::batchUnregisterMemory)
.def("get_first_buffer_address",
&TransferEnginePy::getFirstBufferAddress);
Expand Down
4 changes: 2 additions & 2 deletions mooncake-integration/transfer_engine/transfer_engine_py.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,12 @@ class TransferEnginePy {
}

// FOR EXPERIMENT ONLY
int registerMemory(uintptr_t buffer_addr, size_t capacity);
int registerMemory(uintptr_t buffer_addr, size_t capacity, const std::string &location = kWildcardLocation);

// must be called before TransferEnginePy::~TransferEnginePy()
int unregisterMemory(uintptr_t buffer_addr);

int batchRegisterMemory(std::vector<uintptr_t> buffer_addresses, std::vector<size_t> capacities);
int batchRegisterMemory(std::vector<uintptr_t> buffer_addresses, std::vector<size_t> capacities, const std::string &location = kWildcardLocation);

int batchUnregisterMemory(std::vector<uintptr_t> buffer_addresses);

Expand Down
Loading