Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/api/cpp/backend/backend_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,15 @@ class nixlBackendEngine {
return NIXL_ERR_NOT_SUPPORTED;
}

// Retrieve pointers mapped into the local virtual address space for
// all memory regions described by descriptors in a prepared list.
virtual nixl_status_t
getMappedPtrs(const nixl_meta_dlist_t &meta_dlist,
std::vector<void *> &ptrs,
const nixl_opt_b_args_t *opt_args = nullptr) const {
return NIXL_ERR_NOT_SUPPORTED;
}

// *** Needs to be implemented if supportsRemote() is true *** //

// Gets serialized form of public metadata
Expand Down
20 changes: 20 additions & 0 deletions src/api/cpp/nixl.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,26 @@ class nixlAgent {
nixl_status_t
releasedDlistH (nixlDlistH* dlist_hndl) const;

/**
* @brief Retrieve pointers mapped into the local virtual address space for
* all memory regions described by descriptors in a prepared list.
*
* If the descriptor refers to remote memory, this attempts to return a valid
* local virtual address that maps to the remote address.
* For descriptors where mapping is not supported, the corresponding pointer
* will be nullptr.
*
* @param dlist_hndl Prepared descriptor list handle
* @param ptrs [out] Output vector of pointers (one per descriptor in the list).
* Will be resized to match descriptor count.
* @param extra_params Optional extra parameters in getting the mapped pointers.
* @return nixl_status_t NIXL_SUCCESS if the query completed successfully.
* Individual pointers may be nullptr if mapping is not supported.
*/
nixl_status_t
getMappedPtrs(const nixlDlistH *dlist_hndl,
std::vector<void *> &ptrs,
const nixl_opt_args_t *extra_params = nullptr) const;

/*** Notification Handling ***/

Expand Down
43 changes: 43 additions & 0 deletions src/core/nixl_agent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,49 @@ nixlAgent::releasedDlistH (nixlDlistH* dlist_hndl) const {
return NIXL_SUCCESS;
}

nixl_status_t
nixlAgent::getMappedPtrs(const nixlDlistH *dlist_hndl,
std::vector<void *> &ptrs,
const nixl_opt_args_t *extra_params) const {
ptrs.clear();
if (!dlist_hndl) {
return NIXL_ERR_INVALID_PARAM;
}

NIXL_LOCK_GUARD(data->lock);

nixlBackendEngine *backend = nullptr;

if (extra_params && extra_params->backends.size() > 0) {
for (auto &elm : extra_params->backends) {
if (dlist_hndl->descs.count(elm->engine) > 0) {
backend = elm->engine;
break;
}
}
} else {
for (auto &bknd : dlist_hndl->descs) {
backend = bknd.first;
break;
}
}

if (!backend) {
NIXL_ERROR_FUNC << "could not find a backend in the specified or "
"available list of backends for the prepped Dlist";
return NIXL_ERR_INVALID_PARAM;
}

nixl_meta_dlist_t *meta_dlist = dlist_hndl->descs.at(backend);

nixl_opt_b_args_t opt_args;
if (extra_params && extra_params->customParam.length() > 0) {
opt_args.customParam = extra_params->customParam;
}

return backend->getMappedPtrs(*meta_dlist, ptrs, &opt_args);
}

nixl_status_t
nixlAgent::getNotifs(nixl_notifs_t &notif_map,
const nixl_opt_args_t* extra_params) {
Expand Down
31 changes: 31 additions & 0 deletions src/plugins/ucx/ucx_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,11 @@ nixlUcxEngine::nixlUcxEngine(const nixlBackendInitParams &init_params)
const auto engine_config =
(engine_config_it != custom_params->end()) ? engine_config_it->second : "";

#ifdef HAVE_CUDA
// Enable cuda_ipc for same-process transfers (needed for getMappedPtrs)
setenv("UCX_CUDA_IPC_ENABLE_SAME_PROCESS", "y", 0);
#endif

uc = std::make_unique<nixlUcxContext>(
devs, init_params.enableProgTh, num_workers, init_params.syncMode, engine_config);

Expand Down Expand Up @@ -1748,6 +1753,32 @@ nixlUcxEngine::prepGpuSignal(const nixlBackendMD &meta,
}
}

nixl_status_t
nixlUcxEngine::getMappedPtrs(const nixl_meta_dlist_t &meta_dlist,
std::vector<void *> &ptrs,
const nixl_opt_b_args_t *opt_args) const {
ptrs.resize(meta_dlist.descCount(), nullptr);

const auto opt_worker_id = getWorkerIdFromOptArgs(opt_args);
const size_t worker_id = opt_worker_id.value_or(getWorkerId());

for (int i = 0; i < meta_dlist.descCount(); ++i) {
const nixlMetaDesc &desc = meta_dlist[i];
const auto *ucx_meta = static_cast<const nixlUcxPublicMetadata *>(desc.metadataP);

try {
const auto &rkey_obj = ucx_meta->getRkey(worker_id);
ptrs[i] = rkey_obj.getPtr(desc.addr);
}
catch (const std::exception &e) {
NIXL_ERROR << "getMappedPtrs failed for descriptor " << i << ": " << e.what();
return NIXL_ERR_BACKEND;
Comment on lines +1769 to +1770
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't want to modify std::vector<void *> &ptrs in this case.

}
}

return NIXL_SUCCESS;
}

int nixlUcxEngine::progress() {
// TODO: add listen for connection handling if necessary
int ret = 0;
Expand Down
5 changes: 5 additions & 0 deletions src/plugins/ucx/ucx_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ class nixlUcxEngine : public nixlBackendEngine {
void *signal,
const nixl_opt_b_args_t *opt_args = nullptr) const override;

nixl_status_t
getMappedPtrs(const nixl_meta_dlist_t &meta_dlist,
std::vector<void *> &ptrs,
const nixl_opt_b_args_t *opt_args = nullptr) const override;

int
progress();

Expand Down
17 changes: 17 additions & 0 deletions src/utils/ucx/rkey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,21 @@ rkey::unpackUcpRkey(const nixlUcxEp &ep, const void *rkey_buffer) {
}
return rkey;
}

void *
rkey::getPtr(uint64_t raddr) const {
void *ptr = nullptr;
const ucs_status_t status = ucp_rkey_ptr(rkey_.get(), raddr, &ptr);

if (status == UCS_OK) {
return ptr;
}

if (status == UCS_ERR_UNREACHABLE) {
return nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should just return ptr as ucp_rkey_ptr shouldn't modify it in case of error.

}
Comment on lines +47 to +49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we throw an exception in this case?
getMappedPtrs returns NIXL_SUCCESS, when every rkey::getPtr returns nullptr.


throw std::runtime_error(std::string("Failed to get pointer from UCX rkey: ") +
ucs_status_string(status));
}
} // namespace nixl::ucx
3 changes: 3 additions & 0 deletions src/utils/ucx/rkey.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class rkey {
return rkey_.get();
}

[[nodiscard]] void *
getPtr(uint64_t raddr) const;

private:
[[nodiscard]] static ucp_rkey_h
unpackUcpRkey(const nixlUcxEp &, const void *rkey_buffer);
Expand Down
44 changes: 44 additions & 0 deletions test/gtest/test_transfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,50 @@ TEST_P(TestTransfer, PrepGpuSignal) {
#endif
}

TEST_P(TestTransfer, getMappedPtrs) {
constexpr size_t buffer_size = 4096;
constexpr size_t num_buffers = 3;

if (!hasCudaGpu()) {
GTEST_SKIP() << "No CUDA-capable GPU is available, skipping test.";
}

std::vector<MemBuffer> buffers_local;
std::vector<MemBuffer> buffers_remote;

createRegisteredMem(getAgent(0), buffer_size, num_buffers, VRAM_SEG, buffers_local);
createRegisteredMem(getAgent(1), buffer_size, num_buffers, VRAM_SEG, buffers_remote);

exchangeMDIP(0, 1);

nixl_opt_args_t conn_params = {.backends = {backend_handles[0]}};
nixl_status_t conn_status = getAgent(0).makeConnection(getAgentName(1), &conn_params);
ASSERT_EQ(conn_status, NIXL_SUCCESS) << "makeConnection failed for VRAM";

auto remote_desc_list = makeDescList<nixlBasicDesc>(buffers_remote, VRAM_SEG);

nixlDlistH *dlist_hndl = nullptr;
nixl_opt_args_t extra_params = {.backends = {backend_handles[0]}};

nixl_status_t status =
getAgent(0).prepXferDlist(getAgentName(1), remote_desc_list, dlist_hndl, &extra_params);
ASSERT_EQ(status, NIXL_SUCCESS) << "prepXferDlist failed for VRAM";
ASSERT_NE(dlist_hndl, nullptr);

std::vector<void *> ptrs;
status = getAgent(0).getMappedPtrs(dlist_hndl, ptrs, &extra_params);
ASSERT_EQ(status, NIXL_SUCCESS) << "getMappedPtrs failed for VRAM";
ASSERT_EQ(ptrs.size(), num_buffers) << "Wrong number of pointers returned for VRAM";

for (size_t i = 0; i < ptrs.size(); ++i) {
EXPECT_NE(ptrs[i], nullptr) << "Buffer " << i << " returned null pointer";
}

getAgent(0).releasedDlistH(dlist_hndl);
deregisterMem(getAgent(0), buffers_local, VRAM_SEG);
deregisterMem(getAgent(1), buffers_remote, VRAM_SEG);
}

NIXL_INSTANTIATE_TEST(ucx, TestTransfer, "UCX", true, 2, 0, "");
NIXL_INSTANTIATE_TEST(ucx_no_pt, TestTransfer, "UCX", false, 2, 0, "");
NIXL_INSTANTIATE_TEST(ucx_threadpool, TestTransfer, "UCX", true, 6, 4, "");
Expand Down