diff --git a/.github/workflows/build-release.yaml b/.github/workflows/build-release.yaml index 63c3637..a146c02 100644 --- a/.github/workflows/build-release.yaml +++ b/.github/workflows/build-release.yaml @@ -111,6 +111,8 @@ jobs: permissions: id-token: write name: publish + env: + CMAKE_POLICY_VERSION_MINIMUM: 3.5 steps: - name: Free Disk Space uses: jlumbroso/free-disk-space@main @@ -131,7 +133,6 @@ jobs: gcc \ g++ \ make \ - cmake \ binutils \ libibverbs-dev \ librdmacm-dev \ @@ -143,6 +144,24 @@ jobs: libhiredis-dev \ liburing-dev + # install cmake + cd /tmp && \ + wget -q https://cmake.org/files/v4.1/cmake-4.1.3-linux-x86_64.sh && \ + sudo bash cmake-4.1.3-linux-x86_64.sh --skip-license --prefix=/usr && \ + rm cmake-4.1.3-linux-x86_64.sh && \ + cd - + + # libevhtp uses TestEndianess.c.in + sudo ln -s /usr/share/cmake-4.1/Modules/TestEndianness.c.in /usr/share/cmake-4.1/Modules/TestEndianess.c.in + + # install ucx + cd /tmp && \ + wget -q https://github.com/openucx/ucx/releases/download/v1.19.0/ucx-1.19.0-ubuntu22.04-mofed5-cuda12-x86_64.tar.bz2 && \ + tar xvf ucx-1.19.0-ubuntu22.04-mofed5-cuda12-x86_64.tar.bz2 && \ + sudo dpkg -i ucx-1.19.0.deb ucx-cuda-1.19.0.deb ucx-xpmem-1.19.0.deb && \ + rm /tmp/ucx-1.19.0-ubuntu22.04-mofed5-cuda12-x86_64.tar.bz2 /tmp/ucx-1.19.0.deb /tmp/ucx-cuda-1.19.0.deb /tmp/ucx-xpmem-1.19.0.deb && \ + cd - + # install python dependencies pip install -r requirements.txt - name: Build diff --git a/.github/workflows/build-test.yaml b/.github/workflows/build-test.yaml index e01f916..dd0c3d3 100644 --- a/.github/workflows/build-test.yaml +++ b/.github/workflows/build-test.yaml @@ -44,6 +44,8 @@ jobs: matrix: os: [ubuntu-22.04] python-version: ["3.11", "3.12"] + env: + CMAKE_POLICY_VERSION_MINIMUM: 3.5 steps: - name: Free Disk Space uses: jlumbroso/free-disk-space@main @@ -71,7 +73,6 @@ jobs: gcc \ g++ \ make \ - cmake \ binutils \ libibverbs-dev \ librdmacm-dev \ @@ -83,6 +84,24 @@ jobs: libhiredis-dev \ liburing-dev + # install cmake + cd /tmp && \ + wget -q https://cmake.org/files/v4.1/cmake-4.1.3-linux-x86_64.sh && \ + sudo bash cmake-4.1.3-linux-x86_64.sh --skip-license --prefix=/usr && \ + rm cmake-4.1.3-linux-x86_64.sh && \ + cd - + + # libevhtp uses TestEndianess.c.in + sudo ln -s /usr/share/cmake-4.1/Modules/TestEndianness.c.in /usr/share/cmake-4.1/Modules/TestEndianess.c.in + + # install ucx + cd /tmp && \ + wget -q https://github.com/openucx/ucx/releases/download/v1.19.0/ucx-1.19.0-ubuntu22.04-mofed5-cuda12-x86_64.tar.bz2 && \ + tar xvf ucx-1.19.0-ubuntu22.04-mofed5-cuda12-x86_64.tar.bz2 && \ + sudo dpkg -i ucx-1.19.0.deb ucx-cuda-1.19.0.deb ucx-xpmem-1.19.0.deb && \ + rm /tmp/ucx-1.19.0-ubuntu22.04-mofed5-cuda12-x86_64.tar.bz2 /tmp/ucx-1.19.0.deb /tmp/ucx-cuda-1.19.0.deb /tmp/ucx-xpmem-1.19.0.deb && \ + cd - + # install clang-format-19 sudo curl -fsSL --retry 3 --retry-delay 2 https://github.com/muttleyxd/clang-tools-static-binaries/releases/download/master-46b8640/clang-format-19_linux-amd64 -o /usr/bin/clang-format-19 || { echo "ERROR: Failed to download clang-format-19"; exit 1; } test -s /usr/bin/clang-format-19 || { echo "ERROR: Downloaded file is empty"; exit 1; } diff --git a/README.md b/README.md index d5ffee8..f4c25be 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,10 @@ [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE) PrisKV is specifically designed for modern high-performance computing (HPC) and -artificial intelligence (AI) computing. It solely supports RDMA. PrisKV also -supports GDR (GPU Direct RDMA), enabling the value of a key to be directly -transferred between PrisKV and the GPU. +artificial intelligence (AI) computing. It supports common transport protocols, +including RDMA, TCP, and shared memory, to enable efficient communication for +different scenarios. PrisKV also supports GDR (GPU Direct RDMA), enabling the +value of a key to be directly transferred between PrisKV and the GPU. ## How to Build diff --git a/client/Makefile b/client/Makefile index 0ba9c03..2539cb7 100644 --- a/client/Makefile +++ b/client/Makefile @@ -33,7 +33,7 @@ VALKEY_LDFLAGS = -I$(VALKEY_INCLUDE_PATH) PRISKV_TARGETS_ALL = priskv-client priskv-benchmark priskv-example priskv-test_runtime PRISKV_TARGETS = priskv-client priskv-benchmark PRISKV_TARGETS_SRCS = $(patsubst priskv-%, %.c, $(PRISKV_TARGETS_ALL)) -SRCS = $(filter-out $(PRISKV_TARGETS_SRCS) $(VALKEY_BENCHMARK_SRC), $(wildcard *.c)) +SRCS = $(filter-out $(PRISKV_TARGETS_SRCS) $(VALKEY_BENCHMARK_SRC), $(wildcard *.c transport/*.c)) OBJS := $(SRCS:%.c=%.o) DEPS := $(OBJS:%.o=%.d) @@ -47,6 +47,9 @@ LIBIBVERBS_VERSION_MAJOR = $(shell echo ${LIBIBVERBS_VERSION} | cut -d "." -f 1) LIBIBVERBS_VERSION_MINOR = $(shell echo ${LIBIBVERBS_VERSION} | cut -d "." -f 2) RDMA_LDFLAGS = -lrdmacm -libverbs -DLIBIBVERBS_VERSION_MAJOR=$(LIBIBVERBS_VERSION_MAJOR) -DLIBIBVERBS_VERSION_MINOR=$(LIBIBVERBS_VERSION_MINOR) +UCX_CFLAGS = $(shell pkg-config --cflags ucx) +UCX_LDFLAGS = $(shell pkg-config --libs ucx) + CUDA_HOME ?= /usr/local/cuda CUDA_LIB ?= $(CUDA_HOME)/lib64 CUDA_INC ?= $(CUDA_HOME)/include @@ -102,13 +105,13 @@ $(THIRDPARTY_SRCS): git submodule update --init --recursive ../thirdparty/linenoise $(PRISKV_TARGETS): priskv-%: %.c $(STATIC_LIB) $(THIRDPARTY_SRCS) - $(CC) $< $(THIRDPARTY_SRCS) -o $@ $(STATIC_LIB) $(CFLAGS) $(RDMA_LDFLAGS) $(THIRDPARTY_INCFLAGS) + $(CC) $< $(THIRDPARTY_SRCS) -o $@ $(STATIC_LIB) $(CFLAGS) $(RDMA_LDFLAGS) $(UCX_CFLAGS) $(UCX_LDFLAGS) $(THIRDPARTY_INCFLAGS) $(VALKEY_STATIC_LIB): cd $(VALKEY_PATH) && USE_RDMA=1 make $(VALKEY_BENCHMARK_NAME): $(VALKEY_BENCHMARK_SRC) $(STATIC_LIB) $(VALKEY_STATIC_LIB) - $(CC) $^ $(CFLAGS) $(RDMA_LDFLAGS) $(CUDA_LDFLAGS) $(VALKEY_LDFLAGS) -o $@ + $(CC) $^ $(CFLAGS) $(RDMA_LDFLAGS) $(UCX_CFLAGS) $(UCX_LDFLAGS) $(CUDA_LDFLAGS) $(VALKEY_LDFLAGS) -o $@ install: $(PRISKV_TARGETS) install -m 755 -d $(PRISKV_DESTDIR)/$(PRISKV_BINPATH) @@ -129,3 +132,4 @@ clean: format: $(FMT) -i *.c *.h + $(FMT) -i transport/*.c transport/*.h diff --git a/client/benchmark.c b/client/benchmark.c index e6fa031..c8954c9 100644 --- a/client/benchmark.c +++ b/client/benchmark.c @@ -132,6 +132,7 @@ typedef struct { int (*get_fd)(void *ctx); void (*handler)(int fd, void *opaque, uint32_t events); const char *(*status_str)(int status); + bool (*is_error)(int status); void (*get)(void *ctx, const char *key, void *value, uint32_t value_len, void (*cb)(int, void *), void *cbarg); void (*set)(void *ctx, const char *key, void *value, uint32_t value_len, @@ -1049,6 +1050,12 @@ static void priskv_drv_test(void *ctx, const char *key, void (*cb)(int, void *), priskv_test_async(priskv_ctx->client, key, (uint64_t)priskv_req_ctx, priskv_req_cb); } +static bool priskv_drv_is_error(int status) +{ + return status != PRISKV_STATUS_OK && status != PRISKV_STATUS_NO_SUCH_KEY && + status != PRISKV_STATUS_KEY_UPDATING; +} + static const kv_driver priskv_drv = { .name = "priskv", .transfer = false, @@ -1057,6 +1064,7 @@ static const kv_driver priskv_drv = { .get_fd = priskv_drv_get_fd, .handler = priskv_drv_handler, .status_str = priskv_drv_status_str, + .is_error = priskv_drv_is_error, .get = priskv_drv_get, .set = priskv_drv_set, .del = priskv_drv_del, @@ -1071,6 +1079,7 @@ static const kv_driver priskv_drv_transfer = { .get_fd = priskv_drv_get_fd, .handler = priskv_drv_handler, .status_str = priskv_drv_status_str, + .is_error = priskv_drv_is_error, .get = priskv_drv_get_transfer, .set = priskv_drv_set_transfer, .del = priskv_drv_del, @@ -1161,7 +1170,7 @@ static void job_cb(int status, void *arg) { job_context *job = arg; - if (status) { + if (job->kv_drv->is_error(status)) { job_set_error(job, "resp status[%d]: %s", status, job->kv_drv->status_str(status)); } @@ -1647,9 +1656,6 @@ static int job_init(job_context *job, int threadid) } job->threadid = threadid; - priskv_set_fd_handler(job->epollfd, job_process, NULL, job); - priskv_thread_add_event_handler(priskv_threadpool_get_iothread(g_threadpool, job->threadid), - job->epollfd); printf("job[%d] finish to init\n", threadid); return 0; @@ -1677,6 +1683,9 @@ static void job_kick(job_context *job) { uint64_t u = 1; + priskv_set_fd_handler(job->epollfd, job_process, NULL, job); + priskv_thread_add_event_handler(priskv_threadpool_get_iothread(g_threadpool, job->threadid), + job->epollfd); write(job->eventfd, &u, sizeof(u)); } diff --git a/client/client.c b/client/client.c index ccafef6..edb1e7d 100644 --- a/client/client.c +++ b/client/client.c @@ -37,6 +37,7 @@ #include "priskv-log.h" #include "priskv-logo.h" #include "linenoise.h" +#include "transport/transport.h" #define VALUE_MAX_LEN (64 * 1024) diff --git a/client/priskv.h b/client/priskv.h index afaf4e3..8e36e0b 100644 --- a/client/priskv.h +++ b/client/priskv.h @@ -44,7 +44,8 @@ typedef struct priskv_memory priskv_memory; * @lport: local port. Ignore port on NULL @laddr * @nqueue: the number of worker threads */ -priskv_client *priskv_connect(const char *raddr, int rport, const char *laddr, int lport, int nqueue); +priskv_client *priskv_connect(const char *raddr, int rport, const char *laddr, int lport, + int nqueue); /* Close a client context created by @priskv_connect */ void priskv_close(priskv_client *client); @@ -52,8 +53,8 @@ void priskv_close(priskv_client *client); /* Get memory handler * @offset: uint64_t type on valid fd(>= 0); or void * type on invalid fd. */ -priskv_memory *priskv_reg_memory(priskv_client *client, uint64_t offset, size_t length, uint64_t iova, - int fd); +priskv_memory *priskv_reg_memory(priskv_client *client, uint64_t offset, size_t length, + uint64_t iova, int fd); /* Put a memory handler * @mem: created by @priskv_reg_memory @@ -122,8 +123,8 @@ typedef enum priskv_status { /* RDMA disconnected from the server side */ PRISKV_STATUS_DISCONNECTED = 0xF00, - /* local RDMA error occurs */ - PRISKV_STATUS_RDMA_ERROR, + /* local transport error occurs */ + PRISKV_STATUS_TRANSPORT_ERROR, /* does inflight requests exceed @max_inflight_command? */ PRISKV_STATUS_BUSY, @@ -180,8 +181,8 @@ static inline const char *priskv_status_str(priskv_status status) case PRISKV_STATUS_DISCONNECTED: return "Disconnected"; - case PRISKV_STATUS_RDMA_ERROR: - return "RDMA error"; + case PRISKV_STATUS_TRANSPORT_ERROR: + return "Transport error"; case PRISKV_STATUS_BUSY: return "Busy"; @@ -218,30 +219,31 @@ typedef void (*priskv_generic_cb)(uint64_t request_id, priskv_status status, voi * PRISKV_STATUS_VALUE_TOO_BIG as soon as possible. */ int priskv_get_async(priskv_client *client, const char *key, priskv_sgl *sgl, uint16_t nsgl, - uint64_t request_id, priskv_generic_cb cb); + uint64_t request_id, priskv_generic_cb cb); /* Set value of a key */ int priskv_set_async(priskv_client *client, const char *key, priskv_sgl *sgl, uint16_t nsgl, - uint64_t timeout, uint64_t request_id, priskv_generic_cb cb); + uint64_t timeout, uint64_t request_id, priskv_generic_cb cb); /* Test a key-value exist or not */ -int priskv_test_async(priskv_client *client, const char *key, uint64_t request_id, priskv_generic_cb cb); +int priskv_test_async(priskv_client *client, const char *key, uint64_t request_id, + priskv_generic_cb cb); /* Delete a key-value exist or not */ int priskv_delete_async(priskv_client *client, const char *key, uint64_t request_id, - priskv_generic_cb cb); + priskv_generic_cb cb); /* Set expire time of a key */ -int priskv_expire_async(priskv_client *client, const char *key, uint64_t timeout, uint64_t request_id, - priskv_generic_cb cb); +int priskv_expire_async(priskv_client *client, const char *key, uint64_t timeout, + uint64_t request_id, priskv_generic_cb cb); /* Get the number of keys which match the @regex */ int priskv_nrkeys_async(priskv_client *client, const char *regex, uint64_t request_id, - priskv_generic_cb cb); + priskv_generic_cb cb); /* Flush the keys which match the @regex */ int priskv_flush_async(priskv_client *client, const char *regex, uint64_t request_id, - priskv_generic_cb cb); + priskv_generic_cb cb); /* for *KEYS* command */ typedef struct priskv_key { @@ -259,13 +261,14 @@ void priskv_keyset_free(priskv_keyset *keyset); /* Get the keys which match the @regex and return the result in priskv_generic_cb */ int priskv_keys_async(priskv_client *client, const char *regex, uint64_t request_id, - priskv_generic_cb cb); + priskv_generic_cb cb); /* sync APIs */ int priskv_get(priskv_client *client, const char *key, priskv_sgl *sgl, uint16_t nsgl, - uint32_t *valuelen); + uint32_t *valuelen); -int priskv_set(priskv_client *client, const char *key, priskv_sgl *sgl, uint16_t nsgl, uint64_t timeout); +int priskv_set(priskv_client *client, const char *key, priskv_sgl *sgl, uint16_t nsgl, + uint64_t timeout); int priskv_test(priskv_client *client, const char *key, uint32_t *valuelen); diff --git a/client/rdma.h b/client/rdma.h deleted file mode 100644 index 5cfa3cf..0000000 --- a/client/rdma.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2025 ByteDance Ltd. and/or its affiliates -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* - * Authors: - * Jinlong Xuan <15563983051@163.com> - * Xu Ji - * Yu Wang - * Bo Liu - * Zhenwei Pi - * Rui Zhang - * Changqi Lu - * Enhua Zhou - */ - -#ifndef __PRISKV_CLIENT_RDMA__ -#define __PRISKV_CLIENT_RDMA__ - -#endif /* __PRISKV_CLIENT_RDMA__ */ diff --git a/client/sync.c b/client/sync.c index 6fe17a0..cf92463 100644 --- a/client/sync.c +++ b/client/sync.c @@ -34,22 +34,22 @@ #include "priskv-log.h" #include "priskv.h" -typedef struct priskv_rdma_req_sync { +typedef struct priskv_transport_req_sync { priskv_status status; uint32_t valuelen; bool done; -} priskv_rdma_req_sync; +} priskv_transport_req_sync; static void priskv_common_sync_cb(uint64_t request_id, priskv_status status, void *result) { - priskv_rdma_req_sync *rdma_req_sync = (priskv_rdma_req_sync *)request_id; + priskv_transport_req_sync *req_sync = (priskv_transport_req_sync *)request_id; uint32_t valuelen = result ? *(uint32_t *)result : 0; - priskv_log_debug("RDMA: callback request_id 0x%lx, status: %s[0x%x], length %d\n", request_id, + priskv_log_debug("priskv_common_sync_cb: callback request_id 0x%lx, status: %s[0x%x], length %d\n", request_id, priskv_resp_status_str(status), status, valuelen); - rdma_req_sync->status = status; - rdma_req_sync->valuelen = valuelen; - rdma_req_sync->done = true; + req_sync->status = status; + req_sync->valuelen = valuelen; + req_sync->done = true; } static inline int priskv_sync_wait(priskv_client *client, bool *done) @@ -63,89 +63,89 @@ static inline int priskv_sync_wait(priskv_client *client, bool *done) int priskv_get(priskv_client *client, const char *key, priskv_sgl *sgl, uint16_t nsgl, uint32_t *valuelen) { - priskv_rdma_req_sync rdma_req_sync = {.status = 0xffff, .done = false}; + priskv_transport_req_sync req_sync = {.status = 0xffff, .done = false}; - priskv_get_async(client, key, sgl, nsgl, (uint64_t)&rdma_req_sync, priskv_common_sync_cb); - priskv_sync_wait(client, &rdma_req_sync.done); - *valuelen = rdma_req_sync.valuelen; + priskv_get_async(client, key, sgl, nsgl, (uint64_t)&req_sync, priskv_common_sync_cb); + priskv_sync_wait(client, &req_sync.done); + *valuelen = req_sync.valuelen; - return rdma_req_sync.status; + return req_sync.status; } int priskv_set(priskv_client *client, const char *key, priskv_sgl *sgl, uint16_t nsgl, uint64_t timeout) { - priskv_rdma_req_sync rdma_req_sync = {.status = 0xffff, .done = false}; + priskv_transport_req_sync req_sync = {.status = 0xffff, .done = false}; - priskv_set_async(client, key, sgl, nsgl, timeout, (uint64_t)&rdma_req_sync, priskv_common_sync_cb); - priskv_sync_wait(client, &rdma_req_sync.done); + priskv_set_async(client, key, sgl, nsgl, timeout, (uint64_t)&req_sync, priskv_common_sync_cb); + priskv_sync_wait(client, &req_sync.done); - return rdma_req_sync.status; + return req_sync.status; } int priskv_test(priskv_client *client, const char *key, uint32_t *valuelen) { - priskv_rdma_req_sync rdma_req_sync = {.status = 0xffff, .done = false}; + priskv_transport_req_sync req_sync = {.status = 0xffff, .done = false}; - priskv_test_async(client, key, (uint64_t)&rdma_req_sync, priskv_common_sync_cb); - priskv_sync_wait(client, &rdma_req_sync.done); - *valuelen = rdma_req_sync.valuelen; + priskv_test_async(client, key, (uint64_t)&req_sync, priskv_common_sync_cb); + priskv_sync_wait(client, &req_sync.done); + *valuelen = req_sync.valuelen; - return rdma_req_sync.status; + return req_sync.status; } int priskv_delete(priskv_client *client, const char *key) { - priskv_rdma_req_sync rdma_req_sync = {.status = 0xffff, .done = false}; + priskv_transport_req_sync req_sync = {.status = 0xffff, .done = false}; - priskv_delete_async(client, key, (uint64_t)&rdma_req_sync, priskv_common_sync_cb); - priskv_sync_wait(client, &rdma_req_sync.done); + priskv_delete_async(client, key, (uint64_t)&req_sync, priskv_common_sync_cb); + priskv_sync_wait(client, &req_sync.done); - return rdma_req_sync.status; + return req_sync.status; } int priskv_expire(priskv_client *client, const char *key, uint64_t timeout) { - priskv_rdma_req_sync rdma_req_sync = {.status = 0xffff, .done = false}; + priskv_transport_req_sync req_sync = {.status = 0xffff, .done = false}; - priskv_expire_async(client, key, timeout, (uint64_t)&rdma_req_sync, priskv_common_sync_cb); - priskv_sync_wait(client, &rdma_req_sync.done); + priskv_expire_async(client, key, timeout, (uint64_t)&req_sync, priskv_common_sync_cb); + priskv_sync_wait(client, &req_sync.done); - return rdma_req_sync.status; + return req_sync.status; } int priskv_nrkeys(priskv_client *client, const char *regex, uint32_t *nkey) { - priskv_rdma_req_sync rdma_req_sync = {.status = 0xffff, .done = false}; + priskv_transport_req_sync req_sync = {.status = 0xffff, .done = false}; - priskv_nrkeys_async(client, regex, (uint64_t)&rdma_req_sync, priskv_common_sync_cb); - priskv_sync_wait(client, &rdma_req_sync.done); - *nkey = rdma_req_sync.valuelen; + priskv_nrkeys_async(client, regex, (uint64_t)&req_sync, priskv_common_sync_cb); + priskv_sync_wait(client, &req_sync.done); + *nkey = req_sync.valuelen; - return rdma_req_sync.status; + return req_sync.status; } int priskv_flush(priskv_client *client, const char *regex, uint32_t *nkey) { - priskv_rdma_req_sync rdma_req_sync = {.status = 0xffff, .done = false}; + priskv_transport_req_sync req_sync = {.status = 0xffff, .done = false}; - priskv_flush_async(client, regex, (uint64_t)&rdma_req_sync, priskv_common_sync_cb); - priskv_sync_wait(client, &rdma_req_sync.done); - *nkey = rdma_req_sync.valuelen; + priskv_flush_async(client, regex, (uint64_t)&req_sync, priskv_common_sync_cb); + priskv_sync_wait(client, &req_sync.done); + *nkey = req_sync.valuelen; - return rdma_req_sync.status; + return req_sync.status; } -typedef struct priskv_rdma_keys_sync { +typedef struct priskv_transport_keys_sync { priskv_status status; bool done; priskv_keyset **keyset; -} priskv_rdma_keys_sync; +} priskv_transport_keys_sync; static void priskv_keys_sync_cb(uint64_t request_id, priskv_status status, void *result) { - priskv_rdma_keys_sync *keys_req_sync = (priskv_rdma_keys_sync *)request_id; + priskv_transport_keys_sync *keys_req_sync = (priskv_transport_keys_sync *)request_id; - priskv_log_debug("RDMA: callback request_id 0x%lx, status: %s[0x%x]\n", request_id, + priskv_log_debug("priskv_keys_sync_cb: callback request_id 0x%lx, status: %s[0x%x]\n", request_id, priskv_resp_status_str(status), status); keys_req_sync->status = status; keys_req_sync->done = true; @@ -157,7 +157,7 @@ static void priskv_keys_sync_cb(uint64_t request_id, priskv_status status, void int priskv_keys(priskv_client *client, const char *regex, priskv_keyset **keyset) { - priskv_rdma_keys_sync keys_req_sync = {0}; + priskv_transport_keys_sync keys_req_sync = {0}; keys_req_sync.keyset = keyset; priskv_keys_async(client, regex, (uint64_t)&keys_req_sync, priskv_keys_sync_cb); diff --git a/client/rdma.c b/client/transport/rdma.c similarity index 59% rename from client/rdma.c rename to client/transport/rdma.c index 91aa641..826a6b1 100644 --- a/client/rdma.c +++ b/client/transport/rdma.c @@ -37,6 +37,7 @@ #include #include +#include "../priskv.h" #include "priskv-threads.h" #include "priskv-event.h" #include "priskv-workqueue.h" @@ -44,174 +45,30 @@ #include "priskv-protocol-helper.h" #include "priskv-utils.h" #include "priskv-log.h" -#include "priskv.h" #include "list.h" +#include "transport.h" #define PRISKV_RDMA_DEFAULT_INFLIGHT_COMMAND 128 -#define PRISKV_RDMA_DEF_ADDR(id) \ - char local_addr[PRISKV_ADDR_LEN] = {0}; \ - char peer_addr[PRISKV_ADDR_LEN] = {0}; \ - priskv_inet_ntop(rdma_get_local_addr(id), local_addr); \ +#define PRISKV_RDMA_DEF_ADDR(id) \ + char local_addr[PRISKV_ADDR_LEN] = {0}; \ + char peer_addr[PRISKV_ADDR_LEN] = {0}; \ + priskv_inet_ntop(rdma_get_local_addr(id), local_addr); \ priskv_inet_ntop(rdma_get_peer_addr(id), peer_addr); -typedef struct priskv_rdma_mem priskv_rdma_mem; -typedef struct priskv_conn_operation priskv_conn_operation; -typedef enum priskv_rdma_mem_type priskv_rdma_mem_type; -typedef struct priskv_connect_param priskv_connect_param; -typedef struct priskv_rdma_conn priskv_rdma_conn; -typedef struct priskv_rdma_req priskv_rdma_req; -typedef struct priskv_sgl_private priskv_sgl_private; - -struct priskv_rdma_mem { -#define PRISKV_RDMA_MEM_NAME_LEN 32 - char name[PRISKV_RDMA_MEM_NAME_LEN]; - uint8_t *buf; - uint32_t buf_size; - struct ibv_mr *mr; -}; - -enum priskv_rdma_mem_type { - PRISKV_RDMA_MEM_REQ, - PRISKV_RDMA_MEM_RESP, - PRISKV_RDMA_MEM_KEYS, - - PRISKV_RDMA_MEM_MAX -}; - -struct priskv_connect_param { - /* the maxium count of @priskv_sgl */ - uint16_t max_sgl; - /* the maxium length of a KEY in bytes */ - uint16_t max_key_length; - /* the maxium command in flight, aka depth of commands */ - uint16_t max_inflight_command; -}; - -struct priskv_memory { - priskv_client *client; - int count; - struct ibv_mr **mrs; -}; - -struct priskv_rdma_conn { - struct rdma_cm_id *cm_id; - struct rdma_event_channel *cm_channel; - struct ibv_comp_channel *comp_channel; - struct ibv_cq *cq; - struct ibv_qp *qp; - - uint8_t id; - priskv_thread *thread; - - priskv_rdma_mem rmem[PRISKV_RDMA_MEM_MAX]; - - priskv_connect_param param; - uint64_t capacity; - int epollfd; - bool established; - struct list_head inflight_list; - struct list_head complete_list; - - priskv_rdma_req *keys_running_req; - priskv_memory keys_mems; - - uint64_t stats[PRISKV_COMMAND_MAX]; - uint64_t resps; - uint64_t wc_recv; - uint64_t wc_send; -}; - -struct priskv_client { - priskv_threadpool *pool; - priskv_rdma_conn **conns; - int nqueue; - int cur_conn; - int epollfd; - priskv_workqueue *wq; - priskv_conn_operation *ops; -}; - -struct priskv_sgl_private { - priskv_sgl sgl; - /* used for automatic registration memory */ - struct ibv_mr *mr; -}; - -struct priskv_rdma_req { - priskv_rdma_conn *conn; - priskv_conn_operation *ops; - priskv_workqueue *main_wq; - struct list_node entry; - priskv_request *req; - uint64_t request_id; - char *key; - priskv_sgl_private *sgl; - uint16_t nsgl; - uint16_t keylen; - uint64_t timeout; - priskv_req_command cmd; - void (*cb)(struct priskv_rdma_req *rdma_req); - priskv_generic_cb usercb; -#define PRISKV_RDMA_REQ_FLAG_SEND (1 << 0) -#define PRISKV_RDMA_REQ_FLAG_RECV (1 << 2) -#define PRISKV_RDMA_REQ_FLAG_DONE (PRISKV_RDMA_REQ_FLAG_SEND | PRISKV_RDMA_REQ_FLAG_RECV) - uint8_t flags; - uint16_t status; - uint32_t length; - void *result; - bool delaying; -}; - -struct priskv_conn_operation { - int (*init)(priskv_client *client, const char *raddr, int rport, const char *laddr, int lport, - int nqueue); - void (*deinit)(priskv_client *client); - priskv_rdma_conn *(*select_conn)(priskv_client *client); - priskv_memory *(*reg_memory)(priskv_client *client, uint64_t offset, size_t length, uint64_t iova, - int fd); - void (*dereg_memory)(priskv_memory *mem); - struct ibv_mr *(*get_mr)(priskv_memory *mem, int connid); - void (*rdma_req_submit)(priskv_rdma_req *rdma_req); - void (*rdma_req_cb)(priskv_rdma_req *rdma_req); -}; - -static int priskv_rdma_handle_cq(priskv_rdma_conn *conn); -static int _priskv_rdma_req_cb(void *arg); +static int priskv_rdma_handle_cq(priskv_transport_conn *conn); +static int priskv_rdma_req_cb_intl(void *arg); static int priskv_rdma_req_send(void *arg); -static inline void priskv_rdma_req_free(priskv_rdma_req *rdma_req); -static inline void priskv_rdma_req_complete(priskv_rdma_conn *conn); -static inline void priskv_rdma_req_reset(priskv_rdma_req *rdma_req); - -static int priskv_build_check(void) -{ - PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_OK != (int)PRISKV_RESP_STATUS_OK); - PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_INVALID_COMMAND != (int)PRISKV_RESP_STATUS_INVALID_COMMAND); - PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_KEY_EMPTY != (int)PRISKV_RESP_STATUS_KEY_EMPTY); - PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_KEY_TOO_BIG != (int)PRISKV_RESP_STATUS_KEY_TOO_BIG); - PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_VALUE_EMPTY != (int)PRISKV_RESP_STATUS_VALUE_EMPTY); - PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_VALUE_TOO_BIG != (int)PRISKV_RESP_STATUS_VALUE_TOO_BIG); - PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_NO_SUCH_COMMAND != (int)PRISKV_RESP_STATUS_NO_SUCH_COMMAND); - PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_NO_SUCH_KEY != (int)PRISKV_RESP_STATUS_NO_SUCH_KEY); - PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_INVALID_SGL != (int)PRISKV_RESP_STATUS_INVALID_SGL); - PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_INVALID_REGEX != (int)PRISKV_RESP_STATUS_INVALID_REGEX); - PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_KEY_UPDATING != (int)PRISKV_RESP_STATUS_KEY_UPDATING); - PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_CONNECT_ERROR != (int)PRISKV_RESP_STATUS_CONNECT_ERROR); - PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_SERVER_ERROR != (int)PRISKV_RESP_STATUS_SERVER_ERROR); - PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_NO_MEM != (int)PRISKV_RESP_STATUS_NO_MEM); - return 0; -} +static inline void priskv_rdma_req_free(priskv_transport_req *rdma_req); +static inline void priskv_rdma_req_complete(priskv_transport_conn *conn); +static inline void priskv_rdma_req_reset(priskv_transport_req *rdma_req); +static inline priskv_transport_req * +priskv_rdma_req_new(priskv_client *client, priskv_transport_conn *conn, uint64_t request_id, + const char *key, uint16_t keylen, priskv_sgl *sgl, uint16_t nsgl, + uint64_t timeout, priskv_req_command cmd, priskv_generic_cb usercb); -/* use 64 bytes aligned request buffer. */ -static inline unsigned int priskv_request_size_aligend(priskv_rdma_conn *conn) -{ - uint16_t s = priskv_request_size(conn->param.max_sgl, conn->param.max_key_length); - - return ALIGN_UP(s, 64); -} - -static int priskv_rdma_mem_new(priskv_rdma_conn *conn, priskv_rdma_mem *rmem, const char *name, - uint32_t size, bool remote_write) +static int priskv_rdma_mem_new(priskv_transport_conn *conn, priskv_transport_mem *rmem, + const char *name, uint32_t size, bool remote_write) { uint32_t page_size = getpagesize(); uint8_t *buf; @@ -237,7 +94,7 @@ static int priskv_rdma_mem_new(priskv_rdma_conn *conn, priskv_rdma_mem *rmem, co goto free_mem; } - strncpy(rmem->name, name, PRISKV_RDMA_MEM_NAME_LEN - 1); + strncpy(rmem->name, name, PRISKV_TRANSPORT_MEM_NAME_LEN - 1); rmem->buf = buf; rmem->buf_size = size; @@ -249,12 +106,12 @@ static int priskv_rdma_mem_new(priskv_rdma_conn *conn, priskv_rdma_mem *rmem, co munmap(rmem->buf, rmem->buf_size); error: - memset(rmem, 0x00, sizeof(priskv_rdma_mem)); + memset(rmem, 0x00, sizeof(priskv_transport_mem)); return ret; } -static void priskv_rdma_mem_free(priskv_rdma_conn *conn, priskv_rdma_mem *rmem) +static void priskv_rdma_mem_free(priskv_transport_conn *conn, priskv_transport_mem *rmem) { if (rmem->mr) { ibv_dereg_mr(rmem->mr); @@ -266,58 +123,61 @@ static void priskv_rdma_mem_free(priskv_rdma_conn *conn, priskv_rdma_mem *rmem) } priskv_log_info("RDMA: free rmem %s, size %d\n", rmem->name, rmem->buf_size); - memset(rmem, 0x00, sizeof(priskv_rdma_mem)); + memset(rmem, 0x00, sizeof(priskv_transport_mem)); } -static inline void priskv_rdma_mem_free_all(priskv_rdma_conn *conn) +static inline void priskv_rdma_mem_free_all(priskv_transport_conn *conn) { - for (int i = 0; i < PRISKV_RDMA_MEM_MAX; i++) { - priskv_rdma_mem *rmem = &conn->rmem[i]; + for (int i = 0; i < PRISKV_TRANSPORT_MEM_MAX; i++) { + priskv_transport_mem *rmem = &conn->rmem[i]; priskv_rdma_mem_free(conn, rmem); } } #define PRISKV_RDMA_REQUEST_FREE_COMMAND 0xffff -static void priskv_request_free(priskv_request *req, priskv_rdma_conn *conn) +static void priskv_rdma_request_free(priskv_request *req, priskv_transport_conn *conn) { uint8_t *ptr = (uint8_t *)req; - priskv_rdma_mem *rmem = &conn->rmem[PRISKV_RDMA_MEM_REQ]; + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_REQ]; assert(ptr >= rmem->buf); assert(ptr < rmem->buf + rmem->buf_size); - assert(!((ptr - rmem->buf) % priskv_request_size_aligend(conn))); + assert(!((ptr - rmem->buf) % priskv_rdma_max_request_size_aligned(conn->param.max_sgl, + conn->param.max_key_length))); req->command = PRISKV_RDMA_REQUEST_FREE_COMMAND; } -static int priskv_rdma_mem_new_all(priskv_rdma_conn *conn) +static int priskv_rdma_mem_new_all(priskv_transport_conn *conn) { uint32_t page_size = getpagesize(), size; /* #step 1, prepare buffer & MR for request to server */ - int reqsize = priskv_request_size_aligend(conn); + int reqsize = + priskv_rdma_max_request_size_aligned(conn->param.max_sgl, conn->param.max_key_length); size = reqsize * conn->param.max_inflight_command; - if (priskv_rdma_mem_new(conn, &conn->rmem[PRISKV_RDMA_MEM_REQ], "Request", size, false)) { + if (priskv_rdma_mem_new(conn, &conn->rmem[PRISKV_TRANSPORT_MEM_REQ], "Request", size, false)) { goto error; } /* additional work: set priskv_request::command as PRISKV_RDMA_REQUEST_FREE_COMMAND */ - priskv_rdma_mem *rmem = &conn->rmem[PRISKV_RDMA_MEM_REQ]; + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_REQ]; for (uint16_t i = 0; i < conn->param.max_inflight_command; i++) { priskv_request *req = (priskv_request *)(rmem->buf + i * reqsize); - priskv_request_free(req, conn); + priskv_rdma_request_free(req, conn); } /* #step 2, prepare buffer & MR for response from server */ size = sizeof(priskv_response) * conn->param.max_inflight_command; - if (priskv_rdma_mem_new(conn, &conn->rmem[PRISKV_RDMA_MEM_RESP], "Response", size, false)) { + if (priskv_rdma_mem_new(conn, &conn->rmem[PRISKV_TRANSPORT_MEM_RESP], "Response", size, + false)) { goto error; } /* #step 3, prepare buffer & MR for keys */ size = page_size; - if (priskv_rdma_mem_new(conn, &conn->rmem[PRISKV_RDMA_MEM_KEYS], "Keys", size, true)) { + if (priskv_rdma_mem_new(conn, &conn->rmem[PRISKV_TRANSPORT_MEM_KEYS], "Keys", size, true)) { goto error; } @@ -329,10 +189,11 @@ static int priskv_rdma_mem_new_all(priskv_rdma_conn *conn) return -ENOMEM; } -static priskv_request *priskv_rdma_unused_command(priskv_rdma_conn *conn, uint16_t *idx) +static priskv_request *priskv_rdma_unused_command(priskv_transport_conn *conn, uint16_t *idx) { - uint16_t req_buf_size = priskv_request_size_aligend(conn); - priskv_rdma_mem *rmem = &conn->rmem[PRISKV_RDMA_MEM_REQ]; + uint16_t req_buf_size = + priskv_rdma_max_request_size_aligned(conn->param.max_sgl, conn->param.max_key_length); + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_REQ]; for (uint16_t i = 0; i < conn->param.max_inflight_command; i++) { priskv_request *req = (priskv_request *)(rmem->buf + i * req_buf_size); @@ -347,27 +208,27 @@ static priskv_request *priskv_rdma_unused_command(priskv_rdma_conn *conn, uint16 return NULL; } -static void priskv_rdma_close_conn(priskv_rdma_conn *conn) +static void priskv_rdma_close_conn(priskv_transport_conn *conn) { - priskv_rdma_req *rdma_req, *tmp; + priskv_transport_req *rdma_req, *tmp; - if (conn->established) { + if (conn->state == PRISKV_TRANSPORT_CONN_STATE_ESTABLISHED) { PRISKV_RDMA_DEF_ADDR(conn->cm_id) priskv_log_notice("RDMA: <%s - %s> close. Requests GET %ld, SET %ld, TEST %ld, DELETE %ld, " - "Responses %ld\n", - local_addr, peer_addr, conn->stats[PRISKV_COMMAND_GET], - conn->stats[PRISKV_COMMAND_SET], conn->stats[PRISKV_COMMAND_TEST], - conn->stats[PRISKV_COMMAND_DELETE], conn->resps); + "Responses %ld\n", + local_addr, peer_addr, conn->stats[PRISKV_COMMAND_GET], + conn->stats[PRISKV_COMMAND_SET], conn->stats[PRISKV_COMMAND_TEST], + conn->stats[PRISKV_COMMAND_DELETE], conn->resps); } - conn->established = false; + conn->state = PRISKV_TRANSPORT_CONN_STATE_CLOSED; priskv_rdma_req_complete(conn); list_for_each_safe (&conn->inflight_list, rdma_req, tmp, entry) { list_del(&rdma_req->entry); - priskv_request_free(rdma_req->req, conn); + priskv_rdma_request_free(rdma_req->req, conn); rdma_req->status = PRISKV_STATUS_DISCONNECTED; rdma_req->cb(rdma_req); } @@ -416,12 +277,12 @@ static void priskv_rdma_close_conn(priskv_rdma_conn *conn) } /* return negative number on failure, return received buffer size on success */ -static int priskv_rdma_recv_resp(priskv_rdma_conn *conn, priskv_response *resp) +static int priskv_rdma_recv_resp(priskv_transport_conn *conn, priskv_response *resp) { struct ibv_sge sge; struct ibv_recv_wr recv_wr, *bad_wr; uint16_t resp_buf_size = sizeof(priskv_response); - priskv_rdma_mem *rmem = &conn->rmem[PRISKV_RDMA_MEM_RESP]; + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_RESP]; int ret; sge.addr = (uint64_t)resp; @@ -445,12 +306,12 @@ static int priskv_rdma_recv_resp(priskv_rdma_conn *conn, priskv_response *resp) static void priskv_rdma_cq_process(int fd, void *opaque, uint32_t ev) { - priskv_rdma_conn *conn = opaque; + priskv_transport_conn *conn = opaque; priskv_rdma_handle_cq(conn); } -static int priskv_rdma_new_qp(priskv_rdma_conn *conn) +static int priskv_rdma_new_qp(priskv_transport_conn *conn) { struct ibv_qp_init_attr init_attr = {0}; @@ -497,22 +358,20 @@ static int priskv_rdma_new_qp(priskv_rdma_conn *conn) return 0; } -static int priskv_rdma_connect(priskv_rdma_conn *conn) +static int priskv_rdma_connect(priskv_transport_conn *conn) { struct rdma_cm_id *cm_id = conn->cm_id; struct rdma_conn_param conn_param = {0}; priskv_connect_param *param = &conn->param; - priskv_rdma_cm_req cm_req = {0}; + priskv_cm_cap cm_req = {0}; int ret; - assert(!priskv_build_check()); - ret = priskv_rdma_new_qp(conn); if (ret) { return ret; } - cm_req.version = htobe16(PRISKV_RDMA_CM_VERSION); + cm_req.version = htobe16(PRISKV_CM_VERSION); cm_req.max_sgl = htobe16(param->max_sgl); cm_req.max_key_length = htobe16(param->max_key_length); cm_req.max_inflight_command = htobe16(param->max_inflight_command); @@ -533,7 +392,7 @@ static int priskv_rdma_connect(priskv_rdma_conn *conn) return 0; } -static int priskv_rdma_establish_qp(priskv_rdma_conn *conn) +static int priskv_rdma_establish_qp(priskv_transport_conn *conn) { struct ibv_qp_attr qp_attr; int qp_attr_mask, ret; @@ -586,8 +445,8 @@ static int priskv_rdma_establish_qp(priskv_rdma_conn *conn) return 0; } -static int priskv_rdma_modify_max_inflight_command(priskv_rdma_conn *conn, - uint16_t max_inflight_command) +static int priskv_rdma_modify_max_inflight_command(priskv_transport_conn *conn, + uint16_t max_inflight_command) { /* auto detect max_inflight_command from server */ if (max_inflight_command == PRISKV_RDMA_DEFAULT_INFLIGHT_COMMAND) { @@ -605,7 +464,7 @@ static int priskv_rdma_modify_max_inflight_command(priskv_rdma_conn *conn, conn->param.max_inflight_command = priskv_min_u16(max_inflight_command, PRISKV_RDMA_DEFAULT_INFLIGHT_COMMAND); priskv_log_warn("RDMA: ignore modify max_inflight_command, use %d\n", - conn->param.max_inflight_command); + conn->param.max_inflight_command); return 0; /* not fatal error */ } @@ -635,28 +494,29 @@ static int priskv_rdma_modify_max_inflight_command(priskv_rdma_conn *conn, return 0; } -static int priskv_rdma_responsed(struct rdma_cm_event *ev, priskv_rdma_conn *conn) +static int priskv_rdma_responsed(struct rdma_cm_event *ev, priskv_transport_conn *conn) { struct rdma_conn_param *rep_param = &ev->param.conn; - unsigned char exp_len = sizeof(priskv_rdma_cm_rep); + unsigned char exp_len = sizeof(priskv_cm_cap); int ret = -EPROTO; if (rep_param->private_data_len < exp_len) { priskv_log_error("RDMA: unexpected CM REQ length %d, expetected %d\n", - rep_param->private_data_len, exp_len); + rep_param->private_data_len, exp_len); return -EPROTO; } - priskv_rdma_cm_rep *rep = (priskv_rdma_cm_rep *)rep_param->private_data; + priskv_cm_cap *rep = (priskv_cm_cap *)rep_param->private_data; uint16_t version = be16toh(rep->version); conn->param.max_sgl = be16toh(rep->max_sgl); conn->param.max_key_length = be16toh(rep->max_key_length); uint16_t max_inflight_command = be16toh(rep->max_inflight_command); conn->capacity = be64toh(rep->capacity); - priskv_log_info("RDMA: response version %d, max_sgl %d, max_key_length %d, max_inflight_command " - "%d, capacity %ld from server\n", - version, conn->param.max_sgl, conn->param.max_key_length, max_inflight_command, - conn->capacity); + priskv_log_info( + "RDMA: response version %d, max_sgl %d, max_key_length %d, max_inflight_command " + "%d, capacity %ld from server\n", + version, conn->param.max_sgl, conn->param.max_key_length, max_inflight_command, + conn->capacity); ret = priskv_rdma_modify_max_inflight_command(conn, max_inflight_command); if (ret) { @@ -664,9 +524,9 @@ static int priskv_rdma_responsed(struct rdma_cm_event *ev, priskv_rdma_conn *con } priskv_log_info("RDMA: update connection parameters, max_sgl %d, max_key_length %d, " - "max_inflight_command %d\n", - conn->param.max_sgl, conn->param.max_key_length, - conn->param.max_inflight_command); + "max_inflight_command %d\n", + conn->param.max_sgl, conn->param.max_key_length, + conn->param.max_inflight_command); ret = priskv_rdma_establish_qp(conn); if (ret) { @@ -678,7 +538,7 @@ static int priskv_rdma_responsed(struct rdma_cm_event *ev, priskv_rdma_conn *con return ret; } - priskv_rdma_mem *rmem = &conn->rmem[PRISKV_RDMA_MEM_RESP]; + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_RESP]; priskv_response *resp = (priskv_response *)rmem->buf; for (int i = 0; i < conn->param.max_inflight_command; i++) { ret = priskv_rdma_recv_resp(conn, resp + i); @@ -690,28 +550,28 @@ static int priskv_rdma_responsed(struct rdma_cm_event *ev, priskv_rdma_conn *con return 0; } -static int priskv_rdma_rejected(struct rdma_cm_event *ev, priskv_rdma_conn *conn) +static int priskv_rdma_rejected(struct rdma_cm_event *ev, priskv_transport_conn *conn) { struct rdma_conn_param *rep_param = &ev->param.conn; - unsigned char exp_len = sizeof(priskv_rdma_cm_rej); + unsigned char exp_len = sizeof(priskv_cm_rej); if (rep_param->private_data_len < exp_len) { priskv_log_error("RDMA: unexpected REJECT REQ length %d, expetected %d\n", - rep_param->private_data_len, exp_len); + rep_param->private_data_len, exp_len); return -EPROTO; } - priskv_rdma_cm_rej *rej = (priskv_rdma_cm_rej *)rep_param->private_data; + priskv_cm_rej *rej = (priskv_cm_rej *)rep_param->private_data; uint16_t version = be16toh(rej->version); uint16_t status = be16toh(rej->status); uint64_t value = be64toh(rej->value); priskv_log_error("RDMA: reject version %d, status: %s(%d), supported value %ld from server\n", - version, priskv_rdma_cm_status_str(status), status, value); + version, priskv_cm_status_str(status), status, value); return -ECONNREFUSED; } -static int priskv_rdma_handle_cm_event(priskv_rdma_conn *conn) +static int priskv_rdma_handle_cm_event(priskv_transport_conn *conn) { struct rdma_event_channel *cm_channel = conn->cm_id->channel; struct rdma_cm_event *ev; @@ -734,7 +594,7 @@ static int priskv_rdma_handle_cm_event(priskv_rdma_conn *conn) case RDMA_CM_EVENT_ADDR_RESOLVED: ret = rdma_resolve_route(ev->id, 1000); if (ret) { - priskv_log_error("RMDA: rdma_resolve_route failed: %m"); + priskv_log_error("RMDA: rdma_resolve_route failed: %m\n"); } break; @@ -744,7 +604,7 @@ static int priskv_rdma_handle_cm_event(priskv_rdma_conn *conn) case RDMA_CM_EVENT_CONNECT_RESPONSE: ret = priskv_rdma_responsed(ev, conn); - conn->established = true; + conn->state = PRISKV_TRANSPORT_CONN_STATE_ESTABLISHED; break; case RDMA_CM_EVENT_REJECTED: @@ -752,7 +612,7 @@ static int priskv_rdma_handle_cm_event(priskv_rdma_conn *conn) break; case RDMA_CM_EVENT_DISCONNECTED: - conn->established = false; + conn->state = PRISKV_TRANSPORT_CONN_STATE_CLOSED; break; default: @@ -769,7 +629,7 @@ static int priskv_rdma_handle_cm_event(priskv_rdma_conn *conn) goto again; } -static int priskv_rdma_wait_established(priskv_rdma_conn *conn) +static int priskv_rdma_wait_established(priskv_transport_conn *conn) { struct epoll_event event; struct timeval start, end; @@ -777,7 +637,7 @@ static int priskv_rdma_wait_established(priskv_rdma_conn *conn) gettimeofday(&start, NULL); - while (!conn->established) { + while (conn->state == PRISKV_TRANSPORT_CONN_STATE_INIT) { ret = epoll_wait(conn->epollfd, &event, 1, 1000); if (ret < 0) { if (errno == EINTR) { @@ -799,36 +659,28 @@ static int priskv_rdma_wait_established(priskv_rdma_conn *conn) gettimeofday(&end, NULL); PRISKV_RDMA_DEF_ADDR(conn->cm_id) priskv_log_debug("RDMA: <%s - %s> wait established delay %d us\n", local_addr, peer_addr, - priskv_time_elapsed_us(&start, &end)); + priskv_time_elapsed_us(&start, &end)); return 0; } static void priskv_rdma_cm_process(int fd, void *opaque, uint32_t ev) { - priskv_rdma_conn *conn = opaque; + priskv_transport_conn *conn = opaque; priskv_rdma_handle_cm_event(conn); priskv_rdma_handle_cq(conn); } -static void priskv_conn_process(int fd, void *opaque, uint32_t ev) -{ - priskv_rdma_conn *conn = opaque; - - assert(conn->epollfd == fd); - - priskv_events_process(conn->epollfd, -1); -} - -static priskv_rdma_conn *priskv_conn_connect(const char *raddr, int rport, const char *laddr, int lport) +static priskv_transport_conn *priskv_rdma_conn_connect(const char *raddr, int rport, + const char *laddr, int lport) { struct rdma_addrinfo hints = {0}, *addrinfo = NULL; - priskv_rdma_conn *conn = NULL; + priskv_transport_conn *conn = NULL; char _port[6]; /* strlen("65535") */ - conn = calloc(sizeof(struct priskv_rdma_conn), 1); + conn = calloc(sizeof(struct priskv_transport_conn), 1); if (!conn) { priskv_log_error("RDMA: failed to allocate memory for RDMA connection\n"); return NULL; @@ -845,13 +697,14 @@ static priskv_rdma_conn *priskv_conn_connect(const char *raddr, int rport, const conn->keys_mems.count = 1; conn->keys_mems.mrs = calloc(sizeof(struct ibv_mr *), 1); + conn->state = PRISKV_TRANSPORT_CONN_STATE_INIT; conn->epollfd = epoll_create1(0); if (conn->epollfd < 0) { priskv_log_error("RDMA: failed to create epoll fd\n"); return NULL; } - priskv_set_fd_handler(conn->epollfd, priskv_conn_process, NULL, conn); + priskv_set_fd_handler(conn->epollfd, priskv_transport_conn_process, NULL, conn); priskv_set_nonblock(conn->epollfd); conn->cm_channel = rdma_create_event_channel(); @@ -926,7 +779,7 @@ static priskv_rdma_conn *priskv_conn_connect(const char *raddr, int rport, const return conn; } -int priskv_conn_close(void *conn) +int priskv_rdma_conn_close(void *conn) { if (!conn) { return 0; @@ -938,8 +791,8 @@ int priskv_conn_close(void *conn) return 0; } -static struct ibv_mr *priskv_conn_reg_memory(priskv_rdma_conn *conn, uint64_t offset, size_t length, - uint64_t iova, int fd) +static struct ibv_mr *priskv_rdma_conn_reg_memory(priskv_transport_conn *conn, uint64_t offset, + size_t length, uint64_t iova, int fd) { struct ibv_pd *pd = conn->cm_id->pd; unsigned int access = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; @@ -956,21 +809,22 @@ static struct ibv_mr *priskv_conn_reg_memory(priskv_rdma_conn *conn, uint64_t of } if (!mr) { - priskv_log_error("RDMA: failed to reg mr 0x%lx:%ld %m. If you are using GPU memory, check if " - "the nvidia_peermem module is installed\n", - offset, length); + priskv_log_error( + "RDMA: failed to reg mr 0x%lx:%ld %m. If you are using GPU memory, check if " + "the nvidia_peermem module is installed\n", + offset, length); } return mr; } -static void priskv_conn_dereg_memory(struct ibv_mr *mr) +static void priskv_rdma_conn_dereg_memory(struct ibv_mr *mr) { ibv_dereg_mr(mr); } -static int priskv_mq_init(priskv_client *client, const char *raddr, int rport, const char *laddr, - int lport, int nqueue) +static int priskv_rdma_mq_init(priskv_client *client, const char *raddr, int rport, + const char *laddr, int lport, int nqueue) { client->wq = priskv_workqueue_create(client->epollfd); if (!client->wq) { @@ -984,7 +838,7 @@ static int priskv_mq_init(priskv_client *client, const char *raddr, int rport, c return -1; } - client->conns = calloc(nqueue, sizeof(priskv_rdma_conn *)); + client->conns = calloc(nqueue, sizeof(priskv_transport_conn *)); if (!client->conns) { priskv_log_error("RDMA: failed to allocate memory for connections\n"); return -1; @@ -992,7 +846,7 @@ static int priskv_mq_init(priskv_client *client, const char *raddr, int rport, c client->nqueue = nqueue; for (uint8_t i = 0; i < nqueue; i++) { - client->conns[i] = priskv_conn_connect(raddr, rport, laddr, lport); + client->conns[i] = priskv_rdma_conn_connect(raddr, rport, laddr, lport); if (!client->conns[i]) { priskv_log_error("RDMA: failed to connect to %s:%d\n", raddr, rport); return -1; @@ -1009,12 +863,12 @@ static int priskv_mq_init(priskv_client *client, const char *raddr, int rport, c return 0; } -static void priskv_mq_deinit(priskv_client *client) +static void priskv_rdma_mq_deinit(priskv_client *client) { if (client->conns) { for (int i = 0; i < client->nqueue; i++) { priskv_thread_call_function(priskv_threadpool_get_iothread(client->pool, i), - priskv_conn_close, client->conns[i]); + priskv_rdma_conn_close, client->conns[i]); } } @@ -1023,13 +877,13 @@ static void priskv_mq_deinit(priskv_client *client) free(client->conns); } -static priskv_rdma_conn *priskv_mq_select_conn(priskv_client *client) +static priskv_transport_conn *priskv_rdma_mq_select_conn(priskv_client *client) { return client->conns[client->cur_conn++ % client->nqueue]; } -static priskv_memory *priskv_mq_reg_memory(priskv_client *client, uint64_t offset, size_t length, - uint64_t iova, int fd) +static priskv_memory *priskv_rdma_mq_reg_memory(priskv_client *client, uint64_t offset, + size_t length, uint64_t iova, int fd) { priskv_memory *mem = malloc(sizeof(priskv_memory)); @@ -1038,57 +892,58 @@ static priskv_memory *priskv_mq_reg_memory(priskv_client *client, uint64_t offse mem->mrs = malloc(client->nqueue * sizeof(struct ibv_mr *)); for (int i = 0; i < mem->count; i++) { - mem->mrs[i] = priskv_conn_reg_memory(client->conns[i], offset, length, iova, fd); + mem->mrs[i] = priskv_rdma_conn_reg_memory(client->conns[i], offset, length, iova, fd); } return mem; } -static void priskv_mq_dereg_memory(priskv_memory *mem) +static void priskv_rdma_mq_dereg_memory(priskv_memory *mem) { for (int i = 0; i < mem->count; i++) { - priskv_conn_dereg_memory(mem->mrs[i]); + priskv_rdma_conn_dereg_memory(mem->mrs[i]); } free(mem->mrs); free(mem); } -static struct ibv_mr *priskv_mq_get_mr(priskv_memory *mem, int connid) +static struct ibv_mr *priskv_rdma_mq_get_mr(priskv_memory *mem, int connid) { return mem->mrs[connid]; } -static void priskv_mq_rdma_req_submit(priskv_rdma_req *rdma_req) +static void priskv_rdma_mq_req_submit(priskv_transport_req *rdma_req) { priskv_thread_submit_function(rdma_req->conn->thread, priskv_rdma_req_send, rdma_req); } -static void priskv_mq_rdma_req_cb(priskv_rdma_req *rdma_req) +static void priskv_rdma_mq_req_cb(priskv_transport_req *rdma_req) { - priskv_workqueue_submit(rdma_req->main_wq, _priskv_rdma_req_cb, rdma_req); + priskv_workqueue_submit(rdma_req->main_wq, priskv_rdma_req_cb_intl, rdma_req); } -static priskv_conn_operation priskv_mq_ops = { - .init = priskv_mq_init, - .deinit = priskv_mq_deinit, - .select_conn = priskv_mq_select_conn, - .reg_memory = priskv_mq_reg_memory, - .dereg_memory = priskv_mq_dereg_memory, - .get_mr = priskv_mq_get_mr, - .rdma_req_submit = priskv_mq_rdma_req_submit, - .rdma_req_cb = priskv_mq_rdma_req_cb, +static priskv_conn_operation priskv_rdma_mq_ops = { + .init = priskv_rdma_mq_init, + .deinit = priskv_rdma_mq_deinit, + .select_conn = priskv_rdma_mq_select_conn, + .reg_memory = priskv_rdma_mq_reg_memory, + .dereg_memory = priskv_rdma_mq_dereg_memory, + .get_mr = priskv_rdma_mq_get_mr, + .submit_req = priskv_rdma_mq_req_submit, + .req_cb = priskv_rdma_mq_req_cb, + .new_req = priskv_rdma_req_new, }; -static int priskv_sq_init(priskv_client *client, const char *raddr, int rport, const char *laddr, - int lport, int nqueue) +static int priskv_rdma_sq_init(priskv_client *client, const char *raddr, int rport, + const char *laddr, int lport, int nqueue) { - client->conns = calloc(1, sizeof(priskv_rdma_conn *)); + client->conns = calloc(1, sizeof(priskv_transport_conn *)); if (!client->conns) { priskv_log_error("RDMA: failed to allocate memory for connections\n"); return -1; } - client->conns[0] = priskv_conn_connect(raddr, rport, laddr, lport); + client->conns[0] = priskv_rdma_conn_connect(raddr, rport, laddr, lport); if (!client->conns[0]) { priskv_log_error("RDMA: failed to connect to %s:%d\n", raddr, rport); return -1; @@ -1099,19 +954,19 @@ static int priskv_sq_init(priskv_client *client, const char *raddr, int rport, c return 0; } -static void priskv_sq_deinit(priskv_client *client) +static void priskv_rdma_sq_deinit(priskv_client *client) { - priskv_conn_close(client->conns[0]); + priskv_rdma_conn_close(client->conns[0]); free(client->conns); } -static priskv_rdma_conn *priskv_sq_select_conn(priskv_client *client) +static priskv_transport_conn *priskv_rdma_sq_select_conn(priskv_client *client) { return client->conns[0]; } -static priskv_memory *priskv_sq_reg_memory(priskv_client *client, uint64_t offset, size_t length, - uint64_t iova, int fd) +static priskv_memory *priskv_rdma_sq_reg_memory(priskv_client *client, uint64_t offset, + size_t length, uint64_t iova, int fd) { priskv_memory *mem = malloc(sizeof(priskv_memory)); @@ -1119,122 +974,57 @@ static priskv_memory *priskv_sq_reg_memory(priskv_client *client, uint64_t offse mem->count = 1; mem->mrs = malloc(sizeof(struct ibv_mr *)); - mem->mrs[0] = priskv_conn_reg_memory(client->conns[0], offset, length, iova, fd); + mem->mrs[0] = priskv_rdma_conn_reg_memory(client->conns[0], offset, length, iova, fd); return mem; } -static void priskv_sq_dereg_memory(priskv_memory *mem) +static void priskv_rdma_sq_dereg_memory(priskv_memory *mem) { - priskv_conn_dereg_memory(mem->mrs[0]); + priskv_rdma_conn_dereg_memory(mem->mrs[0]); free(mem->mrs); free(mem); } -static struct ibv_mr *priskv_sq_get_mr(priskv_memory *mem, int connid) +static struct ibv_mr *priskv_rdma_sq_get_mr(priskv_memory *mem, int connid) { return mem->mrs[0]; } -static void priskv_sq_rdma_req_submit(priskv_rdma_req *rdma_req) +static void priskv_rdma_sq_req_submit(priskv_transport_req *rdma_req) { priskv_rdma_req_send(rdma_req); } -static void priskv_sq_rdma_req_cb(priskv_rdma_req *rdma_req) +static void priskv_rdma_sq_req_cb(priskv_transport_req *rdma_req) { - _priskv_rdma_req_cb(rdma_req); + priskv_rdma_req_cb_intl(rdma_req); } -static priskv_conn_operation priskv_sq_ops = { - .init = priskv_sq_init, - .deinit = priskv_sq_deinit, - .select_conn = priskv_sq_select_conn, - .reg_memory = priskv_sq_reg_memory, - .dereg_memory = priskv_sq_dereg_memory, - .get_mr = priskv_sq_get_mr, - .rdma_req_submit = priskv_sq_rdma_req_submit, - .rdma_req_cb = priskv_sq_rdma_req_cb, +static priskv_conn_operation priskv_rdma_sq_ops = { + .init = priskv_rdma_sq_init, + .deinit = priskv_rdma_sq_deinit, + .select_conn = priskv_rdma_sq_select_conn, + .reg_memory = priskv_rdma_sq_reg_memory, + .dereg_memory = priskv_rdma_sq_dereg_memory, + .get_mr = priskv_rdma_sq_get_mr, + .submit_req = priskv_rdma_sq_req_submit, + .req_cb = priskv_rdma_sq_req_cb, + .new_req = priskv_rdma_req_new, }; -priskv_client *priskv_connect(const char *raddr, int rport, const char *laddr, int lport, int nqueue) -{ - priskv_client *client = NULL; - - if (lport && nqueue > 1) { - priskv_log_error("RDMA: unable to bind local port when queues > 1\n"); - return NULL; - } - - client = calloc(sizeof(priskv_client), 1); - if (!client) { - priskv_log_error("RDMA: failed to allocate memory for client\n"); - return NULL; - } - - client->epollfd = epoll_create1(0); - if (client->epollfd < 0) { - priskv_log_error("RDMA: failed to create epoll fd\n"); - goto err; - } - priskv_set_nonblock(client->epollfd); - - if (nqueue > 0) { - client->ops = &priskv_mq_ops; - } else { - client->ops = &priskv_sq_ops; - } - - if (client->ops->init(client, raddr, rport, laddr, lport, nqueue)) { - priskv_log_error("RDMA: failed to initialize client\n"); - goto err; - } - - return client; -err: - client->ops->deinit(client); - close(client->epollfd); - free(client); - - return NULL; -} - -void priskv_close(priskv_client *client) -{ - - client->ops->deinit(client); - close(client->epollfd); - client->epollfd = -1; - free(client); -} - -int priskv_get_fd(priskv_client *client) +static inline void priskv_rdma_fillup_sgl(priskv_transport_req *rdma_req, + priskv_keyed_sgl *keyed_sgl) { - return client->epollfd; -} - -int priskv_process(priskv_client *client, uint32_t event) -{ - if (client->epollfd < 0) { - return -1; - } - - priskv_events_process(client->epollfd, -1); - - return 0; -} - -static inline void priskv_fillup_sql(priskv_rdma_req *rdma_req, priskv_keyed_sgl *keyed_sgl) -{ - priskv_rdma_conn *conn = rdma_req->conn; + priskv_transport_conn *conn = rdma_req->conn; for (uint16_t i = 0; i < rdma_req->nsgl; i++) { priskv_sgl_private *_sgl = &rdma_req->sgl[i]; struct ibv_mr *mr; if (!_sgl->sgl.mem) { - mr = _sgl->mr = - priskv_conn_reg_memory(conn, _sgl->sgl.iova, _sgl->sgl.length, _sgl->sgl.iova, -1); + mr = _sgl->mr = priskv_rdma_conn_reg_memory(conn, _sgl->sgl.iova, _sgl->sgl.length, + _sgl->sgl.iova, -1); } else { if (rdma_req->cmd != PRISKV_COMMAND_KEYS) { mr = rdma_req->ops->get_mr(_sgl->sgl.mem, conn->id); @@ -1250,13 +1040,13 @@ static inline void priskv_fillup_sql(priskv_rdma_req *rdma_req, priskv_keyed_sgl _keyed_sgl->key = htobe32(mr->rkey); priskv_log_debug("RDMA: addr 0x%lx@%x rkey 0x%x\n", _sgl->sgl.iova, _sgl->sgl.length, - mr->rkey); + mr->rkey); } } -static int _priskv_rdma_req_cb(void *arg) +static int priskv_rdma_req_cb_intl(void *arg) { - priskv_rdma_req *rdma_req = arg; + priskv_transport_req *rdma_req = arg; if (rdma_req->usercb) { rdma_req->usercb(rdma_req->request_id, rdma_req->status, rdma_req->result); @@ -1267,15 +1057,15 @@ static int _priskv_rdma_req_cb(void *arg) return 0; } -static void priskv_rdma_req_cb(priskv_rdma_req *rdma_req) +static void priskv_rdma_req_cb(priskv_transport_req *rdma_req) { - rdma_req->ops->rdma_req_cb(rdma_req); + rdma_req->ops->req_cb(rdma_req); } -static void priskv_rdma_keys_req_cb(priskv_rdma_req *rdma_req) +static void priskv_rdma_keys_req_cb(priskv_transport_req *rdma_req) { - priskv_rdma_conn *conn = rdma_req->conn; - priskv_rdma_mem *rmem = &conn->rmem[PRISKV_RDMA_MEM_KEYS]; + priskv_transport_conn *conn = rdma_req->conn; + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_KEYS]; uint32_t valuelen = rdma_req->length; if (rdma_req->status == PRISKV_STATUS_OK) { @@ -1322,7 +1112,7 @@ static void priskv_rdma_keys_req_cb(priskv_rdma_req *rdma_req) priskv_rdma_mem_free(conn, rmem); if (priskv_rdma_mem_new(conn, rmem, "Keys", valuelen + valuelen / 8, true)) { priskv_log_error("RDMA: failed to resize KEYS buffer to valuelen %d\n", valuelen); - rdma_req->status = PRISKV_STATUS_RDMA_ERROR; + rdma_req->status = PRISKV_STATUS_TRANSPORT_ERROR; goto exit; } @@ -1345,26 +1135,12 @@ static void priskv_rdma_keys_req_cb(priskv_rdma_req *rdma_req) priskv_rdma_req_cb(rdma_req); } -void priskv_keyset_free(priskv_keyset *keyset) -{ - if (!keyset) { - return; - } - - for (uint32_t i = 0; i < keyset->nkey; i++) { - free(keyset->keys[i].key); - } - free(keyset->keys); - free(keyset); -} - -static inline priskv_rdma_req *priskv_rdma_req_new(priskv_client *client, priskv_rdma_conn *conn, - uint64_t request_id, const char *key, - uint16_t keylen, priskv_sgl *sgl, uint16_t nsgl, - uint64_t timeout, priskv_req_command cmd, - priskv_generic_cb usercb) +static inline priskv_transport_req * +priskv_rdma_req_new(priskv_client *client, priskv_transport_conn *conn, uint64_t request_id, + const char *key, uint16_t keylen, priskv_sgl *sgl, uint16_t nsgl, + uint64_t timeout, priskv_req_command cmd, priskv_generic_cb usercb) { - priskv_rdma_req *rdma_req = calloc(1, sizeof(priskv_rdma_req)); + priskv_transport_req *rdma_req = calloc(1, sizeof(priskv_transport_req)); if (!rdma_req) { return NULL; } @@ -1389,7 +1165,7 @@ static inline priskv_rdma_req *priskv_rdma_req_new(priskv_client *client, priskv memcpy(&rdma_req->sgl[i], &sgl[i], sizeof(priskv_sgl)); } } else if (cmd == PRISKV_COMMAND_KEYS) { - priskv_rdma_mem *rmem = &conn->rmem[PRISKV_RDMA_MEM_KEYS]; + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_KEYS]; conn->keys_mems.count = 1; conn->keys_mems.mrs[0] = rmem->mr; @@ -1405,12 +1181,12 @@ static inline priskv_rdma_req *priskv_rdma_req_new(priskv_client *client, priskv return rdma_req; } -static inline void priskv_rdma_req_free(priskv_rdma_req *rdma_req) +static inline void priskv_rdma_req_free(priskv_transport_req *rdma_req) { for (int i = 0; i < rdma_req->nsgl; i++) { priskv_sgl_private *_sgl = &rdma_req->sgl[i]; if (_sgl->mr) { - priskv_conn_dereg_memory(_sgl->mr); + priskv_rdma_conn_dereg_memory(_sgl->mr); _sgl->mr = NULL; } } @@ -1420,7 +1196,7 @@ static inline void priskv_rdma_req_free(priskv_rdma_req *rdma_req) free(rdma_req); } -static inline void priskv_rdma_req_reset(priskv_rdma_req *rdma_req) +static inline void priskv_rdma_req_reset(priskv_transport_req *rdma_req) { rdma_req->flags = 0; rdma_req->req = NULL; @@ -1431,15 +1207,15 @@ static inline void priskv_rdma_req_reset(priskv_rdma_req *rdma_req) static int priskv_rdma_req_send(void *arg) { - priskv_rdma_req *rdma_req = arg; - priskv_rdma_conn *conn = rdma_req->conn; + priskv_transport_req *rdma_req = arg; + priskv_transport_conn *conn = rdma_req->conn; struct ibv_send_wr wr = {0}, *bad_wr; struct ibv_sge rsge; priskv_request *req; uint16_t req_idx; - priskv_rdma_mem *rmem = &conn->rmem[PRISKV_RDMA_MEM_REQ]; + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_REQ]; - if (!conn->established) { + if (conn->state != PRISKV_TRANSPORT_CONN_STATE_ESTABLISHED) { rdma_req->status = PRISKV_STATUS_DISCONNECTED; rdma_req->cb(rdma_req); return -1; @@ -1484,13 +1260,13 @@ static int priskv_rdma_req_send(void *arg) gettimeofday(&client_metadata_send_time, NULL); req->runtime.client_metadata_send_time = client_metadata_send_time; - priskv_fillup_sql(rdma_req, req->sgls); - memcpy(priskv_request_key(req, rdma_req->nsgl), rdma_req->key, rdma_req->keylen); + priskv_rdma_fillup_sgl(rdma_req, req->sgls); + memcpy(priskv_rdma_request_key(req), rdma_req->key, rdma_req->keylen); rdma_req->req = req; rsge.addr = (uint64_t)req; - rsge.length = priskv_request_size(rdma_req->nsgl, rdma_req->keylen); + rsge.length = priskv_rdma_request_size(req); rsge.lkey = rmem->mr->lkey; wr.wr_id = (uint64_t)req; @@ -1503,15 +1279,15 @@ static int priskv_rdma_req_send(void *arg) if (ret) { PRISKV_RDMA_DEF_ADDR(conn->cm_id) priskv_log_notice("RDMA: <%s - %s> close. Requests GET %ld, SET %ld, TEST %ld, DELETE %ld, " - "Responses %ld\n", - local_addr, peer_addr, conn->stats[PRISKV_COMMAND_GET], - conn->stats[PRISKV_COMMAND_SET], conn->stats[PRISKV_COMMAND_TEST], - conn->stats[PRISKV_COMMAND_DELETE], conn->resps); + "Responses %ld\n", + local_addr, peer_addr, conn->stats[PRISKV_COMMAND_GET], + conn->stats[PRISKV_COMMAND_SET], conn->stats[PRISKV_COMMAND_TEST], + conn->stats[PRISKV_COMMAND_DELETE], conn->resps); priskv_log_error( "RDMA: ibv_post_send addr %p, length %d. wc_recv %ld, wc_send %ld, failed: %d\n", req, rsge.length, conn->wc_recv, conn->wc_send, ret); - rdma_req->status = PRISKV_STATUS_RDMA_ERROR; + rdma_req->status = PRISKV_STATUS_TRANSPORT_ERROR; rdma_req->cb(rdma_req); return -1; } @@ -1521,14 +1297,14 @@ static int priskv_rdma_req_send(void *arg) return 0; } -static inline void priskv_rdma_req_submit(priskv_rdma_req *rdma_req) +static inline void priskv_rdma_req_submit(priskv_transport_req *rdma_req) { - rdma_req->ops->rdma_req_submit(rdma_req); + rdma_req->ops->submit_req(rdma_req); } -static inline void priskv_rdma_req_delay_send(priskv_rdma_conn *conn) +static inline void priskv_rdma_req_delay_send(priskv_transport_conn *conn) { - priskv_rdma_req *rdma_req, *tmp; + priskv_transport_req *rdma_req, *tmp; list_for_each_safe (&conn->inflight_list, rdma_req, tmp, entry) { list_del(&rdma_req->entry); @@ -1539,31 +1315,31 @@ static inline void priskv_rdma_req_delay_send(priskv_rdma_conn *conn) } } -static inline void priskv_rdma_req_done(priskv_rdma_conn *conn, priskv_rdma_req *rdma_req) +static inline void priskv_rdma_req_done(priskv_transport_conn *conn, priskv_transport_req *rdma_req) { - if ((rdma_req->flags & PRISKV_RDMA_REQ_FLAG_DONE) == PRISKV_RDMA_REQ_FLAG_DONE) { + if ((rdma_req->flags & PRISKV_TRANSPORT_REQ_FLAG_DONE) == PRISKV_TRANSPORT_REQ_FLAG_DONE) { list_add_tail(&conn->complete_list, &rdma_req->entry); } } -static inline void priskv_rdma_req_complete(priskv_rdma_conn *conn) +static inline void priskv_rdma_req_complete(priskv_transport_conn *conn) { - priskv_rdma_req *rdma_req, *tmp; + priskv_transport_req *rdma_req, *tmp; list_for_each_safe (&conn->complete_list, rdma_req, tmp, entry) { list_del(&rdma_req->entry); - priskv_request_free(rdma_req->req, conn); + priskv_rdma_request_free(rdma_req->req, conn); rdma_req->cb(rdma_req); } } -static int priskv_rdma_handle_recv(priskv_rdma_conn *conn, priskv_response *resp, uint32_t len) +static int priskv_rdma_handle_recv(priskv_transport_conn *conn, priskv_response *resp, uint32_t len) { uint64_t request_id = be64toh(resp->request_id); uint16_t status = be16toh(resp->status); uint32_t length = be32toh(resp->length); - priskv_rdma_req *rdma_req; + priskv_transport_req *rdma_req; if (len != sizeof(priskv_response)) { priskv_log_warn("RDMA: recv %d, expected %ld\n", len, sizeof(priskv_response)); @@ -1571,8 +1347,8 @@ static int priskv_rdma_handle_recv(priskv_rdma_conn *conn, priskv_response *resp } priskv_log_debug("Response request_id 0x%lx, status(%d) %s, length %d\n", request_id, status, - priskv_resp_status_str(status), length); - rdma_req = (priskv_rdma_req *)request_id; + priskv_resp_status_str(status), length); + rdma_req = (priskv_transport_req *)request_id; rdma_req->status = status; rdma_req->length = length; @@ -1580,7 +1356,7 @@ static int priskv_rdma_handle_recv(priskv_rdma_conn *conn, priskv_response *resp rdma_req->result = &rdma_req->length; } - rdma_req->flags |= PRISKV_RDMA_REQ_FLAG_RECV; + rdma_req->flags |= PRISKV_TRANSPORT_REQ_FLAG_RECV; priskv_rdma_req_done(conn, rdma_req); conn->resps++; @@ -1589,15 +1365,15 @@ static int priskv_rdma_handle_recv(priskv_rdma_conn *conn, priskv_response *resp return 0; } -static void priskv_rdma_handle_send(priskv_rdma_conn *conn, priskv_request *req) +static void priskv_rdma_handle_send(priskv_transport_conn *conn, priskv_request *req) { - priskv_rdma_req *rdma_req = (priskv_rdma_req *)be64toh(req->request_id); + priskv_transport_req *rdma_req = (priskv_transport_req *)be64toh(req->request_id); - rdma_req->flags |= PRISKV_RDMA_REQ_FLAG_SEND; + rdma_req->flags |= PRISKV_TRANSPORT_REQ_FLAG_SEND; priskv_rdma_req_done(conn, rdma_req); } -static int priskv_rdma_handle_cq(priskv_rdma_conn *conn) +static int priskv_rdma_handle_cq(priskv_transport_conn *conn) { struct ibv_cq *ev_cq = NULL; @@ -1630,11 +1406,11 @@ static int priskv_rdma_handle_cq(priskv_rdma_conn *conn) } priskv_log_debug("RDMA: CQ handle status: %s[0x%x], wr_id: %p, opcode: 0x%x, byte_len: %d\n", - ibv_wc_status_str(wc.status), wc.status, (void *)wc.wr_id, wc.opcode, - wc.byte_len); + ibv_wc_status_str(wc.status), wc.status, (void *)wc.wr_id, wc.opcode, + wc.byte_len); if (wc.status != IBV_WC_SUCCESS) { priskv_log_error("CQ handle error status: %s[0x%x], opcode : 0x%x\n", - ibv_wc_status_str(wc.status), wc.status, wc.opcode); + ibv_wc_status_str(wc.status), wc.status, wc.opcode); return -EIO; } @@ -1654,128 +1430,26 @@ static int priskv_rdma_handle_cq(priskv_rdma_conn *conn) break; default: - priskv_log_error("unexpected opcode 0x%x", wc.opcode); + priskv_log_error("unexpected opcode 0x%x\n", wc.opcode); return -EIO; } goto poll_cq; } -priskv_memory *priskv_reg_memory(priskv_client *client, uint64_t offset, size_t length, uint64_t iova, - int fd) -{ - return client->ops->reg_memory(client, offset, length, iova, fd); -} - -void priskv_dereg_memory(priskv_memory *mem) -{ - priskv_client *client = mem->client; - - client->ops->dereg_memory(mem); -} - -static inline priskv_rdma_conn *priskv_select_conn(priskv_client *client) +priskv_conn_operation *priskv_rdma_get_sq_ops(void) { - return client->ops->select_conn(client); + return &priskv_rdma_sq_ops; } -static void priskv_send_command(priskv_client *client, uint64_t request_id, const char *key, - priskv_sgl *sgl, uint16_t nsgl, uint64_t timeout, priskv_req_command cmd, - priskv_generic_cb cb) +priskv_conn_operation *priskv_rdma_get_mq_ops(void) { - priskv_rdma_conn *conn = priskv_select_conn(client); - priskv_connect_param *param = &conn->param; - priskv_rdma_req *rdma_req; - uint16_t keylen = strlen(key); - - assert(cmd < PRISKV_COMMAND_MAX); - if (!key || !keylen) { - cb(request_id, PRISKV_STATUS_KEY_EMPTY, NULL); - } - - if (keylen > param->max_key_length) { - cb(request_id, PRISKV_STATUS_KEY_TOO_BIG, NULL); - } - - if (nsgl > param->max_sgl) { - priskv_log_error("RDMA: nsgl %d > max_sgl %d\n", nsgl, param->max_sgl); - cb(request_id, PRISKV_STATUS_INVALID_SGL, NULL); - } - - rdma_req = - priskv_rdma_req_new(client, conn, request_id, key, keylen, sgl, nsgl, timeout, cmd, cb); - if (!rdma_req) { - cb(request_id, PRISKV_STATUS_NO_MEM, NULL); - return; - } - - priskv_rdma_req_submit(rdma_req); + return &priskv_rdma_mq_ops; } -int priskv_get_async(priskv_client *client, const char *key, priskv_sgl *sgl, uint16_t nsgl, - uint64_t request_id, priskv_generic_cb cb) -{ - if (!sgl || !nsgl) { - cb(request_id, PRISKV_STATUS_VALUE_EMPTY, 0); - return 0; - } - - priskv_send_command(client, request_id, key, sgl, nsgl, 0, PRISKV_COMMAND_GET, cb); - return 0; -} - -int priskv_set_async(priskv_client *client, const char *key, priskv_sgl *sgl, uint16_t nsgl, - uint64_t timeout, uint64_t request_id, priskv_generic_cb cb) -{ - if (!sgl || !nsgl) { - cb(request_id, PRISKV_STATUS_VALUE_EMPTY, 0); - return 0; - } - - priskv_send_command(client, request_id, key, sgl, nsgl, timeout, PRISKV_COMMAND_SET, cb); - return 0; -} - -int priskv_test_async(priskv_client *client, const char *key, uint64_t request_id, priskv_generic_cb cb) -{ - priskv_send_command(client, request_id, key, NULL, 0, 0, PRISKV_COMMAND_TEST, cb); - return 0; -} - -int priskv_delete_async(priskv_client *client, const char *key, uint64_t request_id, priskv_generic_cb cb) -{ - priskv_send_command(client, request_id, key, NULL, 0, 0, PRISKV_COMMAND_DELETE, cb); - return 0; -} - -int priskv_expire_async(priskv_client *client, const char *key, uint64_t timeout, uint64_t request_id, - priskv_generic_cb cb) -{ - priskv_send_command(client, request_id, key, NULL, 0, timeout, PRISKV_COMMAND_EXPIRE, cb); - return 0; -} - -int priskv_keys_async(priskv_client *client, const char *regex, uint64_t request_id, priskv_generic_cb cb) -{ - priskv_send_command(client, request_id, regex, NULL, 0, 0, PRISKV_COMMAND_KEYS, cb); - return 0; -} - -int priskv_nrkeys_async(priskv_client *client, const char *regex, uint64_t request_id, - priskv_generic_cb cb) -{ - priskv_send_command(client, request_id, regex, NULL, 0, 0, PRISKV_COMMAND_NRKEYS, cb); - return 0; -} - -int priskv_flush_async(priskv_client *client, const char *regex, uint64_t request_id, - priskv_generic_cb cb) -{ - priskv_send_command(client, request_id, regex, NULL, 0, 0, PRISKV_COMMAND_FLUSH, cb); - return 0; -} - -uint64_t priskv_capacity(priskv_client *client) -{ - return client->conns[0]->capacity; -} +priskv_transport_driver priskv_transport_driver_rdma = { + .name = "rdma", + .init = NULL, + .get_sq_ops = priskv_rdma_get_sq_ops, + .get_mq_ops = priskv_rdma_get_mq_ops, +}; diff --git a/client/transport/transport.c b/client/transport/transport.c new file mode 100644 index 0000000..c74ba2a --- /dev/null +++ b/client/transport/transport.c @@ -0,0 +1,315 @@ +// Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "transport.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "priskv-event.h" +#include "priskv-log.h" + +priskv_transport_driver *g_client_driver = NULL; + +extern priskv_transport_driver priskv_transport_driver_ucx; +extern priskv_transport_driver priskv_transport_driver_rdma; + +static int priskv_build_check(void) +{ + PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_OK != (int)PRISKV_RESP_STATUS_OK); + PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_INVALID_COMMAND != + (int)PRISKV_RESP_STATUS_INVALID_COMMAND); + PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_KEY_EMPTY != (int)PRISKV_RESP_STATUS_KEY_EMPTY); + PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_KEY_TOO_BIG != (int)PRISKV_RESP_STATUS_KEY_TOO_BIG); + PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_VALUE_EMPTY != (int)PRISKV_RESP_STATUS_VALUE_EMPTY); + PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_VALUE_TOO_BIG != (int)PRISKV_RESP_STATUS_VALUE_TOO_BIG); + PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_NO_SUCH_COMMAND != + (int)PRISKV_RESP_STATUS_NO_SUCH_COMMAND); + PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_NO_SUCH_KEY != (int)PRISKV_RESP_STATUS_NO_SUCH_KEY); + PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_INVALID_SGL != (int)PRISKV_RESP_STATUS_INVALID_SGL); + PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_INVALID_REGEX != (int)PRISKV_RESP_STATUS_INVALID_REGEX); + PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_KEY_UPDATING != (int)PRISKV_RESP_STATUS_KEY_UPDATING); + PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_CONNECT_ERROR != (int)PRISKV_RESP_STATUS_CONNECT_ERROR); + PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_SERVER_ERROR != (int)PRISKV_RESP_STATUS_SERVER_ERROR); + PRISKV_BUILD_BUG_ON((int)PRISKV_STATUS_NO_MEM != (int)PRISKV_RESP_STATUS_NO_MEM); + return 0; +} + +static void __attribute__((constructor)) priskv_client_transport_init(void) +{ + assert(!priskv_build_check()); + + const char *transport_env = getenv("PRISKV_TRANSPORT"); + priskv_transport_backend backend = PRISKV_TRANSPORT_BACKEND_RDMA; + if (transport_env) { + if (strcasecmp(transport_env, "UCX") == 0) { + backend = PRISKV_TRANSPORT_BACKEND_UCX; + } else if (strcasecmp(transport_env, "RDMA") == 0) { + backend = PRISKV_TRANSPORT_BACKEND_RDMA; + } else { + priskv_log_error("Unknown transport backend: %s\n", transport_env); + } + } + + priskv_transport_driver *driver = NULL; + switch (backend) { + case PRISKV_TRANSPORT_BACKEND_UCX: + driver = &priskv_transport_driver_ucx; + priskv_log_notice("Using UCX transport backend\n"); + break; + case PRISKV_TRANSPORT_BACKEND_RDMA: + driver = &priskv_transport_driver_rdma; + priskv_log_notice("Using RDMA transport backend\n"); + break; + default: + priskv_log_error("Unknown transport backend: %d\n", backend); + break; + } + + if (driver && driver->init) { + if (driver->init() != 0) { + priskv_log_error("Failed to initialize transport driver: %s\n", driver->name); + driver = NULL; + } + } + + if (driver) { + g_client_driver = driver; + return 0; + } + + return -1; +} + +void priskv_keyset_free(priskv_keyset *keyset) +{ + if (!keyset) { + return; + } + + for (uint32_t i = 0; i < keyset->nkey; i++) { + free(keyset->keys[i].key); + } + free(keyset->keys); + free(keyset); +} + +priskv_client *priskv_connect(const char *raddr, int rport, const char *laddr, int lport, + int nqueue) +{ + priskv_client *client = NULL; + + if (lport && nqueue > 1) { + priskv_log_error("Transport: unable to bind local port when queues > 1\n"); + return NULL; + } + + client = calloc(sizeof(priskv_client), 1); + if (!client) { + priskv_log_error("Transport: failed to allocate memory for client\n"); + return NULL; + } + + client->epollfd = epoll_create1(0); + if (client->epollfd < 0) { + priskv_log_error("Transport: failed to create epoll fd\n"); + goto err; + } + priskv_set_nonblock(client->epollfd); + + if (nqueue > 0) { + client->ops = g_client_driver->get_mq_ops(); + } else { + client->ops = g_client_driver->get_sq_ops(); + } + + if (client->ops->init(client, raddr, rport, laddr, lport, nqueue)) { + priskv_log_error("Transport: failed to initialize client\n"); + goto err; + } + + return client; +err: + client->ops->deinit(client); + close(client->epollfd); + free(client); + + return NULL; +} + +void priskv_close(priskv_client *client) +{ + + client->ops->deinit(client); + close(client->epollfd); + client->epollfd = -1; + free(client); +} + +int priskv_get_fd(priskv_client *client) +{ + return client->epollfd; +} + +int priskv_process(priskv_client *client, uint32_t event) +{ + if (client->epollfd < 0) { + return -1; + } + + priskv_events_process(client->epollfd, -1); + + return 0; +} + +priskv_memory *priskv_reg_memory(priskv_client *client, uint64_t offset, size_t length, + uint64_t iova, int fd) +{ + return client->ops->reg_memory(client, offset, length, iova, fd); +} + +void priskv_dereg_memory(priskv_memory *mem) +{ + priskv_client *client = mem->client; + + client->ops->dereg_memory(mem); +} + +static inline priskv_transport_conn *priskv_select_conn(priskv_client *client) +{ + return client->ops->select_conn(client); +} + +static void priskv_send_command(priskv_client *client, uint64_t request_id, const char *key, + priskv_sgl *sgl, uint16_t nsgl, uint64_t timeout, + priskv_req_command cmd, priskv_generic_cb cb) +{ + priskv_transport_conn *conn = priskv_select_conn(client); + priskv_connect_param *param = &conn->param; + priskv_transport_req *req; + uint16_t keylen = strlen(key); + + assert(cmd < PRISKV_COMMAND_MAX); + if (!key || !keylen) { + cb(request_id, PRISKV_STATUS_KEY_EMPTY, NULL); + } + + if (keylen > param->max_key_length) { + cb(request_id, PRISKV_STATUS_KEY_TOO_BIG, NULL); + } + + if (nsgl > param->max_sgl) { + priskv_log_error("Transport: nsgl %d > max_sgl %d\n", nsgl, param->max_sgl); + cb(request_id, PRISKV_STATUS_INVALID_SGL, NULL); + } + + req = client->ops->new_req(client, conn, request_id, key, keylen, sgl, nsgl, timeout, cmd, cb); + if (!req) { + cb(request_id, PRISKV_STATUS_NO_MEM, NULL); + return; + } + + client->ops->submit_req(req); +} + +int priskv_get_async(priskv_client *client, const char *key, priskv_sgl *sgl, uint16_t nsgl, + uint64_t request_id, priskv_generic_cb cb) +{ + if (!sgl || !nsgl) { + cb(request_id, PRISKV_STATUS_VALUE_EMPTY, 0); + return 0; + } + + priskv_send_command(client, request_id, key, sgl, nsgl, 0, PRISKV_COMMAND_GET, cb); + return 0; +} + +int priskv_set_async(priskv_client *client, const char *key, priskv_sgl *sgl, uint16_t nsgl, + uint64_t timeout, uint64_t request_id, priskv_generic_cb cb) +{ + if (!sgl || !nsgl) { + cb(request_id, PRISKV_STATUS_VALUE_EMPTY, 0); + return 0; + } + + priskv_send_command(client, request_id, key, sgl, nsgl, timeout, PRISKV_COMMAND_SET, cb); + return 0; +} + +int priskv_test_async(priskv_client *client, const char *key, uint64_t request_id, + priskv_generic_cb cb) +{ + priskv_send_command(client, request_id, key, NULL, 0, 0, PRISKV_COMMAND_TEST, cb); + return 0; +} + +int priskv_delete_async(priskv_client *client, const char *key, uint64_t request_id, + priskv_generic_cb cb) +{ + priskv_send_command(client, request_id, key, NULL, 0, 0, PRISKV_COMMAND_DELETE, cb); + return 0; +} + +int priskv_expire_async(priskv_client *client, const char *key, uint64_t timeout, + uint64_t request_id, priskv_generic_cb cb) +{ + priskv_send_command(client, request_id, key, NULL, 0, timeout, PRISKV_COMMAND_EXPIRE, cb); + return 0; +} + +int priskv_keys_async(priskv_client *client, const char *regex, uint64_t request_id, + priskv_generic_cb cb) +{ + priskv_send_command(client, request_id, regex, NULL, 0, 0, PRISKV_COMMAND_KEYS, cb); + return 0; +} + +int priskv_nrkeys_async(priskv_client *client, const char *regex, uint64_t request_id, + priskv_generic_cb cb) +{ + priskv_send_command(client, request_id, regex, NULL, 0, 0, PRISKV_COMMAND_NRKEYS, cb); + return 0; +} + +int priskv_flush_async(priskv_client *client, const char *regex, uint64_t request_id, + priskv_generic_cb cb) +{ + priskv_send_command(client, request_id, regex, NULL, 0, 0, PRISKV_COMMAND_FLUSH, cb); + return 0; +} + +uint64_t priskv_capacity(priskv_client *client) +{ + return client->conns[0]->capacity; +} + +void priskv_transport_conn_process(int fd, void *opaque, uint32_t ev) +{ + priskv_transport_conn *conn = opaque; + + assert(conn->epollfd == fd); + + priskv_log_debug("Transport: process event %d\n", ev); + priskv_events_process(conn->epollfd, -1); +} diff --git a/client/transport/transport.h b/client/transport/transport.h new file mode 100644 index 0000000..67566fc --- /dev/null +++ b/client/transport/transport.h @@ -0,0 +1,221 @@ +// Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef __PRISKV_CLIENT_TRANSPORT__ +#define __PRISKV_CLIENT_TRANSPORT__ + +#if defined(__cplusplus) +extern "C" +{ +#endif + +#include +#include +#include + +#include "../priskv.h" +#include "list.h" +#include "priskv-protocol.h" +#include "priskv-threads.h" +#include "priskv-ucx.h" +#include "priskv-utils.h" +#include "priskv-workqueue.h" +#include "uthash.h" + +// forward declaration +typedef struct priskv_conn_operation priskv_conn_operation; +typedef struct priskv_transport_conn priskv_transport_conn; + +typedef enum priskv_transport_backend { + PRISKV_TRANSPORT_BACKEND_UCX, + PRISKV_TRANSPORT_BACKEND_RDMA, + PRISKV_TRANSPORT_BACKEND_MAX, +} priskv_transport_backend; + +typedef enum priskv_transport_mem_type { + PRISKV_TRANSPORT_MEM_REQ, + PRISKV_TRANSPORT_MEM_RESP, + PRISKV_TRANSPORT_MEM_KEYS, + + PRISKV_TRANSPORT_MEM_MAX +} priskv_transport_mem_type; + +typedef struct priskv_transport_mem { +#define PRISKV_TRANSPORT_MEM_NAME_LEN 32 + char name[PRISKV_TRANSPORT_MEM_NAME_LEN]; + uint8_t *buf; + uint32_t buf_size; + union { + struct { + struct ibv_mr *mr; + }; // rdma + struct { + priskv_ucx_memh *memh; + }; // ucx + }; +} priskv_transport_mem; + +typedef struct priskv_connect_param { + /* the maxium count of @priskv_sgl */ + uint16_t max_sgl; + /* the maxium length of a KEY in bytes */ + uint16_t max_key_length; + /* the maxium command in flight, aka depth of commands */ + uint16_t max_inflight_command; +} priskv_connect_param; + +typedef struct priskv_memory { + priskv_client *client; + int count; + struct { + struct ibv_mr **mrs; + }; // rdma + struct { + priskv_ucx_memh **memhs; + }; // ucx +} priskv_memory; + +typedef struct priskv_sgl_private { + priskv_sgl sgl; + /* used for automatic registration memory */ + union { + struct ibv_mr *mr; // rdma + priskv_ucx_memh *memh; // ucx + }; +} priskv_sgl_private; + +typedef struct priskv_transport_req { + priskv_transport_conn *conn; + priskv_conn_operation *ops; + priskv_workqueue *main_wq; + struct list_node entry; + priskv_request *req; + uint64_t request_id; + char *key; + priskv_sgl_private *sgl; + uint16_t nsgl; + uint16_t keylen; + uint64_t timeout; + priskv_req_command cmd; + void (*cb)(struct priskv_transport_req *req); + priskv_generic_cb usercb; +#define PRISKV_TRANSPORT_REQ_FLAG_SEND (1 << 0) +#define PRISKV_TRANSPORT_REQ_FLAG_RECV (1 << 2) +#define PRISKV_TRANSPORT_REQ_FLAG_DONE \ + (PRISKV_TRANSPORT_REQ_FLAG_SEND | PRISKV_TRANSPORT_REQ_FLAG_RECV) + uint8_t flags; + uint16_t status; + uint32_t length; + void *result; + bool delaying; +} priskv_transport_req; + +typedef enum priskv_transport_conn_state { + PRISKV_TRANSPORT_CONN_STATE_INIT, + PRISKV_TRANSPORT_CONN_STATE_ESTABLISHED, + PRISKV_TRANSPORT_CONN_STATE_CLOSING, + PRISKV_TRANSPORT_CONN_STATE_CLOSED, +} priskv_transport_conn_state; + +typedef struct priskv_transport_conn { + char local_addr[PRISKV_ADDR_LEN]; + char peer_addr[PRISKV_ADDR_LEN]; + union { + struct { + struct rdma_cm_id *cm_id; + struct rdma_event_channel *cm_channel; + struct ibv_comp_channel *comp_channel; + struct ibv_cq *cq; + struct ibv_qp *qp; + }; // rdma + struct { + int connfd; + priskv_ucx_worker *worker; + priskv_ucx_ep *ep; + priskv_ucx_request *inflight_reqs; + }; // ucx + }; + + uint8_t id; + priskv_thread *thread; + + priskv_transport_mem rmem[PRISKV_TRANSPORT_MEM_MAX]; + + priskv_connect_param param; + uint64_t capacity; + int epollfd; + priskv_transport_conn_state state; + struct list_head inflight_list; + struct list_head complete_list; + + priskv_transport_req *keys_running_req; + priskv_memory keys_mems; + + uint64_t stats[PRISKV_COMMAND_MAX]; + uint64_t resps; + uint64_t wc_recv; + uint64_t wc_send; +} priskv_transport_conn; + +typedef struct priskv_client { + priskv_threadpool *pool; + priskv_transport_conn **conns; + int nqueue; + int cur_conn; + int epollfd; + priskv_workqueue *wq; + priskv_conn_operation *ops; +} priskv_client; + +typedef struct priskv_conn_operation { + int (*init)(priskv_client *client, const char *raddr, int rport, const char *laddr, int lport, + int nqueue); + void (*deinit)(priskv_client *client); + priskv_transport_conn *(*select_conn)(priskv_client *client); + priskv_memory *(*reg_memory)(priskv_client *client, uint64_t offset, size_t length, + uint64_t iova, int fd); + void (*dereg_memory)(priskv_memory *mem); + union { + struct ibv_mr *(*get_mr)(priskv_memory *mem, int connid); + priskv_ucx_memh *(*get_memh)(priskv_memory *mem, int connid); + }; + void (*submit_req)(priskv_transport_req *req); + void (*req_cb)(priskv_transport_req *req); + priskv_transport_req *(*new_req)(priskv_client *client, priskv_transport_conn *conn, + uint64_t request_id, const char *key, uint16_t keylen, + priskv_sgl *sgl, uint16_t nsgl, uint64_t timeout, + priskv_req_command cmd, priskv_generic_cb usercb); +} priskv_conn_operation; + +typedef struct priskv_transport_driver { + const char *name; + int (*init)(void); + priskv_conn_operation *(*get_sq_ops)(void); + priskv_conn_operation *(*get_mq_ops)(void); +} priskv_transport_driver; + +/** + * @brief Process events for the connection. + * + * @param fd The file descriptor. + * @param opaque The opaque pointer. + * @param ev The event. + */ +void priskv_transport_conn_process(int fd, void *opaque, uint32_t ev); + +#if defined(__cplusplus) +} +#endif + +#endif diff --git a/client/transport/ucx.c b/client/transport/ucx.c new file mode 100644 index 0000000..4ab6ed7 --- /dev/null +++ b/client/transport/ucx.c @@ -0,0 +1,1446 @@ +// Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "transport.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "priskv-protocol.h" +#include "priskv-protocol-helper.h" +#include "priskv-log.h" +#include "priskv-utils.h" +#include "priskv-threads.h" +#include "priskv-event.h" +#include "priskv-ucx.h" +#include "transport.h" +#include "uthash.h" + +#define PRISKV_UCX_DEFAULT_INFLIGHT_COMMAND 128 + +static priskv_ucx_context *g_ucx_ctx = NULL; + +typedef struct priskv_ucx_conn_resp { + priskv_transport_conn *conn; + priskv_response *resp; +} priskv_ucx_conn_resp; + +static int priskv_ucx_req_cb_intl(void *arg); +static inline void priskv_ucx_req_complete(priskv_transport_conn *conn); +static inline void priskv_ucx_req_free(priskv_transport_req *ucx_req); +static inline priskv_transport_req * +priskv_ucx_req_new(priskv_client *client, priskv_transport_conn *conn, uint64_t request_id, + const char *key, uint16_t keylen, priskv_sgl *sgl, uint16_t nsgl, + uint64_t timeout, priskv_req_command cmd, priskv_generic_cb usercb); +static inline void priskv_ucx_req_reset(priskv_transport_req *ucx_req); +static inline void priskv_ucx_req_done(priskv_transport_conn *conn, priskv_transport_req *ucx_req); +static int priskv_ucx_recv_resp(priskv_transport_conn *conn, priskv_response *resp); +static void priskv_ucx_req_delay_send(priskv_transport_conn *conn); +static int priskv_ucx_send_req(void *arg); + +static int priskv_ucx_init(void) +{ + g_ucx_ctx = priskv_ucx_context_init(0); + if (g_ucx_ctx == NULL) { + priskv_log_error("ucx context init failed\n"); + return -1; + } + + return 0; +} + +static int priskv_ucx_mem_new(priskv_transport_conn *conn, priskv_transport_mem *rmem, + const char *name, uint32_t size, bool remote_write) +{ + uint32_t page_size = getpagesize(); + uint8_t *buf; + int ret; + + size = ALIGN_UP(size, page_size); + buf = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + if (buf == MAP_FAILED) { + priskv_log_error("UCX: failed to allocate %s buffer: %m\n", name); + ret = -ENOMEM; + goto error; + } + + rmem->memh = priskv_ucx_mmap(g_ucx_ctx, buf, size, UCS_MEMORY_TYPE_HOST); + if (!rmem->memh) { + priskv_log_error("UCX: failed to reg MR for %s buffer: %m\n", name); + ret = -errno; + goto free_mem; + } + + strncpy(rmem->name, name, PRISKV_TRANSPORT_MEM_NAME_LEN - 1); + rmem->buf = buf; + rmem->buf_size = size; + + priskv_log_info("UCX: new rmem %s, size %d\n", name, size); + priskv_log_debug("UCX: new rmem %s, buf %p\n", name, buf); + return 0; + +free_mem: + munmap(rmem->buf, rmem->buf_size); + +error: + memset(rmem, 0x00, sizeof(priskv_transport_mem)); + + return ret; +} + +static inline void priskv_ucx_mem_free(priskv_transport_conn *conn, priskv_transport_mem *rmem) +{ + if (rmem->memh) { + priskv_ucx_munmap(rmem->memh); + rmem->memh = NULL; + } + + if (rmem->buf) { + priskv_log_debug("UCX: free rmem %s, buf %p\n", rmem->name, rmem->buf); + munmap(rmem->buf, rmem->buf_size); + } + + priskv_log_info("UCX: free rmem %s, size %d\n", rmem->name, rmem->buf_size); + memset(rmem, 0x00, sizeof(priskv_transport_mem)); +} + +static inline void priskv_ucx_mem_free_all(priskv_transport_conn *conn) +{ + for (int i = 0; i < PRISKV_TRANSPORT_MEM_MAX; i++) { + priskv_transport_mem *rmem = &conn->rmem[i]; + + priskv_ucx_mem_free(conn, rmem); + } +} + +#define PRISKV_UCX_REQUEST_FREE_COMMAND 0xffff +static void priskv_ucx_request_free(priskv_request *req, priskv_transport_conn *conn) +{ + uint8_t *ptr = (uint8_t *)req; + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_REQ]; + + assert(ptr >= rmem->buf); + assert(ptr < rmem->buf + rmem->buf_size); + assert(!((ptr - rmem->buf) % + priskv_ucx_max_request_size_aligned(conn->param.max_sgl, conn->param.max_key_length))); + + req->command = PRISKV_UCX_REQUEST_FREE_COMMAND; +} + +static int priskv_ucx_mem_new_all(priskv_transport_conn *conn) +{ + uint32_t page_size = getpagesize(), size; + + /* #step 1, prepare buffer & MR for request to server */ + int reqsize = + priskv_ucx_max_request_size_aligned(conn->param.max_sgl, conn->param.max_key_length); + size = reqsize * conn->param.max_inflight_command; + if (priskv_ucx_mem_new(conn, &conn->rmem[PRISKV_TRANSPORT_MEM_REQ], "Request", size, false)) { + goto error; + } + + /* additional work: set priskv_request::command as PRISKV_UCX_REQUEST_FREE_COMMAND */ + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_REQ]; + for (uint16_t i = 0; i < conn->param.max_inflight_command; i++) { + priskv_request *req = (priskv_request *)(rmem->buf + i * reqsize); + priskv_ucx_request_free(req, conn); + } + + /* #step 2, prepare buffer & MR for response from server */ + size = sizeof(priskv_response) * conn->param.max_inflight_command; + if (priskv_ucx_mem_new(conn, &conn->rmem[PRISKV_TRANSPORT_MEM_RESP], "Response", size, false)) { + goto error; + } + + /* #step 3, prepare buffer & MR for keys */ + size = page_size; + if (priskv_ucx_mem_new(conn, &conn->rmem[PRISKV_TRANSPORT_MEM_KEYS], "Keys", size, true)) { + goto error; + } + + return 0; + +error: + priskv_ucx_mem_free_all(conn); + + return -ENOMEM; +} + +static priskv_request *priskv_ucx_unused_command(priskv_transport_conn *conn, uint16_t *idx) +{ + uint16_t req_buf_size = + priskv_ucx_max_request_size_aligned(conn->param.max_sgl, conn->param.max_key_length); + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_REQ]; + + for (uint16_t i = 0; i < conn->param.max_inflight_command; i++) { + priskv_request *req = (priskv_request *)(rmem->buf + i * req_buf_size); + if (req->command == PRISKV_UCX_REQUEST_FREE_COMMAND) { + priskv_log_debug("UCX: use request %d\n", i); + req->command = 0xc001; + *idx = i; + return req; + } + } + + return NULL; +} + +static void priskv_ucx_close_conn(priskv_transport_conn *conn) +{ + if (conn->state == PRISKV_TRANSPORT_CONN_STATE_CLOSED) { + return; + } + + priskv_transport_req *req, *tmp; + priskv_ucx_request *ucx_req, *ucx_tmp; + size_t total_inflight = 0; + + if (conn->state == PRISKV_TRANSPORT_CONN_STATE_ESTABLISHED) { + priskv_log_notice( + "UCX: <%s - %s> ep %s close. Requests GET %ld, SET %ld, TEST %ld, DELETE %ld, " + "Responses %ld\n", + conn->local_addr, conn->peer_addr, conn->ep->name, conn->stats[PRISKV_COMMAND_GET], + conn->stats[PRISKV_COMMAND_SET], conn->stats[PRISKV_COMMAND_TEST], + conn->stats[PRISKV_COMMAND_DELETE], conn->resps); + } + + conn->state = PRISKV_TRANSPORT_CONN_STATE_CLOSED; + + HASH_ITER(hh, conn->inflight_reqs, ucx_req, ucx_tmp) + { + total_inflight++; + HASH_DEL(conn->inflight_reqs, ucx_req); + priskv_ucx_request_cancel(ucx_req); + } + if (total_inflight > 0) { + priskv_log_notice("UCX: <%s - %s> ep %s close. %ld requests are still inflight\n", + conn->local_addr, conn->peer_addr, conn->ep->name, total_inflight); + } + + priskv_ucx_req_complete(conn); + + list_for_each_safe (&conn->inflight_list, req, tmp, entry) { + list_del(&req->entry); + + priskv_ucx_request_free(req->req, conn); + req->status = PRISKV_STATUS_DISCONNECTED; + req->cb(req); + } + + if (conn->ep) { + priskv_ucx_ep_destroy(conn->ep); + conn->ep = NULL; + } + + if (conn->worker) { + priskv_ucx_worker_destroy(conn->worker); + conn->worker = NULL; + } + + if (conn->connfd >= 0) { + ucs_close_fd(&conn->connfd); + conn->connfd = -1; + } + + priskv_ucx_mem_free_all(conn); + + free(conn->keys_mems.memhs); + free(conn); +} + +static void priskv_ucx_recv_resp_cb(ucs_status_t status, ucp_tag_t sender_tag, size_t length, + void *arg) +{ + if (ucs_unlikely(arg == NULL)) { + priskv_log_error("UCX: priskv_ucx_recv_resp_cb, arg is NULL\n"); + return; + } + + priskv_ucx_conn_resp *conn_resp = arg; + priskv_transport_conn *conn = conn_resp->conn; + priskv_response *resp = conn_resp->resp; + free(conn_resp); + + priskv_ucx_request *handle = NULL; + HASH_FIND_PTR(conn->inflight_reqs, &conn_resp, handle); + if (handle) { + priskv_log_debug("UCX: remove request %p from inflight_reqs\n", handle); + HASH_DEL(conn->inflight_reqs, handle); + } + + if (ucs_unlikely(status != UCS_OK)) { + if (status == UCS_ERR_CANCELED) { + priskv_log_debug("UCX: priskv_ucx_recv_resp_cb, status: %s\n", + ucs_status_string(status)); + } else { + priskv_log_error("UCX: priskv_ucx_recv_resp_cb, status: %s\n", + ucs_status_string(status)); + } + return; + } + + if (length != sizeof(priskv_response)) { + priskv_log_warn("UCX: recv %d, expected %ld\n", length, sizeof(priskv_response)); + return; + } + + uint64_t request_id = be64toh(resp->request_id); + if (status != UCS_OK) { + priskv_log_error("UCX: priskv_ucx_recv_resp_cb, status: %s, resp: %p, request_id: 0x%lx\n", + ucs_status_string(status), resp, request_id); + return; + } + + uint16_t resp_status = be16toh(resp->status); + uint32_t resp_length = be32toh(resp->length); + priskv_transport_req *ucx_req; + + priskv_log_debug("Response request_id 0x%lx, status(%d) %s, length %d\n", request_id, + resp_status, priskv_resp_status_str(resp_status), resp_length); + ucx_req = (priskv_transport_req *)request_id; + ucx_req->status = resp_status; + ucx_req->length = resp_length; + + if (ucx_req->cmd != PRISKV_COMMAND_KEYS) { + ucx_req->result = &ucx_req->length; + } + + ucx_req->flags |= PRISKV_TRANSPORT_REQ_FLAG_RECV; + priskv_ucx_req_done(conn, ucx_req); + + conn->wc_recv++; + conn->resps++; + priskv_ucx_recv_resp(conn, resp); +} + +/* return negative number on failure, return received buffer size on success */ +static int priskv_ucx_recv_resp(priskv_transport_conn *conn, priskv_response *resp) +{ + uint16_t resp_buf_size = sizeof(priskv_response); + + priskv_log_debug("UCX: priskv_ucx_recv_resp addr %p, length %d\n", resp, resp_buf_size); + priskv_ucx_conn_resp *conn_resp = malloc(sizeof(priskv_ucx_conn_resp)); + if (ucs_unlikely(conn_resp == NULL)) { + priskv_log_error("UCX: priskv_ucx_recv_resp, malloc conn_resp failed\n"); + return -ENOMEM; + } + + conn_resp->conn = conn; + conn_resp->resp = resp; + ucs_status_ptr_t handle = + priskv_ucx_ep_post_tag_recv(conn->ep, resp, resp_buf_size, PRISKV_PROTO_TAG_CTRL, + PRISKV_PROTO_FULL_TAG_MASK, priskv_ucx_recv_resp_cb, conn_resp); + if (UCS_PTR_IS_ERR(handle)) { + ucs_status_t status = UCS_PTR_STATUS(handle); + priskv_log_error("UCX: <%s - %s> priskv_ucx_ep_post_tag_recv failed, status: %s\n", + conn->local_addr, conn->peer_addr, ucs_status_string(status)); + return -EIO; + } else if (UCS_PTR_IS_PTR(handle)) { + // still in progress + priskv_ucx_request *req = (priskv_ucx_request *)handle; + if (req->status == UCS_INPROGRESS) { + req->key = conn_resp; + HASH_ADD_PTR(conn->inflight_reqs, key, req); + } + } else { + // Operation completed immediately + } + + return resp_buf_size; +} + +static int priskv_ucx_modify_max_inflight_command(priskv_transport_conn *conn, + uint16_t max_inflight_command) +{ + /* auto detect max_inflight_command from server */ + if (max_inflight_command == PRISKV_UCX_DEFAULT_INFLIGHT_COMMAND) { + conn->param.max_inflight_command = PRISKV_UCX_DEFAULT_INFLIGHT_COMMAND; + return 0; /* no need to change */ + } + + priskv_log_warn("UCX: not support modify max_inflight_command\n"); + return 0; +} + +static void priskv_ucx_conn_close_cb(ucs_status_t status, void *arg) +{ + priskv_transport_conn *conn = arg; + if (status != UCS_OK) { + priskv_log_error("UCX: <%s - %s> ep close, status: %s\n", conn->local_addr, conn->peer_addr, + ucs_status_string(status)); + } + + // mark conn as closing + conn->state = PRISKV_TRANSPORT_CONN_STATE_CLOSING; +} + +static int priskv_ucx_handshake(priskv_transport_conn *conn, ucp_address_t **address, + uint32_t *address_len) +{ + int ret; + uint8_t *peer_worker_address = NULL; + + size_t hs_size = sizeof(priskv_cm_ucx_handshake) + conn->worker->address_len; + priskv_cm_ucx_handshake *hs = malloc(hs_size); + if (ucs_unlikely(hs == NULL)) { + priskv_log_error("UCX: priskv_ucx_handshake, malloc hs failed\n"); + ret = -1; + goto error; + } + + /* send handshake msg to server */ + if (priskv_get_log_level() >= priskv_log_debug) { + size_t print_len = conn->worker->address_len > 128 ? 128 : conn->worker->address_len; + char worker_address_hex[print_len * 2 + 1]; + priskv_ucx_to_hex(worker_address_hex, conn->worker->address, print_len); + priskv_log_debug( + "UCX: send worker address to server, address_len %d, address (first %d) %s\n", + conn->worker->address_len, print_len, worker_address_hex); + } + + if (priskv_set_block(conn->connfd)) { + priskv_log_error("UCX: failed to set block mode for connfd\n"); + ret = -1; + goto error; + } + + hs->cap.version = htobe16(PRISKV_CM_VERSION); + hs->cap.max_sgl = htobe16(conn->param.max_sgl); + hs->cap.max_key_length = htobe16(conn->param.max_key_length); + hs->cap.max_inflight_command = htobe16(conn->param.max_inflight_command); + hs->address_len = htobe32(conn->worker->address_len); + memcpy(hs->address, conn->worker->address, conn->worker->address_len); + ret = priskv_safe_send(conn->connfd, hs, hs_size, NULL, NULL); + free(hs); + if (ret) { + priskv_log_error("UCX: failed to send capability to server\n"); + ret = -1; + goto error; + } + + /* receive response from server */ + priskv_cm_ucx_handshake peer_hs; + ret = priskv_safe_recv(conn->connfd, &peer_hs, sizeof(peer_hs), NULL, NULL); + if (ret) { + priskv_log_error("UCX: failed to receive handshake msg from server\n"); + ret = -1; + goto error; + } + + /* check reject message */ + if (peer_hs.flag == 0) { + uint16_t version = be16toh(peer_hs.version); + priskv_cm_status status = be16toh(peer_hs.status); + uint64_t value = be64toh(peer_hs.value); + priskv_log_error( + "UCX: reject version %d, status: %s(%d), supported value %ld from server\n", version, + priskv_cm_status_str(status), status, value); + ret = -1; + goto error; + } + + /* accept */ + uint16_t version = be16toh(peer_hs.cap.version); + conn->param.max_sgl = be16toh(peer_hs.cap.max_sgl); + conn->param.max_key_length = be16toh(peer_hs.cap.max_key_length); + uint16_t max_inflight_command = be16toh(peer_hs.cap.max_inflight_command); + conn->capacity = be64toh(peer_hs.cap.capacity); + uint32_t peer_worker_address_len = be32toh(peer_hs.address_len); + if (peer_worker_address_len > 0) { + peer_worker_address = malloc(peer_worker_address_len); + if (peer_worker_address == NULL) { + priskv_log_error("UCX: failed to allocate memory for peer_worker_address\n"); + ret = -1; + goto error; + } + ret = priskv_safe_recv(conn->connfd, peer_worker_address, peer_worker_address_len, NULL, + NULL); + if (ret) { + priskv_log_error("UCX: failed to receive peer_worker_address from server\n"); + ret = -1; + goto error; + } + } + + if (!peer_worker_address) { + priskv_log_error("UCX: peer_worker_address is NULL\n"); + ret = -1; + goto error; + } + + if (priskv_set_nonblock(conn->connfd)) { + priskv_log_error("UCX: failed to set nonblock mode for connfd\n"); + ret = -1; + goto error; + } + + if (priskv_get_log_level() >= priskv_log_debug) { + size_t print_len = peer_worker_address_len > 128 ? 128 : peer_worker_address_len; + char worker_address_hex[print_len * 2 + 1]; + priskv_ucx_to_hex(worker_address_hex, peer_worker_address, print_len); + priskv_log_debug( + "UCX: got peer worker address from server, address_len %d, address (first %d) %s\n", + peer_worker_address_len, print_len, worker_address_hex); + } + + priskv_log_info( + "UCX: got response version %d, max_sgl %d, max_key_length %d, max_inflight_command " + "%d, capacity %ld, address_len %d from server\n", + version, conn->param.max_sgl, conn->param.max_key_length, max_inflight_command, + conn->capacity, peer_worker_address_len); + + ret = priskv_ucx_modify_max_inflight_command(conn, max_inflight_command); + if (ret) { + goto error; + } + + priskv_log_info("UCX: update connection parameters, max_sgl %d, max_key_length %d, " + "max_inflight_command %d\n", + conn->param.max_sgl, conn->param.max_key_length, + conn->param.max_inflight_command); + + *address_len = peer_worker_address_len; + *address = peer_worker_address; + + return 0; +error: + priskv_log_error("UCX: <%s - %s> connect failed\n", conn->local_addr, conn->peer_addr); + if (peer_worker_address) { + free(peer_worker_address); + peer_worker_address = NULL; + } + return ret; +} + +static inline void priskv_ucx_conn_pollin_progress(int fd, void *opaque, uint32_t ev) +{ + priskv_log_debug("UCX: pollin event %d\n", ev); + priskv_transport_conn *conn = opaque; + priskv_ucx_worker_progress(conn->worker); + if (conn->state != PRISKV_TRANSPORT_CONN_STATE_ESTABLISHED) { + priskv_ucx_close_conn(conn); + return; + } + + priskv_ucx_req_complete(conn); + priskv_ucx_req_delay_send(conn); +} + +static inline void priskv_ucx_conn_connfd_progress(int fd, void *opaque, uint32_t ev) +{ + priskv_transport_conn *conn = opaque; + priskv_ucx_close_conn(conn); +} + +static int priskv_ucx_bind(const char *addr, int port, int *connfd) +{ + int ret; + int sockfd = -1; + int optval = 1; + struct addrinfo hints, *res, *t; + char service[8]; + char err_str[64]; + ucs_status_t status; + + ucs_snprintf_safe(service, sizeof(service), "%u", port); + memset(&hints, 0, sizeof(hints)); + hints.ai_flags = 0; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + ret = getaddrinfo(addr, service, &hints, &res); + if (ret < 0) { + priskv_log_error("UCX: getaddrinfo failed, addr %s, port %s, error %s\n", addr, service, + gai_strerror(ret)); + ret = -1; + goto out; + } + + if (res == NULL) { + priskv_log_error("UCX: getaddrinfo returned empty list\n"); + ret = -1; + goto out; + } + + for (t = res; t != NULL; t = t->ai_next) { + sockfd = socket(t->ai_family, t->ai_socktype, t->ai_protocol); + if (sockfd < 0) { + snprintf(err_str, 64, "socket failed: %m"); + continue; + } + + status = ucs_socket_setopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)); + if (status != UCS_OK) { + snprintf(err_str, 64, "setopt failed: %m"); + continue; + } + + if (bind(sockfd, t->ai_addr, t->ai_addrlen) == 0) { + break; + } + + snprintf(err_str, 64, "bind failed: %m"); + ucs_close_fd(&sockfd); + sockfd = -1; + } + + if (sockfd < 0) { + priskv_log_error("UCX: bind failed, addr %s, port %s, error %s\n", addr, service, err_str); + ret = -1; + goto out_free_res; + } + + *connfd = sockfd; + return 0; + +out_free_res: + freeaddrinfo(res); +out: + return ret; +} + +static int priskv_ucx_connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen, + int timeout_ms, char *err_str) +{ + int epoll_fd, ret; + + ret = priskv_set_nonblock(sockfd); + if (ret) { + snprintf(err_str, 64, "set NONBLOCK failed: %m"); + return -1; + } + + ret = connect(sockfd, addr, addrlen); + if (ret == 0) { + return 0; + } + + if (errno != EINPROGRESS) { + snprintf(err_str, 64, "connect failed: %m"); + return -1; + } + + epoll_fd = epoll_create1(0); + if (epoll_fd < 0) { + snprintf(err_str, 64, "epoll_create1 failed: %m"); + return -1; + } + + struct epoll_event ev, event; + ev.events = EPOLLOUT; + ev.data.fd = sockfd; + + if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, sockfd, &ev) < 0) { + snprintf(err_str, 64, "epoll_ctl failed: %m"); + ret = -1; + goto out_free_res; + } + +poll_again: + ret = epoll_wait(epoll_fd, &event, 1, timeout_ms); + if (ret == 0) { + snprintf(err_str, 64, "connect timeout"); + errno = ETIMEDOUT; + ret = -1; + goto out_free_res; + } else if (ret < 0) { + if (errno == EINTR) { + goto poll_again; + } + snprintf(err_str, 64, "connect failed: %m"); + ret = -1; + goto out_free_res; + } + + int opt = 0; + socklen_t optlen = sizeof(opt); + if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &opt, &optlen) < 0) { + snprintf(err_str, 64, "getsockopt failed: %m"); + ret = -1; + goto out_free_res; + } + + if (opt != 0) { + errno = opt; + snprintf(err_str, 64, "connect failed: %m"); + ret = -1; + goto out_free_res; + } + + ret = 0; +out_free_res: + if (epoll_fd >= 0) { + close(epoll_fd); + } + return ret; +} + +static priskv_transport_conn *priskv_ucx_conn_connect(const char *raddr, int rport, + const char *laddr, int lport) +{ + priskv_transport_conn *conn = NULL; + int ret; + int connfd = -1; + int connected = -1; + struct addrinfo hints, *res, *t; + char service[8]; + char err_str[64]; + + /* bind local address if user specify one */ + if (laddr) { + ret = priskv_ucx_bind(laddr, lport, &connfd); + if (ret) { + return NULL; + } + } + + ucs_snprintf_safe(service, sizeof(service), "%u", rport); + memset(&hints, 0, sizeof(hints)); + hints.ai_flags = 0; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + ret = getaddrinfo(raddr, service, &hints, &res); + if (ret < 0) { + priskv_log_error("UCX: getaddrinfo failed, server %s, port %s, error %s\n", raddr, service, + gai_strerror(ret)); + goto out_free_res; + } + + if (res == NULL) { + priskv_log_error("UCX: getaddrinfo returned empty list\n"); + goto out_free_res; + } + + for (t = res; t != NULL; t = t->ai_next) { + if (!laddr) { + connfd = socket(t->ai_family, t->ai_socktype, t->ai_protocol); + if (connfd < 0) { + snprintf(err_str, 64, "socket failed: %m"); + continue; + } + } + + if (priskv_ucx_connect(connfd, t->ai_addr, t->ai_addrlen, 1000, err_str) == 0) { + connected = 1; + break; + } + + if (!laddr) { + ucs_close_fd(&connfd); + connfd = -1; + } + } + + if (connfd < 0 || connected != 1) { + priskv_log_error("UCX: connect failed, server %s, port %s, error %s\n", raddr, service, + err_str); + goto out_free_res; + } + + conn = calloc(sizeof(struct priskv_transport_conn), 1); + if (!conn) { + priskv_log_error("UCX: failed to allocate memory for UCX connection\n"); + goto out_free_res; + } + + conn->param.max_sgl = 0; + conn->param.max_key_length = 0; + conn->param.max_inflight_command = PRISKV_UCX_DEFAULT_INFLIGHT_COMMAND; + + list_head_init(&conn->inflight_list); + list_head_init(&conn->complete_list); + + conn->keys_running_req = NULL; + conn->keys_mems.count = 1; + conn->keys_mems.memhs = calloc(sizeof(priskv_ucx_memh *), 1); + conn->state = PRISKV_TRANSPORT_CONN_STATE_INIT; + conn->connfd = connfd; + connfd = -1; + + ucs_socket_getname_str(conn->connfd, conn->local_addr, sizeof(conn->local_addr)); + priskv_inet_ntop(t->ai_addr, conn->peer_addr); + + conn->worker = priskv_ucx_worker_create(g_ucx_ctx, 0); + if (!conn->worker) { + priskv_log_error("UCX: failed to create worker\n"); + goto err_free_address; + } + + ucp_address_t *address; + uint32_t address_len; + ret = priskv_ucx_handshake(conn, &address, &address_len); + if (ret) { + goto error; + } + + conn->ep = priskv_ucx_ep_create_from_worker_addr(conn->worker, address, + priskv_ucx_conn_close_cb, conn); + if (!conn->ep) { + priskv_log_error("UCX: failed to create endpoint\n"); + goto err_free_address; + } + + free(address); + + priskv_set_fd_handler(conn->worker->efd, priskv_ucx_conn_pollin_progress, NULL, conn); + priskv_set_fd_handler(conn->connfd, priskv_ucx_conn_connfd_progress, NULL, conn); + conn->epollfd = conn->worker->efd; + + ret = priskv_ucx_mem_new_all(conn); + if (ret) { + goto error; + } + + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_RESP]; + priskv_response *resp = (priskv_response *)rmem->buf; + for (int i = 0; i < conn->param.max_inflight_command; i++) { + ret = priskv_ucx_recv_resp(conn, resp + i); + if (ret < 0) { + goto error; + } + } + + conn->state = PRISKV_TRANSPORT_CONN_STATE_ESTABLISHED; + + priskv_log_notice("UCX: <%s - %s> established \n", conn->local_addr, conn->peer_addr); + + return conn; + +err_free_address: + free(address); +error: + priskv_ucx_close_conn(conn); + conn = NULL; +out_free_res: + if (connfd >= 0) { + ucs_close_fd(&connfd); + } + freeaddrinfo(res); + return conn; +} + +int priskv_ucx_conn_close(void *conn) +{ + if (!conn) { + return 0; + } + + priskv_ucx_close_conn(conn); + + return 0; +} + +static priskv_ucx_memh *priskv_ucx_conn_reg_memory(priskv_transport_conn *conn, uint64_t offset, + size_t length, uint64_t iova, int fd) +{ + priskv_ucx_memh *memh = NULL; + + if (fd >= 0) { + priskv_log_error("UCX: not support register memory with fd\n"); + } else { + memh = priskv_ucx_mmap(g_ucx_ctx, (void *)offset, length, UCS_MEMORY_TYPE_HOST); + } + + if (!memh) { + priskv_log_error( + "UCX: failed to reg mr 0x%lx:%ld %m. If you are using GPU memory, check if " + "the nvidia_peermem module is installed\n", + offset, length); + } + + return memh; +} + +static void priskv_ucx_conn_dereg_memory(priskv_ucx_memh *memh) +{ + priskv_ucx_munmap(memh); +} + +static int priskv_ucx_mq_init(priskv_client *client, const char *raddr, int rport, + const char *laddr, int lport, int nqueue) +{ + client->wq = priskv_workqueue_create(client->epollfd); + if (!client->wq) { + priskv_log_error("UCX: failed to create workqueue\n"); + return -1; + } + + client->pool = priskv_threadpool_create("priskv", nqueue, 0, 0); + if (!client->pool) { + priskv_log_error("UCX: failed to create threadpool\n"); + return -1; + } + + client->conns = calloc(nqueue, sizeof(priskv_transport_conn *)); + if (!client->conns) { + priskv_log_error("UCX: failed to allocate memory for connections\n"); + return -1; + } + client->nqueue = nqueue; + + for (uint8_t i = 0; i < nqueue; i++) { + client->conns[i] = priskv_ucx_conn_connect(raddr, rport, laddr, lport); + if (!client->conns[i]) { + priskv_log_error("UCX: failed to connect to %s:%d\n", raddr, rport); + return -1; + } + + client->conns[i]->id = i; + client->conns[i]->thread = priskv_threadpool_get_iothread(client->pool, i); + + priskv_thread_add_event_handler(client->conns[i]->thread, client->conns[i]->epollfd); + priskv_thread_add_event_handler(client->conns[i]->thread, client->conns[i]->connfd); + } + + client->cur_conn = 0; + + return 0; +} + +static void priskv_ucx_mq_deinit(priskv_client *client) +{ + if (client->conns) { + for (int i = 0; i < client->nqueue; i++) { + priskv_thread_call_function(priskv_threadpool_get_iothread(client->pool, i), + priskv_ucx_conn_close, client->conns[i]); + } + } + + priskv_threadpool_destroy(client->pool); + priskv_workqueue_destroy(client->wq); + free(client->conns); +} + +static priskv_transport_conn *priskv_ucx_mq_select_conn(priskv_client *client) +{ + return client->conns[client->cur_conn++ % client->nqueue]; +} + +static priskv_memory *priskv_ucx_mq_reg_memory(priskv_client *client, uint64_t offset, + size_t length, uint64_t iova, int fd) +{ + priskv_memory *mem = malloc(sizeof(priskv_memory)); + + mem->client = client; + mem->count = client->nqueue; + mem->memhs = malloc(client->nqueue * sizeof(priskv_ucx_memh *)); + + for (int i = 0; i < mem->count; i++) { + mem->memhs[i] = priskv_ucx_conn_reg_memory(client->conns[i], offset, length, iova, fd); + } + + return mem; +} + +static void priskv_ucx_mq_dereg_memory(priskv_memory *mem) +{ + for (int i = 0; i < mem->count; i++) { + priskv_ucx_conn_dereg_memory(mem->memhs[i]); + } + free(mem->memhs); + free(mem); +} + +static priskv_ucx_memh *priskv_ucx_mq_get_memh(priskv_memory *mem, int connid) +{ + return mem->memhs[connid]; +} + +static void priskv_ucx_mq_req_submit(priskv_transport_req *ucx_req) +{ + priskv_thread_submit_function(ucx_req->conn->thread, priskv_ucx_send_req, ucx_req); +} + +static void priskv_ucx_mq_req_cb(priskv_transport_req *ucx_req) +{ + priskv_workqueue_submit(ucx_req->main_wq, priskv_ucx_req_cb_intl, ucx_req); +} + +static priskv_conn_operation priskv_ucx_mq_ops = { + .init = priskv_ucx_mq_init, + .deinit = priskv_ucx_mq_deinit, + .select_conn = priskv_ucx_mq_select_conn, + .reg_memory = priskv_ucx_mq_reg_memory, + .dereg_memory = priskv_ucx_mq_dereg_memory, + .get_memh = priskv_ucx_mq_get_memh, + .submit_req = priskv_ucx_mq_req_submit, + .req_cb = priskv_ucx_mq_req_cb, + .new_req = priskv_ucx_req_new, +}; + +static int priskv_ucx_sq_init(priskv_client *client, const char *raddr, int rport, + const char *laddr, int lport, int nqueue) +{ + client->conns = calloc(1, sizeof(priskv_transport_conn *)); + if (!client->conns) { + priskv_log_error("UCX: failed to allocate memory for connections\n"); + return -1; + } + + client->conns[0] = priskv_ucx_conn_connect(raddr, rport, laddr, lport); + if (!client->conns[0]) { + priskv_log_error("UCX: failed to connect to %s:%d\n", raddr, rport); + return -1; + } + + priskv_add_event_fd(client->epollfd, client->conns[0]->epollfd); + priskv_add_event_fd(client->epollfd, client->conns[0]->connfd); + + return 0; +} + +static void priskv_ucx_sq_deinit(priskv_client *client) +{ + priskv_ucx_conn_close(client->conns[0]); + free(client->conns); +} + +static priskv_transport_conn *priskv_ucx_sq_select_conn(priskv_client *client) +{ + return client->conns[0]; +} + +static priskv_memory *priskv_ucx_sq_reg_memory(priskv_client *client, uint64_t offset, + size_t length, uint64_t iova, int fd) +{ + priskv_memory *mem = malloc(sizeof(priskv_memory)); + + mem->client = client; + mem->count = 1; + mem->memhs = malloc(sizeof(priskv_ucx_memh *)); + + mem->memhs[0] = priskv_ucx_conn_reg_memory(client->conns[0], offset, length, iova, fd); + + return mem; +} + +static void priskv_ucx_sq_dereg_memory(priskv_memory *mem) +{ + priskv_ucx_conn_dereg_memory(mem->memhs[0]); + free(mem->memhs); + free(mem); +} + +static priskv_ucx_memh *priskv_ucx_sq_get_memh(priskv_memory *mem, int connid) +{ + return mem->memhs[0]; +} + +static void priskv_ucx_sq_req_submit(priskv_transport_req *ucx_req) +{ + priskv_ucx_send_req(ucx_req); +} + +static void priskv_ucx_sq_req_cb(priskv_transport_req *ucx_req) +{ + priskv_ucx_req_cb_intl(ucx_req); +} + +static priskv_conn_operation priskv_ucx_sq_ops = { + .init = priskv_ucx_sq_init, + .deinit = priskv_ucx_sq_deinit, + .select_conn = priskv_ucx_sq_select_conn, + .reg_memory = priskv_ucx_sq_reg_memory, + .dereg_memory = priskv_ucx_sq_dereg_memory, + .get_memh = priskv_ucx_sq_get_memh, + .submit_req = priskv_ucx_sq_req_submit, + .req_cb = priskv_ucx_sq_req_cb, + .new_req = priskv_ucx_req_new, +}; + +static inline void priskv_ucx_fillup_sgl(priskv_transport_req *req, priskv_keyed_sgl *keyed_sgl) +{ + priskv_transport_conn *conn = req->conn; + uint8_t *keyed_sgl_base = (uint8_t *)keyed_sgl; + + for (uint16_t i = 0; i < req->nsgl; i++) { + priskv_sgl_private *_sgl = &req->sgl[i]; + priskv_ucx_memh *memh; + + if (!_sgl->sgl.mem) { + priskv_log_warn("UCX: SGL %d without registered memory, iova 0x%lx, length 0x%zx\n", i, + _sgl->sgl.iova, _sgl->sgl.length); + memh = _sgl->memh = priskv_ucx_conn_reg_memory(conn, _sgl->sgl.iova, _sgl->sgl.length, + _sgl->sgl.iova, -1); + } else { + if (req->cmd != PRISKV_COMMAND_KEYS) { + memh = req->ops->get_memh(_sgl->sgl.mem, conn->id); + } else { + memh = req->ops->get_memh(_sgl->sgl.mem, 0); + } + } + + assert(memh->rkey_length <= priskv_ucx_max_rkey_len); + + priskv_keyed_sgl *_keyed_sgl = (priskv_keyed_sgl *)keyed_sgl_base; + + _keyed_sgl->addr = htobe64(_sgl->sgl.iova); + _keyed_sgl->length = htobe32(_sgl->sgl.length); + _keyed_sgl->packed_rkey_len = htobe32(memh->rkey_length); + memcpy(_keyed_sgl->packed_rkey, memh->rkey_buffer, memh->rkey_length); + keyed_sgl_base += sizeof(priskv_keyed_sgl) + memh->rkey_length; + + if (priskv_get_log_level() >= priskv_log_debug) { + char rkey_hex[priskv_ucx_max_rkey_len * 2 + 1]; + priskv_ucx_to_hex(rkey_hex, _keyed_sgl->packed_rkey, memh->rkey_length); + priskv_log_debug("UCX: addr 0x%lx@%x rkey (%d) %s\n", _sgl->sgl.iova, _sgl->sgl.length, + memh->rkey_length, rkey_hex); + } + } +} + +static int priskv_ucx_req_cb_intl(void *arg) +{ + priskv_transport_req *req = arg; + + if (req->usercb) { + req->usercb(req->request_id, req->status, req->result); + } + + priskv_ucx_req_free(req); + + return 0; +} + +static void priskv_ucx_req_cb(priskv_transport_req *ucx_req) +{ + ucx_req->ops->req_cb(ucx_req); +} + +static void priskv_ucx_keys_req_cb(priskv_transport_req *ucx_req) +{ + priskv_transport_conn *conn = ucx_req->conn; + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_KEYS]; + uint32_t valuelen = ucx_req->length; + + if (ucx_req->status == PRISKV_STATUS_OK) { + uint32_t nkey = 0; + uint16_t keylen; + uint8_t *buf = rmem->buf; + + while ((buf - rmem->buf) < valuelen) { + priskv_keys_resp *keys_resp = (priskv_keys_resp *)buf; + keylen = be16toh(keys_resp->keylen); + + buf += sizeof(priskv_keys_resp); + buf += keylen; + nkey++; + } + + if ((buf - rmem->buf) != valuelen) { + priskv_log_error("UCX: KEYS protocol error\n"); + ucx_req->status = PRISKV_STATUS_PROTOCOL_ERROR; + goto exit; + } + + priskv_keyset *keyset = calloc(1, sizeof(priskv_keyset)); + keyset->keys = calloc(nkey, sizeof(priskv_key)); + keyset->nkey = nkey; + priskv_key *curkey = &keyset->keys[0]; + buf = rmem->buf; + while ((buf - rmem->buf) < valuelen) { + priskv_keys_resp *keys_resp = (priskv_keys_resp *)buf; + keylen = be16toh(keys_resp->keylen); + curkey->valuelen = be32toh(keys_resp->valuelen); + + buf += sizeof(priskv_keys_resp); + curkey->key = malloc(keylen + 1); + memcpy(curkey->key, buf, keylen); + curkey->key[keylen] = '\0'; + buf += keylen; + curkey++; + } + + ucx_req->result = keyset; + } else if (ucx_req->status == PRISKV_STATUS_VALUE_TOO_BIG) { + priskv_log_info("UCX: resize KEYS buffer to valuelen %d\n", valuelen); + priskv_ucx_mem_free(conn, rmem); + if (priskv_ucx_mem_new(conn, rmem, "Keys", valuelen + valuelen / 8, true)) { + priskv_log_error("UCX: failed to resize KEYS buffer to valuelen %d\n", valuelen); + ucx_req->status = PRISKV_STATUS_TRANSPORT_ERROR; + goto exit; + } + + conn->keys_mems.count = 1; + conn->keys_mems.memhs[0] = rmem->memh; + + priskv_ucx_req_reset(ucx_req); + + ucx_req->nsgl = 1; + ucx_req->sgl[0].sgl.iova = (uint64_t)(rmem->buf); + ucx_req->sgl[0].sgl.length = rmem->buf_size; + ucx_req->sgl[0].sgl.mem = &conn->keys_mems; + + priskv_ucx_send_req(ucx_req); + return; + } + +exit: + conn->keys_running_req = NULL; + priskv_ucx_req_cb(ucx_req); +} + +static inline priskv_transport_req * +priskv_ucx_req_new(priskv_client *client, priskv_transport_conn *conn, uint64_t request_id, + const char *key, uint16_t keylen, priskv_sgl *sgl, uint16_t nsgl, + uint64_t timeout, priskv_req_command cmd, priskv_generic_cb usercb) +{ + priskv_transport_req *ucx_req = calloc(1, sizeof(priskv_transport_req)); + if (!ucx_req) { + return NULL; + } + + ucx_req->conn = conn; + ucx_req->ops = client->ops; + ucx_req->main_wq = client->wq; + ucx_req->cmd = cmd; + ucx_req->timeout = timeout; + ucx_req->key = strdup(key); + ucx_req->keylen = keylen; + ucx_req->request_id = request_id; + ucx_req->usercb = usercb; + ucx_req->cb = priskv_ucx_req_cb; + ucx_req->delaying = false; + list_node_init(&ucx_req->entry); + + if (sgl && nsgl) { + ucx_req->nsgl = nsgl; + ucx_req->sgl = calloc(nsgl, sizeof(priskv_sgl_private)); + for (int i = 0; i < nsgl; i++) { + memcpy(&ucx_req->sgl[i], &sgl[i], sizeof(priskv_sgl)); + } + } else if (cmd == PRISKV_COMMAND_KEYS) { + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_KEYS]; + conn->keys_mems.count = 1; + conn->keys_mems.memhs[0] = rmem->memh; + + ucx_req->nsgl = 1; + ucx_req->sgl = malloc(sizeof(priskv_sgl_private)); + ucx_req->sgl[0].sgl.iova = (uint64_t)(rmem->buf); + ucx_req->sgl[0].sgl.length = rmem->buf_size; + ucx_req->sgl[0].sgl.mem = &conn->keys_mems; + ucx_req->sgl[0].memh = NULL; + ucx_req->cb = priskv_ucx_keys_req_cb; + } + + return ucx_req; +} + +static inline void priskv_ucx_req_free(priskv_transport_req *req) +{ + for (int i = 0; i < req->nsgl; i++) { + priskv_sgl_private *_sgl = &req->sgl[i]; + if (_sgl->memh) { + priskv_ucx_conn_dereg_memory(_sgl->memh); + _sgl->memh = NULL; + } + } + + free(req->sgl); + free(req->key); + free(req); +} + +static inline void priskv_ucx_req_reset(priskv_transport_req *ucx_req) +{ + ucx_req->flags = 0; + ucx_req->req = NULL; + ucx_req->status = PRISKV_STATUS_OK; + ucx_req->length = 0; + ucx_req->delaying = false; +} + +static void priskv_ucx_send_req_cb(ucs_status_t status, void *arg) +{ + if (ucs_unlikely(arg == NULL)) { + priskv_log_error("UCX: priskv_ucx_send_req_cb, arg is NULL\n"); + return; + } + + priskv_transport_req *ucx_req = (priskv_transport_req *)arg; + priskv_transport_conn *conn = ucx_req->conn; + priskv_ucx_request *req; + + HASH_FIND_PTR(conn->inflight_reqs, &ucx_req, req); + if (req) { + priskv_log_debug("UCX: remove request %p from inflight_reqs\n", req); + HASH_DEL(conn->inflight_reqs, req); + } + + if (status != UCS_OK) { + priskv_log_error("UCX: priskv_ucx_send_req_cb, status: %s, request: %p, request_id: %lu\n", + ucs_status_string(status), ucx_req, ucx_req->request_id); + ucx_req->status = PRISKV_STATUS_TRANSPORT_ERROR; + ucx_req->cb(ucx_req); + return; + } else { + conn->wc_send++; + ucx_req->flags |= PRISKV_TRANSPORT_REQ_FLAG_SEND; + priskv_ucx_req_done(ucx_req->conn, ucx_req); + } +} + +static int priskv_ucx_send_req(void *arg) +{ + priskv_transport_req *ucx_req = arg; + priskv_transport_conn *conn = ucx_req->conn; + priskv_request *req; + uint16_t req_idx; + + if (conn->state != PRISKV_TRANSPORT_CONN_STATE_ESTABLISHED) { + ucx_req->status = PRISKV_STATUS_DISCONNECTED; + ucx_req->cb(ucx_req); + return -1; + } + + if (ucx_req->cmd == PRISKV_COMMAND_KEYS) { + if (conn->keys_running_req && conn->keys_running_req != ucx_req) { + ucx_req->status = PRISKV_STATUS_BUSY; + ucx_req->cb(ucx_req); + return -1; + } else { + conn->keys_running_req = ucx_req; + } + } + + req = priskv_ucx_unused_command(conn, &req_idx); + if (!req) { + if (ucx_req->delaying) { + list_add(&conn->inflight_list, &ucx_req->entry); + } else { + list_add_tail(&conn->inflight_list, &ucx_req->entry); + ucx_req->delaying = true; + } + return EAGAIN; + } + + req->request_id = htobe64((uint64_t)ucx_req); + req->command = htobe16(ucx_req->cmd); + req->nsgl = htobe16(ucx_req->nsgl); + req->timeout = htobe64(ucx_req->timeout); + req->key_length = htobe16(ucx_req->keylen); + + struct timeval client_metadata_send_time; + gettimeofday(&client_metadata_send_time, NULL); + req->runtime.client_metadata_send_time = client_metadata_send_time; + + priskv_ucx_fillup_sgl(ucx_req, req->sgls); + memcpy(priskv_ucx_request_key(req), ucx_req->key, ucx_req->keylen); + + uint16_t req_length = priskv_ucx_request_size(req); + if (priskv_get_log_level() >= priskv_log_debug) { + char key_short[16] = {0}; + priskv_string_shorten(ucx_req->key, ucx_req->keylen, key_short, sizeof(key_short)); + + priskv_log_debug( + "UCX: Request command length %u, request_id 0x%lx, %s[0x%x], nsgl %u, key[%u] %s\n", + req_length, ucx_req->request_id, priskv_command_str(ucx_req->cmd), ucx_req->cmd, + ucx_req->nsgl, ucx_req->keylen, key_short); + } + + ucx_req->req = req; + + ucs_status_ptr_t handle = priskv_ucx_ep_post_tag_send( + conn->ep, req, req_length, PRISKV_PROTO_TAG_CTRL, priskv_ucx_send_req_cb, ucx_req); + if (UCS_PTR_IS_ERR(handle)) { + priskv_log_notice( + "UCX: <%s - %s> ep %s close. Requests GET %ld, SET %ld, TEST %ld, DELETE %ld, " + "Responses %ld\n", + conn->local_addr, conn->peer_addr, conn->ep->name, conn->stats[PRISKV_COMMAND_GET], + conn->stats[PRISKV_COMMAND_SET], conn->stats[PRISKV_COMMAND_TEST], + conn->stats[PRISKV_COMMAND_DELETE], conn->resps); + + ucs_status_t status = UCS_PTR_STATUS(handle); + priskv_log_error("UCX: priskv_ucx_ep_post_tag_send addr %p, length %d. wc_recv %ld, " + "wc_send %ld, failed: %s\n", + req, req_length, conn->wc_recv, conn->wc_send, ucs_status_string(status)); + return -1; + } else if (UCS_PTR_IS_PTR(handle)) { + // still in progress + priskv_ucx_request *request = handle; + if (request->status == UCS_INPROGRESS) { + request->key = ucx_req; + HASH_ADD_PTR(conn->inflight_reqs, key, request); + } + } else { + // Operation completed immediately + } + + conn->stats[ucx_req->cmd]++; + + return 0; +} + +static inline void priskv_ucx_req_submit(priskv_transport_req *ucx_req) +{ + ucx_req->ops->submit_req(ucx_req); +} + +static void priskv_ucx_req_delay_send(priskv_transport_conn *conn) +{ + priskv_transport_req *ucx_req, *tmp; + + list_for_each_safe (&conn->inflight_list, ucx_req, tmp, entry) { + list_del(&ucx_req->entry); + + if (priskv_ucx_send_req(ucx_req) == EAGAIN) { + return; + } + } +} + +static inline void priskv_ucx_req_done(priskv_transport_conn *conn, priskv_transport_req *ucx_req) +{ + if ((ucx_req->flags & PRISKV_TRANSPORT_REQ_FLAG_DONE) == PRISKV_TRANSPORT_REQ_FLAG_DONE) { + list_add_tail(&conn->complete_list, &ucx_req->entry); + } +} + +static inline void priskv_ucx_req_complete(priskv_transport_conn *conn) +{ + priskv_transport_req *req, *tmp; + + list_for_each_safe (&conn->complete_list, req, tmp, entry) { + list_del(&req->entry); + + priskv_ucx_request_free(req->req, conn); + req->cb(req); + } +} + +priskv_conn_operation *priskv_ucx_get_sq_ops(void) +{ + return &priskv_ucx_sq_ops; +} + +priskv_conn_operation *priskv_ucx_get_mq_ops(void) +{ + return &priskv_ucx_mq_ops; +} + +priskv_transport_driver priskv_transport_driver_ucx = { + .name = "ucx", + .init = priskv_ucx_init, + .get_sq_ops = priskv_ucx_get_sq_ops, + .get_mq_ops = priskv_ucx_get_mq_ops, +}; diff --git a/cluster/client/Makefile b/cluster/client/Makefile index 45c0d9e..423e50e 100644 --- a/cluster/client/Makefile +++ b/cluster/client/Makefile @@ -2,8 +2,10 @@ LIB_NAME = libpriskvcluster STATIC_LIB = $(LIB_NAME).a STATIC_LIB_TMP = $(LIB_NAME).tmp.a INCFLAGS = -I../../client/ -I../../include/ -LIBS = -lrdmacm -libverbs -lpthread -lhiredis -lncurses -luuid -CFLAGS = -fPIC -Wall -g -O0 $(INCFLAGS) -D_GNU_SOURCE -Wshadow -Wformat=2 -Wwrite-strings -fstack-protector-strong -Wnull-dereference -Wunreachable-code $(LIBS) +UCX_CFLAGS = $(shell pkg-config --cflags ucx) +UCX_LDFLAGS = $(shell pkg-config --libs ucx) +LIBS = -lrdmacm -libverbs -lpthread -lhiredis -lncurses -luuid $(UCX_LDFLAGS) +CFLAGS = -fPIC -Wall -g -O0 $(INCFLAGS) -D_GNU_SOURCE -Wshadow -Wformat=2 -Wwrite-strings -fstack-protector-strong -Wnull-dereference -Wunreachable-code $(UCX_CFLAGS) $(LIBS) CC = gcc AR = ar FMT = clang-format-19 diff --git a/docker/Dockerfile_ubuntu2204 b/docker/Dockerfile_ubuntu2204 index 61523d9..9fa6516 100644 --- a/docker/Dockerfile_ubuntu2204 +++ b/docker/Dockerfile_ubuntu2204 @@ -1,11 +1,29 @@ FROM docker.io/library/ubuntu:22.04 ENV DEBIAN_FRONTEND=noninteractive -RUN apt update && apt install -y git gcc make cmake librdmacm-dev rdma-core libibverbs-dev libncurses5-dev libmount-dev libevent-dev libssl-dev python3-pybind11 python3-dev python3-pip libhiredis-dev liburing-dev +RUN apt update && apt install -y git gcc make wget librdmacm-dev rdma-core libibverbs-dev libncurses5-dev libmount-dev libevent-dev libssl-dev python3-pybind11 python3-dev python3-pip libhiredis-dev liburing-dev + +# install cmake +RUN cd /tmp && \ + wget -q https://cmake.org/files/v4.1/cmake-4.1.3-linux-x86_64.sh && \ + bash cmake-4.1.3-linux-x86_64.sh --skip-license --prefix=/usr && \ + rm cmake-4.1.3-linux-x86_64.sh + +# libevhtp uses TestEndianess.c.in +RUN ln -s /usr/share/cmake-4.1/Modules/TestEndianness.c.in /usr/share/cmake-4.1/Modules/TestEndianess.c.in + +# install ucx +RUN cd /tmp && \ + wget -q https://github.com/openucx/ucx/releases/download/v1.19.0/ucx-1.19.0-ubuntu22.04-mofed5-cuda12-x86_64.tar.bz2 && \ + tar xvf ucx-1.19.0-ubuntu22.04-mofed5-cuda12-x86_64.tar.bz2 && \ + dpkg -i ucx-1.19.0.deb ucx-cuda-1.19.0.deb ucx-xpmem-1.19.0.deb && \ + rm /tmp/ucx-1.19.0-ubuntu22.04-mofed5-cuda12-x86_64.tar.bz2 /tmp/ucx-1.19.0.deb /tmp/ucx-cuda-1.19.0.deb /tmp/ucx-xpmem-1.19.0.deb ADD . /root/priskv WORKDIR /root/priskv +ENV CMAKE_POLICY_VERSION_MINIMUM=3.5 + RUN mkdir /workspace RUN make all RUN cp server/priskv-server /workspace/ diff --git a/include/priskv-protocol-helper.h b/include/priskv-protocol-helper.h index 552e015..8da080e 100644 --- a/include/priskv-protocol-helper.h +++ b/include/priskv-protocol-helper.h @@ -34,23 +34,73 @@ extern "C" #include #include "priskv-protocol.h" +#include "priskv-utils.h" -static inline uint16_t priskv_request_key_off(uint16_t nsgl) +static uint32_t priskv_ucx_max_rkey_len = 256; + +static inline uint16_t priskv_rdma_request_key_off(priskv_request *req) { - // size_t offset = offsetof(priskv_request, sgls); - // return (uint16_t)(offset + nsgl * sizeof(priskv_keyed_sgl)); - return sizeof(priskv_request) + sizeof(priskv_keyed_sgl) * nsgl; + return sizeof(priskv_request) + sizeof(priskv_keyed_sgl) * be16toh(req->nsgl); } -static inline uint16_t priskv_request_size(uint16_t nsgl, uint16_t keylen) +static inline uint16_t priskv_rdma_request_size(priskv_request *req) { - return priskv_request_key_off(nsgl) + keylen; + return priskv_rdma_request_key_off(req) + be16toh(req->key_length); } -static inline uint8_t *priskv_request_key(priskv_request *req, uint16_t nsgl) +static inline unsigned int priskv_rdma_max_request_size_aligned(uint16_t max_sgl, + uint16_t max_key_length) +{ + uint16_t s = sizeof(priskv_request) + sizeof(priskv_keyed_sgl) * max_sgl + max_key_length; + + return ALIGN_UP(s, 64); +} + +static inline uint8_t *priskv_rdma_request_key(priskv_request *req) { unsigned char *base = (unsigned char *)req; - return base + priskv_request_key_off(nsgl); + return base + priskv_rdma_request_key_off(req); +} + +static inline uint16_t priskv_ucx_request_key_off(priskv_request *req) +{ + uint16_t nsgl = be16toh(req->nsgl); + uint16_t off = sizeof(priskv_request); + uint16_t i = 0; + for (; i < nsgl; i++) { + off += sizeof(priskv_keyed_sgl) + be32toh(req->sgls[i].packed_rkey_len); + } + return off; +} + +static inline uint16_t priskv_ucx_request_size(priskv_request *req) +{ + return priskv_ucx_request_key_off(req) + be16toh(req->key_length); +} + +static inline unsigned int priskv_ucx_max_request_size_aligned(uint16_t max_sgl, + uint16_t max_key_length) +{ + uint16_t s = sizeof(priskv_request) + + (sizeof(priskv_keyed_sgl) + priskv_ucx_max_rkey_len) * max_sgl + max_key_length; + + return ALIGN_UP(s, 64); +} + +static inline uint8_t *priskv_ucx_request_key(priskv_request *req) +{ + unsigned char *base = (unsigned char *)req; + return base + priskv_ucx_request_key_off(req); +} + +static inline void priskv_ucx_to_hex(uint8_t *hex, uint8_t *str, uint32_t str_len) +{ + for (size_t j = 0; j < str_len; j++) { + unsigned char b = str[j]; + hex[j * 2] = "0123456789abcdef"[b >> 4]; + hex[j * 2 + 1] = "0123456789abcdef"[b & 0xF]; + } + hex[str_len * 2] = '\0'; } static inline const char *priskv_command_str(priskv_req_command cmd) @@ -114,28 +164,31 @@ static inline const char *priskv_resp_status_str(priskv_resp_status status) return "Unknown"; } -static inline const char *priskv_rdma_cm_status_str(priskv_rdma_cm_status status) +static inline const char *priskv_cm_status_str(priskv_cm_status status) { switch (status) { - case PRISKV_RDMA_CM_REJ_STATUS_INVALID_CM_REP: + case PRISKV_CM_REJ_STATUS_INVALID_CM_REP: return "Invalid CM Reply"; - case PRISKV_RDMA_CM_REJ_STATUS_INVALID_VERSION: + case PRISKV_CM_REJ_STATUS_INVALID_VERSION: return "Invalid version"; - case PRISKV_RDMA_CM_REJ_STATUS_INVALID_SGL: + case PRISKV_CM_REJ_STATUS_INVALID_SGL: return "Invalid SGL"; - case PRISKV_RDMA_CM_REJ_STATUS_INVALID_KEY_LENGTH: + case PRISKV_CM_REJ_STATUS_INVALID_KEY_LENGTH: return "Invalid Key length"; - case PRISKV_RDMA_CM_REJ_STATUS_INVALID_INFLIGHT_COMMAND: + case PRISKV_CM_REJ_STATUS_INVALID_INFLIGHT_COMMAND: return "Invalid inflight command"; - case PRISKV_RDMA_CM_REJ_STATUS_ACL_REFUSE: + case PRISKV_CM_REJ_STATUS_ACL_REFUSE: return "ACL refuse"; - case PRISKV_RDMA_CM_REJ_STATUS_SERVER_ERROR: + case PRISKV_CM_REJ_STATUS_INVALID_WORKER_ADDR: + return "Invalid worker address"; + + case PRISKV_CM_REJ_STATUS_SERVER_ERROR: return "server error"; } diff --git a/include/priskv-protocol.h b/include/priskv-protocol.h index 29b0ec9..af23258 100644 --- a/include/priskv-protocol.h +++ b/include/priskv-protocol.h @@ -42,7 +42,15 @@ extern "C" typedef struct priskv_keyed_sgl { uint64_t addr; uint32_t length; - uint32_t key; + union { + struct { + uint32_t key; + }; // rdma + struct { + uint32_t packed_rkey_len; + uint8_t packed_rkey[0]; + }; // ucx + }; } priskv_keyed_sgl; /* @@ -133,65 +141,79 @@ typedef struct priskv_response { } priskv_response; /* - * currently version 0x01 is supported only. + * currently version 0x02 is supported only. */ -#define PRISKV_RDMA_CM_VERSION 0x01 +#define PRISKV_CM_VERSION 0x02 /* - * rdma connect request + * CM capability * - * @version: must be PRISKV_RDMA_CM_VERSION - * @max_sgl: request max SGLs from client. - * @max_key_length: request max key length in bytes from client. - * @max_inflight_command: request max inflight command(aka command depth) from client. + * @version: must be PRISKV_CM_VERSION + * @max_sgl: max SGLs supported by server/client. + * @max_key_length: max key length in bytes supported by server/client. + * @max_inflight_command: max inflight command(aka command depth) supported by + * server/client. + * @capacity: max capacity in bytes supported by server. * - * @max_sgl, @max_key_length, @max_inflight_command must be less than or equal to the - * limitations from the server side, otherwise the server rejects connection. Or specify 0 to - * use the maximum value from server. + * @max_sgl, @max_key_length, @max_inflight_command of client must be less than + * or equal to the limitations from the server side, otherwise the connection + * will be disconnected. Or specify 0 to use the maximum value from server. */ -typedef struct priskv_rdma_cm_req { - uint16_t version; - uint16_t max_sgl; - uint16_t max_key_length; - uint16_t max_inflight_command; - uint8_t reserved[24]; -} priskv_rdma_cm_req; - -/* - * rdma connect reply - */ -typedef struct priskv_rdma_cm_rep { +typedef struct priskv_cm_cap { uint16_t version; uint16_t max_sgl; uint16_t max_key_length; uint16_t max_inflight_command; uint64_t capacity; uint8_t reserved[16]; -} priskv_rdma_cm_rep; +} priskv_cm_cap; + +typedef struct __attribute__((packed)) priskv_cm_ucx_handshake { + uint8_t flag; // 0: reject, 1: others + union { + struct __attribute__((packed)) { + uint16_t version; + uint16_t status; + uint64_t value; + }; // reject + struct __attribute__((packed)) { + priskv_cm_cap cap; + uint32_t address_len; + uint8_t address[0]; + }; // others + }; +} priskv_cm_ucx_handshake; /* - * status on RDMA CM rejection + * CM status */ -typedef enum priskv_rdma_cm_status { - PRISKV_RDMA_CM_REJ_STATUS_INVALID_CM_REP = 0x01, - PRISKV_RDMA_CM_REJ_STATUS_INVALID_VERSION = 0x02, - PRISKV_RDMA_CM_REJ_STATUS_INVALID_SGL = 0x03, - PRISKV_RDMA_CM_REJ_STATUS_INVALID_KEY_LENGTH = 0x04, - PRISKV_RDMA_CM_REJ_STATUS_INVALID_INFLIGHT_COMMAND = 0x05, - PRISKV_RDMA_CM_REJ_STATUS_ACL_REFUSE = 0x06, - - PRISKV_RDMA_CM_REJ_STATUS_SERVER_ERROR = 0x10 -} priskv_rdma_cm_status; +typedef enum priskv_cm_status { + PRISKV_CM_REJ_STATUS_INVALID_CM_REP = 0x01, + PRISKV_CM_REJ_STATUS_INVALID_VERSION = 0x02, + PRISKV_CM_REJ_STATUS_INVALID_SGL = 0x03, + PRISKV_CM_REJ_STATUS_INVALID_KEY_LENGTH = 0x04, + PRISKV_CM_REJ_STATUS_INVALID_INFLIGHT_COMMAND = 0x05, + PRISKV_CM_REJ_STATUS_ACL_REFUSE = 0x06, + PRISKV_CM_REJ_STATUS_INVALID_WORKER_ADDR = 0x07, + + PRISKV_CM_REJ_STATUS_SERVER_ERROR = 0x10 +} priskv_cm_status; /* - * rdma connect reject + * connect reject */ -typedef struct priskv_rdma_cm_rej { +typedef struct priskv_cm_rej { uint16_t version; - uint16_t status; /* priskv_rdma_cm_status */ + uint16_t status; /* priskv_cm_status */ uint8_t reserved[4]; uint64_t value; /* indicate the supported value */ -} priskv_rdma_cm_rej; +} priskv_cm_rej; + +typedef enum priskv_proto_tag { + PRISKV_PROTO_TAG_HANDSHAKE = 0x01, + PRISKV_PROTO_TAG_CTRL = 0x02, + PRISKV_PROTO_FULL_TAG_MASK = ~0LL, +} priskv_proto_tag; /* *assuming max timeout means no timeout diff --git a/include/priskv-ucx.h b/include/priskv-ucx.h new file mode 100644 index 0000000..222aebb --- /dev/null +++ b/include/priskv-ucx.h @@ -0,0 +1,452 @@ +// Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef __PRISKV_UCX__ +#define __PRISKV_UCX__ + +#if defined(__cplusplus) +extern "C" +{ +#endif + +#include +#include + +#include "priskv-utils.h" +#include "uthash.h" + +// forward declaration +typedef struct priskv_ucx_worker priskv_ucx_worker; +typedef struct priskv_ucx_ep priskv_ucx_ep; +typedef struct priskv_ucx_conn_request priskv_ucx_conn_request; + +typedef struct priskv_ucx_context { + ucp_context_h handle; + uint8_t busy_polling; +} priskv_ucx_context; + +typedef struct priskv_ucx_payload { + void *buffer; + size_t length; + union { + struct { + ucp_tag_t tag; + } tag_send; + struct { + ucp_tag_t tag; + ucp_tag_t mask; + } tag_recv; + struct { + uint64_t raddr; + } rma; + }; +} priskv_ucx_payload; + +typedef struct priskv_ucx_rkey { + priskv_ucx_ep *ep; + ucp_rkey_h handle; +} priskv_ucx_rkey; + +typedef struct priskv_ucx_memh { + priskv_ucx_context *context; + ucp_mem_h handle; + ucs_memory_type_t type; + uint64_t addr; + size_t len; + void *rkey_buffer; + size_t rkey_length; +} priskv_ucx_memh; + +typedef void (*priskv_ucx_request_cb)(ucs_status_t status, void *arg); +typedef void (*priskv_ucx_tag_recv_cb)(ucs_status_t status, ucp_tag_t sender_tag, size_t length, + void *arg); +typedef struct priskv_ucx_request { + const char *name; + void *key; + void *handle; + ucs_status_t status; + priskv_ucx_worker *worker; + priskv_ucx_ep *ep; + priskv_ucx_payload payload; + union { + priskv_ucx_request_cb cb; + priskv_ucx_tag_recv_cb tag_recv_cb; + }; + void *cb_data; + UT_hash_handle hh; +} priskv_ucx_request; + +typedef struct priskv_ucx_worker { + priskv_ucx_context *context; + ucp_worker_h handle; + int efd; + uint64_t client_id; + ucp_address_t *address; + uint32_t address_len; +} priskv_ucx_worker; + +typedef struct priskv_ucx_listener { + priskv_ucx_worker *worker; + ucp_listener_h handle; + priskv_ucx_conn_request *conn_request; + char addr[PRISKV_ADDR_LEN]; +} priskv_ucx_listener; + +typedef void (*priskv_ucx_conn_cb)(priskv_ucx_conn_request *conn_request, void *arg); +typedef struct priskv_ucx_conn_request { + ucp_conn_request_h handle; + ucp_conn_request_attr_t attr; + priskv_ucx_conn_cb cb; + void *cb_data; + char peer_addr[PRISKV_ADDR_LEN]; +} priskv_ucx_conn_request; + +typedef void (*priskv_ucx_ep_close_cb)(ucs_status_t status, void *arg); +typedef struct priskv_ucx_ep { + char name[UCP_ENTITY_NAME_MAX]; + priskv_ucx_worker *worker; + ucp_ep_h handle; + ucs_status_t status; + atomic_uint closing; + priskv_ucx_ep_close_cb close_cb; + void *close_cb_data; + uint8_t paired_worker_ep; + ucp_ep_attr_t attr; +} priskv_ucx_ep; + +/** + * @brief Check if a UCX operation is successful, and handle errors if not. + * + * @param STATUS The status code to check. + * @param MSG The error message to log if the status is not UCS_OK or UCS_INPROGRESS. + * @param CLEANUP The cleanup code to execute if the status is not UCS_OK or UCS_INPROGRESS. + * @param RET The return value to use if the status is not UCS_OK or UCS_INPROGRESS. + */ +#define PRISKV_UCX_RETURN_IF_ERROR(STATUS, MSG, CLEANUP, RET) \ + { \ + switch (STATUS) { \ + case UCS_OK: \ + break; \ + case UCS_INPROGRESS: \ + break; \ + default: \ + const char *status_str = ucs_status_string(STATUS); \ + priskv_log_error(MSG ", status: %s\n", status_str); \ + CLEANUP; \ + return RET; \ + } \ + } + +/** + * @brief Check if CUDA is disabled by UCX_TLS. + * + * @param tls The UCX_TLS environment variable value. + * @return int 1 if CUDA is disabled, 0 otherwise. + */ +int priskv_ucx_tls_has_cuda_disabled(const char *tls); + +/** + * @brief Initialize a UCX context. + * + * @param busy_polling Whether to enable busy polling. + * @return priskv_ucx_context* The initialized UCX context. + */ +priskv_ucx_context *priskv_ucx_context_init(uint8_t busy_polling); + +/** + * @brief Check if a UCX context has CUDA support. + * + * @param context The UCX context. + * @return int 1 if CUDA support is enabled, 0 otherwise. + */ +int priskv_ucx_context_has_cuda_support(priskv_ucx_context *context); + +/** + * @brief Map or allocate memory for zero-copy operations. + * + * @param context The UCX context. + * @param buffer The memory buffer to mmap. + * @param count The size of the memory buffer in bytes. + * @param type The memory type. + * @return priskv_ucx_memh* The created UCX memory handle. + */ +priskv_ucx_memh *priskv_ucx_mmap(priskv_ucx_context *context, void *buffer, size_t count, + ucs_memory_type_t type); + +/** + * @brief Unmap memory segment. + * + * @param memh The UCX memory handle. + * @return ucs_status_t The status of the unmap operation. + */ +ucs_status_t priskv_ucx_munmap(priskv_ucx_memh *memh); + +/** + * @brief Cancel a UCX request. + * + * @param request The UCX request to cancel. + * @return ucs_status_t The status of the request cancellation. + */ +ucs_status_t priskv_ucx_request_cancel(priskv_ucx_request *request); + +/** + * @brief Create a UCX worker. + * + * @param context The UCX context. + * @param client_id The client ID. + * @return priskv_ucx_worker* The created UCX worker. + */ +priskv_ucx_worker *priskv_ucx_worker_create(priskv_ucx_context *context, uint64_t client_id); + +/** + * @brief Post a non-blocking tag receive request. + * + * @param worker The UCX worker. + * @param buffer The buffer to store the received data. + * @param count The number of bytes to receive. + * @param tag The tag to match. + * @param mask The mask to apply to the tag. + * @param cb The callback function to invoke when the operation completes. + * @param cb_data User data to pass to the callback. + * @return ucs_status_ptr_t The status or pointer to the request handle. + */ +ucs_status_ptr_t priskv_ucx_worker_post_tag_recv(priskv_ucx_worker *worker, void *buffer, + size_t count, ucp_tag_t tag, ucp_tag_t mask, + priskv_ucx_tag_recv_cb cb, void *cb_data); + +/** + * @brief Post a non-blocking flush request. + * + * @param worker The UCX worker. + * @param cb The callback function to invoke when the operation completes. + * @param cb_data User data to pass to the callback. + * @return ucs_status_ptr_t The status or pointer to the request handle. + */ +ucs_status_ptr_t priskv_ucx_worker_post_flush(priskv_ucx_worker *worker, priskv_ucx_request_cb cb, + void *cb_data); + +/** + * @brief Progress a UCX worker. + * + * @param worker The UCX worker to progress. + * @return int 0 on success, other values on error. + */ +int priskv_ucx_worker_progress(priskv_ucx_worker *worker); + +/** + * @brief Signal a UCX worker. + * + * @param worker The UCX worker to signal. + */ +void priskv_ucx_worker_signal(priskv_ucx_worker *worker); + +/** + * @brief Destroy a UCX worker. + * + * @param worker The UCX worker to destroy. + * @return ucs_status_t The status of the destroy operation. + */ +ucs_status_t priskv_ucx_worker_destroy(priskv_ucx_worker *worker); + +/** + * @brief Create a UCX listener. + * + * @param worker The UCX worker. + * @param ip_or_hostname The IP address or hostname to listen on. + * @param port The port to listen on. + * @param conn_cb The callback function to invoke when a connection is accepted. + * @param conn_cb_data User data to pass to the callback. + * @return priskv_ucx_listener* The created UCX listener. + */ +priskv_ucx_listener *priskv_ucx_listener_create(priskv_ucx_worker *worker, + const char *ip_or_hostname, uint16_t port, + priskv_ucx_conn_cb conn_cb, void *conn_cb_data); + +/** + * @brief Accept a connection request on a UCX listener. + * + * @param listener The UCX listener. + * @param conn_req The connection request handle. + * @param close_cb The callback function to invoke when the endpoint is closed. + * @param close_cb_data User data to pass to the callback. + * @return priskv_ucx_ep* The created UCX endpoint. + */ +priskv_ucx_ep *priskv_ucx_listener_accept(priskv_ucx_listener *listener, + priskv_ucx_conn_request *conn_req, + priskv_ucx_ep_close_cb close_cb, void *close_cb_data); + +/** + * @brief Reject a connection request on a UCX listener. + * + * @param listener The UCX listener. + * @param conn_req The connection request handle. + * @return ucs_status_t The status of the reject operation. + */ +ucs_status_t priskv_ucx_listener_reject(priskv_ucx_listener *listener, + priskv_ucx_conn_request *conn_req); + +/** + * @brief Destroy a UCX listener. + * + * @param listener The UCX listener to destroy. + * @return ucs_status_t The status of the destroy operation. + */ +ucs_status_t priskv_ucx_listener_destroy(priskv_ucx_listener *listener); + +/** + * @brief Create a UCX endpoint from a worker address. + * + * @param worker The UCX worker. + * @param addr The worker address. + * @param close_cb The callback function to invoke when the endpoint is closed. + * @param close_cb_data User data to pass to the callback. + * @return priskv_ucx_ep* The created UCX endpoint. + */ +priskv_ucx_ep *priskv_ucx_ep_create_from_worker_addr(priskv_ucx_worker *worker, ucp_address_t *addr, + priskv_ucx_ep_close_cb close_cb, + void *close_cb_data); +/** + * @brief Create a UCX endpoint from an IP address or hostname. + * + * @param worker The UCX worker. + * @param ip_or_hostname The IP address or hostname to connect to. + * @param port The port to connect to. + * @param close_cb The callback function to invoke when the endpoint is closed. + * @param close_cb_data User data to pass to the callback. + * @return priskv_ucx_ep* The created UCX endpoint. + */ +priskv_ucx_ep *priskv_ucx_ep_create_from_addr(priskv_ucx_worker *worker, const char *ip_or_hostname, + uint16_t port, priskv_ucx_ep_close_cb close_cb, + void *close_cb_data); + +/** + * @brief Post a non-blocking tag receive request. + * + * @param ep The UCX endpoint. + * @param buffer The buffer to store the received data. + * @param count The number of bytes to receive. + * @param tag The tag to match. + * @param mask The mask to apply to the tag. + * @param cb The callback function to invoke when the operation completes. + * @param cb_data User data to pass to the callback. + * @return ucs_status_ptr_t The status or pointer to the request handle. + */ +ucs_status_ptr_t priskv_ucx_ep_post_tag_recv(priskv_ucx_ep *ep, void *buffer, size_t count, + ucp_tag_t tag, ucp_tag_t mask, + priskv_ucx_tag_recv_cb cb, void *cb_data); + +/** + * @brief Post a non-blocking tag send operation. + * + * @param ep The UCX endpoint. + * @param buffer The buffer to send. + * @param count The number of bytes to send. + * @param tag The tag to send. + * @param cb The callback function to invoke when the operation completes. + * @param cb_data User data to pass to the callback. + * @return ucs_status_ptr_t The status or pointer to the request handle. + */ +ucs_status_ptr_t priskv_ucx_ep_post_tag_send(priskv_ucx_ep *ep, void *buffer, size_t count, + ucp_tag_t tag, priskv_ucx_request_cb cb, + void *cb_data); + +/** + * @brief Post a non-blocking put operation. + * + * @param ep The UCX endpoint. + * @param buffer The buffer to send. + * @param count The number of bytes to send. + * @param rkey The remote key to use. + * @param raddr The remote address to write to. + * @param cb The callback function to invoke when the operation completes. + * @param cb_data User data to pass to the callback. + * @return ucs_status_ptr_t The status or pointer to the request handle. + */ +ucs_status_ptr_t priskv_ucx_ep_post_put(priskv_ucx_ep *ep, void *buffer, size_t count, + priskv_ucx_rkey *rkey, uint64_t raddr, + priskv_ucx_request_cb cb, void *cb_data); + +/** + * @brief Post a non-blocking get operation. + * + * @param ep The UCX endpoint. + * @param buffer The buffer to receive. + * @param count The number of bytes to receive. + * @param rkey The remote key to use. + * @param raddr The remote address to read from. + * @param cb The callback function to invoke when the operation completes. + * @param cb_data User data to pass to the callback. + * @return ucs_status_ptr_t The status or pointer to the request handle. + */ +ucs_status_ptr_t priskv_ucx_ep_post_get(priskv_ucx_ep *ep, void *buffer, size_t count, + priskv_ucx_rkey *rkey, uint64_t raddr, + priskv_ucx_request_cb cb, void *cb_data); + +/** + * @brief Post a non-blocking flush operation. + * + * @param ep The UCX endpoint. + * @param cb The callback function to invoke when the operation completes. + * @param cb_data User data to pass to the callback. + * @return ucs_status_ptr_t The status or pointer to the request handle. + */ +ucs_status_ptr_t priskv_ucx_ep_post_flush(priskv_ucx_ep *ep, priskv_ucx_request_cb cb, + void *cb_data); + +/** + * @brief Get the local address of a UCX endpoint. + * + * @param ep The UCX endpoint. + * @return struct sockaddr_storage* The local address of the endpoint. + */ +struct sockaddr_storage *priskv_ucx_ep_get_local_addr(priskv_ucx_ep *ep); + +/** + * @brief Get the peer address of a UCX endpoint. + * + * @param ep The UCX endpoint. + * @return struct sockaddr_storage* The peer address of the endpoint. + */ +struct sockaddr_storage *priskv_ucx_ep_get_peer_addr(priskv_ucx_ep *ep); + +/** + * @brief Destroy a UCX endpoint. + * + * @param ep The UCX endpoint to destroy. + * @return ucs_status_t The status of the destroy operation. + */ +ucs_status_t priskv_ucx_ep_destroy(priskv_ucx_ep *ep); + +/** + * @brief Create a UCX remote key from a packed remote key. + * + * @param ep The UCX endpoint. + * @param packed_rkey The packed remote key to unpack. + * @return priskv_ucx_rkey* The unpacked remote key. + */ +priskv_ucx_rkey *priskv_ucx_rkey_create(priskv_ucx_ep *ep, const void *packed_rkey); + +/** + * @brief Destroy a UCX remote key. + * + * @param rkey The UCX remote key to destroy. + * @return ucs_status_t The status of the destroy operation. + */ +ucs_status_t priskv_ucx_rkey_destroy(priskv_ucx_rkey *rkey); + +#if defined(__cplusplus) +} +#endif + +#endif /* __PRISKV_UCX__ */ diff --git a/include/priskv-utils.h b/include/priskv-utils.h index 8729d0e..7f75a0c 100644 --- a/include/priskv-utils.h +++ b/include/priskv-utils.h @@ -44,6 +44,8 @@ extern "C" #include #include #include +#include +#include #ifndef offsetof #define offsetof(TYPE, MEMBER) ((size_t)&((TYPE *)0)->MEMBER) @@ -76,6 +78,13 @@ static inline int priskv_set_nonblock(int fd) return fcntl(fd, F_SETFL, flags | O_NONBLOCK); } +static inline int priskv_set_block(int fd) +{ + int flags = fcntl(fd, F_GETFL, 0); + + return fcntl(fd, F_SETFL, flags & ~O_NONBLOCK); +} + static inline int priskv_add_event_fd(int epollfd, int fd) { struct epoll_event event = {0}; @@ -139,6 +148,55 @@ static inline void priskv_inet_ntop(struct sockaddr *addr, char *dst) } } +static inline int priskv_sock_io(int sock, ssize_t (*sock_call)(int, void *, size_t, int), + int poll_events, void *data, size_t size, + void (*progress)(void *arg), void *arg, const char *name) +{ + size_t total = 0; + struct pollfd pfd; + int ret; + + while (total < size) { + pfd.fd = sock; + pfd.events = poll_events; + pfd.revents = 0; + + ret = poll(&pfd, 1, 1); /* poll for 1ms */ + if (ret > 0) { + ret = sock_call(sock, (char *)data + total, size - total, 0); + if ((ret == 0) && (poll_events & POLLIN)) { + return -1; + } + if (ret < 0) { + return -1; + } + total += ret; + } else if ((ret < 0) && (errno != EINTR)) { + return -1; + } + + /* progress user context */ + if (progress != NULL) { + progress(arg); + } + } + return 0; +} + +static inline int priskv_safe_send(int sock, void *data, size_t size, void (*progress)(void *arg), + void *arg) +{ + typedef ssize_t (*sock_call)(int, void *, size_t, int); + + return priskv_sock_io(sock, (sock_call)send, POLLOUT, data, size, progress, arg, "send"); +} + +static inline int priskv_safe_recv(int sock, void *data, size_t size, void (*progress)(void *arg), + void *arg) +{ + return priskv_sock_io(sock, recv, POLLIN, data, size, progress, arg, "recv"); +} + static inline unsigned long priskv_rdtsc(void) { unsigned long low, high; diff --git a/lib/Makefile b/lib/Makefile index 0901d3a..27aa026 100644 --- a/lib/Makefile +++ b/lib/Makefile @@ -1,5 +1,6 @@ PREFIX = /usr -CFLAGS = -fPIC -Wall -g -O0 -I../include -D_GNU_SOURCE -Wshadow -Wformat=2 -Wwrite-strings -fstack-protector-strong -Wnull-dereference -Wunreachable-code +UCX_CFLAGS = $(shell pkg-config --cflags ucx) +CFLAGS = -fPIC -Wall -g -O0 -I../include -D_GNU_SOURCE -Wshadow -Wformat=2 -Wwrite-strings -fstack-protector-strong -Wnull-dereference -Wunreachable-code $(UCX_CFLAGS) CC = gcc AR = ar FMT = clang-format-19 diff --git a/lib/test/Makefile b/lib/test/Makefile index 5b5c0f0..3197bb3 100644 --- a/lib/test/Makefile +++ b/lib/test/Makefile @@ -3,6 +3,9 @@ VERSION = 0.1 TEST_EVENT = test-event TEST_THREADS = test-threads TEST_CODEC = test-codec +TEST_UCX = test-ucx +UCX_CFLAGS = $(shell pkg-config --cflags ucx) +UCX_LIBS = $(shell pkg-config --libs ucx) CFLAGS = -fPIC -Wall -g -O0 -I .. -I ../../include -I ../../thirdparty/json-c/build/include -D_GNU_SOURCE -Wshadow -Wformat=2 -Wwrite-strings -fstack-protector-strong -Wnull-dereference -Wunreachable-code FMT = clang-format-19 @@ -20,7 +23,7 @@ include ../libjsonc.mk .PHONY: all valgrind rebuild clean format -all: $(TEST_EVENT) $(TEST_THREADS) $(TEST_CODEC) +all: $(TEST_EVENT) $(TEST_THREADS) $(TEST_CODEC) $(TEST_UCX) $(TEST_EVENT): $(CC) test_event.c ../event.c $(CFLAGS) -o $(TEST_EVENT) -lpthread @@ -31,6 +34,9 @@ $(TEST_THREADS): $(TEST_CODEC): test_codec.c ../codec.c ../../lib/log.c $(LIBJSONC) $(CC) $^ $(CFLAGS) -o $(TEST_CODEC) +$(TEST_UCX): test_ucx.c ../ucx.c ../../lib/log.c + $(CC) $^ $(CFLAGS) $(UCX_CFLAGS) -o $(TEST_UCX) $(UCX_LIBS) + valgrind: $(TEST_EVENT) $(TEST_THREADS) valgrind -s --track-origins=yes --show-possibly-lost=no --leak-check=full ./$(TEST_EVENT) valgrind -s --track-origins=yes --show-possibly-lost=no --leak-check=full ./$(TEST_THREADS) @@ -40,7 +46,7 @@ rebuild: clean make all clean: - rm -f $(TEST_EVENT) $(TEST_THREADS) $(TEST_CODEC) + rm -f $(TEST_EVENT) $(TEST_THREADS) $(TEST_CODEC) $(TEST_UCX) format: $(FMT) -i *.c diff --git a/lib/test/test_ucx.c b/lib/test/test_ucx.c new file mode 100644 index 0000000..1899219 --- /dev/null +++ b/lib/test/test_ucx.c @@ -0,0 +1,35 @@ +#include +#include + +#include "priskv-ucx.h" + +static int test_priskv_ucx_tls_has_cuda_disabled(const char *tls, int expect_disabled) +{ + int d = priskv_ucx_tls_has_cuda_disabled(tls); + printf("TLS='%s' disabled=%d expect=%d\n", tls ? tls : "(null)", d, expect_disabled); + fflush(stdout); + return d == expect_disabled ? 0 : 1; +} + +int main(void) +{ + int fails = 0; + fails += test_priskv_ucx_tls_has_cuda_disabled("^cuda", 1); + fails += test_priskv_ucx_tls_has_cuda_disabled("^cuda_copy", 1); + fails += test_priskv_ucx_tls_has_cuda_disabled("^cuda_ipc", 0); + fails += test_priskv_ucx_tls_has_cuda_disabled("all", 0); + fails += test_priskv_ucx_tls_has_cuda_disabled("rc,sm,self", 1); + fails += test_priskv_ucx_tls_has_cuda_disabled("rc,cuda_ipc", 0); + fails += test_priskv_ucx_tls_has_cuda_disabled("cuda", 0); + fails += test_priskv_ucx_tls_has_cuda_disabled("cuda_copy", 0); + fails += test_priskv_ucx_tls_has_cuda_disabled(" cuda_copy , rc ", 0); + fails += test_priskv_ucx_tls_has_cuda_disabled(" ^ cuda_copy ", 1); + fails += test_priskv_ucx_tls_has_cuda_disabled(" ^cuda_copy", 1); + fails += test_priskv_ucx_tls_has_cuda_disabled("", 1); + if (fails == 0) { + printf("ok\n"); + return 0; + } + printf("failed=%d\n", fails); + return 1; +} diff --git a/lib/threads.c b/lib/threads.c index 9ec6b5a..d36b065 100644 --- a/lib/threads.c +++ b/lib/threads.c @@ -62,6 +62,7 @@ struct priskv_thread { } __attribute__((aligned(64))); struct priskv_threadpool { + uint32_t thread_flags; priskv_thread *iothreads; int niothread; priskv_thread *bgthreads; @@ -252,6 +253,7 @@ priskv_threadpool *priskv_threadpool_create_with_hooks(const char *prefix, int n pool = calloc(1, sizeof(priskv_threadpool)); assert(pool); + pool->thread_flags = flags; pool->niothread = niothread; pool->iothreads = calloc(niothread, sizeof(priskv_thread)); assert(pool->iothreads); diff --git a/lib/ucx.c b/lib/ucx.c new file mode 100644 index 0000000..4125c17 --- /dev/null +++ b/lib/ucx.c @@ -0,0 +1,1307 @@ +// Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "priskv-log.h" +#include "priskv-ucx.h" + +// forward declaration +static priskv_ucx_ep *priskv_ucx_ep_create(priskv_ucx_worker *worker, ucp_ep_params_t *params, + priskv_ucx_ep_close_cb close_cb, void *close_cb_data, + uint8_t paired_worker_ep); + +static int token_eq(const char *s, size_t len, const char *t) +{ + size_t tl = strlen(t); + if (len != tl) + return 0; + return strncmp(s, t, tl) == 0; +} + +int priskv_ucx_tls_has_cuda_disabled(const char *tls) +{ + if (!tls) + return 1; + const char *q = tls; + while (*q == ' ' || *q == '\t') + q++; + if (*q == '\0') + return 1; + if (*q == '^') { + const char *p = q + 1; + while (*p) { + const char *start = p; + const char *comma = strchr(p, ','); + size_t len = comma ? (size_t)(comma - start) : strlen(start); + while (len > 0 && (start[0] == ' ' || start[0] == '\t')) { + start++; + len--; + } + while (len > 0 && (start[len - 1] == ' ' || start[len - 1] == '\t')) { + len--; + } + if (token_eq(start, len, "cuda") || token_eq(start, len, "cuda_copy")) + return 1; + if (!comma) + break; + p = comma + 1; + } + return 0; + } + const char *end = q + strlen(q); + while (end > q && ((end[-1] == ' ') || (end[-1] == '\t'))) + end--; + if (token_eq(q, (size_t)(end - q), "all")) + return 0; + if (strstr(q, "cuda") != NULL) + return 0; + return 1; +} + +priskv_ucx_context *priskv_ucx_context_init(uint8_t busy_polling) +{ + ucp_config_t *config; + ucs_status_t status = ucp_config_read(NULL, NULL, &config); + PRISKV_UCX_RETURN_IF_ERROR(status, "priskv_ucx_init: failed to read config", {}, NULL); + + // enable reuseaddr + ucp_config_modify(config, "CM_REUSEADDR", "y"); + + if (priskv_get_log_level() >= priskv_log_info) { + ucp_config_print(config, stdout, "UCX Config", UCS_CONFIG_PRINT_CONFIG); + } + + ucp_params_t params; + memset(¶ms, 0, sizeof(params)); + params.field_mask = UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_MT_WORKERS_SHARED; + params.features = UCP_FEATURE_TAG | UCP_FEATURE_RMA; + params.mt_workers_shared = 1; + + if (!busy_polling) { + params.features |= UCP_FEATURE_WAKEUP; + } + + priskv_ucx_context *context = malloc(sizeof(priskv_ucx_context)); + if (ucs_unlikely(!context)) { + priskv_log_error("priskv_ucx_init: failed to malloc context\n"); + return NULL; + } + status = ucp_init(¶ms, config, &context->handle); + ucp_config_release(config); + PRISKV_UCX_RETURN_IF_ERROR( + status, "priskv_ucx_init: failed to init context", { free(context); }, NULL); + context->busy_polling = busy_polling; + if (priskv_get_log_level() >= priskv_log_info) { + ucp_context_print_info(context->handle, stdout); + } + return context; +} + +int priskv_ucx_context_has_cuda_support(priskv_ucx_context *context) +{ + ucp_context_attr_t attr = {.field_mask = UCP_ATTR_FIELD_MEMORY_TYPES}; + ucp_context_query(context->handle, &attr); + uint8_t has_cuda_support = (attr.memory_types & UCS_MEMORY_TYPE_CUDA) == UCS_MEMORY_TYPE_CUDA; + if (has_cuda_support) { + priskv_log_info("priskv_ucx_init: UCX supports CUDA memory type\n"); + const char *tls_env = getenv("UCX_TLS"); + if (priskv_ucx_tls_has_cuda_disabled(tls_env)) { + has_cuda_support = 0; + } + if (has_cuda_support) { + priskv_log_info("priskv_ucx_init: CUDA support enabled by UCX_TLS\n"); + } else { + priskv_log_info("priskv_ucx_init: CUDA support disabled by UCX_TLS\n"); + } + } + + return has_cuda_support; +} + +priskv_ucx_memh *priskv_ucx_mmap(priskv_ucx_context *context, void *buffer, size_t count, + ucs_memory_type_t type) +{ + priskv_ucx_memh *memh = malloc(sizeof(priskv_ucx_memh)); + if (ucs_unlikely(!memh)) { + priskv_log_error("priskv_ucx_mem_register: failed to malloc memh\n"); + return NULL; + } + memh->context = context; + + ucp_mem_map_params_t params = {.field_mask = UCP_MEM_MAP_PARAM_FIELD_LENGTH | + UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE, + .length = count, + .memory_type = type}; + if (buffer == NULL) { + priskv_log_debug("priskv_ucx_mmap: allocate memory of size %zu\n", count); + params.field_mask |= UCP_MEM_MAP_PARAM_FIELD_FLAGS; + params.flags = UCP_MEM_MAP_ALLOCATE; + } else { + priskv_log_debug("priskv_ucx_mmap: map memory %p of size %zu\n", buffer, count); + params.field_mask |= UCP_MEM_MAP_PARAM_FIELD_ADDRESS; + params.address = buffer; + } + + ucs_status_t status = ucp_mem_map(context->handle, ¶ms, &memh->handle); + PRISKV_UCX_RETURN_IF_ERROR(status, "priskv_ucx_mmap: failed to mmap", { free(memh); }, NULL); + + ucp_mem_attr_t attr = {.field_mask = UCP_MEM_ATTR_FIELD_ADDRESS | UCP_MEM_ATTR_FIELD_LENGTH | + UCP_MEM_ATTR_FIELD_MEM_TYPE}; + + status = ucp_mem_query(memh->handle, &attr); + PRISKV_UCX_RETURN_IF_ERROR( + status, "priskv_ucx_mmap: failed to query", + { + ucp_mem_unmap(context->handle, memh->handle); + free(memh); + }, + NULL); + + memh->addr = (uint64_t)attr.address; + memh->len = attr.length; + memh->type = attr.mem_type; + memh->rkey_buffer = NULL; + + priskv_log_debug("priskv_ucx_mmap: memh %p [addr=%lu, len=%zu, type=%s] mapped\n", memh, + memh->addr, memh->len, ucs_memory_type_names[memh->type]); + + status = ucp_rkey_pack(context->handle, memh->handle, &memh->rkey_buffer, &memh->rkey_length); + PRISKV_UCX_RETURN_IF_ERROR( + status, "priskv_ucx_mmap: failed to pack", + { + ucp_mem_unmap(context->handle, memh->handle); + free(memh); + }, + NULL); + + return memh; +} + +ucs_status_t priskv_ucx_munmap(priskv_ucx_memh *memh) +{ + if (ucs_unlikely(memh == NULL)) { + priskv_log_error("priskv_ucx_munmap: memh is NULL\n"); + return UCS_ERR_INVALID_PARAM; + } + + if (ucs_unlikely(memh->handle == NULL)) { + priskv_log_error("priskv_ucx_munmap: handle is NULL\n"); + return UCS_ERR_INVALID_PARAM; + } + + if (memh->rkey_buffer) { + ucp_memh_buffer_release_params_t params = {.field_mask = 0}; + ucp_memh_buffer_release(memh->rkey_buffer, ¶ms); + memh->rkey_buffer = NULL; + } + ucp_mem_unmap(memh->context->handle, memh->handle); + memh->handle = NULL; + free(memh); + return UCS_OK; +} + +priskv_ucx_worker *priskv_ucx_worker_create(priskv_ucx_context *context, uint64_t client_id) +{ + if (ucs_unlikely(context == NULL)) { + priskv_log_error("priskv_ucx_worker_init: context is NULL\n"); + return NULL; + } + + priskv_ucx_worker *worker = malloc(sizeof(priskv_ucx_worker)); + if (ucs_unlikely(worker == NULL)) { + priskv_log_error("priskv_ucx_worker_init: failed to malloc worker\n"); + return NULL; + } + + worker->context = context; + worker->address = NULL; + + ucp_worker_params_t params = {.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE | + UCP_WORKER_PARAM_FIELD_CLIENT_ID, + .thread_mode = UCS_THREAD_MODE_SINGLE, + .client_id = client_id}; + + ucs_status_t status = ucp_worker_create(context->handle, ¶ms, &worker->handle); + PRISKV_UCX_RETURN_IF_ERROR( + status, "priskv_ucx_worker_init: failed to init worker", { free(worker); }, NULL); + if (!context->busy_polling) { + status = ucp_worker_get_efd(worker->handle, &worker->efd); + PRISKV_UCX_RETURN_IF_ERROR( + status, "priskv_ucx_worker_init: failed to get efd", { free(worker); }, NULL); + if (priskv_set_nonblock(worker->efd)) { + priskv_log_error("priskv_ucx_worker_init: failed to set nonblock\n"); + free(worker); + return NULL; + } + status = ucp_worker_arm(worker->handle); + PRISKV_UCX_RETURN_IF_ERROR( + status, "priskv_ucx_worker_init: failed to arm worker", { free(worker); }, NULL); + } else { + worker->efd = -1; + } + + status = ucp_worker_get_address(worker->handle, &worker->address, &worker->address_len); + PRISKV_UCX_RETURN_IF_ERROR( + status, "priskv_ucx_worker_init: failed to get address", { free(worker); }, NULL); + + if (priskv_get_log_level() >= priskv_log_info) { + ucp_worker_print_info(worker->handle, stdout); + } + return worker; +} + +static inline void priskv_ucx_request_init(priskv_ucx_request *request) +{ + request->status = UCS_INPROGRESS; + request->worker = NULL; + request->ep = NULL; + request->key = NULL; +} + +static void priskv_ucx_request_complete(priskv_ucx_request *request, ucs_status_t status, + const ucp_tag_recv_info_t *info) +{ + if (ucs_unlikely(request->handle == NULL)) { + goto cb; + } + + if (ucs_unlikely(request->status != UCS_INPROGRESS)) { + priskv_log_warn("priskv_ucx_request_complete: request %p completed with status %s, " + "but status is already set to %s\n", + request, ucs_status_string(status), ucs_status_string(request->status)); + } + + request->status = status; + + if (UCS_PTR_IS_PTR(request->handle)) { + ucp_request_free(request->handle); + request->handle = NULL; + } + +cb: + if (request->cb != NULL) { + priskv_log_debug("priskv_ucx_request_complete: call cb %p\n", request->cb); + if (info == NULL) { + request->cb(status, request->cb_data); + } else { + request->tag_recv_cb(status, info->sender_tag, info->length, request->cb_data); + } + request->cb = NULL; + } +} + +static ucs_status_ptr_t priskv_ucx_request_progress(priskv_ucx_request *request, + ucp_request_param_t *params) +{ + ucs_status_t status = UCS_INPROGRESS; + ucs_status_ptr_t handle = request->handle; + + if (UCS_PTR_IS_ERR(handle)) { + status = UCS_PTR_STATUS(handle); + } else if (UCS_PTR_IS_PTR(handle)) { + // still in progress + status = UCS_INPROGRESS; + } else { + // Operation completed immediately + status = UCS_OK; + } + + if (status != UCS_INPROGRESS) { + ucp_tag_recv_info_t *info = NULL; + if (params->recv_info.tag_info) { + info = params->recv_info.tag_info; + } + priskv_ucx_request_complete(request, status, info); + + request->handle = NULL; + free(request); + return UCS_STATUS_PTR(status); + } + + return request; +} + +ucs_status_t priskv_ucx_request_cancel(priskv_ucx_request *request) +{ + if (ucs_unlikely(request == NULL || request->handle == NULL)) { + priskv_log_error("priskv_ucx_request_cancel: failed to cancel request\n"); + if (request) { + free(request); + } + return UCS_ERR_INVALID_PARAM; + } + + if (request->status == UCS_INPROGRESS) { + if (UCS_PTR_IS_ERR(request->handle)) { + ucs_status_t status = UCS_PTR_STATUS(request->handle); + priskv_log_debug( + "priskv_ucx_request_cancel: unprocessed request %p failed with status %s, " + "no need to cancel\n", + request, ucs_status_string(status)); + } else if (request->handle) { + priskv_log_debug("priskv_ucx_request_cancel: unprocessed request %p with handle %p, " + "canceling...\n", + request, request->handle); + ucp_request_cancel(request->worker->handle, request->handle); + return UCS_OK; + } + } else { + priskv_log_debug("priskv_ucx_request_cancel: request %p already completed with status %s, " + "no need to cancel\n", + request, ucs_status_string(request->status)); + } + request->handle = NULL; + + free(request); + return UCS_OK; +} + +static inline ucs_status_t priskv_ucx_request_common_cb_intl(priskv_ucx_request *req, void *request, + ucs_status_t status, + const ucp_tag_recv_info_t *info) +{ + priskv_ucx_request_complete(req, status, info); + free(req); + return UCS_OK; +} + +static void priskv_ucx_request_tag_recv_cb_intl(void *request, ucs_status_t status, + const ucp_tag_recv_info_t *info, void *arg) +{ + priskv_ucx_request *req = (priskv_ucx_request *)arg; + if (ucs_unlikely(request == NULL)) { + priskv_log_error("priskv_ucx_request_tag_recv_cb_intl: failed to get request\n"); + return; + } + + if (req->ep) { + priskv_log_debug("priskv_ucx_request_tag_recv_cb_intl: ep %p request %p [buf=%p, len=%zu, " + "tag=%lu, mask=%lu] completed " + "with status %s\n", + req->ep, request, req->payload.buffer, req->payload.length, + req->payload.tag_recv.tag, req->payload.tag_recv.mask, + ucs_status_string(status)); + } else { + priskv_log_debug( + "priskv_ucx_request_tag_recv_cb_intl: worker %p request %p [buf=%p, len=%zu, " + "tag=%lu, mask=%lu] completed " + "with status %s\n", + req->worker, request, req->payload.buffer, req->payload.length, + req->payload.tag_recv.tag, req->payload.tag_recv.mask, ucs_status_string(status)); + } + + priskv_ucx_request_common_cb_intl(req, request, status, info); +} + +static void priskv_ucx_request_tag_send_cb_intl(void *request, ucs_status_t status, void *arg) +{ + priskv_ucx_request *req = (priskv_ucx_request *)arg; + if (ucs_unlikely(request == NULL)) { + priskv_log_error("priskv_ucx_request_tag_send_cb_intl: failed to get request\n"); + return; + } + + priskv_log_debug("priskv_ucx_request_tag_send_cb_intl: ep %p request %p [buf=%p, len=%zu, " + "tag=%lu] completed " + "with status %s\n", + req->ep, request, req->payload.buffer, req->payload.length, + req->payload.tag_send.tag, ucs_status_string(status)); + + priskv_ucx_request_common_cb_intl(req, request, status, NULL); +} + +static void priskv_ucx_request_put_cb_intl(void *request, ucs_status_t status, void *arg) +{ + priskv_ucx_request *req = (priskv_ucx_request *)arg; + if (ucs_unlikely(request == NULL)) { + priskv_log_error("priskv_ucx_request_put_cb_intl: failed to get request\n"); + return; + } + + priskv_log_debug( + "priskv_ucx_request_put_cb_intl: ep %p request %p [buf=%p, len=%zu, raddr=%lu] completed " + "with status %s\n", + req->ep, request, req->payload.buffer, req->payload.length, req->payload.rma.raddr, + ucs_status_string(status)); + + priskv_ucx_request_common_cb_intl(req, request, status, NULL); +} + +static void priskv_ucx_request_get_cb_intl(void *request, ucs_status_t status, void *arg) +{ + priskv_ucx_request *req = (priskv_ucx_request *)arg; + if (ucs_unlikely(request == NULL)) { + priskv_log_error("priskv_ucx_request_get_cb_intl: failed to get request\n"); + return; + } + + priskv_log_debug( + "priskv_ucx_request_get_cb_intl: ep %p request %p [buf=%p, len=%zu, raddr=%lu] completed " + "with status %s\n", + req->ep, request, req->payload.buffer, req->payload.length, req->payload.rma.raddr, + ucs_status_string(status)); + + priskv_ucx_request_common_cb_intl(req, request, status, NULL); +} + +static void priskv_ucx_request_flush_cb_intl(void *request, ucs_status_t status, void *arg) +{ + priskv_ucx_request *req = (priskv_ucx_request *)arg; + if (ucs_unlikely(request == NULL)) { + priskv_log_error("priskv_ucx_request_flush_cb_intl: failed to get request\n"); + return; + } + + if (req->ep) { + priskv_log_debug( + "priskv_ucx_request_flush_cb_intl: ep %p request %p completed with status %s\n", + req->ep, request, ucs_status_string(status)); + } else { + priskv_log_debug( + "priskv_ucx_request_flush_cb_intl: worker %p request %p completed with status %s\n", + req->worker, request, ucs_status_string(status)); + } + + priskv_ucx_request_common_cb_intl(req, request, status, NULL); +} + +static ucs_status_ptr_t priskv_ucx_post_tag_recv(priskv_ucx_worker *worker, priskv_ucx_ep *ep, + void *buffer, size_t count, ucp_tag_t tag, + ucp_tag_t mask, priskv_ucx_tag_recv_cb cb, + void *cb_data) +{ + priskv_ucx_request *request = malloc(sizeof(priskv_ucx_request)); + ucs_status_t status; + if (ucs_unlikely(request == NULL)) { + priskv_log_error("priskv_ucx_post_tag_recv: failed to malloc request\n"); + status = UCS_ERR_NO_MEMORY; + goto callback; + } + + priskv_ucx_request_init(request); + + ucp_request_param_t param = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_DATATYPE | + UCP_OP_ATTR_FIELD_USER_DATA, + .datatype = ucp_dt_make_contig(1), + .cb.recv = priskv_ucx_request_tag_recv_cb_intl, + .user_data = request}; + request->handle = ucp_tag_recv_nbx(worker->handle, buffer, count, tag, mask, ¶m); + request->name = "tag_recv"; + request->status = UCS_INPROGRESS; + request->worker = worker; + request->ep = ep; + request->payload.buffer = buffer; + request->payload.length = count; + request->payload.tag_recv.tag = tag; + request->payload.tag_recv.mask = mask; + request->tag_recv_cb = cb; + request->cb_data = cb_data; + + if (ep) { + priskv_log_debug("priskv_ucx_post_tag_recv: ep %p request %p [buf=%p, len=%zu, tag=%lu, " + "mask=%lu] posted\n", + ep, request, buffer, count, tag, mask); + } else { + priskv_log_debug("priskv_ucx_post_tag_recv: worker %p request %p [buf=%p, len=%zu, " + "tag=%lu, mask=%lu] posted\n", + worker, request, buffer, count, tag, mask); + } + + return priskv_ucx_request_progress(request, ¶m); + +callback: + if (cb) { + cb(status, tag & mask, 0, cb_data); + } + return NULL; +} + +ucs_status_ptr_t priskv_ucx_worker_post_tag_recv(priskv_ucx_worker *worker, void *buffer, + size_t count, ucp_tag_t tag, ucp_tag_t mask, + priskv_ucx_tag_recv_cb cb, void *cb_data) +{ + return priskv_ucx_post_tag_recv(worker, NULL, buffer, count, tag, mask, cb, cb_data); +} + +ucs_status_ptr_t priskv_ucx_worker_post_flush(priskv_ucx_worker *worker, priskv_ucx_request_cb cb, + void *cb_data) +{ + ucs_status_t status; + if (ucs_unlikely(worker == NULL)) { + priskv_log_error("priskv_ucx_worker_post_flush: worker is NULL\n"); + status = UCS_ERR_INVALID_PARAM; + goto callback; + } + + priskv_ucx_request *request = malloc(sizeof(priskv_ucx_request)); + if (ucs_unlikely(request == NULL)) { + priskv_log_error("priskv_ucx_worker_post_flush: failed to malloc request\n"); + status = UCS_ERR_NO_MEMORY; + goto callback; + } + + priskv_ucx_request_init(request); + + ucp_request_param_t param = {.op_attr_mask = + UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA, + .cb.send = priskv_ucx_request_flush_cb_intl, + .user_data = request}; + request->handle = ucp_worker_flush_nbx(worker->handle, ¶m); + request->name = "flush"; + request->status = UCS_INPROGRESS; + request->worker = worker; + request->ep = NULL; + request->cb = cb; + request->cb_data = cb_data; + + priskv_log_debug("priskv_ucx_worker_post_flush: worker %p request %p posted\n", worker, + request); + + return priskv_ucx_request_progress(request, ¶m); +callback: + if (cb) { + cb(status, cb_data); + } + return NULL; +} + +static void priskv_ucx_worker_drain_cb_intl(void *request, ucs_status_t status, + const ucp_tag_recv_info_t *info, void *arg) +{ + // do nothing +} + +static void priskv_ucx_worker_drain_tag_recv(priskv_ucx_worker *worker) +{ + if (ucs_unlikely(worker == NULL || worker->handle == NULL)) { + return; + } + + ucs_status_ptr_t status; + ucp_tag_message_h message; + ucp_tag_recv_info_t info; + ucp_worker_h handle = worker->handle; + ucp_request_param_t param = {.op_attr_mask = + UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_DATATYPE, + .cb = {.recv = priskv_ucx_worker_drain_cb_intl}, + .datatype = ucp_dt_make_contig(1)}; + + while ((message = ucp_tag_probe_nb(handle, 0, 0, 1, &info)) != NULL) { + priskv_log_debug("draining tag receive message, tag: 0x%lx, length: %lu\n", info.sender_tag, + info.length); + + char *buf = malloc(info.length); + if (ucs_unlikely(buf == NULL)) { + priskv_log_error("priskv_ucx_worker_drain_tag_recv: failed to malloc %lu bytes\n", + info.length); + return; + } + + status = ucp_tag_msg_recv_nbx(handle, buf, info.length, message, ¶m); + + if (status != NULL) { + while (UCS_PTR_STATUS(status) == UCS_INPROGRESS) { + priskv_ucx_worker_progress(worker); + } + } + free(buf); + } +} + +int priskv_ucx_worker_progress(priskv_ucx_worker *worker) +{ + int ret = 0, in_progress = 0; + ucs_status_t status; + for (;;) { + if ((in_progress = ucp_worker_progress(worker->handle)) != 0) { + ret += in_progress; + continue; // some progress happened but condition not met + } + + // arm the worker and clean-up fd + status = ucp_worker_arm(worker->handle); + if (UCS_OK == status) { + return ret; + } else if (UCS_ERR_BUSY == status) { + continue; // could not arm, need to progress more + } else { + return ret; + } + } + return ret; +} + +void priskv_ucx_worker_signal(priskv_ucx_worker *worker) +{ + ucp_worker_signal(worker->handle); +} + +ucs_status_t priskv_ucx_worker_destroy(priskv_ucx_worker *worker) +{ + priskv_ucx_worker_drain_tag_recv(worker); + if (worker->address) { + ucp_worker_release_address(worker->handle, worker->address); + worker->address = NULL; + } + ucp_worker_destroy(worker->handle); + priskv_log_debug("priskv_ucx_worker_destroy: worker %p destroyed\n", worker); + free(worker); + return UCS_OK; +} + +static inline struct addrinfo *priskv_ucx_get_addrinfo(const char *ip_or_hostname, uint16_t port) +{ + char port_str[6]; + struct addrinfo *result = NULL; + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_flags = AI_NUMERICSERV | AI_PASSIVE; + + int port_str_len = snprintf(port_str, sizeof(port_str), "%u", port); + if (port_str_len < 0 || port_str_len > sizeof(port_str)) { + priskv_log_error("priskv_ucx_get_addrinfo: invalid port %u\n", port); + return NULL; + } + if (getaddrinfo(ip_or_hostname, port_str, &hints, &result)) { + priskv_log_error("priskv_ucx_get_addrinfo: invalid IP or hostname\n"); + return NULL; + } + + return result; +} + +static void priskv_ucx_conn_cb_intl(ucp_conn_request_h conn_request, void *arg) +{ + priskv_ucx_conn_request *conn_req = arg; + conn_req->handle = conn_request; + conn_req->attr.field_mask = + UCP_CONN_REQUEST_ATTR_FIELD_CLIENT_ADDR | UCP_CONN_REQUEST_ATTR_FIELD_CLIENT_ID; + ucs_status_t status = ucp_conn_request_query(conn_request, &conn_req->attr); + if (ucs_unlikely(status != UCS_OK)) { + priskv_log_error("priskv_ucx_conn_cb_intl: failed to query conn_request %p, status %s\n", + conn_request, ucs_status_string(status)); + } else { + priskv_inet_ntop((struct sockaddr *)&conn_req->attr.client_address, conn_req->peer_addr); + } + + if (conn_req->cb != NULL) { + priskv_log_debug( + "priskv_ucx_conn_cb_intl: conn_request id %d peer_addr %s cb %p cb_data %p\n", + conn_req->attr.client_id, conn_req->peer_addr, conn_req->cb, conn_req->cb_data); + conn_req->cb(conn_req, conn_req->cb_data); + } +} + +priskv_ucx_listener *priskv_ucx_listener_create(priskv_ucx_worker *worker, + const char *ip_or_hostname, uint16_t port, + priskv_ucx_conn_cb conn_cb, void *conn_cb_data) +{ + if (ucs_unlikely(worker == NULL)) { + priskv_log_error("priskv_ucx_listener_create: worker is NULL\n"); + return NULL; + } + + struct addrinfo *info = priskv_ucx_get_addrinfo(ip_or_hostname, port); + if (ucs_unlikely(info == NULL)) { + priskv_log_error("priskv_ucx_listener_create: failed to get addrinfo\n"); + return NULL; + } + + priskv_ucx_conn_request *conn_req = malloc(sizeof(priskv_ucx_conn_request)); + if (ucs_unlikely(conn_req == NULL)) { + priskv_log_error("priskv_ucx_listener_create: failed to malloc conn_req\n"); + freeaddrinfo(info); + return NULL; + } + + conn_req->handle = NULL; + conn_req->cb = conn_cb; + conn_req->cb_data = conn_cb_data; + + ucp_listener_params_t params = { + .field_mask = UCP_LISTENER_PARAM_FIELD_SOCK_ADDR | UCP_LISTENER_PARAM_FIELD_CONN_HANDLER, + .sockaddr = {.addr = info->ai_addr, .addrlen = info->ai_addrlen}, + .conn_handler = {.cb = priskv_ucx_conn_cb_intl, .arg = conn_req}}; + + freeaddrinfo(info); + + priskv_ucx_listener *listener = malloc(sizeof(priskv_ucx_listener)); + if (ucs_unlikely(listener == NULL)) { + priskv_log_error("priskv_ucx_listener_create: failed to malloc listener\n"); + free(conn_req); + return NULL; + } + + listener->conn_request = conn_req; + ucs_status_t status = ucp_listener_create(worker->handle, ¶ms, &listener->handle); + PRISKV_UCX_RETURN_IF_ERROR( + status, "priskv_ucx_listener_create: failed to create listener", + { + free(conn_req); + free(listener); + }, + NULL); + + ucp_listener_attr_t attr = {.field_mask = UCP_LISTENER_ATTR_FIELD_SOCKADDR}; + status = ucp_listener_query(listener->handle, &attr); + PRISKV_UCX_RETURN_IF_ERROR( + status, "priskv_ucx_listener_create: failed to query listener", + { + ucp_listener_destroy(listener->handle); + free(conn_req); + free(listener); + }, + NULL); + + priskv_inet_ntop((struct sockaddr *)&attr.sockaddr, listener->addr); + + listener->worker = worker; + + priskv_log_debug("priskv_ucx_listener_create: listener %s created\n", listener->addr); + + return listener; +} + +priskv_ucx_ep *priskv_ucx_listener_accept(priskv_ucx_listener *listener, + priskv_ucx_conn_request *conn_req, + priskv_ucx_ep_close_cb close_cb, void *close_cb_data) +{ + if (ucs_unlikely(listener == NULL)) { + priskv_log_error("priskv_ucx_listener_accept: listener is NULL\n"); + return NULL; + } + + if (ucs_unlikely(conn_req == NULL)) { + priskv_log_error("priskv_ucx_listener_accept: conn_req is NULL\n"); + return NULL; + } + + if (ucs_unlikely(conn_req->handle == NULL)) { + priskv_log_error("priskv_ucx_listener_accept: conn_req handle is NULL\n"); + return NULL; + } + + ucp_ep_params_t params = {.field_mask = + UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_CONN_REQUEST, + .flags = UCP_EP_PARAMS_FLAGS_NO_LOOPBACK, + .conn_request = conn_req->handle}; + + priskv_ucx_worker *worker = priskv_ucx_worker_create(listener->worker->context, 0); + if (ucs_unlikely(worker == NULL)) { + priskv_log_error("priskv_ucx_listener_accept: failed to create worker\n"); + return NULL; + } + return priskv_ucx_ep_create(worker, ¶ms, close_cb, close_cb_data, 1); +} + +ucs_status_t priskv_ucx_listener_reject(priskv_ucx_listener *listener, + priskv_ucx_conn_request *conn_req) +{ + if (ucs_unlikely(listener == NULL)) { + priskv_log_error("priskv_ucx_listener_reject: listener is NULL\n"); + return UCS_ERR_INVALID_PARAM; + } + + if (ucs_unlikely(conn_req == NULL)) { + priskv_log_error("priskv_ucx_listener_reject: conn_req is NULL\n"); + return UCS_ERR_INVALID_PARAM; + } + + if (ucs_unlikely(conn_req->handle == NULL)) { + priskv_log_error("priskv_ucx_listener_reject: conn_req handle is NULL\n"); + return UCS_ERR_INVALID_PARAM; + } + + return ucp_listener_reject(listener->handle, conn_req->handle); +} + +ucs_status_t priskv_ucx_listener_destroy(priskv_ucx_listener *listener) +{ + if (ucs_unlikely(listener == NULL)) { + priskv_log_error("priskv_ucx_listener_destroy: listener is NULL\n"); + return UCS_ERR_INVALID_PARAM; + } + + ucp_listener_destroy(listener->handle); + + priskv_ucx_worker_progress(listener->worker); + + priskv_log_debug("priskv_ucx_listener_destroy: listener %s destroyed\n", listener->addr); + + if (listener->conn_request) { + free(listener->conn_request); + listener->conn_request = NULL; + } + + free(listener); + + return UCS_OK; +} + +static void priskv_ucx_ep_error_cb_intl(void *arg, ucp_ep_h handle, ucs_status_t status) +{ + priskv_ucx_ep *ep = (priskv_ucx_ep *)arg; + if (ucs_unlikely(ep == NULL)) { + priskv_log_error("priskv_ucx_ep_error_cb_intl: ep is NULL\n"); + return; + } + + if (atomic_exchange(&ep->closing, 1) == 1) { + // ep is closing, ignore the error + return; + } + + ep->status = status; + + if (ep->close_cb != NULL) { + priskv_log_debug("priskv_ucx_ep_error_cb_intl: call close_cb %p\n", ep->close_cb); + ep->close_cb(status, ep->close_cb_data); + ep->close_cb = NULL; + } + + // Connection reset and timeout often represent just a normal remote + // endpoint disconnect, log only in debug mode. + if (status == UCS_ERR_CONNECTION_RESET || status == UCS_ERR_ENDPOINT_TIMEOUT) { + priskv_log_debug("priskv_ucx_ep_error_cb_intl: ep %p error %s\n", ep, + ucs_status_string(status)); + } else { + priskv_log_error("priskv_ucx_ep_error_cb_intl: ep %p error %s\n", ep, + ucs_status_string(status)); + } +} + +static priskv_ucx_ep *priskv_ucx_ep_create(priskv_ucx_worker *worker, ucp_ep_params_t *params, + priskv_ucx_ep_close_cb close_cb, void *close_cb_data, + uint8_t paired_worker_ep) +{ + priskv_ucx_ep *ep = malloc(sizeof(priskv_ucx_ep)); + if (ucs_unlikely(ep == NULL)) { + priskv_log_error("priskv_ucx_ep_create: failed to malloc ep\n"); + return NULL; + } + + params->field_mask |= UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER; + // sm does not support peer error handling mode + // params->err_mode = UCP_ERR_HANDLING_MODE_PEER; + params->err_mode = UCP_ERR_HANDLING_MODE_NONE; + params->err_handler.cb = priskv_ucx_ep_error_cb_intl; + params->err_handler.arg = ep; + ucs_status_t status = ucp_ep_create(worker->handle, params, &ep->handle); + PRISKV_UCX_RETURN_IF_ERROR( + status, "priskv_ucx_ep_create: failed to create ep", { free(ep); }, NULL); + atomic_init(&ep->closing, 0); + ep->worker = worker; + ep->close_cb = close_cb; + ep->close_cb_data = close_cb_data; + ep->paired_worker_ep = paired_worker_ep; + ep->status = UCS_INPROGRESS; + + ucp_ep_attr_t attr = {.field_mask = UCP_EP_ATTR_FIELD_NAME}; + status = ucp_ep_query(ep->handle, &attr); + if (ucs_unlikely(status != UCS_OK)) { + priskv_log_error("priskv_ucx_ep_create: failed to query ep name, status: %s\n", + ucs_status_string(status)); + strcpy(ep->name, "nil"); + } else { + strcpy(ep->name, attr.name); + } + + return ep; +} + +priskv_ucx_ep *priskv_ucx_ep_create_from_worker_addr(priskv_ucx_worker *worker, ucp_address_t *addr, + priskv_ucx_ep_close_cb close_cb, + void *close_cb_data) +{ + if (ucs_unlikely(worker == NULL)) { + priskv_log_error("priskv_ucx_ep_create_from_worker_addr: " + "worker is NULL\n"); + return NULL; + } + + if (ucs_unlikely(addr == NULL)) { + priskv_log_error("priskv_ucx_ep_create_from_worker_addr: addr is NULL\n"); + return NULL; + } + + ucp_ep_params_t params = { + .field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_REMOTE_ADDRESS, + .flags = UCP_EP_PARAMS_FLAGS_SEND_CLIENT_ID, // send worker's client id + .address = addr, + }; + + return priskv_ucx_ep_create(worker, ¶ms, close_cb, close_cb_data, 0); +} + +priskv_ucx_ep *priskv_ucx_ep_create_from_addr(priskv_ucx_worker *worker, const char *ip_or_hostname, + uint16_t port, priskv_ucx_ep_close_cb close_cb, + void *close_cb_data) +{ + if (ucs_unlikely(worker == NULL)) { + priskv_log_error("priskv_ucx_ep_create_from_hostname: worker is NULL\n"); + return NULL; + } + + struct addrinfo *info = priskv_ucx_get_addrinfo(ip_or_hostname, port); + if (ucs_unlikely(info == NULL)) { + priskv_log_error("priskv_ucx_ep_create_from_hostname: failed to get addrinfo\n"); + return NULL; + } + + ucp_ep_params_t params = {.field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_SOCK_ADDR, + .flags = + UCP_EP_PARAMS_FLAGS_CLIENT_SERVER | + UCP_EP_PARAMS_FLAGS_SEND_CLIENT_ID, // send worker's client id + .sockaddr = {.addrlen = info->ai_addrlen, .addr = info->ai_addr}}; + + priskv_ucx_ep *ep = priskv_ucx_ep_create(worker, ¶ms, close_cb, close_cb_data, 0); + freeaddrinfo(info); + return ep; +} + +ucs_status_ptr_t priskv_ucx_ep_post_tag_recv(priskv_ucx_ep *ep, void *buffer, size_t count, + ucp_tag_t tag, ucp_tag_t mask, + priskv_ucx_tag_recv_cb cb, void *cb_data) +{ + return priskv_ucx_post_tag_recv(ep->worker, ep, buffer, count, tag, mask, cb, cb_data); +} + +ucs_status_ptr_t priskv_ucx_ep_post_tag_send(priskv_ucx_ep *ep, void *buffer, size_t count, + ucp_tag_t tag, priskv_ucx_request_cb cb, void *cb_data) +{ + ucs_status_t status; + if (ucs_unlikely(ep == NULL)) { + priskv_log_error("priskv_ucx_ep_post_tag_send: ep is NULL\n"); + status = UCS_ERR_INVALID_PARAM; + goto callback; + } + + if (ucs_unlikely(ep->status != UCS_INPROGRESS)) { + priskv_log_error("priskv_ucx_ep_post_tag_send: ep is closed\n"); + status = UCS_ERR_INVALID_PARAM; + goto callback; + } + + priskv_ucx_request *request = malloc(sizeof(priskv_ucx_request)); + if (ucs_unlikely(request == NULL)) { + priskv_log_error("priskv_ucx_ep_post_tag_send: failed to malloc request\n"); + status = UCS_ERR_NO_MEMORY; + goto callback; + } + + priskv_ucx_request_init(request); + + ucp_request_param_t param = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_DATATYPE | + UCP_OP_ATTR_FIELD_USER_DATA, + .datatype = ucp_dt_make_contig(1), + .cb.send = priskv_ucx_request_tag_send_cb_intl, + .user_data = request}; + request->handle = ucp_tag_send_nbx(ep->handle, buffer, count, tag, ¶m); + request->name = "tag_send"; + request->status = UCS_INPROGRESS; + request->worker = ep->worker; + request->ep = ep; + request->payload.buffer = buffer; + request->payload.length = count; + request->payload.tag_send.tag = tag; + request->cb = cb; + request->cb_data = cb_data; + + priskv_log_debug( + "priskv_ucx_ep_post_tag_send: ep %p request %p [buf=%p, len=%zu, tag=%lu] posted\n", ep, + request, buffer, count, tag); + + return priskv_ucx_request_progress(request, ¶m); +callback: + if (cb) { + cb(status, cb_data); + } + return NULL; +} + +ucs_status_ptr_t priskv_ucx_ep_post_put(priskv_ucx_ep *ep, void *buffer, size_t count, + priskv_ucx_rkey *rkey, uint64_t raddr, + priskv_ucx_request_cb cb, void *cb_data) +{ + ucs_status_t status; + if (ucs_unlikely(ep == NULL)) { + priskv_log_error("priskv_ucx_ep_post_put: ep is NULL\n"); + status = UCS_ERR_INVALID_PARAM; + goto callback; + } + + if (ucs_unlikely(ep->status != UCS_INPROGRESS)) { + priskv_log_error("priskv_ucx_ep_post_put: ep is closed\n"); + status = UCS_ERR_INVALID_PARAM; + goto callback; + } + + priskv_ucx_request *request = malloc(sizeof(priskv_ucx_request)); + if (ucs_unlikely(request == NULL)) { + priskv_log_error("priskv_ucx_ep_post_put: failed to malloc request\n"); + status = UCS_ERR_NO_MEMORY; + goto callback; + } + + priskv_ucx_request_init(request); + + ucp_request_param_t param = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_DATATYPE | + UCP_OP_ATTR_FIELD_USER_DATA, + .datatype = ucp_dt_make_contig(1), + .cb.send = priskv_ucx_request_put_cb_intl, + .user_data = request}; + request->handle = ucp_put_nbx(ep->handle, buffer, count, raddr, rkey->handle, ¶m); + request->name = "put"; + request->status = UCS_INPROGRESS; + request->worker = ep->worker; + request->ep = ep; + request->payload.buffer = buffer; + request->payload.length = count; + request->payload.rma.raddr = raddr; + request->cb = cb; + request->cb_data = cb_data; + + priskv_log_debug( + "priskv_ucx_ep_post_put: ep %p request %p [buf=%p, len=%zu, raddr=%lu] posted\n", ep, + request, buffer, count, raddr); + + return priskv_ucx_request_progress(request, ¶m); +callback: + if (cb) { + cb(status, cb_data); + } + return NULL; +} + +ucs_status_ptr_t priskv_ucx_ep_post_get(priskv_ucx_ep *ep, void *buffer, size_t count, + priskv_ucx_rkey *rkey, uint64_t raddr, + priskv_ucx_request_cb cb, void *cb_data) +{ + ucs_status_t status; + if (ucs_unlikely(ep == NULL)) { + priskv_log_error("priskv_ucx_ep_post_get: ep is NULL\n"); + status = UCS_ERR_INVALID_PARAM; + goto callback; + } + + if (ucs_unlikely(ep->status != UCS_INPROGRESS)) { + priskv_log_error("priskv_ucx_ep_post_get: ep is closed\n"); + status = UCS_ERR_INVALID_PARAM; + goto callback; + } + + priskv_ucx_request *request = malloc(sizeof(priskv_ucx_request)); + if (ucs_unlikely(request == NULL)) { + priskv_log_error("priskv_ucx_ep_post_get: failed to malloc request\n"); + status = UCS_ERR_NO_MEMORY; + goto callback; + } + + priskv_ucx_request_init(request); + + ucp_request_param_t param = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_DATATYPE | + UCP_OP_ATTR_FIELD_USER_DATA, + .datatype = ucp_dt_make_contig(1), + .cb.send = priskv_ucx_request_get_cb_intl, + .user_data = request}; + request->handle = ucp_get_nbx(ep->handle, buffer, count, raddr, rkey->handle, ¶m); + request->name = "get"; + request->status = UCS_INPROGRESS; + request->worker = ep->worker; + request->ep = ep; + request->payload.buffer = buffer; + request->payload.length = count; + request->payload.rma.raddr = raddr; + request->cb = cb; + request->cb_data = cb_data; + + priskv_log_debug( + "priskv_ucx_ep_post_get: ep %p request %p [buf=%p, len=%zu, raddr=%lu] posted\n", ep, + request, buffer, count, raddr); + + return priskv_ucx_request_progress(request, ¶m); +callback: + if (cb) { + cb(status, cb_data); + } + return NULL; +} + +ucs_status_ptr_t priskv_ucx_ep_post_flush(priskv_ucx_ep *ep, priskv_ucx_request_cb cb, + void *cb_data) +{ + ucs_status_t status; + if (ucs_unlikely(ep == NULL)) { + priskv_log_error("priskv_ucx_ep_post_flush: ep is NULL\n"); + status = UCS_ERR_INVALID_PARAM; + goto callback; + } + + if (ucs_unlikely(ep->status != UCS_INPROGRESS)) { + priskv_log_error("priskv_ucx_ep_post_flush: ep is closed\n"); + status = UCS_ERR_INVALID_PARAM; + goto callback; + } + + priskv_ucx_request *request = malloc(sizeof(priskv_ucx_request)); + if (ucs_unlikely(request == NULL)) { + priskv_log_error("priskv_ucx_ep_post_flush: failed to malloc request\n"); + status = UCS_ERR_NO_MEMORY; + goto callback; + } + + priskv_ucx_request_init(request); + + ucp_request_param_t param = {.op_attr_mask = + UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA, + .cb.send = priskv_ucx_request_flush_cb_intl, + .user_data = request}; + request->handle = ucp_ep_flush_nbx(ep->handle, ¶m); + request->name = "flush"; + request->status = UCS_INPROGRESS; + request->worker = ep->worker; + request->ep = ep; + request->cb = cb; + request->cb_data = cb_data; + + priskv_log_debug("priskv_ucx_ep_post_flush: ep %p request %p posted\n", ep, request); + + return priskv_ucx_request_progress(request, ¶m); +callback: + if (cb) { + cb(status, cb_data); + } + return NULL; +} + +struct sockaddr_storage *priskv_ucx_ep_get_local_addr(priskv_ucx_ep *ep) +{ + ep->attr.field_mask = UCP_EP_ATTR_FIELD_LOCAL_SOCKADDR; + ucs_status_t status = ucp_ep_query(ep->handle, &ep->attr); + if (ucs_unlikely(status != UCS_OK)) { + priskv_log_error( + "priskv_ucx_ep_get_local_addr: failed to query ep local addr, status: %s\n", + ucs_status_string(status)); + return NULL; + } else { + return &ep->attr.local_sockaddr; + } +} + +struct sockaddr_storage *priskv_ucx_ep_get_peer_addr(priskv_ucx_ep *ep) +{ + ep->attr.field_mask = UCP_EP_ATTR_FIELD_REMOTE_SOCKADDR; + ucs_status_t status = ucp_ep_query(ep->handle, &ep->attr); + if (ucs_unlikely(status != UCS_OK)) { + priskv_log_error("priskv_ucx_ep_get_peer_addr: failed to query ep peer addr, status: %s\n", + ucs_status_string(status)); + return NULL; + } else { + return &ep->attr.remote_sockaddr; + } +} + +ucs_status_t priskv_ucx_ep_destroy(priskv_ucx_ep *ep) +{ + if (atomic_exchange(&ep->closing, 1) || ep->handle == NULL) { + return UCS_OK; + } + + ucp_request_param_t param = {.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS, + .flags = UCP_EP_CLOSE_FLAG_FORCE}; + + ucs_status_ptr_t status = ucp_ep_close_nbx(ep->handle, ¶m); + if (UCS_PTR_IS_PTR(status)) { + // wait for close request to complete + ucs_status_t s; + while ((s = ucp_request_check_status(status)) == UCS_INPROGRESS) { + priskv_ucx_worker_progress(ep->worker); + } + ep->status = s; + } else if (UCS_PTR_STATUS(status) != UCS_OK) { + priskv_log_error("priskv_ucx_ep_destroy: failed to close ep %p, status %s\n", ep, + ucs_status_string(UCS_PTR_STATUS(status))); + } + + if (UCS_PTR_IS_PTR(status)) { + ucp_request_free(status); + } + + if (ep->close_cb != NULL) { + priskv_log_debug("priskv_ucx_ep_destroy: call close_cb %p\n", ep->close_cb); + ep->close_cb(ep->status, ep->close_cb_data); + ep->close_cb = NULL; + } + + priskv_ucx_worker *worker = ep->worker; + uint8_t paired_worker_ep = ep->paired_worker_ep; + free(ep); + + if (paired_worker_ep) { + return priskv_ucx_worker_destroy(worker); + } + + return UCS_OK; +} + +priskv_ucx_rkey *priskv_ucx_rkey_create(priskv_ucx_ep *ep, const void *packed_rkey) +{ + if (ucs_unlikely(ep == NULL)) { + priskv_log_error("priskv_ucx_rkey_create: ep is NULL\n"); + return NULL; + } + + priskv_ucx_rkey *rkey = malloc(sizeof(priskv_ucx_rkey)); + if (ucs_unlikely(rkey == NULL)) { + priskv_log_error("priskv_ucx_rkey_create: failed to malloc rkey\n"); + return NULL; + } + + rkey->ep = ep; + ucs_status_t status = ucp_ep_rkey_unpack(ep->handle, packed_rkey, &rkey->handle); + if (ucs_unlikely(status != UCS_OK)) { + priskv_log_error("priskv_ucx_rkey_create: failed to unpack rkey\n"); + free(rkey); + return NULL; + } + + return rkey; +} + +ucs_status_t priskv_ucx_rkey_destroy(priskv_ucx_rkey *rkey) +{ + if (ucs_unlikely(rkey == NULL)) { + priskv_log_error("priskv_ucx_rkey_destroy: rkey is NULL\n"); + return UCS_ERR_INVALID_PARAM; + } + + if (ucs_unlikely(rkey->handle == NULL)) { + priskv_log_error("priskv_ucx_rkey_destroy: rkey handle is NULL\n"); + return UCS_ERR_INVALID_PARAM; + } + + ucp_rkey_destroy(rkey->handle); + + free(rkey); + + return UCS_OK; +} diff --git a/man/libpriskv.7 b/man/libpriskv.7 index 202d5e7..0fecce6 100644 --- a/man/libpriskv.7 +++ b/man/libpriskv.7 @@ -3,7 +3,7 @@ libpriskv \- Library and header file for PrisKV. .SH DESCRIPTION -PrisKV is specifically designed for modern high-performance computing (HPC) and artificial intelligence (AI) computing. It solely supports RDMA. PrisKV also supports GDR (GPU Direct RDMA), enabling the value of a key to be directly transferred between PrisKV and the GPU. +PrisKV is specifically designed for modern high-performance computing (HPC) and artificial intelligence (AI) computing. It supports common transport protocols, including RDMA, TCP, and shared memory, to enable efficient communication for different scenarios. PrisKV also supports GDR (GPU Direct RDMA), enabling the value of a key to be directly transferred between PrisKV and the GPU. .SH EXAMPLES diff --git a/man/priskv-server.man b/man/priskv-server.man index a3c806a..5630c6a 100644 --- a/man/priskv-server.man +++ b/man/priskv-server.man @@ -15,9 +15,10 @@ .SH DESCRIPTION PrisKV is specifically designed for modern high-performance computing (HPC) -and artificial intelligence (AI) computing. It solely supports RDMA. -PrisKV also supports GDR (GPU Direct RDMA), enabling the value of a key to -be directly transferred between PrisKV and the GPU. +and artificial intelligence (AI) computing. It supports common transport protocols, +including RDMA, TCP, and shared memory, to enable efficient communication for +different scenarios. PrisKV also supports GDR (GPU Direct RDMA), enabling the +value of a key to be directly transferred between PrisKV and the GPU. .sp \fBpriskv-server\fP is a server that implements the PrisKV protocol. diff --git a/man/pypriskv.7 b/man/pypriskv.7 index 5229839..cc93f9e 100644 --- a/man/pypriskv.7 +++ b/man/pypriskv.7 @@ -3,7 +3,7 @@ pypriskv \- PrisKV client python runtime. .SH DESCRIPTION -PrisKV is specifically designed for modern high-performance computing (HPC) and artificial intelligence (AI) computing. It solely supports RDMA. PrisKV also supports GDR (GPU Direct RDMA), enabling the value of a key to be directly transferred between PrisKV and the GPU. +PrisKV is specifically designed for modern high-performance computing (HPC) and artificial intelligence (AI) computing. It supports common transport protocols, including RDMA, TCP, and shared memory, to enable efficient communication for different scenarios. PrisKV also supports GDR (GPU Direct RDMA), enabling the value of a key to be directly transferred between PrisKV and the GPU. .SH EXAMPLES diff --git a/pypriskv/setup.py b/pypriskv/setup.py index cdcd977..84168af 100644 --- a/pypriskv/setup.py +++ b/pypriskv/setup.py @@ -23,12 +23,22 @@ # Enhua Zhou from setuptools import setup, find_packages +import subprocess +import shlex try: from pybind11.setup_helpers import Pybind11Extension except ImportError: from setuptools import Extension as Pybind11Extension -# LIBS = ["rdmacm", "ibverbs"] +def _pkgcfg(name, flag): + try: + out = subprocess.check_output(['pkg-config', flag, name], universal_newlines=True) + return shlex.split(out.strip()) + except Exception: + return [] + +UCX_CFLAGS = _pkgcfg('ucx', '--cflags') +UCX_LDFLAGS = _pkgcfg('ucx', '--libs') ext_modules = [ Pybind11Extension( "priskv._priskv._priskv_client", @@ -37,8 +47,8 @@ extra_link_args=[ "-Wl,--start-group", "../cluster/client/libpriskvcluster.a", "-Wl,--end-group", "-lrdmacm", "-libverbs", "-lhiredis" - ], - extra_compile_args=["-g", "-O0", "-std=c++11"], + ] + UCX_LDFLAGS, + extra_compile_args=["-g", "-O0", "-std=c++11"] + UCX_CFLAGS, ), ] @@ -47,9 +57,10 @@ version='0.0.2', description= '''This is priskv's client. priskv is specifically designed for modern high-performance ''' - '''computing (HPC) and artificial intelligence (AI) computing. It solely supports RDMA. ''' - '''priskv also supports GDR (GPU Direct RDMA), enabling the value of a key to be directly ''' - '''transferred between priskv and the GPU.''', + '''computing (HPC) and artificial intelligence (AI) computing. It supports common ''' + '''transport protocols, including RDMA, TCP, and shared memory, to enable efficient ''' + '''communication for different scenarios. priskv also supports GDR (GPU Direct RDMA), ''' + '''enabling the value of a key to be directly transferred between priskv and the GPU.''', # package_data={ # 'priskv._priskv': ['*.pyi', '*.so'], # }, diff --git a/redhat/priskv.spec b/redhat/priskv.spec index edea33b..1316a7a 100644 --- a/redhat/priskv.spec +++ b/redhat/priskv.spec @@ -8,9 +8,10 @@ BuildRequires: git gcc gcc-c++ make cmake librdmacm rdma-core-devel libibverbs %description PrisKV is specifically designed for modern high-performance computing (HPC) \ -and artificial intelligence (AI) computing. It solely supports RDMA. \ -PrisKV also supports GDR (GPU Direct RDMA), enabling the value of a key to \ -be directly transferred between PrisKV and the GPU. +and artificial intelligence (AI) computing. It supports common transport protocols, \ +including RDMA, TCP, and shared memory, to enable efficient communication for \ +different scenarios. PrisKV also supports GDR (GPU Direct RDMA), enabling the \ +value of a key to be directly transferred between PrisKV and the GPU. %package server Summary: RPM packages for PrisKV server. diff --git a/run_e2e_test.py b/run_e2e_test.py index 0a9d82a..e69f114 100755 --- a/run_e2e_test.py +++ b/run_e2e_test.py @@ -8,6 +8,7 @@ import signal import sys import traceback +import tempfile server_process = None client_process = None @@ -81,14 +82,40 @@ def signal_handler(signum, frame): def create_memfile(): global mem_file_path mem_file_name = f"memfile_{int(time.time())}_{random.randint(0,99999)}" - mem_file_path = f"/run/{mem_file_name}" + def _tmpfs_path(name: str) -> str: + uid = os.getuid() + candidates = [] + try: + with open("/proc/mounts") as f: + mounts = [line.split() for line in f] + tmpfs_mounts = {m[1] for m in mounts if len(m) >= 3 and m[2] == "tmpfs"} + user_run = f"/run/user/{uid}" + if user_run in tmpfs_mounts and os.path.isdir(user_run) and os.access(user_run, os.W_OK): + candidates.append(user_run) + if "/dev/shm" in tmpfs_mounts and os.access("/dev/shm", os.W_OK): + candidates.append("/dev/shm") + if "/run" in tmpfs_mounts and os.access("/run", os.W_OK): + candidates.append("/run") + for m in tmpfs_mounts: + if os.access(m, os.W_OK): + candidates.append(m) + except Exception: + pass + for base in candidates: + return os.path.join(base, name) + for base in (f"/run/user/{uid}", "/dev/shm", "/run"): + if os.path.isdir(base) and os.access(base, os.W_OK): + return os.path.join(base, name) + return os.path.join(tempfile.gettempdir(), name) + mem_file_path = _tmpfs_path(mem_file_name) subprocess.run([ "./server/priskv-memfile", "-o", "create", "-f", mem_file_path, "--max-keys", "1024", "--max-key-length", "128", "--value-block-size", "4096", "--value-blocks", "4096" ], - stdout=subprocess.DEVNULL) + stdout=subprocess.DEVNULL, + check=True) def destroy_memfile(): @@ -160,15 +187,30 @@ def priskv_e2e_test(): print("---- E2E TEST ----") port = random.randint(24300, 24500) + + print("---- E2E TEST (RDMA) ----") rdma_ip = find_rdma_dev() if rdma_ip is None: - print("---- No RDMA IP, SKIP E2E TEST ----") - return 0 + print("---- No RDMA IP, SKIP E2E TEST OVER RDMA ----") + else: + priskv_e2e_test_run(rdma_ip, port) + + ucx_wireup_ip = "0.0.0.0" + os.environ["PRISKV_TRANSPORT"] = "ucx" + + print("---- E2E TEST (UCX TCP) ----") + os.environ["UCX_TLS"] = "tcp" + priskv_e2e_test_run(ucx_wireup_ip, port) + + print("---- E2E TEST (UCX SM) ----") + os.environ["UCX_TLS"] = "sm" + priskv_e2e_test_run(ucx_wireup_ip, port) +def priskv_e2e_test_run(ip, port): create_memfile() print("---- E2E TEST: create memfile [OK] ----") try: - create_server(rdma_ip, port) + create_server(ip, port) # wait for server ready time.sleep(10) @@ -176,7 +218,7 @@ def priskv_e2e_test(): value = 456 # step 1, get key from empty KV - ret = do_test(rdma_ip, port, f"get {key}", STATUS_NO_SUCH_KEY) + ret = do_test(ip, port, f"get {key}", STATUS_NO_SUCH_KEY) if ret != 0: print("---- E2E TEST: get key from empty KV [FAILED] ----") return ret @@ -184,7 +226,7 @@ def priskv_e2e_test(): print("---- E2E TEST: get key from empty KV [OK] ----") # step 2, set key to empty KV - ret = do_test(rdma_ip, port, f"set {key} {value}", STATUS_OK) + ret = do_test(ip, port, f"set {key} {value}", STATUS_OK) if ret != 0: print("---- E2E TEST: set key to empty KV [FAILED] ----") return ret @@ -192,7 +234,7 @@ def priskv_e2e_test(): print("---- E2E TEST: set key to empty KV [OK] ----") ## step 3, verify key from filled KV - ret = do_test(rdma_ip, port, f"get {key}", STATUS_OK, str(value)) + ret = do_test(ip, port, f"get {key}", STATUS_OK, str(value)) if ret != 0: print("---- E2E TEST: verify key from filled KV [FAILED] ----") return ret @@ -200,7 +242,7 @@ def priskv_e2e_test(): print("---- E2E TEST: verify key from filled KV [OK] ----") # step 4, delete key from filled KV - ret = do_test(rdma_ip, port, f"delete {key}", STATUS_OK) + ret = do_test(ip, port, f"delete {key}", STATUS_OK) if ret != 0: print("---- E2E TEST: delete key from filled KV [FAILED] ----") return ret @@ -208,7 +250,7 @@ def priskv_e2e_test(): print("---- E2E TEST: delete key from filled KV [OK] ----") # step 5, get key from empty KV - ret = do_test(rdma_ip, port, f"get {key}", STATUS_NO_SUCH_KEY) + ret = do_test(ip, port, f"get {key}", STATUS_NO_SUCH_KEY) if ret != 0: print("---- E2E TEST: get key from empty KV [FAILED] ----") return ret @@ -216,7 +258,7 @@ def priskv_e2e_test(): print("---- E2E TEST: get key from empty KV [OK] ----") # step 6, set key with timeout 5s - ret = do_test(rdma_ip, port, f"set {key} {value} EX 5", STATUS_OK) + ret = do_test(ip, port, f"set {key} {value} EX 5", STATUS_OK) if ret != 0: print("---- E2E TEST: set key with timeuot 5s [FAILED] ----") return ret @@ -225,7 +267,7 @@ def priskv_e2e_test(): # step 7, get key from kv before expired and compare value time.sleep(3) - ret = do_test(rdma_ip, port, f"get {key}", STATUS_OK, str(value)) + ret = do_test(ip, port, f"get {key}", STATUS_OK, str(value)) if ret != 0: print( "---- E2E TEST: verify key from filled KV before expired [FAILED] ----" @@ -238,7 +280,7 @@ def priskv_e2e_test(): # step 8, get key after expired time.sleep(5) - ret = do_test(rdma_ip, port, f"get {key}", STATUS_NO_SUCH_KEY) + ret = do_test(ip, port, f"get {key}", STATUS_NO_SUCH_KEY) if ret != 0: print("---- E2E TEST: get key after expired [FAILED]----") return ret @@ -246,7 +288,7 @@ def priskv_e2e_test(): print("---- E2E TEST: get key after expired [OK] ----") # step 9, set key to empty KV without timeout - ret = do_test(rdma_ip, port, f"set {key} {value}", STATUS_OK) + ret = do_test(ip, port, f"set {key} {value}", STATUS_OK) if ret != 0: print( "---- E2E TEST: set key to empty KV without timeout [FAILED]----" @@ -257,7 +299,7 @@ def priskv_e2e_test(): # step 10, get keys from KV and compare values after a while time.sleep(7) - ret = do_test(rdma_ip, port, f"get {key}", STATUS_OK, str(value)) + ret = do_test(ip, port, f"get {key}", STATUS_OK, str(value)) if ret != 0: print("---- E2E TEST: verify keys from filled KV [FAILED]----") return ret @@ -265,7 +307,7 @@ def priskv_e2e_test(): print("---- E2E TEST: verify keys from filled KV [OK] ----") # step 11, set expire time 5s - ret = do_test(rdma_ip, port, f"expire {key} 5", STATUS_OK) + ret = do_test(ip, port, f"expire {key} 5", STATUS_OK) if ret != 0: print("---- E2E TEST: set expire time 5s [FAILED] ----") return ret @@ -274,7 +316,7 @@ def priskv_e2e_test(): # step 12, get keys from empty KV time.sleep(7) - ret = do_test(rdma_ip, port, f"get {key}", STATUS_NO_SUCH_KEY) + ret = do_test(ip, port, f"get {key}", STATUS_NO_SUCH_KEY) if ret != 0: print("---- E2E TEST: get keys from empty KV [FAILED] ----") return ret diff --git a/server/Makefile b/server/Makefile index 77ecb8a..339a76a 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,5 +1,7 @@ VERSION = 0.1 -LIBS = -lrdmacm -libverbs -lpthread -lmount -levent -levent_openssl -lssl -lcrypto -luring -lhiredis +UCX_CFLAGS = $(shell pkg-config --cflags ucx) +UCX_LIBS = $(shell pkg-config --libs ucx) +LIBS = $(UCX_LIBS) -lrdmacm -libverbs -lpthread -lmount -levent -levent_openssl -lssl -lcrypto -luring -lhiredis INCFLAGS = -I../include -I../client CFLAGS = -fPIC -Wall -g $(INCFLAGS) -D_GNU_SOURCE -Wshadow -Wformat=2 -Wwrite-strings -fstack-protector-strong -Wnull-dereference -Wunreachable-code FMT = clang-format-19 @@ -61,7 +63,7 @@ endif rm -f $@.tmp $(PRISKV_SERVER_TARGETS): priskv-%: %.c $(PRISKV_SERVER_OBJS) $(COMMON_LIB_OBJS) $(DEPS_STATIC_LIBS) - $(CC) $^ -o $@ $(CFLAGS) $(LIBS) + $(CC) $^ -o $@ $(CFLAGS) $(UCX_CFLAGS) $(LIBS) MAN_DIR =../man $(MAN_DIR)/priskv-server.1: $(MAN_DIR)/priskv-server.man @@ -91,4 +93,5 @@ clean: format: $(FMT) -i *.c *.h $(FMT) -i backend/*.c backend/*.h + $(FMT) -i transport/*.c transport/*.h make -C test format diff --git a/server/acl.c b/server/acl.c index 167bfae..4cc7461 100644 --- a/server/acl.c +++ b/server/acl.c @@ -30,7 +30,10 @@ #include #include #include -#include +#include +#include +#include +#include #include "priskv-utils.h" #include "priskv-log.h" @@ -43,23 +46,24 @@ static const char *acl_any = "any"; static int priskv_acl_addr(const char *addr, struct sockaddr *saddr) { - struct rdma_addrinfo hints = {0}, *servinfo; + struct addrinfo hints = {0}, *res = NULL; const char *_port = "0"; int ret; - hints.ai_flags = RAI_PASSIVE; - hints.ai_port_space = RDMA_PS_TCP; - ret = rdma_getaddrinfo(addr, _port, &hints, &servinfo); + hints.ai_flags = AI_PASSIVE; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + ret = getaddrinfo(addr, _port, &hints, &res); if (ret) { - priskv_log_error("ACL: getaddrinfo %s failed: %d, errno %d", addr, ret, errno); - return ret; - } else if (!servinfo) { + priskv_log_error("ACL: getaddrinfo %s failed: %d", addr, ret); + return -EINVAL; + } else if (!res) { priskv_log_error("ACL: getaddrinfo %s: no availabe address", addr); return -EINVAL; } - memcpy(saddr, servinfo->ai_src_addr, servinfo->ai_src_len); - rdma_freeaddrinfo(servinfo); + memcpy(saddr, res->ai_addr, res->ai_addrlen); + freeaddrinfo(res); return 0; } diff --git a/server/info.c b/server/info.c index 7e956f5..7ce662a 100644 --- a/server/info.c +++ b/server/info.c @@ -39,7 +39,7 @@ #include "acl.h" #include "memory.h" #include "kv.h" -#include "rdma.h" +#include "transport/transport.h" void priskv_info_get_acl(void *data) { @@ -68,7 +68,7 @@ void priskv_info_get_memory(void *data) void priskv_info_get_kv(void *data) { priskv_kv_info *info = (priskv_kv_info *)data; - void *kv = priskv_rdma_get_kv(); + void *kv = priskv_transport_get_kv(); info->bucket_count = priskv_get_bucket_count(kv); info->keys_inuse = priskv_get_keys_inuse(kv); @@ -86,11 +86,11 @@ void priskv_info_get_connection(void *data) { priskv_connection_info *info = (priskv_connection_info *)data; - priskv_rdma_listener *listeners = priskv_rdma_get_listeners(&info->nlisteners); + priskv_transport_listener *listeners = priskv_transport_get_listeners(&info->nlisteners); info->listeners = calloc(info->nlisteners, sizeof(priskv_conn_listener_info)); for (int i = 0; i < info->nlisteners; i++) { - priskv_rdma_listener *listener = &listeners[i]; + priskv_transport_listener *listener = &listeners[i]; priskv_conn_listener_info *listener_info = &info->listeners[i]; listener_info->address = strdup(listener->address); @@ -98,7 +98,7 @@ void priskv_info_get_connection(void *data) listener_info->clients = calloc(listener->nclients, sizeof(priskv_conn_client_info)); for (int j = 0; j < listener->nclients; j++) { - priskv_rdma_client *client = &listener->clients[j]; + priskv_transport_client *client = &listener->clients[j]; priskv_conn_client_info *client_info = &listener_info->clients[j]; client_info->address = strdup(client->address); @@ -114,7 +114,7 @@ void priskv_info_get_connection(void *data) } } - priskv_rdma_free_listeners(listeners, info->nlisteners); + priskv_transport_free_listeners(listeners, info->nlisteners); } void priskv_info_free_connection(void *data) @@ -184,7 +184,7 @@ void priskv_info_get_cpu(void *data) pid_t pid = getpid(); int ret = get_process_cpu_time(pid, &used_cpu_user_ticks, &used_cpu_sys_ticks, &clock_ticks); if (!ret) { - priskv_log_warn("failed to get process cpu time"); + priskv_log_warn("failed to get process cpu time\n"); } info->used_cpu_sys_ticks = (uint64_t)used_cpu_sys_ticks; diff --git a/server/kv.h b/server/kv.h index b89b501..48391a8 100644 --- a/server/kv.h +++ b/server/kv.h @@ -38,8 +38,8 @@ extern "C" #include "priskv-protocol.h" #include "backend/backend.h" -typedef struct priskv_rdma_conn priskv_rdma_conn; -struct priskv_rdma_rw_work; +struct priskv_transport_conn; +typedef struct priskv_transport_conn priskv_transport_conn; #define PRISKV_KV_DEFAULT_EXPIRE_ROUTINE_INTERVAL 600 @@ -103,12 +103,12 @@ uint64_t priskv_get_expire_routine_times(void *_kv); // 1. For the same key, a request has already been sent to the backend; // 2. The number of current concurrent requests exceeds the backend queue depth. typedef struct priskv_tiering_req { - priskv_rdma_conn *conn; priskv_thread *thread; priskv_backend_device *backend; void *kv; priskv_request *req; uint64_t request_id; + priskv_transport_conn *conn; /* kv operation context */ uint8_t *key; @@ -128,10 +128,13 @@ typedef struct priskv_tiering_req { priskv_backend_status backend_status; bool recv_reposted; - struct priskv_rdma_rw_work *rdma_work; } priskv_tiering_req; int priskv_backend_req_resubmit(void *req); +struct priskv_tiering_req *priskv_tiering_req_new(priskv_transport_conn *conn, priskv_request *req, + uint8_t *key, uint16_t keylen, uint64_t timeout, + priskv_req_command cmd, uint32_t remote_valuelen, + priskv_resp_status *resp_status); // tiering concurrency control bool priskv_key_serialize_enter(struct priskv_tiering_req *treq); diff --git a/server/memfile.c b/server/memfile.c index b331094..2d2bba5 100644 --- a/server/memfile.c +++ b/server/memfile.c @@ -34,15 +34,15 @@ #include "priskv-log.h" #include "priskv-logo.h" -#include "rdma.h" +#include "transport/transport.h" #include "memory.h" #include "priskv-threads.h" /* arguments of command line */ -static uint32_t max_key_length = PRISKV_RDMA_DEFAULT_KEY_LENGTH; -static uint32_t max_key = PRISKV_RDMA_DEFAULT_KEY; -static uint32_t value_block_size = PRISKV_RDMA_DEFAULT_VALUE_BLOCK_SIZE; -static uint64_t value_block = PRISKV_RDMA_DEFAULT_VALUE_BLOCK; +static uint32_t max_key_length = PRISKV_TRANSPORT_DEFAULT_KEY_LENGTH; +static uint32_t max_key = PRISKV_TRANSPORT_DEFAULT_KEY; +static uint32_t value_block_size = PRISKV_TRANSPORT_DEFAULT_VALUE_BLOCK_SIZE; +static uint64_t value_block = PRISKV_TRANSPORT_DEFAULT_VALUE_BLOCK; static uint8_t threads = 1; static priskv_log_level log_level = priskv_log_notice; static char *memfile; @@ -54,15 +54,15 @@ static void priskv_showhelp(void) printf(" -o/--op OPERATION\n\tsupport operations: create|info[default]\n"); printf(" -f/--memfile PATH\n\tmemory file from tmpfs/hugetlbfs\n"); printf(" -k/--max-keys KEYS\n\tthe maxium count of KV, default %d, max %d\n", - PRISKV_RDMA_DEFAULT_KEY, PRISKV_RDMA_MAX_KEY); + PRISKV_TRANSPORT_DEFAULT_KEY, PRISKV_TRANSPORT_MAX_KEY); printf(" -K/--max-key-length BYTES\n\tthe maxium bytes of a key, default %d, max %d\n", - PRISKV_RDMA_DEFAULT_KEY_LENGTH, PRISKV_RDMA_MAX_KEY_LENGTH); + PRISKV_TRANSPORT_DEFAULT_KEY_LENGTH, PRISKV_TRANSPORT_MAX_KEY_LENGTH); printf(" -v/--value-block-size BYTES\n\tthe block size of minimal value in bytes, " "default %d, max %d\n", - PRISKV_RDMA_DEFAULT_VALUE_BLOCK_SIZE, PRISKV_RDMA_MAX_VALUE_BLOCK_SIZE); + PRISKV_TRANSPORT_DEFAULT_VALUE_BLOCK_SIZE, PRISKV_TRANSPORT_MAX_VALUE_BLOCK_SIZE); printf(" -b/--value-blocks BLOCKS\n\tthe count of value blocks, must be power of 2, " "default %ld, max %ld\n", - PRISKV_RDMA_DEFAULT_VALUE_BLOCK, PRISKV_RDMA_MAX_VALUE_BLOCK); + PRISKV_TRANSPORT_DEFAULT_VALUE_BLOCK, PRISKV_TRANSPORT_MAX_VALUE_BLOCK); printf(" -t/--threads THREADS\n\tthe number of worker threads to clean memory, default 0\n"); printf(" -l/--log-level LEVEL\n\terror, warn, notice[default], info or debug\n"); @@ -170,7 +170,7 @@ static void priskv_parsr_arg(int argc, char *argv[]) case 'k': max_key = atoi(optarg); - if (!max_key || (max_key > PRISKV_RDMA_MAX_KEY)) { + if (!max_key || (max_key > PRISKV_TRANSPORT_MAX_KEY)) { printf("Invalid -k/--max-keys\n"); priskv_showhelp(); } @@ -178,7 +178,7 @@ static void priskv_parsr_arg(int argc, char *argv[]) case 'K': if (priskv_str2num(optarg, &key_length) < 0 || !key_length || - key_length > PRISKV_RDMA_MAX_KEY_LENGTH) { + key_length > PRISKV_TRANSPORT_MAX_KEY_LENGTH) { printf("Invalid -K/--max-key-length\n"); priskv_showhelp(); } @@ -187,7 +187,7 @@ static void priskv_parsr_arg(int argc, char *argv[]) case 'v': if (priskv_str2num(optarg, &block_size) < 0 || !block_size || - block_size > PRISKV_RDMA_MAX_VALUE_BLOCK_SIZE) { + block_size > PRISKV_TRANSPORT_MAX_VALUE_BLOCK_SIZE) { printf("Invalid -v/--value-block-size\n"); priskv_showhelp(); } @@ -197,7 +197,7 @@ static void priskv_parsr_arg(int argc, char *argv[]) case 'b': value_block = atoll(optarg); - if (!value_block || (value_block > PRISKV_RDMA_MAX_VALUE_BLOCK)) { + if (!value_block || (value_block > PRISKV_TRANSPORT_MAX_VALUE_BLOCK)) { priskv_showhelp(); } diff --git a/server/rdma.c b/server/rdma.c deleted file mode 100644 index eb1d93a..0000000 --- a/server/rdma.c +++ /dev/null @@ -1,1949 +0,0 @@ -// Copyright (c) 2025 ByteDance Ltd. and/or its affiliates -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* - * Authors: - * Jinlong Xuan <15563983051@163.com> - * Xu Ji - * Yu Wang - * Bo Liu - * Zhenwei Pi - * Rui Zhang - * Changqi Lu - * Enhua Zhou - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "priskv-protocol.h" -#include "priskv-protocol-helper.h" -#include "priskv-log.h" -#include "priskv-utils.h" -#include "acl.h" -#include "kv.h" -#include "rdma.h" -#include "priskv-event.h" -#include "priskv-threads.h" -#include "list.h" -#include "memory.h" -#include "crc.h" -#include "backend/backend.h" - -priskv_threadpool *g_threadpool; -uint32_t g_slow_query_threshold_latency_us = SLOW_QUERY_THRESHOLD_LATENCY_US; - -#define PRISKV_RDMA_DEF_ADDR(id) \ - char local_addr[PRISKV_ADDR_LEN] = {0}; \ - char peer_addr[PRISKV_ADDR_LEN] = {0}; \ - priskv_inet_ntop(rdma_get_local_addr(id), local_addr); \ - priskv_inet_ntop(rdma_get_peer_addr(id), peer_addr); - -typedef struct priskv_rdma_mem { -#define PRISKV_RDMA_MEM_NAME_LEN 32 - char name[PRISKV_RDMA_MEM_NAME_LEN]; - uint8_t *buf; - uint32_t buf_size; - struct ibv_mr *mr; -} priskv_rdma_mem; - -typedef enum priskv_rdma_mem_type { - PRISKV_RDMA_MEM_REQ, - PRISKV_RDMA_MEM_RESP, - PRISKV_RDMA_MEM_KEYS, - - PRISKV_RDMA_MEM_MAX -} priskv_rdma_mem_type; - -typedef struct priskv_rdma_conn { - struct rdma_cm_id *cm_id; - struct ibv_comp_channel *comp_channel; - struct ibv_cq *cq; - priskv_rdma_conn_cap conn_cap; - pthread_spinlock_t lock; - - union { - struct { - struct list_head head; - uint32_t nclients; - } s; /* for listener */ - struct { - struct priskv_rdma_conn *listener; - struct list_node node; - priskv_thread *thread; - bool closing; - priskv_rdma_stats stats[PRISKV_COMMAND_MAX]; - uint64_t resps; - } c; /* for client */ - }; - - void *kv; - uint8_t *value_base; - struct ibv_mr *value_mr; - - priskv_rdma_mem rmem[PRISKV_RDMA_MEM_MAX]; -} priskv_rdma_conn; - -typedef struct priskv_rdma_rw_work { - priskv_rdma_conn *conn; - uint64_t request_id; /* be64 type */ - priskv_request *req; - struct ibv_mr *mr; - uint32_t valuelen; - uint16_t nsgl; - uint16_t completed; - bool defer_resp; - void (*cb)(void *); - void *cbarg; -} priskv_rdma_rw_work; - -typedef struct priskv_rdma_server { - int epollfd; - void *kv; - int nlisteners; - priskv_rdma_conn listeners[PRISKV_RDMA_MAX_BIND_ADDR]; -} priskv_rdma_server; - -static priskv_rdma_server g_server = { - .epollfd = -1, -}; - -static uint32_t priskv_rdma_max_rw_size = 1024 * 1024 * 1024; - -static void priskv_rdma_handle_cm(int fd, void *opaque, uint32_t events); - -static int priskv_rdma_mem_new(priskv_rdma_conn *conn, priskv_rdma_mem *rmem, const char *name, - uint32_t size) -{ - uint32_t flags = IBV_ACCESS_LOCAL_WRITE; - bool guard = true; /* always enable memory guard */ - uint8_t *buf; - int ret; - - buf = priskv_mem_malloc(size, guard); - if (!buf) { - priskv_log_error("RDMA: failed to allocate %s buffer: %m\n", name); - ret = -ENOMEM; - goto error; - } - - rmem->mr = ibv_reg_mr(conn->cm_id->pd, buf, size, flags); - if (!rmem->mr) { - priskv_log_error("RDMA: failed to reg MR for %s buffer: %m\n", name); - ret = -errno; - goto free_mem; - } - - strncpy(rmem->name, name, PRISKV_RDMA_MEM_NAME_LEN - 1); - rmem->buf = buf; - rmem->buf_size = size; - - priskv_log_info("RDMA: new rmem %s, size %d\n", name, size); - priskv_log_debug("RDMA: new rmem %s, buf %p\n", name, buf); - return 0; - -free_mem: - priskv_mem_free(rmem->buf, rmem->buf_size, guard); - -error: - memset(rmem, 0x00, sizeof(priskv_rdma_mem)); - - return ret; -} - -static void priskv_rdma_mem_free(priskv_rdma_conn *conn, priskv_rdma_mem *rmem) -{ - if (rmem->mr) { - ibv_dereg_mr(rmem->mr); - } - - if (rmem->buf) { - priskv_log_debug("RDMA: free rmem %s, buf %p\n", rmem->name, rmem->buf); - priskv_mem_free(rmem->buf, rmem->buf_size, true); - } - - priskv_log_info("RDMA: free rmem %s, size %d\n", rmem->name, rmem->buf_size); - memset(rmem, 0x00, sizeof(priskv_rdma_mem)); -} - -static inline void priskv_rdma_free_ctrl_buffer(priskv_rdma_conn *conn) -{ - for (int i = 0; i < PRISKV_RDMA_MEM_MAX; i++) { - priskv_rdma_mem *rmem = &conn->rmem[i]; - - priskv_rdma_mem_free(conn, rmem); - } -} - -static int priskv_rdma_listen_one(char *addr, int port, void *kv, priskv_rdma_conn_cap *cap) -{ - int ret = 0, afonly = 1; - char _port[6]; /* strlen("65535") */ - struct rdma_addrinfo hints, *servinfo; - struct rdma_cm_id *listen_cmid = NULL; - struct rdma_event_channel *listen_channel = NULL; - priskv_rdma_conn *listener; - - snprintf(_port, 6, "%d", port); - memset(&hints, 0, sizeof(hints)); - hints.ai_flags = RAI_PASSIVE; - hints.ai_port_space = RDMA_PS_TCP; - ret = rdma_getaddrinfo(addr, _port, &hints, &servinfo); - if (ret) { - priskv_log_error("RDMA: getaddrinfo %s failed: %s", addr, gai_strerror(ret)); - return ret; - } else if (!servinfo) { - priskv_log_error("RDMA: getaddrinfo %s: no availabe address", addr); - return -EINVAL; - } - - listen_channel = rdma_create_event_channel(); - if (!listen_channel) { - ret = -errno; - priskv_log_error("RDMA: create event channel failed\n"); - goto freeaddr; - } - - ret = priskv_set_nonblock(listen_channel->fd); - if (ret) { - priskv_log_error("RDMA: failed to set NONBLOCK on event channel fd\n"); - goto error; - } - - if (rdma_create_id(listen_channel, &listen_cmid, NULL, RDMA_PS_TCP)) { - ret = -errno; - priskv_log_error("RDMA: create listen cm id error\n"); - goto error; - } - - rdma_set_option(listen_cmid, RDMA_OPTION_ID, RDMA_OPTION_ID_AFONLY, &afonly, sizeof(afonly)); - - if (rdma_bind_addr(listen_cmid, servinfo->ai_src_addr)) { - ret = -errno; - priskv_log_error("RDMA: Bind addr error on %s\n", addr); - goto error; - } - - if (rdma_listen(listen_cmid, 0)) { - ret = -errno; - priskv_log_error("RDMA: listen addr error on %s\n", addr); - goto error; - } - - /* TODO split into several MRs, because of max_mr_size of IB device */ - uint8_t *value_base = priskv_get_value_base(kv); - assert(value_base); - uint64_t size = priskv_get_value_blocks(kv) * priskv_get_value_block_size(kv); - assert(size); - uint32_t access = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; - struct ibv_mr *value_mr = ibv_reg_mr(listen_cmid->pd, value_base, size, access); - if (!value_mr) { - ret = -errno; - priskv_log_error( - "RDMA: failed to reg MR for value: %m [%p, %p], value block %ld, value block size %d\n", - value_base, value_base + size, priskv_get_value_blocks(kv), - priskv_get_value_block_size(kv)); - goto error; - } - - priskv_log_debug("RDMA: Value buffer %p, length %ld\n", value_base, size); - - listener = &g_server.listeners[g_server.nlisteners++]; - listener->cm_id = listen_cmid; - listener->value_base = value_base; - listener->kv = kv; - listener->value_mr = value_mr; - listener->conn_cap = *cap; - listener->s.nclients = 0; - list_head_init(&listener->s.head); - pthread_spin_init(&listener->lock, 0); - - priskv_log_info("RDMA: <%s:%d> listener starts\n", addr, port); - - ret = 0; - goto freeaddr; - -error: - if (listen_cmid) { - rdma_destroy_id(listen_cmid); - } - if (listen_channel) { - rdma_destroy_event_channel(listen_channel); - } - -freeaddr: - rdma_freeaddrinfo(servinfo); - return ret; -} - -int priskv_rdma_listen(char **addr, int naddr, int port, void *kv, priskv_rdma_conn_cap *cap) -{ - priskv_rdma_conn *listener; - - for (int i = 0; i < naddr; i++) { - int ret = priskv_rdma_listen_one(addr[i], port, kv, cap); - if (ret) { - return ret; - } - } - - g_server.kv = kv; - - g_server.epollfd = epoll_create(g_server.nlisteners); - if (g_server.epollfd == -1) { - priskv_log_error("RDMA: failed to create epoll fd %m\n"); - return -1; - } - - for (int i = 0; i < g_server.nlisteners; i++) { - listener = &g_server.listeners[i]; - PRISKV_RDMA_DEF_ADDR(listener->cm_id); - - priskv_set_fd_handler(listener->cm_id->channel->fd, priskv_rdma_handle_cm, NULL, listener); - if (priskv_add_event_fd(g_server.epollfd, listener->cm_id->channel->fd)) { - priskv_log_error("RDMA: failed to add listen fd into epoll fd %m\n"); - return -1; - } - - priskv_log_notice("RDMA: <%s> ready\n", local_addr); - } - - return 0; -} - -static void priskv_rdma_get_clients(priskv_rdma_conn *listener, priskv_rdma_client **clients, - int *nclients) -{ - priskv_rdma_conn *client; - *nclients = 0; - - pthread_spin_lock(&listener->lock); - *clients = calloc(listener->s.nclients, sizeof(priskv_rdma_client)); - list_for_each (&listener->s.head, client, c.node) { - PRISKV_RDMA_DEF_ADDR(client->cm_id); - - memcpy((*clients)[*nclients].address, peer_addr, strlen(peer_addr) + 1); - memcpy((*clients)[*nclients].stats, client->c.stats, - PRISKV_COMMAND_MAX * sizeof(priskv_rdma_stats)); - (*clients)[*nclients].resps = client->c.resps; - (*clients)[*nclients].closing = client->c.closing; - (*nclients)++; - - if (*nclients == listener->s.nclients) { - break; - } - } - pthread_spin_unlock(&listener->lock); -} - -static void priskv_rdma_free_clients(priskv_rdma_client *clients) -{ - free(clients); -} - -priskv_rdma_listener *priskv_rdma_get_listeners(int *nlisteners) -{ - priskv_rdma_listener *listeners; - - *nlisteners = g_server.nlisteners; - listeners = calloc(*nlisteners, sizeof(priskv_rdma_listener)); - - for (int i = 0; i < *nlisteners; i++) { - PRISKV_RDMA_DEF_ADDR(g_server.listeners[i].cm_id); - - memcpy(listeners[i].address, local_addr, strlen(local_addr) + 1); - priskv_rdma_get_clients(&g_server.listeners[i], &listeners[i].clients, - &listeners[i].nclients); - } - - return listeners; -} - -void priskv_rdma_free_listeners(priskv_rdma_listener *listeners, int nlisteners) -{ - for (int i = 0; i < nlisteners; i++) { - priskv_rdma_free_clients(listeners[i].clients); - } - free(listeners); -} - -int priskv_rdma_get_fd(void) -{ - return g_server.epollfd; -} - -void *priskv_rdma_get_kv(void) -{ - return g_server.kv; -} - -/* use 64 bytes aligned request buffer. */ -static inline unsigned int priskv_request_size_aligend(priskv_rdma_conn *conn) -{ - uint16_t s = priskv_request_size(conn->conn_cap.max_sgl, conn->conn_cap.max_key_length); - - return ALIGN_UP(s, 64); -} - -#define PRISKV_RDMA_RESPONSE_FREE_STATUS 0xffff -static inline int priskv_rdma_response_free(priskv_response *resp) -{ - if (resp->status == PRISKV_RDMA_RESPONSE_FREE_STATUS) { - return -EPROTO; - } - - resp->status = PRISKV_RDMA_RESPONSE_FREE_STATUS; - return 0; -} - -static inline uint32_t priskv_rdma_wr_size(priskv_rdma_conn *client) -{ - return client->conn_cap.max_inflight_command * (2 + client->conn_cap.max_sgl); -} - -static int priskv_rdma_new_ctrl_buffer(priskv_rdma_conn *conn) -{ - uint16_t size; - uint32_t buf_size; - - /* #step 1, prepare buffer & MR for request from client */ - size = priskv_request_size_aligend(conn); - buf_size = (uint32_t)size * priskv_rdma_wr_size(conn); - if (priskv_rdma_mem_new(conn, &conn->rmem[PRISKV_RDMA_MEM_REQ], "Request", buf_size)) { - goto error; - } - - /* #step 2, prepare buffer & MR for response to client */ - size = sizeof(priskv_response); - buf_size = size * priskv_rdma_wr_size(conn); - if (priskv_rdma_mem_new(conn, &conn->rmem[PRISKV_RDMA_MEM_RESP], "Response", buf_size)) { - goto error; - } - - for (uint16_t i = 0; i < priskv_rdma_wr_size(conn); i++) { - priskv_response *resp = (priskv_response *)(conn->rmem[PRISKV_RDMA_MEM_RESP].buf + i * size); - priskv_rdma_response_free(resp); - } - - return 0; - -error: - priskv_rdma_free_ctrl_buffer(conn); - return -ENOMEM; -} - -static void priskv_rdma_close_client(priskv_rdma_conn *client) -{ - PRISKV_RDMA_DEF_ADDR(client->cm_id) - priskv_log_notice( - "RDMA: <%s - %s> close. Requests GET %ld, SET %ld, TEST %ld, DELETE %ld, Responses %ld\n", - local_addr, peer_addr, client->c.stats[PRISKV_COMMAND_GET].ops, - client->c.stats[PRISKV_COMMAND_SET].ops, client->c.stats[PRISKV_COMMAND_TEST].ops, - client->c.stats[PRISKV_COMMAND_DELETE].ops, client->c.resps); - - if ((client->comp_channel) && (client->c.thread != NULL)) { - priskv_thread_del_event_handler(client->c.thread, client->comp_channel->fd); - priskv_set_fd_handler(client->comp_channel->fd, NULL, NULL, NULL); /* clear fd handler */ - client->c.thread = NULL; - } - - if (client->cm_id && client->cm_id->qp) { - rdma_destroy_qp(client->cm_id); - client->cm_id->qp = NULL; - } - - if (client->cq) { - if (ibv_destroy_cq(client->cq)) { - priskv_log_warn("ibv_destroy_cq failed\n"); - } - client->cq = NULL; - } - - if (client->comp_channel) { - if (ibv_destroy_comp_channel(client->comp_channel)) { - priskv_log_warn("ibv_destroy_comp_channel failed\n"); - } - client->comp_channel = NULL; - } - - priskv_rdma_free_ctrl_buffer(client); - - if (client->cm_id) { - rdma_destroy_id(client->cm_id); - } - - free(client); -} - -static void priskv_rdma_close_client_async(priskv_rdma_conn *client) -{ - PRISKV_RDMA_DEF_ADDR(client->cm_id); - - /* avoid re-entry of closing client: - * - CQ error - * - disconnected CM event from client - */ - pthread_spin_lock(&client->lock); - if (client->c.closing) { - pthread_spin_unlock(&client->lock); - return; - } - - client->c.closing = true; - pthread_spin_unlock(&client->lock); - - priskv_log_notice("RDMA: <%s - %s> async close client\n", local_addr, peer_addr); -} - -static void priskv_rdma_close_disconnected(priskv_rdma_conn *listener) -{ - priskv_rdma_conn *client, *tmp; - - pthread_spin_lock(&listener->lock); - list_for_each_safe (&listener->s.head, client, tmp, c.node) { - if (client->c.closing) { - listener->s.nclients--; - list_del(&client->c.node); - pthread_spin_unlock(&listener->lock); - } else { - continue; - } - - priskv_rdma_close_client(client); - - pthread_spin_lock(&listener->lock); - } - pthread_spin_unlock(&listener->lock); -} - -static priskv_response *priskv_rdma_unused_response(priskv_rdma_conn *conn) -{ - uint16_t resp_buf_size = sizeof(priskv_response); - priskv_rdma_mem *rmem = &conn->rmem[PRISKV_RDMA_MEM_RESP]; - - for (uint16_t i = 0; i < priskv_rdma_wr_size(conn); i++) { - priskv_response *resp = (priskv_response *)(rmem->buf + i * resp_buf_size); - if (resp->status == PRISKV_RDMA_RESPONSE_FREE_STATUS) { - priskv_log_debug("RDMA: use response %d\n", i); - resp->status = PRISKV_RESP_STATUS_OK; - return resp; - } - } - - PRISKV_RDMA_DEF_ADDR(conn->cm_id) - priskv_log_error("RDMA: <%s - %s> inflight response exceeds %d\n", local_addr, peer_addr, - priskv_rdma_wr_size(conn)); - return NULL; -} - -static int priskv_rdma_send_response(priskv_rdma_conn *conn, uint64_t request_id, - priskv_resp_status status, uint32_t length) -{ - priskv_rdma_mem *rmem = &conn->rmem[PRISKV_RDMA_MEM_RESP]; - struct ibv_send_wr wr = {0}, *bad_wr; - struct ibv_sge rsge; - priskv_response *resp; - - resp = priskv_rdma_unused_response(conn); - if (!resp) { - return -EPROTO; - } - - assert(((uint8_t *)resp >= rmem->buf) && ((uint8_t *)resp < rmem->buf + rmem->buf_size)); - - resp->request_id = request_id; /* be64 */ - resp->status = htobe16(status); - resp->length = htobe32(length); - - rsge.addr = (uint64_t)resp; - rsge.length = sizeof(priskv_response); - rsge.lkey = rmem->mr->lkey; - - wr.wr_id = (uint64_t)resp; - wr.sg_list = &rsge; - wr.num_sge = 1; - wr.opcode = IBV_WR_SEND; - wr.send_flags = IBV_SEND_SIGNALED; - - int ret = ibv_post_send(conn->cm_id->qp, &wr, &bad_wr); - if (ret) { - PRISKV_RDMA_DEF_ADDR(conn->cm_id) - priskv_log_error( - "RDMA: <%s - %s> ibv_post_send response failed: addr 0x%lx, length 0x%x ret %d\n", - local_addr, peer_addr, rsge.addr, rsge.length, ret); - } else { - conn->c.resps++; - } - - return ret; -} - -static int priskv_rdma_rw_req(priskv_rdma_conn *conn, priskv_request *req, struct ibv_mr *mr, - uint8_t *val, uint32_t valuelen, bool set, void (*cb)(void *), - void *cbarg, bool defer_resp, priskv_rdma_rw_work **work_out) -{ - priskv_rdma_rw_work *work; - uint32_t offset = 0; - struct ibv_send_wr wr = {0}, *bad_wr; - struct ibv_sge sge; - uint16_t nsgl = be16toh(req->nsgl); - const char *cmdstr = set ? "READ" : "WRITE"; - - if (work_out) { - *work_out = NULL; - } - - work = calloc(1, sizeof(priskv_rdma_rw_work)); - if (!work) { - priskv_log_error("RDMA: failed to allocate memory for %s request\n", cmdstr); - return -ENOMEM; - } - - work->conn = conn; - work->req = req; - work->mr = mr; - work->request_id = req->request_id; /* be64 */ - work->valuelen = valuelen; - work->completed = 0; - work->cb = cb; - work->cbarg = cbarg; - work->defer_resp = defer_resp; - - wr.wr_id = (uint64_t)work; - wr.next = NULL; - wr.sg_list = &sge; - wr.num_sge = 1; - wr.opcode = set ? IBV_WR_RDMA_READ : IBV_WR_RDMA_WRITE; - wr.send_flags = IBV_SEND_SIGNALED; - - for (uint16_t i = 0; i < nsgl; i++) { - priskv_keyed_sgl *sgl = &req->sgls[i]; - - uint32_t sgl_length = be32toh(sgl->length); - uint32_t sgl_offset = 0; - - wr.wr.rdma.rkey = be32toh(sgl->key); - do { - wr.wr.rdma.remote_addr = be64toh(sgl->addr) + sgl_offset; - - sge.addr = (uint64_t)val + offset + sgl_offset; - sge.lkey = mr->lkey; - sge.length = priskv_min_u32(sgl_length - sgl_offset, valuelen); - sge.length = priskv_min_u32(sge.length, priskv_rdma_max_rw_size); - - if (ibv_post_send(conn->cm_id->qp, &wr, &bad_wr)) { - PRISKV_RDMA_DEF_ADDR(conn->cm_id) - priskv_log_error("RDMA: <%s - %s> ibv_post_send RDMA failed: %m\n", local_addr, - peer_addr); - free(work); - return -errno; - } - - priskv_log_debug("RDMA: %s [%d/%d]:[%d/%d] wr_id 0x%lx, val %p, length 0x%x, addr 0x%lx, " - "rkey 0x%x\n", - cmdstr, i, nsgl, sgl_offset, sgl_length, wr.wr_id, - val + offset + sgl_offset, sge.length, wr.wr.rdma.remote_addr, - wr.wr.rdma.rkey); - sgl_offset += sge.length; - - work->nsgl++; - } while (sgl_offset < priskv_min_u32(sgl_length, valuelen)); - - offset += sgl_length; - valuelen -= sgl_length; - } - - if (work_out) { - *work_out = work; - } - - return 0; -} - -static int priskv_rdma_recv_req(priskv_rdma_conn *conn, uint8_t *req); - -static int priskv_rdma_complete_rw_work(priskv_rdma_rw_work *work, priskv_resp_status status, - uint32_t length); - -static void priskv_tiering_req_repost_recv(priskv_tiering_req *treq) -{ - if (!treq || treq->recv_reposted) { - return; - } - - if (!treq->conn || !treq->req) { - treq->recv_reposted = true; - return; - } - - priskv_rdma_recv_req(treq->conn, (uint8_t *)treq->req); - - treq->recv_reposted = true; -} - -static void priskv_tiering_req_free(priskv_tiering_req *treq) -{ - if (!treq) { - return; - } - - free(treq->key); - free(treq); -} - -static priskv_tiering_req *priskv_tiering_req_new(priskv_rdma_conn *conn, priskv_request *req, - uint8_t *key, uint16_t keylen, uint64_t timeout, - priskv_req_command cmd, uint32_t remote_valuelen, - priskv_resp_status *resp_status) -{ - priskv_resp_status status = PRISKV_RESP_STATUS_NO_MEM; - priskv_tiering_req *treq = calloc(1, sizeof(priskv_tiering_req)); - priskv_thread *thread = NULL; - priskv_backend_device *backend = NULL; - - if (!treq) { - goto error; - } - - thread = conn->c.thread; - backend = priskv_get_thread_backend(thread); - if (!backend) { - priskv_log_error("RDMA: Backend is not initialized"); - status = PRISKV_RESP_STATUS_SERVER_ERROR; - goto error; - } - - treq->key = malloc(keylen + 1); - if (!treq->key) { - goto error; - } - memcpy(treq->key, key, keylen); - treq->key[keylen] = '\0'; - - list_node_init(&treq->node); - treq->conn = conn; - treq->thread = thread; - treq->backend = backend; - treq->kv = conn->kv; - treq->req = req; - treq->request_id = req->request_id; - treq->keylen = keylen; - treq->timeout = timeout; - treq->cmd = cmd; - treq->remote_valuelen = remote_valuelen; - treq->valuelen = 0; - treq->execute = false; - treq->recv_reposted = false; - treq->hash_head_index = priskv_crc32(key, keylen) % priskv_get_bucket_count(conn->kv); - treq->backend_status = PRISKV_BACKEND_STATUS_ERROR; - - if (resp_status) { - *resp_status = PRISKV_RESP_STATUS_OK; - } - return treq; - -error: - free(treq->key); - free(treq); - if (resp_status) { - *resp_status = status; - } - return NULL; -} - -static void priskv_tiering_finish(priskv_tiering_req *treq, priskv_resp_status status, uint32_t length) -{ - if (!treq) { - return; - } - - if (treq->execute) { - priskv_key_serialize_exit(treq); - treq->execute = false; - } - - if (treq->cmd >= 0 && treq->cmd < PRISKV_COMMAND_MAX) { - treq->conn->c.stats[treq->cmd].ops++; - if ((treq->cmd == PRISKV_COMMAND_TEST || treq->cmd == PRISKV_COMMAND_GET) && status == PRISKV_RESP_STATUS_OK) { - treq->conn->c.stats[treq->cmd].bytes += length; - } - } - - priskv_tiering_req_repost_recv(treq); - - if (treq->rdma_work) { - priskv_rdma_complete_rw_work(treq->rdma_work, status, length); - treq->rdma_work = NULL; - } else { - priskv_rdma_send_response(treq->conn, treq->request_id, status, length); - } - - priskv_tiering_req_free(treq); -} - -static void priskv_tiering_get_rdma_complete_cb(void *arg) -{ - priskv_tiering_req *treq = arg; - assert(treq); - - priskv_get_key_end(treq->keynode); - - if (treq->cmd == PRISKV_COMMAND_GET) { - treq->conn->c.stats[treq->cmd].ops++; - treq->conn->c.stats[treq->cmd].bytes += treq->valuelen; - } - - priskv_tiering_req_repost_recv(treq); - - priskv_tiering_req_free(treq); -} - -static void priskv_tiering_get_backend_cb(priskv_backend_status backend_status, uint32_t valuelen, void *arg) -{ - priskv_tiering_req *treq = arg; - if (!treq) { - return; - } - - priskv_resp_status resp_status; - treq->backend_status = backend_status; - - switch (backend_status) { - case PRISKV_BACKEND_STATUS_OK: { - uint8_t *val = NULL; - uint32_t cached_length = 0; - void *keynode = NULL; - - // update value according to backend valuelen - priskv_update_valuelen(treq->keynode, valuelen); - treq->valuelen = valuelen; - - // there are value buf and valuelen in treq, here just use priskv_get_key to inc ref of keynode - priskv_get_key(treq->kv, treq->key, treq->keylen, &val, &cached_length, &keynode); - assert(keynode == treq->keynode); - - // Relaunch the next request to allow multiple GETs to execute in parallel - if (treq->execute) { - priskv_key_serialize_exit(treq); - treq->execute = false; - } - - if (priskv_rdma_rw_req(treq->conn, treq->req, treq->conn->value_mr, treq->value, - treq->valuelen, false, priskv_tiering_get_rdma_complete_cb, treq, false, - NULL)) { - priskv_tiering_finish(treq, PRISKV_RESP_STATUS_SERVER_ERROR, 0); - return; - } - return; - } - case PRISKV_BACKEND_STATUS_VALUE_TOO_BIG: - resp_status = PRISKV_RESP_STATUS_VALUE_TOO_BIG; - break; - case PRISKV_BACKEND_STATUS_NOT_FOUND: - resp_status = PRISKV_RESP_STATUS_NO_SUCH_KEY; - break; - default: - resp_status = PRISKV_RESP_STATUS_SERVER_ERROR; - break; - } - - priskv_delete_key(treq->kv, treq->key, treq->keylen); - priskv_tiering_finish(treq, resp_status, 0); - return; -} - -void priskv_tiering_get(priskv_tiering_req *treq) -{ - - priskv_resp_status status; - uint8_t *val = NULL; - uint32_t valuelen = 0; - void *keynode = NULL; - - assert(treq); - - if (!treq->execute) { - if (!priskv_key_serialize_enter(treq)) { - return; - } - treq->execute = true; - } - - // In tiering mode, priskv_get_key will not return HPKV_RESP_STATUS_KEY_UPDATING - status = priskv_get_key(treq->kv, treq->key, treq->keylen, &val, &valuelen, &keynode); - if (status == PRISKV_RESP_STATUS_OK && keynode) { - treq->value = val; - treq->valuelen = valuelen; - treq->keynode = keynode; - - if (treq->remote_valuelen < treq->valuelen) { - priskv_tiering_finish(treq, PRISKV_RESP_STATUS_VALUE_TOO_BIG, treq->valuelen); - return; - } - - if (priskv_rdma_rw_req(treq->conn, treq->req, treq->conn->value_mr, treq->value, - treq->valuelen, false, priskv_tiering_get_rdma_complete_cb, treq, false, - NULL)) { - priskv_get_key_end(keynode); - priskv_tiering_finish(treq, PRISKV_RESP_STATUS_SERVER_ERROR, 0); - return; - } - - // Relaunch the next request to allow multiple GETs to execute in parallel - if (treq->execute) { - priskv_key_serialize_exit(treq); - treq->execute = false; - } - - return; - } - - status = priskv_set_key(treq->kv, treq->key, treq->keylen, &val, treq->remote_valuelen, treq->timeout, - &keynode); - - // no other requests can access this keynode, for simplicity's sake, execute priskv_set_key_end here. - priskv_set_key_end(keynode); - - if (status != PRISKV_RESP_STATUS_OK || !keynode) { - priskv_tiering_finish(treq, status, 0); - return; - } - - treq->value = val; - treq->keynode = keynode; - - if (!treq->backend) { - priskv_delete_key(treq->kv, treq->key, treq->keylen); - priskv_tiering_finish(treq, PRISKV_RESP_STATUS_SERVER_ERROR, 0); - return; - } - - priskv_backend_get(treq->backend, (const char *)treq->key, treq->value, treq->remote_valuelen, - priskv_tiering_get_backend_cb, treq); -} - -static void priskv_tiering_test_backend_cb(priskv_backend_status status, uint32_t valuelen, void *arg) -{ - priskv_tiering_req *treq = arg; - priskv_resp_status resp_status; - uint32_t length = 0; - - treq->backend_status = status; - - switch (status) { - case PRISKV_BACKEND_STATUS_OK: - resp_status = PRISKV_RESP_STATUS_OK; - length = valuelen; - break; - case PRISKV_BACKEND_STATUS_NOT_FOUND: - resp_status = PRISKV_RESP_STATUS_NO_SUCH_KEY; - break; - default: - resp_status = PRISKV_RESP_STATUS_SERVER_ERROR; - break; - } - - priskv_tiering_finish(treq, resp_status, length); -} - -void priskv_tiering_test(priskv_tiering_req *treq) -{ - priskv_resp_status status; - uint8_t *val = NULL; - uint32_t valuelen = 0; - void *keynode = NULL; - - if (!treq) { - return; - } - - if (!treq->execute) { - if (!priskv_key_serialize_enter(treq)) { - return; - } - treq->execute = true; - } - - // In tiering mode, priskv_get_key will not return HPKV_RESP_STATUS_KEY_UPDATING - status = priskv_get_key(treq->kv, treq->key, treq->keylen, &val, &valuelen, &keynode); - - if (status == PRISKV_RESP_STATUS_OK) { - priskv_tiering_finish(treq, PRISKV_RESP_STATUS_OK, valuelen); - priskv_get_key_end(keynode); - return; - } - - priskv_backend_test(treq->backend, (const char *)treq->key, priskv_tiering_test_backend_cb, treq); -} - -static void priskv_tiering_set_backend_cb(priskv_backend_status status, uint32_t valuelen, void *arg) -{ - priskv_tiering_req *treq = arg; - if (!treq) { - return; - } - - priskv_resp_status resp_status; - uint32_t length = 0; - - // here delete new key-value - priskv_delete_key(treq->kv, treq->key, treq->keylen); - - switch (status) { - case PRISKV_BACKEND_STATUS_OK: - resp_status = PRISKV_RESP_STATUS_OK; - length = treq->valuelen; - break; - case PRISKV_BACKEND_STATUS_NO_SPACE: - resp_status = PRISKV_RESP_STATUS_NO_MEM; - break; - case PRISKV_BACKEND_STATUS_ERROR: - default: - resp_status = PRISKV_RESP_STATUS_SERVER_ERROR; - break; - } - - priskv_tiering_finish(treq, resp_status, length); -} - -static void priskv_tiering_set_rdma_complete_cb(void *arg) -{ - priskv_tiering_req *treq = arg; - - if (!treq) { - return; - } - - priskv_set_key_end(treq->keynode); - - if (!treq->backend) { - priskv_delete_key(treq->kv, treq->key, treq->keylen); - priskv_tiering_finish(treq, PRISKV_RESP_STATUS_SERVER_ERROR, 0); - return; - } - - priskv_backend_set(treq->backend, (const char *)treq->key, treq->value, treq->remote_valuelen, - treq->timeout, priskv_tiering_set_backend_cb, treq); -} - -void priskv_tiering_set(priskv_tiering_req *treq) -{ - assert(treq); - priskv_resp_status status; - - if (!treq->execute) { - if (!priskv_key_serialize_enter(treq)) { - return; - } - treq->execute = true; - } - - // delete old key-value here - status = priskv_set_key(treq->kv, treq->key, treq->keylen, &treq->value, treq->remote_valuelen, - treq->timeout, &treq->keynode); - if (status != PRISKV_RESP_STATUS_OK || !treq->keynode) { - priskv_set_key_end(treq->keynode); - priskv_tiering_finish(treq, status, 0); - return; - } - - if (priskv_rdma_rw_req(treq->conn, treq->req, treq->conn->value_mr, treq->value, treq->remote_valuelen, - true, priskv_tiering_set_rdma_complete_cb, treq, true, &treq->rdma_work)) { - priskv_set_key_end(treq->keynode); - priskv_delete_key(treq->kv, treq->key, treq->keylen); - priskv_tiering_finish(treq, PRISKV_RESP_STATUS_NO_MEM, 0); - return; - } -} - -static void priskv_tiering_del_backend_cb(priskv_backend_status status, uint32_t valuelen, void *arg) -{ - priskv_tiering_req *treq = arg; - priskv_resp_status resp_status; - - assert(treq); - - priskv_delete_key(treq->kv, treq->key, treq->keylen); - - treq->backend_status = status; - switch (status) { - case PRISKV_BACKEND_STATUS_OK: - resp_status = PRISKV_RESP_STATUS_OK; - break; - case PRISKV_BACKEND_STATUS_NOT_FOUND: - resp_status = PRISKV_RESP_STATUS_NO_SUCH_KEY; - break; - case PRISKV_BACKEND_STATUS_ERROR: - default: - resp_status = PRISKV_RESP_STATUS_SERVER_ERROR; - break; - } - - priskv_tiering_finish(treq, resp_status, 0); -} - -void priskv_tiering_del(priskv_tiering_req *treq) -{ - assert(treq); - - if (!treq->execute) { - if (!priskv_key_serialize_enter(treq)) { - return; - } - treq->execute = true; - } - - priskv_backend_del(treq->backend, (const char *)treq->key, priskv_tiering_del_backend_cb, treq); -} - -int priskv_backend_req_resubmit(void *req) -{ - priskv_tiering_req *treq = (priskv_tiering_req *)req; - - switch (treq->cmd) { - case PRISKV_COMMAND_GET: - priskv_tiering_get(treq); - break; - case PRISKV_COMMAND_SET: - priskv_tiering_set(treq); - break; - case PRISKV_COMMAND_TEST: - priskv_tiering_test(treq); - break; - case PRISKV_COMMAND_DELETE: - priskv_tiering_del(treq); - break; - default: - priskv_log_error("Invalid backend request type: %d", treq->cmd); - return -1; - } - - return 0; -} - -static void priskv_check_and_log_slow_query(priskv_rdma_rw_work *work) -{ - struct timeval server_resp_send_time; - priskv_request *req = (priskv_request *)work->req; - uint16_t command = be16toh(req->command); - uint16_t nsgl = be16toh(req->nsgl); - uint8_t *key = priskv_request_key(req, nsgl); - uint16_t keylen = be16toh(req->key_length); - char key_short[128] = {0}; - priskv_string_shorten((const char *)key, keylen, key_short, sizeof(key_short)); - - gettimeofday(&server_resp_send_time, NULL); - req->runtime.server_resp_send_time = server_resp_send_time; - if (priskv_time_elapsed_us(&req->runtime.client_metadata_send_time, - &req->runtime.server_resp_send_time) > - g_slow_query_threshold_latency_us) { - priskv_log_notice( - "Slow Query Encountered . " - "Slow Query threshold latency is %ld us |" - "Command %s key[%u] = \"%s\" |" - "thread id is %lu |" - "Client send metadata: %ld.%06ld us | " - "Server recv metadata: %ld.%06ld us | " - "Server RW KV: %ld.%06ld us | " - "Server send data: %ld.%06ld us | " - "Server recv data: %ld.%06ld us | " - "Server send resp: %ld.%06ld us | " - "Total: %ld us | " - "Steps: " - "Client->Server metadata: %ld us | " - "Server metadata->RW KV: %ld us | " - "RW KV->Send data: %ld us | " - "Send data->Recv data: %ld us | " - "Recv data->Resp send: %ld us \n", - - g_slow_query_threshold_latency_us, priskv_command_str(command), keylen, key_short, - pthread_self(), req->runtime.client_metadata_send_time.tv_sec, - req->runtime.client_metadata_send_time.tv_usec, - req->runtime.server_metadata_recv_time.tv_sec, - req->runtime.server_metadata_recv_time.tv_usec, req->runtime.server_rw_kv_time.tv_sec, - req->runtime.server_rw_kv_time.tv_usec, req->runtime.server_data_send_time.tv_sec, - req->runtime.server_data_send_time.tv_usec, req->runtime.server_data_recv_time.tv_sec, - req->runtime.server_data_recv_time.tv_usec, req->runtime.server_resp_send_time.tv_sec, - req->runtime.server_resp_send_time.tv_usec, - - priskv_time_elapsed_us(&req->runtime.client_metadata_send_time, - &req->runtime.server_resp_send_time), - - priskv_time_elapsed_us(&req->runtime.client_metadata_send_time, - &req->runtime.server_metadata_recv_time), - priskv_time_elapsed_us(&req->runtime.server_metadata_recv_time, - &req->runtime.server_rw_kv_time), - priskv_time_elapsed_us(&req->runtime.server_rw_kv_time, - &req->runtime.server_data_send_time), - priskv_time_elapsed_us(&req->runtime.server_data_send_time, - &req->runtime.server_data_recv_time), - priskv_time_elapsed_us(&req->runtime.server_data_recv_time, - &req->runtime.server_resp_send_time)); - } -} - -// RDMA READ/WRITE handler's bottom half -static int priskv_rdma_complete_rw_work(priskv_rdma_rw_work *work, priskv_resp_status status, - uint32_t length) -{ - if (!work) { - return -EINVAL; - } - - priskv_rdma_conn *conn = work->conn; - - int ret = priskv_rdma_send_response(conn, work->request_id, status, length); - - if (work->mr != conn->value_mr) { - priskv_rdma_mem *rmem = &conn->rmem[PRISKV_RDMA_MEM_KEYS]; - assert(work->mr == rmem->mr); - - priskv_rdma_mem_free(conn, rmem); - priskv_log_debug("RDMA: KEYS done"); - } - - priskv_check_and_log_slow_query(work); - - free(work); - return ret; -} - -static int priskv_rdma_handle_rw(priskv_rdma_conn *conn, priskv_rdma_rw_work *work) -{ - work->completed++; - assert(work->completed <= work->nsgl); - - if (work->completed < work->nsgl) { - return 0; - } - - if (work->cb) { - work->cb(work->cbarg); - } - - if (work->defer_resp) { - return 0; - } - - return priskv_rdma_complete_rw_work(work, PRISKV_RESP_STATUS_OK, work->valuelen); -} - -static inline int priskv_rdma_handle_send(priskv_rdma_conn *conn, priskv_response *resp, uint32_t len) -{ - return priskv_rdma_response_free(resp); -} - -/* return negative number on failure, return received buffer size on success */ -static int priskv_rdma_recv_req(priskv_rdma_conn *conn, uint8_t *req) -{ - struct ibv_sge sge; - struct ibv_recv_wr recv_wr, *bad_wr; - priskv_rdma_mem *rmem = &conn->rmem[PRISKV_RDMA_MEM_REQ]; - uint32_t lkey = rmem->mr->lkey; - uint16_t req_buf_size = priskv_request_size_aligend(conn); - int ret; - - assert((req >= rmem->buf) && (req < rmem->buf + rmem->buf_size)); - sge.addr = (uint64_t)req; - sge.length = req_buf_size; - sge.lkey = lkey; - - recv_wr.wr_id = (uint64_t)req; - recv_wr.sg_list = &sge; - recv_wr.num_sge = 1; - recv_wr.next = NULL; - - priskv_log_debug("RDMA: ibv_post_recv addr %p, length %d\n", req, req_buf_size); - ret = ibv_post_recv(conn->cm_id->qp, &recv_wr, &bad_wr); - if (ret) { - PRISKV_RDMA_DEF_ADDR(conn->cm_id) - priskv_log_error("RDMA: <%s - %s> ibv_post_recv failed: %m\n", local_addr, peer_addr); - return -errno; - } - - return req_buf_size; -} - -static int priskv_rdma_handle_recv(priskv_rdma_conn *conn, priskv_request *req, uint32_t len) -{ - uint16_t command = be16toh(req->command); - uint16_t nsgl = be16toh(req->nsgl); - uint64_t timeout = be64toh(req->timeout); - uint8_t *key; - uint16_t keylen; - uint16_t keyoff = priskv_request_key_off(nsgl); - uint8_t *val; - uint32_t valuelen = 0, nkeys = 0; - uint32_t remote_valuelen; - uint64_t bytes = 0; - void *keynode; - priskv_resp_status status; - int ret = 0; - bool tiering_inflight = false; - priskv_rdma_mem *rmem = &conn->rmem[PRISKV_RDMA_MEM_KEYS]; - PRISKV_RDMA_DEF_ADDR(conn->cm_id) - - if (len < keyoff) { - priskv_log_warn("RDMA: <%s - %s> invalid command. recv %d, less than %d, nsgl 0x%x\n", - local_addr, peer_addr, len, keyoff, nsgl); - priskv_rdma_send_response(conn, req->request_id, PRISKV_RESP_STATUS_INVALID_COMMAND, 0); - return -EPROTO; - } - - keylen = len - keyoff; - if (!keylen) { - priskv_log_warn("RDMA: <%s - %s> empty key. recv %d, less than %d, nsgl 0x%x\n", local_addr, - peer_addr, len, keyoff, nsgl); - - priskv_rdma_send_response(conn, req->request_id, PRISKV_RESP_STATUS_KEY_EMPTY, 0); - return -EPROTO; - } - - if (keylen > conn->conn_cap.max_key_length) { - priskv_log_warn("RDMA: <%s - %s> invalid key. key(%d) exceeds max_key_length(%d)\n", - local_addr, peer_addr, keylen, conn->conn_cap.max_key_length); - priskv_rdma_send_response(conn, req->request_id, PRISKV_RESP_STATUS_KEY_TOO_BIG, 0); - return -EPROTO; - } - - if (nsgl > conn->conn_cap.max_sgl) { - priskv_log_warn("RDMA: <%s - %s> invalid nsgl. nsgl(%d) exceeds max_sgl(%d)\n", local_addr, - peer_addr, nsgl, conn->conn_cap.max_sgl); - priskv_rdma_send_response(conn, req->request_id, PRISKV_RESP_STATUS_INVALID_SGL, 0); - return -EPROTO; - } - - key = priskv_request_key(req, nsgl); - - if (priskv_get_log_level() >= priskv_log_debug) { - char key_short[128] = {0}; - priskv_string_shorten((const char *)key, keylen, key_short, sizeof(key_short)); - priskv_log_debug("RDMA: <%s - %s> %s key[%u] = \"%s\"\n", local_addr, peer_addr, - priskv_command_str(command), keylen, key_short); - } - - switch (command) { - case PRISKV_COMMAND_GET: { - struct timeval server_rw_kv_time, server_data_send_time; - remote_valuelen = priskv_sgl_size_from_be(req->sgls, nsgl); - - if (!priskv_backend_tiering_enabled()) { - status = priskv_get_key(conn->kv, key, keylen, &val, &valuelen, &keynode); - if (status != PRISKV_RESP_STATUS_OK || !keynode) { - ret = priskv_rdma_send_response(conn, req->request_id, status, 0); - priskv_get_key_end(keynode); - break; - } - - gettimeofday(&server_rw_kv_time, NULL); - req->runtime.server_rw_kv_time = server_rw_kv_time; - - if (remote_valuelen < valuelen) { - ret = priskv_rdma_send_response(conn, req->request_id, PRISKV_RESP_STATUS_VALUE_TOO_BIG, - valuelen); - priskv_get_key_end(keynode); - break; - } - - ret = priskv_rdma_rw_req(conn, req, conn->value_mr, val, valuelen, false, - priskv_get_key_end, keynode, false, NULL); - - gettimeofday(&server_data_send_time, NULL); - req->runtime.server_data_send_time = server_data_send_time; - - bytes = valuelen; - } else { - priskv_resp_status alloc_status = PRISKV_RESP_STATUS_OK; - priskv_tiering_req *treq = priskv_tiering_req_new(conn, req, key, keylen, PRISKV_KEY_MAX_TIMEOUT, - PRISKV_COMMAND_GET, remote_valuelen, - &alloc_status); - if (!treq) { - ret = priskv_rdma_send_response(conn, req->request_id, alloc_status, 0); - break; - } - - tiering_inflight = true; - priskv_tiering_get(treq); - } - break; - } - case PRISKV_COMMAND_SET: { - struct timeval server_rw_kv_time, server_data_send_time; - - remote_valuelen = priskv_sgl_size_from_be(req->sgls, nsgl); - if (!remote_valuelen) { - ret = priskv_rdma_send_response(conn, req->request_id, PRISKV_RESP_STATUS_VALUE_EMPTY, 0); - break; - } - - if (!priskv_backend_tiering_enabled()) { - status = priskv_set_key(conn->kv, key, keylen, &val, remote_valuelen, timeout, &keynode); - if (status != PRISKV_RESP_STATUS_OK || !keynode) { - ret = priskv_rdma_send_response(conn, req->request_id, status, 0); - priskv_set_key_end(keynode); - break; - } - - gettimeofday(&server_rw_kv_time, NULL); - req->runtime.server_rw_kv_time = server_rw_kv_time; - - ret = priskv_rdma_rw_req(conn, req, conn->value_mr, val, remote_valuelen, true, - priskv_set_key_end, keynode, false, NULL); - - gettimeofday(&server_data_send_time, NULL); - req->runtime.server_data_send_time = server_data_send_time; - - bytes = remote_valuelen; - } else { - priskv_resp_status alloc_status = PRISKV_RESP_STATUS_OK; - priskv_tiering_req *treq = priskv_tiering_req_new(conn, req, key, keylen, timeout, - PRISKV_COMMAND_SET, remote_valuelen, - &alloc_status); - if (!treq) { - ret = priskv_rdma_send_response(conn, req->request_id, alloc_status, 0); - break; - } - - tiering_inflight = true; - priskv_tiering_set(treq); - } - break; - } - - case PRISKV_COMMAND_TEST: { - if (!priskv_backend_tiering_enabled()) { - status = priskv_get_key(conn->kv, key, keylen, &val, &valuelen, &keynode); - ret = priskv_rdma_send_response(conn, req->request_id, status, valuelen); - priskv_get_key_end(keynode); - break; - } - - priskv_resp_status alloc_status = PRISKV_RESP_STATUS_OK; - priskv_tiering_req *treq = priskv_tiering_req_new(conn, req, key, keylen, timeout, - PRISKV_COMMAND_TEST, 0, &alloc_status); - if (!treq) { - ret = priskv_rdma_send_response(conn, req->request_id, alloc_status, 0); - break; - } - - tiering_inflight = true; - priskv_tiering_test(treq); - break; - } - - case PRISKV_COMMAND_DELETE: { - if (!priskv_backend_tiering_enabled()) { - status = priskv_delete_key(conn->kv, key, keylen); - ret = priskv_rdma_send_response(conn, req->request_id, status, 0); - break; - } - - priskv_resp_status alloc_status = PRISKV_RESP_STATUS_OK; - priskv_tiering_req *treq = priskv_tiering_req_new(conn, req, key, keylen, timeout, - PRISKV_COMMAND_DELETE, 0, &alloc_status); - if (!treq) { - ret = priskv_rdma_send_response(conn, req->request_id, alloc_status, 0); - break; - } - - tiering_inflight = true; - priskv_tiering_del(treq); - break; - } - - case PRISKV_COMMAND_EXPIRE: - status = priskv_expire_key(conn->kv, key, keylen, timeout); - ret = priskv_rdma_send_response(conn, req->request_id, status, 0); - break; - - case PRISKV_COMMAND_KEYS: - if (rmem->mr) { - /* a single KEYS command is allowed inflight with a connection */ - priskv_rdma_send_response(conn, req->request_id, PRISKV_RESP_STATUS_NO_MEM, 0); - ret = 0; - break; - } - - remote_valuelen = priskv_sgl_size_from_be(req->sgls, nsgl); - if (priskv_rdma_mem_new(conn, rmem, "Keys", remote_valuelen)) { - ret = priskv_rdma_send_response(conn, req->request_id, PRISKV_RESP_STATUS_NO_MEM, valuelen); - break; - } - - status = - priskv_get_keys(conn->kv, key, keylen, rmem->buf, remote_valuelen, &valuelen, &nkeys); - if ((status != PRISKV_RESP_STATUS_OK) || !valuelen) { - priskv_rdma_mem_free(conn, rmem); - ret = priskv_rdma_send_response(conn, req->request_id, status, valuelen); - break; - } - - ret = priskv_rdma_rw_req(conn, req, rmem->mr, rmem->buf, valuelen, false, NULL, NULL, false, - NULL); - if (ret) { - priskv_rdma_mem_free(conn, rmem); - ret = priskv_rdma_send_response(conn, req->request_id, status, valuelen); - } - break; - - case PRISKV_COMMAND_NRKEYS: - status = priskv_get_keys(conn->kv, key, keylen, NULL, 0, &valuelen, &nkeys); - /* PRISKV_RESP_STATUS_VALUE_TOO_BIG is expected */ - if (status == PRISKV_RESP_STATUS_VALUE_TOO_BIG) { - ret = priskv_rdma_send_response(conn, req->request_id, PRISKV_RESP_STATUS_OK, nkeys); - break; - } - ret = priskv_rdma_send_response(conn, req->request_id, status, 0); - break; - - case PRISKV_COMMAND_FLUSH: - status = priskv_flush_keys(conn->kv, key, keylen, &nkeys); - ret = priskv_rdma_send_response(conn, req->request_id, status, nkeys); - break; - - default: - priskv_log_warn("RDMA: <%s - %s> unknown command %d\n", local_addr, peer_addr, command); - priskv_rdma_send_response(conn, req->request_id, PRISKV_RESP_STATUS_NO_SUCH_COMMAND, 0); - ret = -EPROTO; - } - - if (!tiering_inflight) { - conn->c.stats[command].ops++; - if (!ret) { - priskv_rdma_recv_req(conn, (uint8_t *)req); - conn->c.stats[command].bytes += bytes; - } - } - - return ret; -} - -static void priskv_rdma_handle_cq(int fd, void *opaque, uint32_t events) -{ - priskv_rdma_conn *conn = opaque; - struct ibv_cq *ev_cq = NULL; - void *ev_ctx = NULL; - struct ibv_wc wc; - priskv_request *req; - priskv_rdma_rw_work *work; - priskv_response *resp; - int ret; - - assert(conn->comp_channel->fd == fd); - - if (ibv_get_cq_event(conn->comp_channel, &ev_cq, &ev_ctx) < 0) { - if (errno != EAGAIN) { - priskv_log_warn("RDMA: ibv_get_cq_event failed: %m\n"); - } - goto error_close; - } else if (ibv_req_notify_cq(ev_cq, 0)) { - priskv_log_warn("RDMA: ibv_req_notify_cq failed: %m\n"); - goto error_close; - } - - ibv_ack_cq_events(conn->cq, 1); - -poll_cq: - ret = ibv_poll_cq(conn->cq, 1, &wc); - if (ret < 0) { - priskv_log_warn("RDMA: ibv_poll_cq failed: %m\n"); - goto error_close; - } else if (ret == 0) { - return; - } - - priskv_log_debug("RDMA: CQ handle status: %s[0x%x], wr_id: %p, opcode: 0x%x, byte_len: %u\n", - ibv_wc_status_str(wc.status), wc.status, (void *)wc.wr_id, wc.opcode, - wc.byte_len); - if (wc.status != IBV_WC_SUCCESS) { - PRISKV_RDMA_DEF_ADDR(conn->cm_id) - priskv_log_error("RDMA: <%s - %s> CQ error status: wr_id 0x%lx, %s[0x%x], opcode : 0x%x, " - "byte_len : %ld\n", - local_addr, peer_addr, wc.wr_id, ibv_wc_status_str(wc.status), wc.status, - wc.opcode, wc.byte_len); - if (wc.status == IBV_WC_LOC_QP_OP_ERR) { - priskv_log_error("RDMA: possible remote command size exceeds\n"); - } - goto error_close; - } - - switch (wc.opcode) { - case IBV_WC_RECV: { - struct timeval server_metadata_recv_time; - - req = (priskv_request *)wc.wr_id; - - gettimeofday(&server_metadata_recv_time, NULL); - req->runtime.server_metadata_recv_time = server_metadata_recv_time; - - if (priskv_rdma_handle_recv(conn, req, wc.byte_len)) { - goto error_close; - } - break; - } - - case IBV_WC_RDMA_READ: - case IBV_WC_RDMA_WRITE: { - struct timeval server_data_recv_time; - - work = (priskv_rdma_rw_work *)wc.wr_id; - req = (priskv_request *)work->req; - - gettimeofday(&server_data_recv_time, NULL); - req->runtime.server_data_recv_time = server_data_recv_time; - - if (priskv_rdma_handle_rw(conn, work)) { - goto error_close; - } - - break; - } - - case IBV_WC_SEND: { - resp = (priskv_response *)wc.wr_id; - priskv_rdma_handle_send(conn, resp, wc.byte_len); - break; - } - - default: - priskv_log_error("unexpected opcode 0x%x", wc.opcode); - goto error_close; - } - - goto poll_cq; - -error_close: - priskv_rdma_close_client_async(conn); -} - -static void priskv_rdma_reject(struct rdma_cm_id *cm_id, uint16_t status, uint64_t val) -{ - priskv_rdma_cm_rej rej = {0}; - - rej.version = htobe16(PRISKV_RDMA_CM_VERSION); - rej.status = htobe16(status); - rej.value = htobe64(val); - - rdma_reject(cm_id, &rej, sizeof(priskv_rdma_cm_rej)); -} - -static int priskv_rdma_resp(priskv_rdma_conn *client, struct rdma_cm_id *cm_id) -{ - void *kv = client->c.listener->kv; - uint64_t capacity = priskv_get_value_blocks(kv) * priskv_get_value_block_size(kv); - - priskv_rdma_cm_rep rep = {0}; - rep.version = htobe16(PRISKV_RDMA_CM_VERSION); - rep.max_sgl = htobe16(client->conn_cap.max_sgl); - rep.max_key_length = htobe16(client->conn_cap.max_key_length); - rep.max_inflight_command = htobe16(client->conn_cap.max_inflight_command); - rep.capacity = htobe64(capacity); - - struct rdma_conn_param resp_param = {0}; - resp_param.responder_resources = 1; - resp_param.initiator_depth = 1; - resp_param.retry_count = 5; - resp_param.private_data = &rep; - resp_param.private_data_len = sizeof(rep); - - int ret = rdma_accept(cm_id, &resp_param); - if (ret) { - PRISKV_RDMA_DEF_ADDR(client->cm_id) - priskv_log_error("RDMA: <%s - %s> rdma_accept failed: %m\n", local_addr, peer_addr); - } - - return ret; -} - -static int priskv_rdma_verify_conn_cap(priskv_rdma_conn_cap *client, priskv_rdma_conn_cap *listener, - uint64_t *val) -{ - if (!client->max_sgl) { - client->max_sgl = listener->max_sgl; - } else if (client->max_sgl > listener->max_sgl) { - *val = listener->max_sgl; - return PRISKV_RDMA_CM_REJ_STATUS_INVALID_SGL; - } - - if (!client->max_key_length) { - client->max_key_length = listener->max_key_length; - } else if (client->max_key_length > listener->max_key_length) { - *val = listener->max_key_length; - return PRISKV_RDMA_CM_REJ_STATUS_INVALID_KEY_LENGTH; - } - - if (!client->max_inflight_command) { - client->max_inflight_command = listener->max_inflight_command; - } else if (client->max_inflight_command > listener->max_inflight_command) { - *val = listener->max_inflight_command; - return PRISKV_RDMA_CM_REJ_STATUS_INVALID_INFLIGHT_COMMAND; - } - - return 0; -} - -static void priskv_rdma_handle_connect_request(struct rdma_cm_event *ev, priskv_rdma_conn *listener) -{ - priskv_rdma_conn *client; - struct rdma_cm_id *id = ev->id; - struct ibv_qp_init_attr init_attr = {0}; - struct rdma_conn_param *req_param = &ev->param.conn; - unsigned char exp_len = sizeof(struct rdma_conn_param) + sizeof(priskv_rdma_cm_req); - priskv_rdma_cm_status status; - uint64_t value = 0; - - PRISKV_RDMA_DEF_ADDR(id); - - client = calloc(1, sizeof(priskv_rdma_conn)); - assert(client); - id->context = client; - client->cm_id = id; - client->c.listener = listener; - client->c.thread = NULL; - client->c.closing = false; - list_node_init(&client->c.node); - pthread_spin_init(&client->lock, 0); - - pthread_spin_lock(&listener->lock); - list_add_tail(&listener->s.head, &client->c.node); - listener->s.nclients++; - pthread_spin_unlock(&listener->lock); - - /* #step0, ACL verification */ - if (priskv_acl_verify(rdma_get_peer_addr(id))) { - priskv_log_error("RDMA: <%s - %s> ACL verification failed\n", local_addr, peer_addr); - status = PRISKV_RDMA_CM_REJ_STATUS_ACL_REFUSE; - value = 0; - goto rej; - } - - /* #step1, check incoming request parameters */ - if (req_param->private_data_len != exp_len) { - priskv_log_error("RDMA: <%s - %s> unexpected CM REQ length %d, expetected %d\n", local_addr, - peer_addr, req_param->private_data_len, exp_len); - status = PRISKV_RDMA_CM_REJ_STATUS_INVALID_CM_REP; - value = exp_len; - goto rej; - } - - const priskv_rdma_cm_req *req = req_param->private_data; - uint16_t version = be16toh(req->version); - if (version != PRISKV_RDMA_CM_VERSION) { - status = PRISKV_RDMA_CM_REJ_STATUS_INVALID_VERSION; - value = PRISKV_RDMA_CM_VERSION; - goto rej; - } - - client->conn_cap.max_sgl = be16toh(req->max_sgl); - client->conn_cap.max_key_length = be16toh(req->max_key_length); - client->conn_cap.max_inflight_command = be16toh(req->max_inflight_command); - priskv_log_info("RDMA: <%s - %s> incoming connect request - version %d, max_sgl %d, " - "max_key_length %d, max_inflight_command %d\n", - local_addr, peer_addr, version, client->conn_cap.max_sgl, - client->conn_cap.max_key_length, client->conn_cap.max_inflight_command); - - status = priskv_rdma_verify_conn_cap(&client->conn_cap, &listener->conn_cap, &value); - if (status) { - goto rej; - } - - /* #step2, create QP and related resources */ - client->comp_channel = ibv_create_comp_channel(id->verbs); - if (!client->comp_channel) { - priskv_log_error("RDMA: <%s - %s> ibv_create_comp_channel failed: %m\n", local_addr, - peer_addr); - status = PRISKV_RDMA_CM_REJ_STATUS_SERVER_ERROR; - goto rej; - } - - priskv_set_nonblock(client->comp_channel->fd); - uint32_t wr_size = priskv_rdma_wr_size(client); - client->cq = ibv_create_cq(id->verbs, wr_size * 2 * 4, NULL, client->comp_channel, 0); - if (!client->cq) { - priskv_log_error("RDMA: <%s - %s> ibv_create_cq failed: %m\n", local_addr, peer_addr); - status = PRISKV_RDMA_CM_REJ_STATUS_SERVER_ERROR; - goto rej; - } - - ibv_req_notify_cq(client->cq, 0); - - init_attr.cap.max_send_wr = wr_size * 4; - init_attr.cap.max_recv_wr = wr_size * 4; - init_attr.cap.max_send_sge = 1; - init_attr.cap.max_recv_sge = 1; - init_attr.qp_type = IBV_QPT_RC; - init_attr.send_cq = client->cq; - init_attr.recv_cq = client->cq; - if (rdma_create_qp(id, NULL, &init_attr)) { - priskv_log_error("RDMA: <%s - %s> rdma_create_qp failed: %m\n", local_addr, peer_addr); - status = PRISKV_RDMA_CM_REJ_STATUS_SERVER_ERROR; - goto rej; - } - - /* #step3, create QP and related resources */ - if (priskv_rdma_new_ctrl_buffer(client)) { - status = PRISKV_RDMA_CM_REJ_STATUS_SERVER_ERROR; - goto rej; - } - - /* #step4, post recv all the request commands */ - uint8_t *recv_req = client->rmem[PRISKV_RDMA_MEM_REQ].buf; - for (uint16_t i = 0; i < wr_size; i++) { - int recvsize = priskv_rdma_recv_req(client, recv_req); - if (recvsize < 0) { - status = PRISKV_RDMA_CM_REJ_STATUS_SERVER_ERROR; - goto rej; - } - - recv_req += recvsize; - } - - /* #step5, accept the new client */ - if (priskv_rdma_resp(client, id)) { - goto close_client; - } - - priskv_log_info("RDMA: <%s - %s> accept connect request - version %d, max_sgl %d, " - "max_key_length %d, max_inflight_command %d\n", - local_addr, peer_addr, version, client->conn_cap.max_sgl, - client->conn_cap.max_key_length, client->conn_cap.max_inflight_command); - return; - -rej: - priskv_log_warn("RDMA: <%s - %s> %s, reject\n", local_addr, peer_addr, - priskv_rdma_cm_status_str(status)); - priskv_rdma_reject(id, status, value); - -close_client: - priskv_rdma_close_client_async(client); -} - -static void priskv_rdma_handle_established(struct rdma_cm_event *ev, priskv_rdma_conn *listener) -{ - struct rdma_cm_id *id = ev->id; - priskv_rdma_conn *client = id->context; - - PRISKV_RDMA_DEF_ADDR(id); - - /* initialize KV of client */ - client->value_base = listener->value_base; - client->kv = listener->kv; - client->value_mr = listener->value_mr; - - /* use the idlest worker thread handle CQ event(CM event is still handled by main thread) */ - priskv_set_fd_handler(client->comp_channel->fd, priskv_rdma_handle_cq, NULL, client); - client->c.thread = priskv_threadpool_find_iothread(g_threadpool); - priskv_thread_add_event_handler(client->c.thread, client->comp_channel->fd); - - priskv_log_notice("RDMA: <%s - %s> established\n", local_addr, peer_addr); - priskv_log_debug("RDMA: <%s - %s> assign CQ fd %d to thread %d\n", local_addr, peer_addr, - client->comp_channel->fd, client->c.thread); -} - -static void priskv_rdma_handle_disconnected(struct rdma_cm_event *ev, priskv_rdma_conn *listener) -{ - struct rdma_cm_id *id = ev->id; - priskv_rdma_conn *client = id->context; - - priskv_rdma_close_client_async(client); -} - -static void priskv_rdma_handle_cm(int fd, void *opaque, uint32_t events) -{ - priskv_rdma_conn *listener = opaque; - struct rdma_cm_event *ev; - int ret; - - assert(listener->cm_id->channel->fd == fd); - -again: - ret = rdma_get_cm_event(listener->cm_id->channel, &ev); - if (ret) { - if (errno != EAGAIN) { - priskv_log_error("RDMA: listener rdma_get_cm_event failed: %m\n"); - } - return; - } - - const char *evstr = rdma_event_str(ev->event); - char addrbuf[64] = {0}; - priskv_inet_ntop(rdma_get_local_addr(listener->cm_id), addrbuf); - priskv_log_debug("RDMA: listener<%s> cm event: %s\n", addrbuf, evstr); - - switch (ev->event) { - case RDMA_CM_EVENT_CONNECT_REQUEST: - priskv_rdma_handle_connect_request(ev, listener); - break; - - case RDMA_CM_EVENT_ESTABLISHED: - priskv_rdma_handle_established(ev, listener); - break; - - case RDMA_CM_EVENT_DISCONNECTED: - priskv_rdma_handle_disconnected(ev, listener); - break; - - default: - priskv_log_error("RDMA: listener<%s> listener unexpected cm event: %s\n", addrbuf, evstr); - } - - rdma_ack_cm_event(ev); - - goto again; -} - -void priskv_rdma_process(void) -{ - priskv_rdma_conn *listener; -#define PRISKV_EPOLL_MAX_CM_EVENT 32 - struct epoll_event events[PRISKV_EPOLL_MAX_CM_EVENT]; - int nevents; - - nevents = epoll_wait(g_server.epollfd, events, PRISKV_EPOLL_MAX_CM_EVENT, 1000); - if (!nevents) { - goto close_disconnected; - } - - if (nevents < 0) { - assert(errno == EINTR); - goto close_disconnected; - } - - for (int n = 0; n < nevents; n++) { - struct epoll_event *event = &events[n]; - priskv_fd_handler_event(event); - } - -close_disconnected: - for (int i = 0; i < g_server.nlisteners; i++) { - listener = &g_server.listeners[i]; - priskv_rdma_close_disconnected(listener); - } -} diff --git a/server/rdma.h b/server/rdma.h deleted file mode 100644 index 4921a79..0000000 --- a/server/rdma.h +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright (c) 2025 ByteDance Ltd. and/or its affiliates -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* - * Authors: - * Jinlong Xuan <15563983051@163.com> - * Xu Ji - * Yu Wang - * Bo Liu - * Zhenwei Pi - * Rui Zhang - * Changqi Lu - * Enhua Zhou - */ - -#ifndef __PRISKV_SERVER_RDMA__ -#define __PRISKV_SERVER_RDMA__ - -#if defined(__cplusplus) -extern "C" -{ -#endif - -#include -#include - -#include "priskv-protocol.h" - -#define PRISKV_RDMA_MAX_BIND_ADDR 32 -#define PRISKV_RDMA_DEFAULT_PORT ('H' << 8 | 'P') - -#define PRISKV_RDMA_MAX_INFLIGHT_COMMAND 4096 -#define PRISKV_RDMA_DEFAULT_INFLIGHT_COMMAND 128 -#define PRISKV_RDMA_MAX_SGL 8 -#define PRISKV_RDMA_DEFAULT_SGL 4 -#define PRISKV_RDMA_MAX_KEY (1 << 30) -#define PRISKV_RDMA_DEFAULT_KEY (16 * 1024) -#define PRISKV_RDMA_MAX_KEY_LENGTH 1024 -#define PRISKV_RDMA_DEFAULT_KEY_LENGTH 128 -#define PRISKV_RDMA_MAX_VALUE_BLOCK_SIZE (1 << 20) -#define PRISKV_RDMA_DEFAULT_VALUE_BLOCK_SIZE 4096 -#define PRISKV_RDMA_MAX_VALUE_BLOCK (1UL << 30) -#define PRISKV_RDMA_DEFAULT_VALUE_BLOCK (1024UL * 1024) -#define SLOW_QUERY_THRESHOLD_LATENCY_US 1000000 /* 1 second */ - -extern uint32_t g_slow_query_threshold_latency_us; - -typedef struct priskv_rdma_stats { - uint64_t ops; - uint64_t bytes; -} priskv_rdma_stats; - -typedef struct priskv_rdma_conn_cap { - uint16_t max_sgl; - uint16_t max_key_length; - uint16_t max_inflight_command; -} priskv_rdma_conn_cap; - -typedef struct priskv_rdma_client { - char address[PRISKV_ADDR_LEN]; - priskv_rdma_stats stats[PRISKV_COMMAND_MAX]; - uint64_t resps; - bool closing; -} priskv_rdma_client; - -typedef struct priskv_rdma_listener { - char address[PRISKV_ADDR_LEN]; - int nclients; - priskv_rdma_client *clients; -} priskv_rdma_listener; - -int priskv_rdma_listen(char **addr, int naddr, int port, void *kv, priskv_rdma_conn_cap *cap); -int priskv_rdma_get_fd(void); -void priskv_rdma_process(void); - -void *priskv_rdma_get_kv(void); - -priskv_rdma_listener *priskv_rdma_get_listeners(int *nlisteners); -void priskv_rdma_free_listeners(priskv_rdma_listener *listeners, int nlisteners); - -#if defined(__cplusplus) -} -#endif - -#endif /* __PRISKV_SERVER_RDMA__ */ diff --git a/server/server.c b/server/server.c index e64aada..e78a000 100644 --- a/server/server.c +++ b/server/server.c @@ -35,7 +35,7 @@ #include "priskv-log.h" #include "priskv-logo.h" -#include "rdma.h" +#include "transport/transport.h" #include "memory.h" #include "kv.h" #include "priskv-threads.h" @@ -45,11 +45,11 @@ /* arguments of command line */ static int naddr = 0; -static char *addresses[PRISKV_RDMA_MAX_BIND_ADDR]; -static int port = PRISKV_RDMA_DEFAULT_PORT; -static uint32_t max_key = PRISKV_RDMA_DEFAULT_KEY; -static uint32_t value_block_size = PRISKV_RDMA_DEFAULT_VALUE_BLOCK_SIZE; -static uint64_t value_block = PRISKV_RDMA_DEFAULT_VALUE_BLOCK; +static char *addresses[PRISKV_TRANSPORT_MAX_BIND_ADDR]; +static int port = PRISKV_TRANSPORT_DEFAULT_PORT; +static uint32_t max_key = PRISKV_TRANSPORT_DEFAULT_KEY; +static uint32_t value_block_size = PRISKV_TRANSPORT_DEFAULT_VALUE_BLOCK_SIZE; +static uint64_t value_block = PRISKV_TRANSPORT_DEFAULT_VALUE_BLOCK; static uint8_t threads = 1; static uint32_t thread_flags; static uint32_t expire_routine_interval = PRISKV_KV_DEFAULT_EXPIRE_ROUTINE_INTERVAL; @@ -57,9 +57,10 @@ static const char *memfile; static priskv_log_level log_level = priskv_log_notice; static const char *g_log_file = NULL; static priskv_logger *g_logger = NULL; -static priskv_rdma_conn_cap conn_cap = {.max_sgl = PRISKV_RDMA_DEFAULT_SGL, - .max_key_length = PRISKV_RDMA_DEFAULT_KEY_LENGTH, - .max_inflight_command = PRISKV_RDMA_DEFAULT_INFLIGHT_COMMAND}; +static priskv_transport_conn_cap conn_cap = {.max_sgl = PRISKV_TRANSPORT_DEFAULT_SGL, + .max_key_length = PRISKV_TRANSPORT_DEFAULT_KEY_LENGTH, + .max_inflight_command = + PRISKV_TRANSPORT_DEFAULT_INFLIGHT_COMMAND}; static priskv_http_config http_config = { .addr = NULL, @@ -78,24 +79,24 @@ static void priskv_showhelp(void) printf("\nUsage:\n"); printf(" -a/--addr ADDR\n\tbind to ADDR, support as max as %d addresses. Ex, -a xx.xx.xx.xx " "-a yy.yy.yy.yy\n", - PRISKV_RDMA_MAX_BIND_ADDR); - printf(" -p/--port PORT\n\tlisten to PORT, default %d\n", PRISKV_RDMA_DEFAULT_PORT); + PRISKV_TRANSPORT_MAX_BIND_ADDR); + printf(" -p/--port PORT\n\tlisten to PORT, default %d\n", PRISKV_TRANSPORT_DEFAULT_PORT); printf(" -f/--memfile PATH\n\tload memory file from tmpfs/hugetlbfs\n"); printf(" -c/--max-inflight-command COMMANDS\n\tthe maxium count of inflight command, default " "%d, max %d\n", - PRISKV_RDMA_DEFAULT_INFLIGHT_COMMAND, PRISKV_RDMA_MAX_INFLIGHT_COMMAND); + PRISKV_TRANSPORT_DEFAULT_INFLIGHT_COMMAND, PRISKV_TRANSPORT_MAX_INFLIGHT_COMMAND); printf(" -s/--max-sgl SGLS\n\tthe maxium count of scatter gather list, default %d, max %d\n", - PRISKV_RDMA_DEFAULT_SGL, PRISKV_RDMA_MAX_SGL); + PRISKV_TRANSPORT_DEFAULT_SGL, PRISKV_TRANSPORT_MAX_SGL); printf(" -k/--max-keys KEYS\n\tthe maxium count of KV, default %d, max %d\n", - PRISKV_RDMA_DEFAULT_KEY, PRISKV_RDMA_MAX_KEY); + PRISKV_TRANSPORT_DEFAULT_KEY, PRISKV_TRANSPORT_MAX_KEY); printf(" -K/--max-key-length BYTES\n\tthe maxium bytes of a key, default %d, max %d\n", - PRISKV_RDMA_DEFAULT_KEY_LENGTH, PRISKV_RDMA_MAX_KEY_LENGTH); + PRISKV_TRANSPORT_DEFAULT_KEY_LENGTH, PRISKV_TRANSPORT_MAX_KEY_LENGTH); printf(" -v/--value-block-size BYTES\n\tthe block size of minimal value in bytes, " "default %d, max %d\n", - PRISKV_RDMA_DEFAULT_VALUE_BLOCK_SIZE, PRISKV_RDMA_MAX_VALUE_BLOCK_SIZE); + PRISKV_TRANSPORT_DEFAULT_VALUE_BLOCK_SIZE, PRISKV_TRANSPORT_MAX_VALUE_BLOCK_SIZE); printf(" -b/--value-blocks BLOCKS\n\tthe count of value blocks, must be power of 2, " "default %ld, max %ld\n", - PRISKV_RDMA_DEFAULT_VALUE_BLOCK, PRISKV_RDMA_MAX_VALUE_BLOCK); + PRISKV_TRANSPORT_DEFAULT_VALUE_BLOCK, PRISKV_TRANSPORT_MAX_VALUE_BLOCK); printf(" -t/--threads THREADS\n\tthe number of worker threads, default 1\n"); printf(" -e/--expire-routine-interval INTERVAL\n\tthe interval to auto-clean expired kv in " "second, default 600\n"); @@ -171,7 +172,7 @@ static void priskv_parsr_arg(int argc, char *argv[]) break; case 'p': - if (port != PRISKV_RDMA_DEFAULT_PORT) { + if (port != PRISKV_TRANSPORT_DEFAULT_PORT) { printf("A single port is supported\n"); priskv_showhelp(); } @@ -202,7 +203,7 @@ static void priskv_parsr_arg(int argc, char *argv[]) case 'c': conn_cap.max_inflight_command = atoi(optarg); if (!conn_cap.max_inflight_command || - (conn_cap.max_inflight_command > PRISKV_RDMA_MAX_INFLIGHT_COMMAND)) { + (conn_cap.max_inflight_command > PRISKV_TRANSPORT_MAX_INFLIGHT_COMMAND)) { printf("Invalid -c/--max-inflight-command\n"); priskv_showhelp(); } @@ -210,7 +211,7 @@ static void priskv_parsr_arg(int argc, char *argv[]) case 's': conn_cap.max_sgl = atoi(optarg); - if (!conn_cap.max_sgl || (conn_cap.max_sgl > PRISKV_RDMA_MAX_SGL)) { + if (!conn_cap.max_sgl || (conn_cap.max_sgl > PRISKV_TRANSPORT_MAX_SGL)) { printf("Invalid -s/--max-sgl\n"); priskv_showhelp(); } @@ -218,7 +219,7 @@ static void priskv_parsr_arg(int argc, char *argv[]) case 'k': max_key = atoi(optarg); - if (max_key > PRISKV_RDMA_MAX_KEY) { + if (max_key > PRISKV_TRANSPORT_MAX_KEY) { printf("Invalid -k/--max-keys\n"); priskv_showhelp(); } @@ -226,7 +227,7 @@ static void priskv_parsr_arg(int argc, char *argv[]) case 'K': if (priskv_str2num(optarg, &max_key_length) || !max_key_length || - max_key_length > PRISKV_RDMA_MAX_KEY_LENGTH) { + max_key_length > PRISKV_TRANSPORT_MAX_KEY_LENGTH) { printf("Invalid -K/--max-key-length\n"); priskv_showhelp(); } @@ -235,7 +236,7 @@ static void priskv_parsr_arg(int argc, char *argv[]) case 'v': if (priskv_str2num(optarg, &_value_block_size) || !_value_block_size || - _value_block_size > PRISKV_RDMA_MAX_VALUE_BLOCK_SIZE) { + _value_block_size > PRISKV_TRANSPORT_MAX_VALUE_BLOCK_SIZE) { printf("Invalid -v/--value-block-size\n"); priskv_showhelp(); } @@ -245,7 +246,7 @@ static void priskv_parsr_arg(int argc, char *argv[]) case 'b': value_block = atoll(optarg); - if (!value_block || (value_block > PRISKV_RDMA_MAX_VALUE_BLOCK)) { + if (!value_block || (value_block > PRISKV_TRANSPORT_MAX_VALUE_BLOCK)) { priskv_showhelp(); } @@ -370,9 +371,11 @@ static void *priskv_server_create_kv() return kv; } -static void __priskv_rdma_process(evutil_socket_t fd, short events, void *arg) +extern priskv_transport_driver *g_transport_driver; + +static void __priskv_transport_process(evutil_socket_t fd, short events, void *arg) { - priskv_rdma_process(); + priskv_transport_process(); } static int priskv_server_start(struct event_base *evbase) @@ -400,11 +403,11 @@ static int priskv_server_start(struct event_base *evbase) priskv_set_expire_routine_interval(g_kv, expire_routine_interval); priskv_expire_routine(bgthread, g_kv); - if (priskv_rdma_listen(addresses, naddr, port, g_kv, &conn_cap)) { - return -1; /* priskv_rdma_listen should already print enough messages */ + if (priskv_transport_listen(addresses, naddr, port, g_kv, &conn_cap)) { + return -1; } - ev = event_new(evbase, priskv_rdma_get_fd(), EV_READ | EV_PERSIST, __priskv_rdma_process, NULL); + ev = event_new(evbase, priskv_transport_get_fd(), EV_READ | EV_PERSIST, __priskv_transport_process, NULL); return event_add(ev, NULL); }; diff --git a/server/test/Makefile b/server/test/Makefile index 5d91371..4883a5c 100644 --- a/server/test/Makefile +++ b/server/test/Makefile @@ -36,11 +36,23 @@ $(TEST_SLAB): $(OBJS) $(TEST_SLAB_MT): $(OBJS) $(CC) test_slab_mt.c ../slab.c $(CFLAGS) -pthread -o $(TEST_SLAB_MT) +UCX_CFLAGS = $(shell pkg-config --cflags ucx) +UCX_LIBS = $(shell pkg-config --libs ucx) +ifeq ($(shell pkg-config --exists ucx; echo $$?),0) $(TEST_KV): $(OBJS) - $(CC) test_kv.c ../../lib/workqueue.c ../../lib/threads.c ../../lib/event.c ../memory.c ../kv.c ../slab.c ../buddy.c ../crc.c ../../lib/log.c ../backend/backend.c ../rdma.c ../acl.c $(CFLAGS) -o $(TEST_KV) -lmount -lrdmacm -libverbs + $(CC) test_kv.c ../../lib/workqueue.c ../../lib/threads.c ../../lib/event.c ../../lib/ucx.c ../memory.c ../kv.c ../slab.c ../buddy.c ../crc.c ../../lib/log.c ../backend/backend.c ../transport/transport.c ../transport/rdma.c ../transport/ucx.c ../tiering.c ../acl.c $(CFLAGS) $(UCX_CFLAGS) -o $(TEST_KV) -lmount -lrdmacm -libverbs $(UCX_LIBS) +else +$(TEST_KV): + @echo "UCX not available, skip $(TEST_KV)" +endif +ifeq ($(shell pkg-config --exists ucx; echo $$?),0) $(TEST_KV_MT): $(OBJS) - $(CC) test_kv_mt.c ../../lib/workqueue.c ../../lib/threads.c ../../lib/event.c ../memory.c ../kv.c ../slab.c ../buddy.c ../crc.c ../../lib/log.c ../backend/backend.c ../rdma.c ../acl.c $(CFLAGS) -o $(TEST_KV_MT) -lmount -lrdmacm -libverbs + $(CC) test_kv_mt.c ../../lib/workqueue.c ../../lib/threads.c ../../lib/event.c ../../lib/ucx.c ../memory.c ../kv.c ../slab.c ../buddy.c ../crc.c ../../lib/log.c ../backend/backend.c ../transport/transport.c ../transport/rdma.c ../transport/ucx.c ../tiering.c ../acl.c $(CFLAGS) $(UCX_CFLAGS) -o $(TEST_KV_MT) -lmount -lrdmacm -libverbs $(UCX_LIBS) +else +$(TEST_KV_MT): + @echo "UCX not available, skip $(TEST_KV_MT)" +endif $(TEST_MEMORY): $(OBJS) $(CC) test_memory.c ../memory.c ../../lib/log.c $(CFLAGS) -lmount -o $(TEST_MEMORY) @@ -48,8 +60,13 @@ $(TEST_MEMORY): $(OBJS) $(TEST_ACL): $(OBJS) $(CC) test_acl.c ../acl.c ../../lib/log.c $(CFLAGS) -lrdmacm -o $(TEST_ACL) +ifeq ($(shell pkg-config --exists ucx; echo $$?),0) $(TEST_KV_EXPIRE_ROUTINE): $(OBJS) - $(CC) test_kv_expire_routine.c ../../lib/workqueue.c ../../lib/threads.c ../../lib/event.c ../memory.c ../kv.c ../slab.c ../buddy.c ../crc.c ../../lib/log.c ../backend/backend.c ../rdma.c ../acl.c $(CFLAGS) -o $(TEST_KV_EXPIRE_ROUTINE) -lmount -lpthread -lrdmacm -libverbs + $(CC) test_kv_expire_routine.c ../../lib/workqueue.c ../../lib/threads.c ../../lib/event.c ../../lib/ucx.c ../memory.c ../kv.c ../slab.c ../buddy.c ../crc.c ../../lib/log.c ../backend/backend.c ../transport/transport.c ../transport/rdma.c ../transport/ucx.c ../tiering.c ../acl.c $(CFLAGS) $(UCX_CFLAGS) -o $(TEST_KV_EXPIRE_ROUTINE) -lmount -lpthread -lrdmacm -libverbs $(UCX_LIBS) +else +$(TEST_KV_EXPIRE_ROUTINE): + @echo "UCX not available, skip $(TEST_KV_EXPIRE_ROUTINE)" +endif $(TEST_BE_REDIS): $(CC) test_be_redis.c ../../lib/log.c ../../lib/event.c ../../lib/workqueue.c ../../lib/threads.c ../backend/backend.c ../backend/be_redis.c $(CFLAGS) -o $(TEST_BE_REDIS) -levent -lhiredis diff --git a/server/tiering.c b/server/tiering.c new file mode 100644 index 0000000..d9bd8cc --- /dev/null +++ b/server/tiering.c @@ -0,0 +1,366 @@ +// Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "priskv-log.h" +#include "priskv-threads.h" +#include "backend/backend.h" +#include "transport/transport.h" +#include "kv.h" +#include "priskv-protocol.h" +#include "priskv-protocol-helper.h" + +static void priskv_tiering_req_repost_recv(priskv_tiering_req *treq) +{ + if (!treq || treq->recv_reposted) { + return; + } + treq->recv_reposted = true; +} + +static void priskv_tiering_req_free(priskv_tiering_req *treq) +{ + if (!treq) { + return; + } + free(treq->key); + free(treq); +} + +priskv_tiering_req *priskv_tiering_req_new(priskv_transport_conn *conn, priskv_request *req, + uint8_t *key, uint16_t keylen, uint64_t timeout, + priskv_req_command cmd, uint32_t remote_valuelen, + priskv_resp_status *resp_status) +{ + priskv_resp_status status = PRISKV_RESP_STATUS_NO_MEM; + priskv_tiering_req *treq = calloc(1, sizeof(priskv_tiering_req)); + priskv_thread *thread = NULL; + priskv_backend_device *backend = NULL; + + if (!treq) { + goto error; + } + + thread = NULL; + backend = priskv_get_thread_backend(thread); + if (!backend) { + status = PRISKV_RESP_STATUS_SERVER_ERROR; + goto error; + } + + treq->key = malloc(keylen + 1); + if (!treq->key) { + goto error; + } + memcpy(treq->key, key, keylen); + treq->key[keylen] = '\0'; + + list_node_init(&treq->node); + treq->conn = conn; + treq->thread = thread; + treq->backend = backend; + treq->kv = priskv_transport_get_kv(); + treq->req = req; + treq->request_id = req->request_id; + treq->keylen = keylen; + treq->timeout = timeout; + treq->cmd = cmd; + treq->remote_valuelen = remote_valuelen; + treq->valuelen = 0; + treq->execute = false; + treq->recv_reposted = false; + treq->hash_head_index = priskv_crc32(key, keylen) % priskv_get_bucket_count(treq->kv); + treq->backend_status = PRISKV_BACKEND_STATUS_ERROR; + + if (resp_status) { + *resp_status = PRISKV_RESP_STATUS_OK; + } + return treq; + +error: + free(treq->key); + free(treq); + if (resp_status) { + *resp_status = status; + } + return NULL; +} + +static void priskv_tiering_finish(priskv_tiering_req *treq, priskv_resp_status status, + uint32_t length) +{ + if (!treq) { + return; + } + + if (treq->execute) { + priskv_key_serialize_exit(treq); + treq->execute = false; + } + + priskv_tiering_req_repost_recv(treq); + + priskv_transport_send_response(treq->conn, treq->request_id, status, length); + + priskv_tiering_req_free(treq); +} + +static void priskv_tiering_get_complete_cb(void *arg) +{ + priskv_tiering_req *treq = arg; + priskv_get_key_end(treq->keynode); + priskv_tiering_req_repost_recv(treq); + priskv_tiering_req_free(treq); +} + +static void priskv_tiering_get_backend_cb(priskv_backend_status backend_status, uint32_t valuelen, + void *arg) +{ + priskv_tiering_req *treq = arg; + priskv_resp_status resp_status; + treq->backend_status = backend_status; + + switch (backend_status) { + case PRISKV_BACKEND_STATUS_OK: { + uint8_t *val = NULL; + uint32_t cached_length = 0; + void *keynode = NULL; + + priskv_update_valuelen(treq->keynode, valuelen); + treq->valuelen = valuelen; + + priskv_get_key(treq->kv, treq->key, treq->keylen, &val, &cached_length, &keynode); + priskv_transport_rw_req(treq->conn, treq->req, &treq->conn->value_memh, treq->value, + treq->valuelen, 0, priskv_tiering_get_complete_cb, treq, 0, + NULL); + return; + } + case PRISKV_BACKEND_STATUS_VALUE_TOO_BIG: + resp_status = PRISKV_RESP_STATUS_VALUE_TOO_BIG; + break; + case PRISKV_BACKEND_STATUS_NOT_FOUND: + resp_status = PRISKV_RESP_STATUS_NO_SUCH_KEY; + break; + default: + resp_status = PRISKV_RESP_STATUS_SERVER_ERROR; + break; + } + + priskv_delete_key(treq->kv, treq->key, treq->keylen); + priskv_tiering_finish(treq, resp_status, 0); +} + +void priskv_tiering_get(priskv_tiering_req *treq) +{ + priskv_resp_status status; + uint8_t *val = NULL; + uint32_t valuelen = 0; + void *keynode = NULL; + + if (!treq->execute) { + if (!priskv_key_serialize_enter(treq)) { + return; + } + treq->execute = true; + } + + status = priskv_get_key(treq->kv, treq->key, treq->keylen, &val, &valuelen, &keynode); + if (status == PRISKV_RESP_STATUS_OK && keynode) { + treq->value = val; + treq->valuelen = valuelen; + treq->keynode = keynode; + + if (treq->remote_valuelen < treq->valuelen) { + priskv_tiering_finish(treq, PRISKV_RESP_STATUS_VALUE_TOO_BIG, treq->valuelen); + return; + } + priskv_transport_rw_req(treq->conn, treq->req, &treq->conn->value_memh, treq->value, + treq->valuelen, 0, priskv_tiering_get_complete_cb, treq, 0, NULL); + priskv_key_serialize_exit(treq); + treq->execute = false; + return; + } + + status = priskv_set_key(treq->kv, treq->key, treq->keylen, &val, treq->remote_valuelen, + treq->timeout, &keynode); + priskv_set_key_end(keynode); + if (status != PRISKV_RESP_STATUS_OK || !keynode) { + priskv_tiering_finish(treq, status, 0); + return; + } + treq->value = val; + treq->keynode = keynode; + priskv_backend_get(treq->backend, (const char *)treq->key, treq->value, treq->remote_valuelen, + priskv_tiering_get_backend_cb, treq); +} + +static void priskv_tiering_test_backend_cb(priskv_backend_status status, uint32_t valuelen, + void *arg) +{ + priskv_tiering_req *treq = arg; + priskv_resp_status resp_status; + uint32_t length = 0; + treq->backend_status = status; + switch (status) { + case PRISKV_BACKEND_STATUS_OK: + resp_status = PRISKV_RESP_STATUS_OK; + length = valuelen; + break; + case PRISKV_BACKEND_STATUS_NOT_FOUND: + resp_status = PRISKV_RESP_STATUS_NO_SUCH_KEY; + break; + default: + resp_status = PRISKV_RESP_STATUS_SERVER_ERROR; + break; + } + priskv_tiering_finish(treq, resp_status, length); +} + +void priskv_tiering_test(priskv_tiering_req *treq) +{ + priskv_resp_status status; + uint8_t *val = NULL; + uint32_t valuelen = 0; + void *keynode = NULL; + + if (!treq) + return; + if (!treq->execute) { + if (!priskv_key_serialize_enter(treq)) { + return; + } + treq->execute = true; + } + + status = priskv_get_key(treq->kv, treq->key, treq->keylen, &val, &valuelen, &keynode); + if (status == PRISKV_RESP_STATUS_OK) { + priskv_tiering_finish(treq, PRISKV_RESP_STATUS_OK, valuelen); + priskv_get_key_end(keynode); + return; + } + priskv_backend_test(treq->backend, (const char *)treq->key, priskv_tiering_test_backend_cb, + treq); +} + +static void priskv_tiering_set_backend_cb(priskv_backend_status status, uint32_t valuelen, + void *arg) +{ + priskv_tiering_req *treq = arg; + priskv_resp_status resp_status; + uint32_t length = 0; + priskv_delete_key(treq->kv, treq->key, treq->keylen); + switch (status) { + case PRISKV_BACKEND_STATUS_OK: + resp_status = PRISKV_RESP_STATUS_OK; + length = treq->valuelen; + break; + case PRISKV_BACKEND_STATUS_NO_SPACE: + resp_status = PRISKV_RESP_STATUS_NO_MEM; + break; + default: + resp_status = PRISKV_RESP_STATUS_SERVER_ERROR; + break; + } + priskv_tiering_finish(treq, resp_status, length); +} + +static void priskv_tiering_set_complete_cb(void *arg) +{ + priskv_tiering_req *treq = arg; + priskv_set_key_end(treq->keynode); + if (!treq->backend) { + priskv_delete_key(treq->kv, treq->key, treq->keylen); + priskv_tiering_finish(treq, PRISKV_RESP_STATUS_SERVER_ERROR, 0); + return; + } + priskv_backend_set(treq->backend, (const char *)treq->key, treq->value, treq->remote_valuelen, + treq->timeout, priskv_tiering_set_backend_cb, treq); +} + +void priskv_tiering_set(priskv_tiering_req *treq) +{ + priskv_resp_status status; + if (!treq->execute) { + if (!priskv_key_serialize_enter(treq)) { + return; + } + treq->execute = true; + } + status = priskv_set_key(treq->kv, treq->key, treq->keylen, &treq->value, treq->remote_valuelen, + treq->timeout, &treq->keynode); + if (status != PRISKV_RESP_STATUS_OK || !treq->keynode) { + priskv_set_key_end(treq->keynode); + priskv_tiering_finish(treq, status, 0); + return; + } + priskv_transport_rw_req(treq->conn, treq->req, &treq->conn->value_memh, treq->value, + treq->remote_valuelen, 1, priskv_tiering_set_complete_cb, treq, 0, + NULL); +} + +static void priskv_tiering_del_backend_cb(priskv_backend_status status, uint32_t valuelen, + void *arg) +{ + priskv_tiering_req *treq = arg; + priskv_resp_status resp_status; + priskv_delete_key(treq->kv, treq->key, treq->keylen); + switch (status) { + case PRISKV_BACKEND_STATUS_OK: + resp_status = PRISKV_RESP_STATUS_OK; + break; + case PRISKV_BACKEND_STATUS_NOT_FOUND: + resp_status = PRISKV_RESP_STATUS_NO_SUCH_KEY; + break; + default: + resp_status = PRISKV_RESP_STATUS_SERVER_ERROR; + break; + } + priskv_tiering_finish(treq, resp_status, 0); +} + +void priskv_tiering_del(priskv_tiering_req *treq) +{ + if (!treq->execute) { + if (!priskv_key_serialize_enter(treq)) { + return; + } + treq->execute = true; + } + priskv_backend_del(treq->backend, (const char *)treq->key, priskv_tiering_del_backend_cb, treq); +} + +int priskv_backend_req_resubmit(void *req) +{ + priskv_tiering_req *treq = (priskv_tiering_req *)req; + switch (treq->cmd) { + case PRISKV_COMMAND_GET: + priskv_tiering_get(treq); + break; + case PRISKV_COMMAND_SET: + priskv_tiering_set(treq); + break; + case PRISKV_COMMAND_TEST: + priskv_tiering_test(treq); + break; + case PRISKV_COMMAND_DELETE: + priskv_tiering_del(treq); + break; + default: + return -1; + } + return 0; +} diff --git a/server/transport/rdma.c b/server/transport/rdma.c new file mode 100644 index 0000000..c0e3cb7 --- /dev/null +++ b/server/transport/rdma.c @@ -0,0 +1,1011 @@ +// Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Authors: + * Jinlong Xuan <15563983051@163.com> + * Xu Ji + * Yu Wang + * Bo Liu + * Zhenwei Pi + * Rui Zhang + * Changqi Lu + * Enhua Zhou + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../acl.h" +#include "../backend/backend.h" +#include "../crc.h" +#include "../kv.h" +#include "../memory.h" +#include "transport.h" +#include "priskv-protocol.h" +#include "priskv-protocol-helper.h" +#include "priskv-log.h" +#include "priskv-utils.h" +#include "priskv-event.h" +#include "priskv-threads.h" +#include "list.h" + +#define PRISKV_RDMA_DEF_ADDR(id) \ + char local_addr[PRISKV_ADDR_LEN] = {0}; \ + char peer_addr[PRISKV_ADDR_LEN] = {0}; \ + priskv_inet_ntop(rdma_get_local_addr(id), local_addr); \ + priskv_inet_ntop(rdma_get_peer_addr(id), peer_addr); + +extern priskv_transport_server g_transport_server; +extern priskv_threadpool *g_threadpool; + +static uint32_t priskv_rdma_max_rw_size = 1024 * 1024 * 1024; + +static void priskv_rdma_handle_cm(int fd, void *opaque, uint32_t events); + +static int priskv_rdma_mem_new(priskv_transport_conn *conn, priskv_transport_mem *rmem, + const char *name, uint32_t size) +{ + uint32_t flags = IBV_ACCESS_LOCAL_WRITE; + bool guard = true; /* always enable memory guard */ + uint8_t *buf; + int ret; + + buf = priskv_mem_malloc(size, guard); + if (!buf) { + priskv_log_error("RDMA: failed to allocate %s buffer: %m\n", name); + ret = -ENOMEM; + goto error; + } + + rmem->memh.rdma_mr = ibv_reg_mr(conn->cm_id->pd, buf, size, flags); + if (!rmem->memh.rdma_mr) { + priskv_log_error("RDMA: failed to reg MR for %s buffer: %m\n", name); + ret = -errno; + goto free_mem; + } + + strncpy(rmem->name, name, PRISKV_TRANSPORT_MEM_NAME_LEN - 1); + rmem->buf = buf; + rmem->buf_size = size; + + priskv_log_info("RDMA: new rmem %s, size %d\n", name, size); + priskv_log_debug("RDMA: new rmem %s, buf %p\n", name, buf); + return 0; + +free_mem: + priskv_mem_free(rmem->buf, rmem->buf_size, guard); + +error: + memset(rmem, 0x00, sizeof(priskv_transport_mem)); + + return ret; +} + +static void priskv_rdma_mem_free(priskv_transport_conn *conn, priskv_transport_mem *rmem) +{ + if (rmem->memh.rdma_mr) { + ibv_dereg_mr(rmem->memh.rdma_mr); + } + + if (rmem->buf) { + priskv_log_debug("RDMA: free rmem %s, buf %p\n", rmem->name, rmem->buf); + priskv_mem_free(rmem->buf, rmem->buf_size, true); + } + + priskv_log_info("RDMA: free rmem %s, size %d\n", rmem->name, rmem->buf_size); + memset(rmem, 0x00, sizeof(priskv_transport_mem)); +} + +static inline void priskv_rdma_free_ctrl_buffer(priskv_transport_conn *conn) +{ + for (int i = 0; i < PRISKV_TRANSPORT_MEM_MAX; i++) { + priskv_transport_mem *rmem = &conn->rmem[i]; + + priskv_rdma_mem_free(conn, rmem); + } +} + +static int priskv_rdma_listen_one(char *addr, int port, void *kv, priskv_transport_conn_cap *cap) +{ + int ret = 0, afonly = 1; + char _port[6]; /* strlen("65535") */ + struct rdma_addrinfo hints, *servinfo; + struct rdma_cm_id *listen_cmid = NULL; + struct rdma_event_channel *listen_channel = NULL; + priskv_transport_conn *listener; + + snprintf(_port, 6, "%d", port); + memset(&hints, 0, sizeof(hints)); + hints.ai_flags = RAI_PASSIVE; + hints.ai_port_space = RDMA_PS_TCP; + ret = rdma_getaddrinfo(addr, _port, &hints, &servinfo); + if (ret) { + priskv_log_error("RDMA: getaddrinfo %s failed: %s\n", addr, gai_strerror(ret)); + return ret; + } else if (!servinfo) { + priskv_log_error("RDMA: getaddrinfo %s: no availabe address\n", addr); + return -EINVAL; + } + + listen_channel = rdma_create_event_channel(); + if (!listen_channel) { + ret = -errno; + priskv_log_error("RDMA: create event channel failed\n"); + goto freeaddr; + } + + ret = priskv_set_nonblock(listen_channel->fd); + if (ret) { + priskv_log_error("RDMA: failed to set NONBLOCK on event channel fd\n"); + goto error; + } + + if (rdma_create_id(listen_channel, &listen_cmid, NULL, RDMA_PS_TCP)) { + ret = -errno; + priskv_log_error("RDMA: create listen cm id error\n"); + goto error; + } + + rdma_set_option(listen_cmid, RDMA_OPTION_ID, RDMA_OPTION_ID_AFONLY, &afonly, sizeof(afonly)); + + if (rdma_bind_addr(listen_cmid, servinfo->ai_src_addr)) { + ret = -errno; + priskv_log_error("RDMA: Bind addr error on %s\n", addr); + goto error; + } + + if (rdma_listen(listen_cmid, 0)) { + ret = -errno; + priskv_log_error("RDMA: listen addr error on %s\n", addr); + goto error; + } + + /* TODO split into several MRs, because of max_mr_size of IB device */ + uint8_t *value_base = priskv_get_value_base(kv); + assert(value_base); + uint64_t size = priskv_get_value_blocks(kv) * priskv_get_value_block_size(kv); + assert(size); + uint32_t access = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; + struct ibv_mr *value_mr = ibv_reg_mr(listen_cmid->pd, value_base, size, access); + if (!value_mr) { + ret = -errno; + priskv_log_error( + "RDMA: failed to reg MR for value: %m [%p, %p], value block %ld, value block size %d\n", + value_base, value_base + size, priskv_get_value_blocks(kv), + priskv_get_value_block_size(kv)); + goto error; + } + + priskv_log_debug("RDMA: Value buffer %p, length %ld\n", value_base, size); + + listener = &g_transport_server.listeners[g_transport_server.nlisteners++]; + listener->cm_id = listen_cmid; + listener->value_base = value_base; + listener->kv = kv; + listener->value_memh.rdma_mr = value_mr; + listener->conn_cap = *cap; + listener->s.nclients = 0; + list_head_init(&listener->s.head); + pthread_spin_init(&listener->lock, 0); + + priskv_log_info("RDMA: <%s:%d> listener starts\n", addr, port); + + ret = 0; + goto freeaddr; + +error: + if (listen_cmid) { + rdma_destroy_id(listen_cmid); + } + if (listen_channel) { + rdma_destroy_event_channel(listen_channel); + } + +freeaddr: + rdma_freeaddrinfo(servinfo); + return ret; +} + +int priskv_rdma_listen(char **addr, int naddr, int port, void *kv, priskv_transport_conn_cap *cap) +{ + priskv_transport_conn *listener; + + for (int i = 0; i < naddr; i++) { + int ret = priskv_rdma_listen_one(addr[i], port, kv, cap); + if (ret) { + return ret; + } + } + + g_transport_server.kv = kv; + + g_transport_server.epollfd = epoll_create(g_transport_server.nlisteners); + if (g_transport_server.epollfd == -1) { + priskv_log_error("RDMA: failed to create epoll fd %m\n"); + return -1; + } + + for (int i = 0; i < g_transport_server.nlisteners; i++) { + listener = &g_transport_server.listeners[i]; + PRISKV_RDMA_DEF_ADDR(listener->cm_id); + + priskv_set_fd_handler(listener->cm_id->channel->fd, priskv_rdma_handle_cm, NULL, listener); + if (priskv_add_event_fd(g_transport_server.epollfd, listener->cm_id->channel->fd)) { + priskv_log_error("RDMA: failed to add listen fd into epoll fd %m\n"); + return -1; + } + + priskv_log_notice("RDMA: <%s> ready\n", local_addr); + } + + return 0; +} + +static void priskv_rdma_get_clients(priskv_transport_conn *listener, + priskv_transport_client **clients, int *nclients) +{ + priskv_transport_conn *client; + *nclients = 0; + + pthread_spin_lock(&listener->lock); + *clients = calloc(listener->s.nclients, sizeof(priskv_transport_client)); + list_for_each (&listener->s.head, client, c.node) { + PRISKV_RDMA_DEF_ADDR(client->cm_id); + + memcpy((*clients)[*nclients].address, peer_addr, strlen(peer_addr) + 1); + memcpy((*clients)[*nclients].stats, client->c.stats, + PRISKV_COMMAND_MAX * sizeof(priskv_transport_stats)); + (*clients)[*nclients].resps = client->c.resps; + (*clients)[*nclients].closing = client->c.closing; + (*nclients)++; + + if (*nclients == listener->s.nclients) { + break; + } + } + pthread_spin_unlock(&listener->lock); +} + +static void priskv_rdma_free_clients(priskv_transport_client *clients) +{ + free(clients); +} + +priskv_transport_listener *priskv_rdma_get_listeners(int *nlisteners) +{ + priskv_transport_listener *listeners; + + *nlisteners = g_transport_server.nlisteners; + listeners = calloc(*nlisteners, sizeof(priskv_transport_listener)); + + for (int i = 0; i < *nlisteners; i++) { + PRISKV_RDMA_DEF_ADDR(g_transport_server.listeners[i].cm_id); + + memcpy(listeners[i].address, local_addr, strlen(local_addr) + 1); + priskv_rdma_get_clients(&g_transport_server.listeners[i], &listeners[i].clients, + &listeners[i].nclients); + } + + return listeners; +} + +void priskv_rdma_free_listeners(priskv_transport_listener *listeners, int nlisteners) +{ + for (int i = 0; i < nlisteners; i++) { + priskv_rdma_free_clients(listeners[i].clients); + } + free(listeners); +} + +int priskv_rdma_get_fd(void) +{ + return g_transport_server.epollfd; +} + +void *priskv_rdma_get_kv(void) +{ + return g_transport_server.kv; +} + +#define PRISKV_RDMA_RESPONSE_FREE_STATUS 0xffff +static inline int priskv_rdma_response_free(priskv_response *resp) +{ + if (resp->status == PRISKV_RDMA_RESPONSE_FREE_STATUS) { + return -EPROTO; + } + + resp->status = PRISKV_RDMA_RESPONSE_FREE_STATUS; + return 0; +} + +static inline uint32_t priskv_rdma_wr_size(priskv_transport_conn *client) +{ + return client->conn_cap.max_inflight_command * (2 + client->conn_cap.max_sgl); +} + +static int priskv_rdma_new_ctrl_buffer(priskv_transport_conn *conn) +{ + uint16_t size; + uint32_t buf_size; + + /* #step 1, prepare buffer & MR for request from client */ + size = + priskv_rdma_max_request_size_aligned(conn->conn_cap.max_sgl, conn->conn_cap.max_key_length); + buf_size = (uint32_t)size * priskv_rdma_wr_size(conn); + if (priskv_rdma_mem_new(conn, &conn->rmem[PRISKV_TRANSPORT_MEM_REQ], "Request", buf_size)) { + goto error; + } + + /* #step 2, prepare buffer & MR for response to client */ + size = sizeof(priskv_response); + buf_size = size * priskv_rdma_wr_size(conn); + if (priskv_rdma_mem_new(conn, &conn->rmem[PRISKV_TRANSPORT_MEM_RESP], "Response", buf_size)) { + goto error; + } + + for (uint16_t i = 0; i < priskv_rdma_wr_size(conn); i++) { + priskv_response *resp = + (priskv_response *)(conn->rmem[PRISKV_TRANSPORT_MEM_RESP].buf + i * size); + priskv_rdma_response_free(resp); + } + + return 0; + +error: + priskv_rdma_free_ctrl_buffer(conn); + return -ENOMEM; +} + +static void priskv_rdma_close_client(priskv_transport_conn *client) +{ + PRISKV_RDMA_DEF_ADDR(client->cm_id) + priskv_log_notice( + "RDMA: <%s - %s> close. Requests GET %ld, SET %ld, TEST %ld, DELETE %ld, Responses %ld\n", + local_addr, peer_addr, client->c.stats[PRISKV_COMMAND_GET].ops, + client->c.stats[PRISKV_COMMAND_SET].ops, client->c.stats[PRISKV_COMMAND_TEST].ops, + client->c.stats[PRISKV_COMMAND_DELETE].ops, client->c.resps); + + if ((client->comp_channel) && (client->c.thread != NULL)) { + priskv_thread_del_event_handler(client->c.thread, client->comp_channel->fd); + priskv_set_fd_handler(client->comp_channel->fd, NULL, NULL, NULL); /* clear fd handler */ + client->c.thread = NULL; + } + + if (client->cm_id && client->cm_id->qp) { + rdma_destroy_qp(client->cm_id); + client->cm_id->qp = NULL; + } + + if (client->cq) { + if (ibv_destroy_cq(client->cq)) { + priskv_log_warn("ibv_destroy_cq failed\n"); + } + client->cq = NULL; + } + + if (client->comp_channel) { + if (ibv_destroy_comp_channel(client->comp_channel)) { + priskv_log_warn("ibv_destroy_comp_channel failed\n"); + } + client->comp_channel = NULL; + } + + priskv_rdma_free_ctrl_buffer(client); + + if (client->cm_id) { + rdma_destroy_id(client->cm_id); + } + + free(client); +} + +static priskv_response *priskv_rdma_unused_response(priskv_transport_conn *conn) +{ + uint16_t resp_buf_size = sizeof(priskv_response); + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_RESP]; + + for (uint16_t i = 0; i < priskv_rdma_wr_size(conn); i++) { + priskv_response *resp = (priskv_response *)(rmem->buf + i * resp_buf_size); + if (resp->status == PRISKV_RDMA_RESPONSE_FREE_STATUS) { + priskv_log_debug("RDMA: use response %d\n", i); + resp->status = PRISKV_RESP_STATUS_OK; + return resp; + } + } + + priskv_log_error("RDMA: <%s - %s> inflight response exceeds %d\n", conn->local_addr, + conn->peer_addr, priskv_rdma_wr_size(conn)); + return NULL; +} + +static int priskv_rdma_send_response(priskv_transport_conn *conn, uint64_t request_id, + priskv_resp_status status, uint32_t length) +{ + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_RESP]; + struct ibv_send_wr wr = {0}, *bad_wr; + struct ibv_sge rsge; + priskv_response *resp; + + resp = priskv_rdma_unused_response(conn); + if (!resp) { + return -EPROTO; + } + + assert(((uint8_t *)resp >= rmem->buf) && ((uint8_t *)resp < rmem->buf + rmem->buf_size)); + + resp->request_id = request_id; /* be64 */ + resp->status = htobe16(status); + resp->length = htobe32(length); + + rsge.addr = (uint64_t)resp; + rsge.length = sizeof(priskv_response); + rsge.lkey = rmem->memh.rdma_mr->lkey; + + wr.wr_id = (uint64_t)resp; + wr.sg_list = &rsge; + wr.num_sge = 1; + wr.opcode = IBV_WR_SEND; + wr.send_flags = IBV_SEND_SIGNALED; + + int ret = ibv_post_send(conn->cm_id->qp, &wr, &bad_wr); + if (ret) { + PRISKV_RDMA_DEF_ADDR(conn->cm_id) + priskv_log_error( + "RDMA: <%s - %s> ibv_post_send response failed: addr 0x%lx, length 0x%x ret %d\n", + local_addr, peer_addr, rsge.addr, rsge.length, ret); + } else { + conn->c.resps++; + } + + return ret; +} + +static int priskv_rdma_rw_req(priskv_transport_conn *conn, priskv_request *req, + priskv_transport_memh *memh, uint8_t *val, uint32_t valuelen, + bool set, void (*cb)(void *), void *cbarg, bool defer_resp, + priskv_transport_rw_work **work_out) +{ + priskv_transport_rw_work *work; + uint32_t offset = 0; + struct ibv_send_wr wr = {0}, *bad_wr; + struct ibv_sge sge; + uint16_t nsgl = be16toh(req->nsgl); + const char *cmdstr = set ? "READ" : "WRITE"; + + if (work_out) { + *work_out = NULL; + } + + work = calloc(1, sizeof(priskv_transport_rw_work)); + if (!work) { + priskv_log_error("RDMA: failed to allocate memory for %s request\n", cmdstr); + return -ENOMEM; + } + + work->conn = conn; + work->req = req; + work->memh.rdma_mr = memh->rdma_mr; + work->request_id = req->request_id; /* be64 */ + work->valuelen = valuelen; + work->completed = 0; + work->cb = cb; + work->cbarg = cbarg; + work->defer_resp = defer_resp; + + wr.wr_id = (uint64_t)work; + wr.next = NULL; + wr.sg_list = &sge; + wr.num_sge = 1; + wr.opcode = set ? IBV_WR_RDMA_READ : IBV_WR_RDMA_WRITE; + wr.send_flags = IBV_SEND_SIGNALED; + + for (uint16_t i = 0; i < nsgl; i++) { + priskv_keyed_sgl *sgl = &req->sgls[i]; + + uint32_t sgl_length = be32toh(sgl->length); + uint32_t sgl_offset = 0; + + wr.wr.rdma.rkey = be32toh(sgl->key); + work->rdma_rkey = wr.wr.rdma.rkey; + do { + wr.wr.rdma.remote_addr = be64toh(sgl->addr) + sgl_offset; + + sge.addr = (uint64_t)val + offset + sgl_offset; + sge.lkey = memh->rdma_mr->lkey; + sge.length = priskv_min_u32(sgl_length - sgl_offset, valuelen); + sge.length = priskv_min_u32(sge.length, priskv_rdma_max_rw_size); + + if (ibv_post_send(conn->cm_id->qp, &wr, &bad_wr)) { + PRISKV_RDMA_DEF_ADDR(conn->cm_id) + priskv_log_error("RDMA: <%s - %s> ibv_post_send RDMA failed: %m\n", local_addr, + peer_addr); + free(work); + return -errno; + } + + priskv_log_debug( + "RDMA: %s [%d/%d]:[%d/%d] wr_id 0x%lx, val %p, length 0x%x, addr 0x%lx, " + "rkey 0x%x\n", + cmdstr, i, nsgl, sgl_offset, sgl_length, wr.wr_id, val + offset + sgl_offset, + sge.length, wr.wr.rdma.remote_addr, wr.wr.rdma.rkey); + sgl_offset += sge.length; + + work->nsgl++; + } while (sgl_offset < priskv_min_u32(sgl_length, valuelen)); + + offset += sgl_length; + valuelen -= sgl_length; + } + + if (work_out) { + *work_out = work; + } + + return 0; +} + +static int priskv_rdma_recv_req(priskv_transport_conn *conn, uint8_t *req); + +static inline int priskv_rdma_handle_send(priskv_transport_conn *conn, priskv_response *resp, + uint32_t len) +{ + return priskv_rdma_response_free(resp); +} + +/* return negative number on failure, return received buffer size on success */ +static int priskv_rdma_recv_req(priskv_transport_conn *conn, uint8_t *req) +{ + struct ibv_sge sge; + struct ibv_recv_wr recv_wr, *bad_wr; + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_REQ]; + uint32_t lkey = rmem->memh.rdma_mr->lkey; + uint16_t req_buf_size = + priskv_rdma_max_request_size_aligned(conn->conn_cap.max_sgl, conn->conn_cap.max_key_length); + int ret; + + assert((req >= rmem->buf) && (req < rmem->buf + rmem->buf_size)); + sge.addr = (uint64_t)req; + sge.length = req_buf_size; + sge.lkey = lkey; + + recv_wr.wr_id = (uint64_t)req; + recv_wr.sg_list = &sge; + recv_wr.num_sge = 1; + recv_wr.next = NULL; + + priskv_log_debug("RDMA: ibv_post_recv addr %p, length %d\n", req, req_buf_size); + ret = ibv_post_recv(conn->cm_id->qp, &recv_wr, &bad_wr); + if (ret) { + PRISKV_RDMA_DEF_ADDR(conn->cm_id) + priskv_log_error("RDMA: <%s - %s> ibv_post_recv failed: %m\n", local_addr, peer_addr); + return -errno; + } + + return req_buf_size; +} + +static void priskv_rdma_handle_cq(int fd, void *opaque, uint32_t events) +{ + priskv_transport_conn *conn = opaque; + struct ibv_cq *ev_cq = NULL; + void *ev_ctx = NULL; + struct ibv_wc wc; + priskv_request *req; + priskv_transport_rw_work *work; + priskv_response *resp; + int ret; + + assert(conn->comp_channel->fd == fd); + + if (ibv_get_cq_event(conn->comp_channel, &ev_cq, &ev_ctx) < 0) { + if (errno != EAGAIN) { + priskv_log_warn("RDMA: ibv_get_cq_event failed: %m\n"); + } + goto error_close; + } else if (ibv_req_notify_cq(ev_cq, 0)) { + priskv_log_warn("RDMA: ibv_req_notify_cq failed: %m\n"); + goto error_close; + } + + ibv_ack_cq_events(conn->cq, 1); + +poll_cq: + ret = ibv_poll_cq(conn->cq, 1, &wc); + if (ret < 0) { + priskv_log_warn("RDMA: ibv_poll_cq failed: %m\n"); + goto error_close; + } else if (ret == 0) { + return; + } + + priskv_log_debug("RDMA: CQ handle status: %s[0x%x], wr_id: %p, opcode: 0x%x, byte_len: %u\n", + ibv_wc_status_str(wc.status), wc.status, (void *)wc.wr_id, wc.opcode, + wc.byte_len); + if (wc.status != IBV_WC_SUCCESS) { + PRISKV_RDMA_DEF_ADDR(conn->cm_id) + priskv_log_error("RDMA: <%s - %s> CQ error status: wr_id 0x%lx, %s[0x%x], opcode : 0x%x, " + "byte_len : %ld\n", + local_addr, peer_addr, wc.wr_id, ibv_wc_status_str(wc.status), wc.status, + wc.opcode, wc.byte_len); + if (wc.status == IBV_WC_LOC_QP_OP_ERR) { + priskv_log_error("RDMA: possible remote command size exceeds\n"); + } + goto error_close; + } + + switch (wc.opcode) { + case IBV_WC_RECV: { + struct timeval server_metadata_recv_time; + + req = (priskv_request *)wc.wr_id; + + gettimeofday(&server_metadata_recv_time, NULL); + req->runtime.server_metadata_recv_time = server_metadata_recv_time; + + if (priskv_transport_handle_recv(conn, req, wc.byte_len)) { + goto error_close; + } + break; + } + + case IBV_WC_RDMA_READ: + case IBV_WC_RDMA_WRITE: { + struct timeval server_data_recv_time; + + work = (priskv_transport_rw_work *)wc.wr_id; + req = (priskv_request *)work->req; + + gettimeofday(&server_data_recv_time, NULL); + req->runtime.server_data_recv_time = server_data_recv_time; + + if (priskv_transport_handle_rw(conn, work)) { + goto error_close; + } + + break; + } + + case IBV_WC_SEND: { + resp = (priskv_response *)wc.wr_id; + priskv_rdma_handle_send(conn, resp, wc.byte_len); + break; + } + + default: + priskv_log_error("unexpected opcode 0x%x\n", wc.opcode); + goto error_close; + } + + goto poll_cq; + +error_close: + priskv_transport_mark_client_closed(conn); +} + +static void priskv_rdma_reject(struct rdma_cm_id *cm_id, uint16_t status, uint64_t val) +{ + priskv_cm_rej rej = {0}; + + rej.version = htobe16(PRISKV_CM_VERSION); + rej.status = htobe16(status); + rej.value = htobe64(val); + + rdma_reject(cm_id, &rej, sizeof(priskv_cm_rej)); +} + +static int priskv_rdma_resp(priskv_transport_conn *client, struct rdma_cm_id *cm_id) +{ + void *kv = client->c.listener->kv; + uint64_t capacity = priskv_get_value_blocks(kv) * priskv_get_value_block_size(kv); + + priskv_cm_cap rep = {0}; + rep.version = htobe16(PRISKV_CM_VERSION); + rep.max_sgl = htobe16(client->conn_cap.max_sgl); + rep.max_key_length = htobe16(client->conn_cap.max_key_length); + rep.max_inflight_command = htobe16(client->conn_cap.max_inflight_command); + rep.capacity = htobe64(capacity); + + struct rdma_conn_param resp_param = {0}; + resp_param.responder_resources = 1; + resp_param.initiator_depth = 1; + resp_param.retry_count = 5; + resp_param.private_data = &rep; + resp_param.private_data_len = sizeof(rep); + + int ret = rdma_accept(cm_id, &resp_param); + if (ret) { + PRISKV_RDMA_DEF_ADDR(client->cm_id) + priskv_log_error("RDMA: <%s - %s> rdma_accept failed: %m\n", local_addr, peer_addr); + } + + return ret; +} + +static int priskv_rdma_verify_conn_cap(priskv_transport_conn_cap *client, + priskv_transport_conn_cap *listener, uint64_t *val) +{ + if (!client->max_sgl) { + client->max_sgl = listener->max_sgl; + } else if (client->max_sgl > listener->max_sgl) { + *val = listener->max_sgl; + return PRISKV_CM_REJ_STATUS_INVALID_SGL; + } + + if (!client->max_key_length) { + client->max_key_length = listener->max_key_length; + } else if (client->max_key_length > listener->max_key_length) { + *val = listener->max_key_length; + return PRISKV_CM_REJ_STATUS_INVALID_KEY_LENGTH; + } + + if (!client->max_inflight_command) { + client->max_inflight_command = listener->max_inflight_command; + } else if (client->max_inflight_command > listener->max_inflight_command) { + *val = listener->max_inflight_command; + return PRISKV_CM_REJ_STATUS_INVALID_INFLIGHT_COMMAND; + } + + return 0; +} + +static void priskv_rdma_handle_connect_request(struct rdma_cm_event *ev, + priskv_transport_conn *listener) +{ + priskv_transport_conn *client; + struct rdma_cm_id *id = ev->id; + struct ibv_qp_init_attr init_attr = {0}; + struct rdma_conn_param *req_param = &ev->param.conn; + unsigned char exp_len = sizeof(struct rdma_conn_param) + sizeof(priskv_cm_cap); + priskv_cm_status status; + uint64_t value = 0; + + PRISKV_RDMA_DEF_ADDR(id); + + client = calloc(1, sizeof(priskv_transport_conn)); + assert(client); + id->context = client; + client->cm_id = id; + client->c.listener = listener; + client->c.thread = NULL; + client->c.closing = false; + list_node_init(&client->c.node); + pthread_spin_init(&client->lock, 0); + + snprintf(client->local_addr, PRISKV_ADDR_LEN, "%s", local_addr); + snprintf(client->peer_addr, PRISKV_ADDR_LEN, "%s", peer_addr); + + pthread_spin_lock(&listener->lock); + list_add_tail(&listener->s.head, &client->c.node); + listener->s.nclients++; + pthread_spin_unlock(&listener->lock); + + /* #step0, ACL verification */ + if (priskv_acl_verify(rdma_get_peer_addr(id))) { + priskv_log_error("RDMA: <%s - %s> ACL verification failed\n", local_addr, peer_addr); + status = PRISKV_CM_REJ_STATUS_ACL_REFUSE; + value = 0; + goto rej; + } + + /* #step1, check incoming request parameters */ + if (req_param->private_data_len != exp_len) { + priskv_log_error("RDMA: <%s - %s> unexpected CM REQ length %d, expetected %d\n", local_addr, + peer_addr, req_param->private_data_len, exp_len); + status = PRISKV_CM_REJ_STATUS_INVALID_CM_REP; + value = exp_len; + goto rej; + } + + const priskv_cm_cap *req = req_param->private_data; + uint16_t version = be16toh(req->version); + if (version != PRISKV_CM_VERSION) { + status = PRISKV_CM_REJ_STATUS_INVALID_VERSION; + value = PRISKV_CM_VERSION; + goto rej; + } + + client->conn_cap.max_sgl = be16toh(req->max_sgl); + client->conn_cap.max_key_length = be16toh(req->max_key_length); + client->conn_cap.max_inflight_command = be16toh(req->max_inflight_command); + priskv_log_info("RDMA: <%s - %s> incoming connect request - version %d, max_sgl %d, " + "max_key_length %d, max_inflight_command %d\n", + local_addr, peer_addr, version, client->conn_cap.max_sgl, + client->conn_cap.max_key_length, client->conn_cap.max_inflight_command); + + status = priskv_rdma_verify_conn_cap(&client->conn_cap, &listener->conn_cap, &value); + if (status) { + goto rej; + } + + /* #step2, create QP and related resources */ + client->comp_channel = ibv_create_comp_channel(id->verbs); + if (!client->comp_channel) { + priskv_log_error("RDMA: <%s - %s> ibv_create_comp_channel failed: %m\n", local_addr, + peer_addr); + status = PRISKV_CM_REJ_STATUS_SERVER_ERROR; + goto rej; + } + + priskv_set_nonblock(client->comp_channel->fd); + uint32_t wr_size = priskv_rdma_wr_size(client); + client->cq = ibv_create_cq(id->verbs, wr_size * 2 * 4, NULL, client->comp_channel, 0); + if (!client->cq) { + priskv_log_error("RDMA: <%s - %s> ibv_create_cq failed: %m\n", local_addr, peer_addr); + status = PRISKV_CM_REJ_STATUS_SERVER_ERROR; + goto rej; + } + + ibv_req_notify_cq(client->cq, 0); + + init_attr.cap.max_send_wr = wr_size * 4; + init_attr.cap.max_recv_wr = wr_size * 4; + init_attr.cap.max_send_sge = 1; + init_attr.cap.max_recv_sge = 1; + init_attr.qp_type = IBV_QPT_RC; + init_attr.send_cq = client->cq; + init_attr.recv_cq = client->cq; + if (rdma_create_qp(id, NULL, &init_attr)) { + priskv_log_error("RDMA: <%s - %s> rdma_create_qp failed: %m\n", local_addr, peer_addr); + status = PRISKV_CM_REJ_STATUS_SERVER_ERROR; + goto rej; + } + + /* #step3, create QP and related resources */ + if (priskv_rdma_new_ctrl_buffer(client)) { + status = PRISKV_CM_REJ_STATUS_SERVER_ERROR; + goto rej; + } + + /* #step4, post recv all the request commands */ + uint8_t *recv_req = client->rmem[PRISKV_TRANSPORT_MEM_REQ].buf; + for (uint16_t i = 0; i < wr_size; i++) { + int recvsize = priskv_rdma_recv_req(client, recv_req); + if (recvsize < 0) { + status = PRISKV_CM_REJ_STATUS_SERVER_ERROR; + goto rej; + } + + recv_req += recvsize; + } + + /* #step5, accept the new client */ + if (priskv_rdma_resp(client, id)) { + goto close_client; + } + + priskv_log_info("RDMA: <%s - %s> accept connect request - version %d, max_sgl %d, " + "max_key_length %d, max_inflight_command %d\n", + local_addr, peer_addr, version, client->conn_cap.max_sgl, + client->conn_cap.max_key_length, client->conn_cap.max_inflight_command); + return; + +rej: + priskv_log_warn("RDMA: <%s - %s> %s, reject\n", local_addr, peer_addr, + priskv_cm_status_str(status)); + priskv_rdma_reject(id, status, value); + +close_client: + priskv_transport_mark_client_closed(client); +} + +static void priskv_rdma_handle_established(struct rdma_cm_event *ev, + priskv_transport_conn *listener) +{ + struct rdma_cm_id *id = ev->id; + priskv_transport_conn *client = id->context; + + PRISKV_RDMA_DEF_ADDR(id); + + /* initialize KV of client */ + client->value_base = listener->value_base; + client->kv = listener->kv; + client->value_memh.rdma_mr = listener->value_memh.rdma_mr; + + /* use the idlest worker thread handle CQ event(CM event is still handled by main thread) */ + priskv_set_fd_handler(client->comp_channel->fd, priskv_rdma_handle_cq, NULL, client); + client->c.thread = priskv_threadpool_find_iothread(g_threadpool); + priskv_thread_add_event_handler(client->c.thread, client->comp_channel->fd); + + priskv_log_notice("RDMA: <%s - %s> established\n", local_addr, peer_addr); + priskv_log_debug("RDMA: <%s - %s> assign CQ fd %d to thread %d\n", local_addr, peer_addr, + client->comp_channel->fd, client->c.thread); +} + +static void priskv_rdma_handle_disconnected(struct rdma_cm_event *ev, + priskv_transport_conn *listener) +{ + struct rdma_cm_id *id = ev->id; + priskv_transport_conn *client = id->context; + + priskv_transport_mark_client_closed(client); +} + +static void priskv_rdma_handle_cm(int fd, void *opaque, uint32_t events) +{ + priskv_transport_conn *listener = opaque; + struct rdma_cm_event *ev; + int ret; + + assert(listener->cm_id->channel->fd == fd); + +again: + ret = rdma_get_cm_event(listener->cm_id->channel, &ev); + if (ret) { + if (errno != EAGAIN) { + priskv_log_error("RDMA: listener rdma_get_cm_event failed: %m\n"); + } + return; + } + + const char *evstr = rdma_event_str(ev->event); + char addrbuf[64] = {0}; + priskv_inet_ntop(rdma_get_local_addr(listener->cm_id), addrbuf); + priskv_log_debug("RDMA: listener<%s> cm event: %s\n", addrbuf, evstr); + + switch (ev->event) { + case RDMA_CM_EVENT_CONNECT_REQUEST: + priskv_rdma_handle_connect_request(ev, listener); + break; + + case RDMA_CM_EVENT_ESTABLISHED: + priskv_rdma_handle_established(ev, listener); + break; + + case RDMA_CM_EVENT_DISCONNECTED: + priskv_rdma_handle_disconnected(ev, listener); + break; + + default: + priskv_log_error("RDMA: listener<%s> listener unexpected cm event: %s\n", addrbuf, evstr); + } + + rdma_ack_cm_event(ev); + + goto again; +} + +priskv_transport_driver priskv_transport_driver_rdma = { + .name = "rdma", + .init = NULL, + .listen = priskv_rdma_listen, + .get_fd = priskv_rdma_get_fd, + .get_kv = priskv_rdma_get_kv, + .get_listeners = priskv_rdma_get_listeners, + .free_listeners = priskv_rdma_free_listeners, + .send_response = priskv_rdma_send_response, + .rw_req = priskv_rdma_rw_req, + .recv_req = priskv_rdma_recv_req, + .mem_new = priskv_rdma_mem_new, + .mem_free = priskv_rdma_mem_free, + .request_key_off = priskv_rdma_request_key_off, + .request_key = priskv_rdma_request_key, + .close_client = priskv_rdma_close_client, +}; diff --git a/server/transport/transport.c b/server/transport/transport.c new file mode 100644 index 0000000..83f6f55 --- /dev/null +++ b/server/transport/transport.c @@ -0,0 +1,579 @@ +// Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "transport.h" +#include + +#include "../kv.h" +#include "priskv-event.h" +#include "priskv-log.h" +#include "priskv-threads.h" +#include "priskv-protocol-helper.h" + +priskv_transport_driver *g_transport_driver = NULL; +priskv_threadpool *g_threadpool = NULL; +priskv_transport_server g_transport_server = { + .nlisteners = 0, + .epollfd = -1, + .context = NULL, +}; + +extern priskv_transport_driver priskv_transport_driver_ucx; +extern priskv_transport_driver priskv_transport_driver_rdma; + +uint32_t g_slow_query_threshold_latency_us = SLOW_QUERY_THRESHOLD_LATENCY_US; + +// forward declaration +void priskv_tiering_get(priskv_tiering_req *treq); +void priskv_tiering_test(priskv_tiering_req *treq); +void priskv_tiering_set(priskv_tiering_req *treq); +void priskv_tiering_del(priskv_tiering_req *treq); + +static void __attribute__((constructor)) priskv_server_transport_init(void) +{ + const char *transport_env = getenv("PRISKV_TRANSPORT"); + priskv_transport_backend backend = PRISKV_TRANSPORT_BACKEND_RDMA; + if (transport_env) { + if (strcasecmp(transport_env, "UCX") == 0) { + backend = PRISKV_TRANSPORT_BACKEND_UCX; + } else if (strcasecmp(transport_env, "RDMA") == 0) { + backend = PRISKV_TRANSPORT_BACKEND_RDMA; + } else { + priskv_log_error("Unknown transport backend: %s\n", transport_env); + } + } + + priskv_transport_driver *driver = NULL; + switch (backend) { + case PRISKV_TRANSPORT_BACKEND_UCX: + driver = &priskv_transport_driver_ucx; + priskv_log_notice("Using UCX transport backend\n"); + break; + case PRISKV_TRANSPORT_BACKEND_RDMA: + driver = &priskv_transport_driver_rdma; + priskv_log_notice("Using RDMA transport backend\n"); + break; + default: + priskv_log_error("Unknown transport backend: %d\n", backend); + break; + } + + if (driver && driver->init) { + if (driver->init() != 0) { + priskv_log_error("Failed to initialize transport driver: %s\n", driver->name); + driver = NULL; + } + } + + if (driver) { + g_transport_driver = driver; + return 0; + } + + return -1; +} + +int priskv_transport_listen(char **addr, int naddr, int port, void *kv, + priskv_transport_conn_cap *cap) +{ + if (!g_transport_driver) { + priskv_log_error("Transport driver is NULL\n"); + return -1; + } + return g_transport_driver->listen(addr, naddr, port, kv, cap); +} + +int priskv_transport_get_fd(void) +{ + if (!g_transport_driver) { + priskv_log_error("Transport driver is NULL\n"); + return -1; + } + return g_transport_driver->get_fd(); +} + +void *priskv_transport_get_kv(void) +{ + if (!g_transport_driver) { + priskv_log_error("Transport driver is NULL\n"); + return NULL; + } + return g_transport_driver->get_kv(); +} + +priskv_transport_listener *priskv_transport_get_listeners(int *nlisteners) +{ + if (!g_transport_driver) { + priskv_log_error("Transport driver is NULL\n"); + return NULL; + } + return g_transport_driver->get_listeners(nlisteners); +} + +void priskv_transport_free_listeners(priskv_transport_listener *listeners, int nlisteners) +{ + if (!g_transport_driver) { + priskv_log_error("Transport driver is NULL\n"); + return; + } + g_transport_driver->free_listeners(listeners, nlisteners); +} + +int priskv_transport_send_response(priskv_transport_conn *conn, uint64_t request_id, + priskv_resp_status status, uint32_t length) +{ + if (!g_transport_driver) { + priskv_log_error("Transport driver is NULL\n"); + return -1; + } + return g_transport_driver->send_response(conn, request_id, status, length); +} + +int priskv_transport_rw_req(priskv_transport_conn *conn, priskv_request *req, + priskv_transport_memh *memh, uint8_t *val, uint32_t valuelen, bool set, + void (*cb)(void *), void *cbarg, bool defer_resp, + priskv_transport_rw_work **work_out) +{ + if (!g_transport_driver) { + priskv_log_error("Transport driver is NULL\n"); + return -1; + } + return g_transport_driver->rw_req(conn, req, memh, val, valuelen, set, cb, cbarg, defer_resp, + work_out); +} + +void priskv_check_and_log_slow_query(priskv_transport_rw_work *work) +{ + struct timeval server_resp_send_time; + priskv_request *req = (priskv_request *)work->req; + uint16_t command = be16toh(req->command); + uint8_t *key = g_transport_driver->request_key(req); + uint16_t keylen = be16toh(req->key_length); + + gettimeofday(&server_resp_send_time, NULL); + req->runtime.server_resp_send_time = server_resp_send_time; + if (priskv_time_elapsed_us(&req->runtime.client_metadata_send_time, + &req->runtime.server_resp_send_time) > + g_slow_query_threshold_latency_us) { + char key_short[128] = {0}; + priskv_string_shorten((const char *)key, keylen, key_short, sizeof(key_short)); + priskv_log_notice( + "Slow Query Encountered . " + "Slow Query threshold latency is %ld us |" + "Command %s key[%u] = \"%s\" |" + "thread id is %lu |" + "Client send metadata: %ld.%06ld us | " + "Server recv metadata: %ld.%06ld us | " + "Server RW KV: %ld.%06ld us | " + "Server send data: %ld.%06ld us | " + "Server recv data: %ld.%06ld us | " + "Server send resp: %ld.%06ld us | " + "Total: %ld us | " + "Steps: " + "Client->Server metadata: %ld us | " + "Server metadata->RW KV: %ld us | " + "RW KV->Send data: %ld us | " + "Send data->Recv data: %ld us | " + "Recv data->Resp send: %ld us \n", + + g_slow_query_threshold_latency_us, priskv_command_str(command), keylen, key_short, + pthread_self(), req->runtime.client_metadata_send_time.tv_sec, + req->runtime.client_metadata_send_time.tv_usec, + req->runtime.server_metadata_recv_time.tv_sec, + req->runtime.server_metadata_recv_time.tv_usec, req->runtime.server_rw_kv_time.tv_sec, + req->runtime.server_rw_kv_time.tv_usec, req->runtime.server_data_send_time.tv_sec, + req->runtime.server_data_send_time.tv_usec, req->runtime.server_data_recv_time.tv_sec, + req->runtime.server_data_recv_time.tv_usec, req->runtime.server_resp_send_time.tv_sec, + req->runtime.server_resp_send_time.tv_usec, + + priskv_time_elapsed_us(&req->runtime.client_metadata_send_time, + &req->runtime.server_resp_send_time), + + priskv_time_elapsed_us(&req->runtime.client_metadata_send_time, + &req->runtime.server_metadata_recv_time), + priskv_time_elapsed_us(&req->runtime.server_metadata_recv_time, + &req->runtime.server_rw_kv_time), + priskv_time_elapsed_us(&req->runtime.server_rw_kv_time, + &req->runtime.server_data_send_time), + priskv_time_elapsed_us(&req->runtime.server_data_send_time, + &req->runtime.server_data_recv_time), + priskv_time_elapsed_us(&req->runtime.server_data_recv_time, + &req->runtime.server_resp_send_time)); + } +} + +int priskv_transport_handle_recv(priskv_transport_conn *conn, priskv_request *req, uint32_t len) +{ + uint16_t command = be16toh(req->command); + uint16_t nsgl = be16toh(req->nsgl); + uint64_t timeout = be64toh(req->timeout); + uint8_t *key; + uint16_t keylen; + uint16_t keyoff = g_transport_driver->request_key_off(req); + uint8_t *val; + uint32_t valuelen = 0, nkeys = 0; + uint32_t remote_valuelen; + uint64_t bytes = 0; + void *keynode; + priskv_resp_status status; + int ret = 0; + bool tiering_inflight = false; + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_KEYS]; + priskv_transport_driver *driver = g_transport_driver; + + if (len < keyoff) { + priskv_log_warn("Transport: <%s - %s> invalid command. recv %d, less than %d, nsgl 0x%x\n", + conn->local_addr, conn->peer_addr, len, keyoff, nsgl); + driver->send_response(conn, req->request_id, PRISKV_RESP_STATUS_INVALID_COMMAND, 0); + return -EPROTO; + } + + keylen = len - keyoff; + if (!keylen) { + priskv_log_warn("Transport: <%s - %s> empty key. recv %d, less than %d, nsgl 0x%x\n", + conn->local_addr, conn->peer_addr, len, keyoff, nsgl); + + driver->send_response(conn, req->request_id, PRISKV_RESP_STATUS_KEY_EMPTY, 0); + return -EPROTO; + } + + if (keylen > conn->conn_cap.max_key_length) { + priskv_log_warn("Transport: <%s - %s> invalid key. key(%d) exceeds max_key_length(%d)\n", + conn->local_addr, conn->peer_addr, keylen, conn->conn_cap.max_key_length); + driver->send_response(conn, req->request_id, PRISKV_RESP_STATUS_KEY_TOO_BIG, 0); + return -EPROTO; + } + + if (nsgl > conn->conn_cap.max_sgl) { + priskv_log_warn("Transport: <%s - %s> invalid nsgl. nsgl(%d) exceeds max_sgl(%d)\n", + conn->local_addr, conn->peer_addr, nsgl, conn->conn_cap.max_sgl); + driver->send_response(conn, req->request_id, PRISKV_RESP_STATUS_INVALID_SGL, 0); + return -EPROTO; + } + + key = driver->request_key(req); + + if (priskv_get_log_level() >= priskv_log_debug) { + char key_short[128] = {0}; + priskv_string_shorten((const char *)key, keylen, key_short, sizeof(key_short)); + priskv_log_debug("Transport: <%s - %s> %s key[%u] = \"%s\"\n", conn->local_addr, + conn->peer_addr, priskv_command_str(command), keylen, key_short); + } + + switch (command) { + case PRISKV_COMMAND_GET: { + struct timeval server_rw_kv_time, server_data_send_time; + remote_valuelen = priskv_sgl_size_from_be(req->sgls, nsgl); + + if (!priskv_backend_tiering_enabled()) { + status = priskv_get_key(conn->kv, key, keylen, &val, &valuelen, &keynode); + if (status != PRISKV_RESP_STATUS_OK || !keynode) { + ret = driver->send_response(conn, req->request_id, status, 0); + priskv_get_key_end(keynode); + break; + } + + gettimeofday(&server_rw_kv_time, NULL); + req->runtime.server_rw_kv_time = server_rw_kv_time; + + if (remote_valuelen < valuelen) { + ret = driver->send_response(conn, req->request_id, PRISKV_RESP_STATUS_VALUE_TOO_BIG, + valuelen); + priskv_get_key_end(keynode); + break; + } + + ret = driver->rw_req(conn, req, &conn->value_memh, val, valuelen, false, + priskv_get_key_end, keynode, false, NULL); + + gettimeofday(&server_data_send_time, NULL); + req->runtime.server_data_send_time = server_data_send_time; + + bytes = valuelen; + } else { + priskv_resp_status alloc_status = PRISKV_RESP_STATUS_OK; + priskv_tiering_req *treq = + priskv_tiering_req_new(conn, req, key, keylen, PRISKV_KEY_MAX_TIMEOUT, + PRISKV_COMMAND_GET, remote_valuelen, &alloc_status); + if (!treq) { + ret = driver->send_response(conn, req->request_id, alloc_status, 0); + break; + } + + tiering_inflight = true; + priskv_tiering_get(treq); + } + break; + } + case PRISKV_COMMAND_SET: { + struct timeval server_rw_kv_time, server_data_send_time; + + remote_valuelen = priskv_sgl_size_from_be(req->sgls, nsgl); + if (!remote_valuelen) { + ret = driver->send_response(conn, req->request_id, PRISKV_RESP_STATUS_VALUE_EMPTY, 0); + break; + } + + if (!priskv_backend_tiering_enabled()) { + status = + priskv_set_key(conn->kv, key, keylen, &val, remote_valuelen, timeout, &keynode); + if (status != PRISKV_RESP_STATUS_OK || !keynode) { + ret = driver->send_response(conn, req->request_id, status, 0); + priskv_set_key_end(keynode); + break; + } + + gettimeofday(&server_rw_kv_time, NULL); + req->runtime.server_rw_kv_time = server_rw_kv_time; + + ret = driver->rw_req(conn, req, &conn->value_memh, val, remote_valuelen, true, + priskv_set_key_end, keynode, false, NULL); + + gettimeofday(&server_data_send_time, NULL); + req->runtime.server_data_send_time = server_data_send_time; + + bytes = remote_valuelen; + } else { + priskv_resp_status alloc_status = PRISKV_RESP_STATUS_OK; + priskv_tiering_req *treq = + priskv_tiering_req_new(conn, req, key, keylen, timeout, PRISKV_COMMAND_SET, + remote_valuelen, &alloc_status); + if (!treq) { + ret = driver->send_response(conn, req->request_id, alloc_status, 0); + break; + } + + tiering_inflight = true; + priskv_tiering_set(treq); + } + break; + } + + case PRISKV_COMMAND_TEST: { + if (!priskv_backend_tiering_enabled()) { + status = priskv_get_key(conn->kv, key, keylen, &val, &valuelen, &keynode); + ret = driver->send_response(conn, req->request_id, status, valuelen); + priskv_get_key_end(keynode); + break; + } + + priskv_resp_status alloc_status = PRISKV_RESP_STATUS_OK; + priskv_tiering_req *treq = priskv_tiering_req_new(conn, req, key, keylen, timeout, + PRISKV_COMMAND_TEST, 0, &alloc_status); + if (!treq) { + ret = driver->send_response(conn, req->request_id, alloc_status, 0); + break; + } + + tiering_inflight = true; + priskv_tiering_test(treq); + break; + } + + case PRISKV_COMMAND_DELETE: { + if (!priskv_backend_tiering_enabled()) { + status = priskv_delete_key(conn->kv, key, keylen); + ret = driver->send_response(conn, req->request_id, status, 0); + break; + } + + priskv_resp_status alloc_status = PRISKV_RESP_STATUS_OK; + priskv_tiering_req *treq = priskv_tiering_req_new(conn, req, key, keylen, timeout, + PRISKV_COMMAND_DELETE, 0, &alloc_status); + if (!treq) { + ret = driver->send_response(conn, req->request_id, alloc_status, 0); + break; + } + + tiering_inflight = true; + priskv_tiering_del(treq); + break; + } + + case PRISKV_COMMAND_EXPIRE: + status = priskv_expire_key(conn->kv, key, keylen, timeout); + ret = driver->send_response(conn, req->request_id, status, 0); + break; + + case PRISKV_COMMAND_KEYS: + if (rmem->memh.handle) { + /* a single KEYS command is allowed inflight with a connection */ + driver->send_response(conn, req->request_id, PRISKV_RESP_STATUS_NO_MEM, 0); + ret = 0; + break; + } + + remote_valuelen = priskv_sgl_size_from_be(req->sgls, nsgl); + if (driver->mem_new(conn, rmem, "Keys", remote_valuelen)) { + ret = driver->send_response(conn, req->request_id, PRISKV_RESP_STATUS_NO_MEM, valuelen); + break; + } + + status = + priskv_get_keys(conn->kv, key, keylen, rmem->buf, remote_valuelen, &valuelen, &nkeys); + if ((status != PRISKV_RESP_STATUS_OK) || !valuelen) { + driver->mem_free(conn, rmem); + ret = driver->send_response(conn, req->request_id, status, valuelen); + break; + } + + ret = driver->rw_req(conn, req, &rmem->memh, rmem->buf, valuelen, false, NULL, NULL, false, + NULL); + if (ret) { + driver->mem_free(conn, rmem); + ret = driver->send_response(conn, req->request_id, status, valuelen); + } + break; + + case PRISKV_COMMAND_NRKEYS: + status = priskv_get_keys(conn->kv, key, keylen, NULL, 0, &valuelen, &nkeys); + /* PRISKV_RESP_STATUS_VALUE_TOO_BIG is expected */ + if (status == PRISKV_RESP_STATUS_VALUE_TOO_BIG) { + ret = driver->send_response(conn, req->request_id, PRISKV_RESP_STATUS_OK, nkeys); + break; + } + ret = driver->send_response(conn, req->request_id, status, 0); + break; + + case PRISKV_COMMAND_FLUSH: + status = priskv_flush_keys(conn->kv, key, keylen, &nkeys); + ret = driver->send_response(conn, req->request_id, status, nkeys); + break; + + default: + priskv_log_warn("Transport: <%s - %s> unknown command %d\n", conn->local_addr, + conn->peer_addr, command); + ret = driver->send_response(conn, req->request_id, PRISKV_RESP_STATUS_NO_SUCH_COMMAND, 0); + } + + if (!tiering_inflight) { + conn->c.stats[command].ops++; + if (!ret) { + driver->recv_req(conn, (uint8_t *)req); + conn->c.stats[command].bytes += bytes; + } + } + + return ret; +} + +static int priskv_transport_complete_rw_work(priskv_transport_rw_work *work, + priskv_resp_status status, uint32_t length) +{ + if (!work) { + return -EINVAL; + } + + priskv_transport_conn *conn = work->conn; + + int ret = g_transport_driver->send_response(conn, work->request_id, status, length); + + if (work->memh.handle != conn->value_memh.handle) { + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_KEYS]; + assert(work->memh.handle == rmem->memh.handle); + + g_transport_driver->mem_free(conn, rmem); + priskv_log_debug("Transport: KEYS done\n"); + } + + priskv_check_and_log_slow_query(work); + + free(work); + return ret; +} + +int priskv_transport_handle_rw(priskv_transport_conn *conn, priskv_transport_rw_work *work) +{ + work->completed++; + assert(work->completed <= work->nsgl); + + if (work->completed < work->nsgl) { + return 0; + } + + if (work->cb) { + work->cb(work->cbarg); + } + + if (work->defer_resp) { + return 0; + } + + return priskv_transport_complete_rw_work(work, PRISKV_RESP_STATUS_OK, work->valuelen); +} + +void priskv_transport_mark_client_closed(priskv_transport_conn *client) +{ + pthread_spin_lock(&client->lock); + if (client->c.closing) { + pthread_spin_unlock(&client->lock); + return; + } + + client->c.closing = true; + pthread_spin_unlock(&client->lock); + + priskv_log_notice("Transport: <%s - %s> async close client\n", client->local_addr, + client->peer_addr); +} + +void priskv_transport_close_disconnected(priskv_transport_conn *listener) +{ + priskv_transport_conn *client, *tmp; + + pthread_spin_lock(&listener->lock); + list_for_each_safe (&listener->s.head, client, tmp, c.node) { + if (client->c.closing) { + listener->s.nclients--; + list_del(&client->c.node); + pthread_spin_unlock(&listener->lock); + } else { + continue; + } + + g_transport_driver->close_client(client); + + pthread_spin_lock(&listener->lock); + } + pthread_spin_unlock(&listener->lock); +} + +void priskv_transport_process(void) +{ + priskv_transport_conn *listener; +#define PRISKV_EPOLL_MAX_CM_EVENT 32 + struct epoll_event events[PRISKV_EPOLL_MAX_CM_EVENT]; + int nevents; + + nevents = epoll_wait(g_transport_server.epollfd, events, PRISKV_EPOLL_MAX_CM_EVENT, 1000); + if (!nevents) { + goto close_disconnected; + } + + if (nevents < 0) { + assert(errno == EINTR); + goto close_disconnected; + } + + for (int n = 0; n < nevents; n++) { + struct epoll_event *event = &events[n]; + priskv_fd_handler_event(event); + } + +close_disconnected: + for (int i = 0; i < g_transport_server.nlisteners; i++) { + listener = &g_transport_server.listeners[i]; + priskv_transport_close_disconnected(listener); + } +} diff --git a/server/transport/transport.h b/server/transport/transport.h new file mode 100644 index 0000000..661c8ba --- /dev/null +++ b/server/transport/transport.h @@ -0,0 +1,338 @@ +// Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef __PRISKV_SERVER_TRANSPORT__ +#define __PRISKV_SERVER_TRANSPORT__ + +#if defined(__cplusplus) +extern "C" +{ +#endif + +#include +#include +#include + +#include "list.h" +#include "priskv-protocol.h" +#include "priskv-threads.h" +#include "priskv-ucx.h" +#include "priskv-utils.h" +#include "uthash.h" + +#define PRISKV_TRANSPORT_MAX_BIND_ADDR 32 +#define PRISKV_TRANSPORT_DEFAULT_PORT ('H' << 8 | 'P') +#define PRISKV_TRANSPORT_MAX_INFLIGHT_COMMAND 4096 +#define PRISKV_TRANSPORT_DEFAULT_INFLIGHT_COMMAND 128 +#define PRISKV_TRANSPORT_MAX_SGL 8 +#define PRISKV_TRANSPORT_DEFAULT_SGL 4 +#define PRISKV_TRANSPORT_MAX_KEY (1 << 30) +#define PRISKV_TRANSPORT_DEFAULT_KEY (16 * 1024) +#define PRISKV_TRANSPORT_MAX_KEY_LENGTH 1024 +#define PRISKV_TRANSPORT_DEFAULT_KEY_LENGTH 128 +#define PRISKV_TRANSPORT_MAX_VALUE_BLOCK_SIZE (1 << 20) +#define PRISKV_TRANSPORT_DEFAULT_VALUE_BLOCK_SIZE 4096 +#define PRISKV_TRANSPORT_MAX_VALUE_BLOCK (1UL << 30) +#define PRISKV_TRANSPORT_DEFAULT_VALUE_BLOCK (1024UL * 1024) +#define SLOW_QUERY_THRESHOLD_LATENCY_US 1000000 + +extern uint32_t g_slow_query_threshold_latency_us; + +typedef enum priskv_transport_backend { + PRISKV_TRANSPORT_BACKEND_UCX, + PRISKV_TRANSPORT_BACKEND_RDMA, + PRISKV_TRANSPORT_BACKEND_MAX, +} priskv_transport_backend; + +typedef struct priskv_transport_stats { + uint64_t ops; + uint64_t bytes; +} priskv_transport_stats; + +typedef priskv_cm_cap priskv_transport_conn_cap; + +typedef struct priskv_transport_client { + char address[PRISKV_ADDR_LEN]; + priskv_transport_stats stats[PRISKV_COMMAND_MAX]; + uint64_t resps; + bool closing; +} priskv_transport_client; + +typedef struct priskv_transport_listener { + char address[PRISKV_ADDR_LEN]; + int nclients; + priskv_transport_client *clients; +} priskv_transport_listener; + +typedef union priskv_transport_memh { + void *handle; + struct ibv_mr *rdma_mr; // rdma + priskv_ucx_memh *ucx_memh; // ucx +} priskv_transport_memh; + +typedef struct priskv_transport_mem { +#define PRISKV_TRANSPORT_MEM_NAME_LEN 32 + char name[PRISKV_TRANSPORT_MEM_NAME_LEN]; + uint8_t *buf; + uint32_t buf_size; + priskv_transport_memh memh; +} priskv_transport_mem; + +typedef enum priskv_transport_mem_type { + PRISKV_TRANSPORT_MEM_REQ, + PRISKV_TRANSPORT_MEM_RESP, + PRISKV_TRANSPORT_MEM_KEYS, + + PRISKV_TRANSPORT_MEM_MAX +} priskv_transport_mem_type; + +typedef struct priskv_transport_conn { + priskv_transport_conn_cap conn_cap; + char local_addr[PRISKV_ADDR_LEN]; + char peer_addr[PRISKV_ADDR_LEN]; + pthread_spinlock_t lock; + + priskv_transport_memh value_memh; + union { + struct { + struct rdma_cm_id *cm_id; + struct ibv_comp_channel *comp_channel; + struct ibv_cq *cq; + }; // rdma + struct { + union { + struct { + int listenfd; + int efd; + priskv_transport_conn_cap conn_cap_be; + }; // listener + struct { + int connfd; + priskv_ucx_worker *worker; + priskv_ucx_ep *ep; + priskv_ucx_request *inflight_reqs; + }; // client + }; + }; // ucx + }; + + union { + struct { + struct list_head head; + uint32_t nclients; + } s; /* for listener */ + struct { + struct priskv_transport_conn *listener; + struct list_node node; + priskv_thread *thread; + bool closing; + priskv_transport_stats stats[PRISKV_COMMAND_MAX]; + uint64_t resps; + } c; /* for client */ + }; + + void *kv; + uint8_t *value_base; + priskv_transport_mem rmem[PRISKV_TRANSPORT_MEM_MAX]; +} priskv_transport_conn; + +typedef struct priskv_transport_rw_work { + priskv_transport_conn *conn; + uint64_t request_id; /* be64 type */ + priskv_request *req; + priskv_transport_memh memh; + union { + uint32_t rdma_rkey; // rdma + priskv_ucx_rkey *ucx_rkey; // ucx + }; + + uint32_t valuelen; + uint16_t nsgl; + uint16_t completed; + bool defer_resp; + void (*cb)(void *); + void *cbarg; +} priskv_transport_rw_work; + +typedef struct priskv_transport_server { + int epollfd; + void *kv; + int nlisteners; + priskv_transport_conn listeners[PRISKV_TRANSPORT_MAX_BIND_ADDR]; + priskv_transport_conn_cap cap; + union { + struct {}; // rdma + struct { + priskv_ucx_context *context; + }; // ucx + }; +} priskv_transport_server; + +typedef struct priskv_transport_driver { + const char *name; + int (*init)(void); + int (*listen)(char **addr, int naddr, int port, void *kv, priskv_transport_conn_cap *cap); + + int (*get_fd)(void); + void *(*get_kv)(void); + + // listeners + priskv_transport_listener *(*get_listeners)(int *nlisteners); + void (*free_listeners)(priskv_transport_listener *listeners, int nlisteners); + + // req/resp + int (*send_response)(priskv_transport_conn *conn, uint64_t request_id, + priskv_resp_status status, uint32_t length); + int (*rw_req)(priskv_transport_conn *conn, priskv_request *req, priskv_transport_memh *memh, + uint8_t *val, uint32_t valuelen, bool set, void (*cb)(void *), void *cbarg, + bool defer_resp, priskv_transport_rw_work **work_out); + int (*recv_req)(priskv_transport_conn *conn, uint8_t *req); + + // mem + int (*mem_new)(priskv_transport_conn *conn, priskv_transport_mem *rmem, const char *name, + uint32_t size); + void (*mem_free)(priskv_transport_conn *conn, priskv_transport_mem *rmem); + + // request layout + uint16_t (*request_key_off)(priskv_request *req); + uint8_t *(*request_key)(priskv_request *req); + + void (*close_client)(priskv_transport_conn *client); +} priskv_transport_driver; + +/** + * @brief Listen on the specified addresses and port. + * + * @param addr The addresses to listen on. + * @param naddr The number of addresses. + * @param port The port to listen on. + * @param kv The key-value store to use. + * @param cap The connection capacity. + * @return int 0 on success, others on error. + */ +int priskv_transport_listen(char **addr, int naddr, int port, void *kv, + priskv_transport_conn_cap *cap); + +/** + * @brief Get the file descriptor for the transport driver. + * + * @return int The file descriptor. + */ +int priskv_transport_get_fd(void); + +/** + * @brief Process events for the transport driver. + */ +void priskv_transport_process(void); + +/** + * @brief Get the key-value store associated with the transport driver. + * + * @return void* The key-value store. + */ +void *priskv_transport_get_kv(void); + +/** + * @brief Get the listeners associated with the transport driver. + * + * @param nlisteners The number of listeners. + * @return priskv_transport_listener* The listeners. + */ +priskv_transport_listener *priskv_transport_get_listeners(int *nlisteners); + +/** + * @brief Free the listeners associated with the transport driver. + * + * @param listeners The listeners to free. + * @param nlisteners The number of listeners. + */ +void priskv_transport_free_listeners(priskv_transport_listener *listeners, int nlisteners); + +/** + * @brief Send a response to the client. + * + * @param conn The transport connection. + * @param request_id The request ID. + * @param status The response status. + * @param length The response length. + * @return int 0 on success, others on error. + */ +int priskv_transport_send_response(priskv_transport_conn *conn, uint64_t request_id, + priskv_resp_status status, uint32_t length); + +/** + * @brief Submit a read or write request to the transport driver. + * + * @param conn The transport connection. + * @param req The request to submit. + * @param memh The memory handle to use. + * @param val The value to read or write. + * @param valuelen The length of the value. + * @param set Whether to perform a write operation. + * @param cb The callback function to invoke when the request completes. + * @param cbarg The argument to pass to the callback function. + * @param defer_resp Whether to defer sending the response. + * @param work_out The output parameter to store the submitted work. + * @return int 0 on success, others on error. + */ +int priskv_transport_rw_req(priskv_transport_conn *conn, priskv_request *req, + priskv_transport_memh *memh, uint8_t *val, uint32_t valuelen, bool set, + void (*cb)(void *), void *cbarg, bool defer_resp, + priskv_transport_rw_work **work_out); + +/** + * @brief Check and log slow queries. + * + * @param work The transport rw work. + */ +void priskv_check_and_log_slow_query(priskv_transport_rw_work *work); + +/** + * @brief Handle a received request. + * + * @param conn The transport connection. + * @param req The request to handle. + * @param len The length of the request. + * @return int 0 on success, others on error. + */ +int priskv_transport_handle_recv(priskv_transport_conn *conn, priskv_request *req, uint32_t len); + +/** + * @brief Handle a read or write request. + * + * @param conn The transport connection. + * @param work The transport rw work. + * @return int 0 on success, others on error. + */ +int priskv_transport_handle_rw(priskv_transport_conn *conn, priskv_transport_rw_work *work); + +/** + * @brief Mark a client connection as closed. + * + * @param client The transport connection. + */ +void priskv_transport_mark_client_closed(priskv_transport_conn *client); + +/** + * @brief Close all disconnected client connections. + * + * @param listener The transport listener. + */ +void priskv_transport_close_disconnected(priskv_transport_conn *listener); + +#if defined(__cplusplus) +} +#endif + +#endif diff --git a/server/transport/ucx.c b/server/transport/ucx.c new file mode 100644 index 0000000..1de3e32 --- /dev/null +++ b/server/transport/ucx.c @@ -0,0 +1,1119 @@ +// Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "transport.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../acl.h" +#include "../memory.h" +#include "../kv.h" +#include "../backend/backend.h" +#include "priskv-protocol.h" +#include "priskv-protocol-helper.h" +#include "priskv-log.h" +#include "priskv-utils.h" +#include "priskv-threads.h" +#include "priskv-event.h" +#include "priskv-ucx.h" +#include "transport.h" +#include "uthash.h" + +extern priskv_transport_server g_transport_server; +extern priskv_threadpool *g_threadpool; + +static uint32_t priskv_ucx_max_rw_size = 1024 * 1024 * 1024; + +// forward declaration +static int priskv_ucx_new_ctrl_buffer(priskv_transport_conn *conn); +static inline void priskv_ucx_free_ctrl_buffer(priskv_transport_conn *conn); + +typedef struct priskv_ucx_conn_aux { + priskv_transport_conn *conn; + void *aux; +} priskv_ucx_conn_aux; + +static inline uint32_t priskv_ucx_wr_size(priskv_transport_conn *client) +{ + return client->conn_cap.max_inflight_command * (2 + client->conn_cap.max_sgl); +} + +#define PRISKV_UCX_RESPONSE_FREE_STATUS 0xffff +static inline int priskv_ucx_response_free(priskv_response *resp) +{ + if (resp->status == PRISKV_UCX_RESPONSE_FREE_STATUS) { + return -EPROTO; + } + + resp->status = PRISKV_UCX_RESPONSE_FREE_STATUS; + return 0; +} + +static int priskv_ucx_init(void) +{ + g_transport_server.context = priskv_ucx_context_init(0); + if (g_transport_server.context == NULL) { + priskv_log_error("ucx context init failed\n"); + return -1; + } + + return 0; +} + +static inline void priskv_ucx_listener_signal(priskv_transport_conn *listener) +{ + uint64_t u = 1; + + write(listener->efd, &u, sizeof(u)); +} + +static inline void priskv_ucx_mark_client_closed(priskv_transport_conn *client) +{ + priskv_transport_mark_client_closed(client); + priskv_ucx_listener_signal(client->c.listener); +} + +static void priskv_ucx_recv_req_cb(ucs_status_t status, ucp_tag_t sender_tag, size_t length, + void *arg) +{ + if (ucs_unlikely(arg == NULL)) { + priskv_log_error("UCX: priskv_ucx_recv_req_cb, arg is NULL\n"); + return; + } + + priskv_ucx_conn_aux *conn_req = arg; + priskv_transport_conn *conn = conn_req->conn; + priskv_request *req = conn_req->aux; + free(conn_req); + + priskv_ucx_request *handle = NULL; + HASH_FIND_PTR(conn->inflight_reqs, &conn_req, handle); + if (handle) { + priskv_log_debug("UCX: remove request %p from inflight_reqs\n", handle); + HASH_DEL(conn->inflight_reqs, handle); + } + + if (ucs_unlikely(status != UCS_OK)) { + if (status == UCS_ERR_CANCELED) { + priskv_log_debug("UCX: priskv_ucx_recv_req_cb, status: %s, req: %p\n", + ucs_status_string(status), arg); + } else { + priskv_log_error("UCX: priskv_ucx_recv_req_cb, status: %s, req: %p\n", + ucs_status_string(status), arg); + } + priskv_ucx_mark_client_closed(conn); + return; + } + + struct timeval server_metadata_recv_time; + + gettimeofday(&server_metadata_recv_time, NULL); + req->runtime.server_metadata_recv_time = server_metadata_recv_time; + + if (priskv_transport_handle_recv(conn, req, length)) { + priskv_ucx_mark_client_closed(conn); + } +} + +static int priskv_ucx_recv_req(priskv_transport_conn *conn, uint8_t *req) +{ + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_REQ]; + uint16_t req_buf_size = + priskv_ucx_max_request_size_aligned(conn->conn_cap.max_sgl, conn->conn_cap.max_key_length); + + priskv_log_debug("UCX: priskv_ucx_recv_req addr %p, length %d\n", req, req_buf_size); + assert((req >= rmem->buf) && (req < rmem->buf + rmem->buf_size)); + + priskv_ucx_conn_aux *conn_req = malloc(sizeof(priskv_ucx_conn_aux)); + if (ucs_unlikely(conn_req == NULL)) { + priskv_log_error("UCX: priskv_ucx_recv_req, malloc conn_req failed\n"); + return -ENOMEM; + } + + conn_req->conn = conn; + conn_req->aux = req; + ucs_status_ptr_t handle = + priskv_ucx_ep_post_tag_recv(conn->ep, req, req_buf_size, PRISKV_PROTO_TAG_CTRL, + PRISKV_PROTO_FULL_TAG_MASK, priskv_ucx_recv_req_cb, conn_req); + if (UCS_PTR_IS_ERR(handle)) { + ucs_status_t status = UCS_PTR_STATUS(handle); + priskv_log_error("UCX: <%s - %s> priskv_ucx_ep_post_tag_recv failed, status: %s\n", + conn->local_addr, conn->peer_addr, ucs_status_string(status)); + return -EIO; + } else if (UCS_PTR_IS_PTR(handle)) { + // still in progress + priskv_ucx_request *request = (priskv_ucx_request *)handle; + if (request->status == UCS_INPROGRESS) { + request->key = conn_req; + HASH_ADD_PTR(conn->inflight_reqs, key, request); + } + } else { + // Operation completed immediately + } + + return req_buf_size; +} + +static void priskv_ucx_conn_close_cb(ucs_status_t status, void *arg) +{ + priskv_transport_conn *client = arg; + priskv_log_error("UCX: <%s - %s> ep close, name %s, status: %s\n", client->local_addr, + client->peer_addr, client->ep->name, ucs_status_string(status)); + priskv_ucx_mark_client_closed(client); +} + +static inline void priskv_ucx_client_efd_progress(int fd, void *opaque, uint32_t ev) +{ + priskv_log_debug("UCX: client efd progress event %d, worker %p, efd %d\n", ev, opaque, fd); + priskv_ucx_worker_progress(opaque); +} + +static inline void priskv_ucx_client_connfd_progress(int fd, void *opaque, uint32_t ev) +{ + priskv_transport_conn *client = opaque; + priskv_log_error("UCX: <%s - %s> ep close, name %s\n", client->local_addr, client->peer_addr, + client->ep->name); + priskv_ucx_mark_client_closed(client); +} + +static inline int priskv_ucx_verify_conn_cap(priskv_transport_conn_cap *client, + priskv_transport_conn_cap *listener, uint64_t *val) +{ + if (!client->max_sgl) { + client->max_sgl = listener->max_sgl; + } else if (client->max_sgl > listener->max_sgl) { + *val = listener->max_sgl; + return PRISKV_CM_REJ_STATUS_INVALID_SGL; + } + + if (!client->max_key_length) { + client->max_key_length = listener->max_key_length; + } else if (client->max_key_length > listener->max_key_length) { + *val = listener->max_key_length; + return PRISKV_CM_REJ_STATUS_INVALID_KEY_LENGTH; + } + + if (!client->max_inflight_command) { + client->max_inflight_command = listener->max_inflight_command; + } else if (client->max_inflight_command > listener->max_inflight_command) { + *val = listener->max_inflight_command; + return PRISKV_CM_REJ_STATUS_INVALID_INFLIGHT_COMMAND; + } + + return 0; +} + +static inline void priskv_ucx_reject(priskv_transport_conn *client, priskv_cm_status status, + uint64_t value) +{ + priskv_cm_ucx_handshake rej_msg_be = { + .flag = 0, + .version = htobe16(PRISKV_CM_VERSION), + .status = htobe16(status), + .value = htobe64(value), + }; + int ret = priskv_safe_send(client->connfd, &rej_msg_be, sizeof(rej_msg_be), NULL, NULL); + if (ret < 0) { + priskv_log_error("UCX: send reject message failed: %m\n"); + } +} + +static inline int priskv_ucx_accept(priskv_transport_conn *client) +{ + uint32_t address_len = client->worker->address_len; + size_t hs_size = sizeof(priskv_cm_ucx_handshake) + address_len; + priskv_cm_ucx_handshake *hs = malloc(hs_size); + if (ucs_unlikely(!hs)) { + priskv_log_error("UCX: malloc accept message failed: %m\n"); + return -1; + } + + hs->flag = 1; + hs->cap.version = htobe16(client->conn_cap.version); + hs->cap.max_sgl = htobe16(client->conn_cap.max_sgl); + hs->cap.max_key_length = htobe16(client->conn_cap.max_key_length); + hs->cap.max_inflight_command = htobe16(client->conn_cap.max_inflight_command); + hs->cap.capacity = htobe64(client->conn_cap.capacity); + hs->address_len = htobe32(address_len); + memcpy(hs->address, client->worker->address, address_len); + + if (priskv_get_log_level() >= priskv_log_debug) { + size_t print_len = address_len > 128 ? 128 : address_len; + char worker_address_hex[print_len * 2 + 1]; + priskv_ucx_to_hex(worker_address_hex, client->worker->address, print_len); + priskv_log_debug( + "UCX: send worker address to client %s, address_len %d, address (first %d) %s\n", + client->peer_addr, address_len, print_len, worker_address_hex); + } + + int ret = priskv_safe_send(client->connfd, hs, hs_size, NULL, NULL); + if (ret < 0) { + priskv_log_error("UCX: send accept message failed: %m\n"); + goto out_free_msg; + } + + priskv_log_info("UCX: <%s - %s> accept connect request, name %s\n", client->local_addr, + client->peer_addr, client->ep->name); + +out_free_msg: + free(hs); + return ret; +} + +static inline void priskv_ucx_handle_cm(int fd, void *opaque, uint32_t ev) +{ + priskv_log_debug("UCX: listener efd progress event %d, listener %p, efd %d\n", ev, opaque, fd); + + priskv_transport_conn *listener = opaque; + int ret; + struct sockaddr_storage client_addr; + socklen_t client_addr_len; + int connfd; + char peer_addr[PRISKV_ADDR_LEN]; + priskv_cm_ucx_handshake peer_hs; + priskv_cm_status status; + uint64_t value = 0; + uint8_t *peer_worker_address = NULL; + + assert(listener->listenfd == fd); + + client_addr_len = sizeof(client_addr); + connfd = accept(listener->listenfd, (struct sockaddr *)&client_addr, &client_addr_len); + if (connfd < 0) { + priskv_log_error("UCX: accept on listenfd %d failed: %m\n", listener->listenfd); + return; + } + + priskv_inet_ntop(&client_addr, peer_addr); + priskv_log_info("UCX: accept on listenfd %d, connfd %d, client addr %s\n", listener->listenfd, + connfd, peer_addr); + + priskv_transport_conn *client = calloc(1, sizeof(priskv_transport_conn)); + assert(client); + client->ep = NULL; + client->inflight_reqs = NULL; + client->c.listener = listener; + client->c.thread = NULL; + client->c.closing = false; + client->connfd = connfd; + list_node_init(&client->c.node); + pthread_spin_init(&client->lock, 0); + + const char *local_addr = listener->local_addr; + snprintf(client->local_addr, PRISKV_ADDR_LEN, "%s", local_addr); + snprintf(client->peer_addr, PRISKV_ADDR_LEN, "%s", peer_addr); + + pthread_spin_lock(&listener->lock); + list_add_tail(&listener->s.head, &client->c.node); + listener->s.nclients++; + pthread_spin_unlock(&listener->lock); + + /* #step0, recv handshake msg */ + ret = priskv_safe_recv(connfd, &peer_hs, sizeof(peer_hs), NULL, NULL); + if (ret < 0) { + priskv_log_error("UCX: recv handshake msg failed: %m\n"); + ucs_close_fd(&connfd); + return; + } + + client->conn_cap.version = be16toh(peer_hs.cap.version); + client->conn_cap.max_sgl = be16toh(peer_hs.cap.max_sgl); + client->conn_cap.max_key_length = be32toh(peer_hs.cap.max_key_length); + client->conn_cap.max_inflight_command = be16toh(peer_hs.cap.max_inflight_command); + size_t peer_worker_address_len = be32toh(peer_hs.address_len); + + if (peer_worker_address_len > 0) { + peer_worker_address = malloc(peer_worker_address_len); + if (!peer_worker_address) { + priskv_log_error("UCX: malloc peer address failed: %m\n"); + ucs_close_fd(&connfd); + return; + } + ret = priskv_safe_recv(connfd, peer_worker_address, peer_worker_address_len, NULL, NULL); + if (ret < 0) { + priskv_log_error("UCX: recv peer address failed: %m\n"); + ucs_close_fd(&connfd); + return; + } + } + + priskv_log_info("UCX: <%s - %s> incoming connect request - version %d, max_sgl %d, " + "max_key_length %d, max_inflight_command %d, address_len %d\n", + local_addr, peer_addr, client->conn_cap.version, client->conn_cap.max_sgl, + client->conn_cap.max_key_length, client->conn_cap.max_inflight_command, + peer_worker_address_len); + + if (client->conn_cap.version != PRISKV_CM_VERSION) { + status = PRISKV_CM_REJ_STATUS_INVALID_VERSION; + value = PRISKV_CM_VERSION; + goto rej; + } + + if (!peer_worker_address) { + priskv_log_error("UCX: <%s - %s> peer worker address is empty\n", local_addr, peer_addr); + status = PRISKV_CM_REJ_STATUS_INVALID_WORKER_ADDR; + value = 0; + goto rej; + } + + if (priskv_get_log_level() >= priskv_log_debug) { + size_t print_len = peer_worker_address_len > 128 ? 128 : peer_worker_address_len; + char worker_address_hex[print_len * 2 + 1]; + priskv_ucx_to_hex(worker_address_hex, peer_worker_address, print_len); + priskv_log_debug( + "UCX: got peer worker address from client %s, address_len %d, address (first %d) %s\n", + peer_addr, peer_worker_address_len, print_len, worker_address_hex); + } + + status = priskv_ucx_verify_conn_cap(&client->conn_cap, &listener->conn_cap, &value); + if (status) { + goto rej; + } + + /* #step1, ACL verification */ + if (priskv_acl_verify((struct sockaddr *)&client_addr)) { + priskv_log_error("UCX: <%s - %s> ACL verification failed\n", local_addr, peer_addr); + status = PRISKV_CM_REJ_STATUS_ACL_REFUSE; + value = 0; + goto rej; + } + + client->worker = priskv_ucx_worker_create(g_transport_server.context, 0); + if (client->worker == NULL) { + priskv_log_error("UCX: <%s - %s> create worker failed\n", local_addr, peer_addr); + status = PRISKV_CM_REJ_STATUS_SERVER_ERROR; + value = 0; + goto rej; + } + + client->ep = priskv_ucx_ep_create_from_worker_addr(client->worker, peer_worker_address, + priskv_ucx_conn_close_cb, client); + if (client->ep == NULL) { + priskv_log_error("UCX: <%s - %s> create ep failed\n", local_addr, peer_addr); + status = PRISKV_CM_REJ_STATUS_SERVER_ERROR; + goto rej; + } + + free(peer_worker_address); + peer_worker_address = NULL; + + /* #step2, create related resources */ + if (priskv_ucx_new_ctrl_buffer(client)) { + priskv_log_error("UCX: <%s - %s> create ctrl buffer failed, name %s\n", local_addr, + peer_addr, client->ep->name); + status = PRISKV_CM_REJ_STATUS_SERVER_ERROR; + goto rej; + } + + /* #step3, post tag recv */ + uint32_t wr_size = priskv_ucx_wr_size(client); + uint8_t *recv_req = client->rmem[PRISKV_TRANSPORT_MEM_REQ].buf; + for (uint16_t i = 0; i < wr_size; i++) { + int recvsize = priskv_ucx_recv_req(client, recv_req); + if (recvsize < 0) { + status = PRISKV_CM_REJ_STATUS_SERVER_ERROR; + goto rej; + } + + recv_req += recvsize; + } + + /* #step4, initialize KV of client */ + client->value_base = client->c.listener->value_base; + client->kv = client->c.listener->kv; + client->value_memh.ucx_memh = client->c.listener->value_memh.ucx_memh; + + /* use the idlest worker thread to drive the progress */ + priskv_set_fd_handler(client->worker->efd, priskv_ucx_client_efd_progress, NULL, + client->worker); + client->c.thread = priskv_threadpool_find_iothread(g_threadpool); + priskv_thread_add_event_handler(client->c.thread, client->worker->efd); + priskv_log_debug("UCX: <%s - %s> assign worker efd %d to thread %d\n", local_addr, peer_addr, + client->worker->efd, client->c.thread); + + if (priskv_set_nonblock(client->connfd)) { + priskv_log_error("UCX: <%s - %s> failed to set nonblock mode for connfd\n", local_addr, + peer_addr); + status = PRISKV_CM_REJ_STATUS_SERVER_ERROR; + goto rej; + } + + if (priskv_ucx_accept(client) < 0) { + priskv_log_error("UCX: <%s - %s> accept failed, name %s\n", local_addr, peer_addr, + client->ep->name); + status = PRISKV_CM_REJ_STATUS_SERVER_ERROR; + goto rej; + } + + priskv_set_fd_handler(client->connfd, priskv_ucx_client_connfd_progress, NULL, client); + priskv_thread_add_event_handler(client->c.thread, client->connfd); + priskv_log_debug("UCX: <%s - %s> assign connfd %d to thread %d\n", local_addr, peer_addr, + client->connfd, client->c.thread); + + priskv_log_notice("UCX: <%s - %s> established\n", local_addr, peer_addr); + + return; + +rej: + priskv_log_warn("UCX: <%s - %s> %s, reject\n", local_addr, peer_addr, + priskv_cm_status_str(status)); + if (peer_worker_address) { + free(peer_worker_address); + peer_worker_address = NULL; + } + priskv_ucx_reject(client, status, value); + priskv_ucx_mark_client_closed(client); +} + +static int priskv_ucx_listen_one(char *addr, int port, void *kv, priskv_transport_conn_cap *cap) +{ + int ret; + int listenfd = -1; + int optval = 1; + struct addrinfo hints, *res, *t; + char service[8]; + char err_str[64]; + ucs_status_t status; + + ucs_snprintf_safe(service, sizeof(service), "%u", port); + memset(&hints, 0, sizeof(hints)); + hints.ai_flags = AI_PASSIVE; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + ret = getaddrinfo(addr, service, &hints, &res); + if (ret < 0) { + priskv_log_error("UCX: getaddrinfo failed, server %s, port %s, error %s\n", addr, service, + gai_strerror(ret)); + ret = -1; + goto out; + } + + if (res == NULL) { + priskv_log_error("UCX: getaddrinfo returned empty list\n"); + ret = -1; + goto out; + } + + for (t = res; t != NULL; t = t->ai_next) { + listenfd = socket(t->ai_family, t->ai_socktype, t->ai_protocol); + if (listenfd < 0) { + snprintf(err_str, 64, "socket failed: %m\n"); + continue; + } + + status = ucs_socket_setopt(listenfd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)); + if (status != UCS_OK) { + snprintf(err_str, 64, "setopt failed: %m\n"); + continue; + } + + ret = priskv_set_nonblock(listenfd); + if (ret) { + snprintf(err_str, 64, "set NONBLOCK failed: %m\n"); + continue; + } + + if (bind(listenfd, t->ai_addr, t->ai_addrlen) == 0) { + break; + } + + snprintf(err_str, 64, "bind failed: %m\n"); + ucs_close_fd(&listenfd); + listenfd = -1; + } + + if (listenfd < 0) { + priskv_log_error("UCX: bind failed, server %s, port %s, error %s\n", addr, service, + err_str); + ret = -1; + goto out_free_res; + } + + ret = listen(listenfd, 0); + if (ret < 0) { + priskv_log_error("UCX: listen failed: %m\n"); + ret = -1; + goto err_close_listenfd; + } + + priskv_transport_conn *listener = &g_transport_server.listeners[g_transport_server.nlisteners]; + + uint8_t *value_base = priskv_get_value_base(kv); + assert(value_base); + uint64_t size = priskv_get_value_blocks(kv) * priskv_get_value_block_size(kv); + assert(size); + listener->value_memh.ucx_memh = + priskv_ucx_mmap(g_transport_server.context, value_base, size, UCS_MEMORY_TYPE_HOST); + if (!listener->value_memh.ucx_memh) { + ret = -1; + priskv_log_error( + "UCX: failed to reg MR for value: %m [%p, %p], value block %ld, value block size %d\n", + value_base, value_base + size, priskv_get_value_blocks(kv), + priskv_get_value_block_size(kv)); + goto err_close_listenfd; + } + + priskv_log_debug("UCX: Value buffer %p, length %ld\n", value_base, size); + + listener->listenfd = listenfd; + listener->value_base = value_base; + listener->kv = kv; + listener->conn_cap = *cap; + listener->conn_cap.capacity = size; + listener->conn_cap_be.version = htobe16(PRISKV_CM_VERSION); + listener->conn_cap_be.max_sgl = htobe16(cap->max_sgl); + listener->conn_cap_be.max_key_length = htobe16(cap->max_key_length); + listener->conn_cap_be.max_inflight_command = htobe16(cap->max_inflight_command); + listener->conn_cap_be.capacity = htobe64(size); + listener->s.nclients = 0; + list_head_init(&listener->s.head); + pthread_spin_init(&listener->lock, 0); + + listener->efd = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK); + if (listener->efd < 0) { + priskv_log_error("UCX: failed to create eventfd %m\n"); + ret = -1; + goto err_close_listenfd; + } + + priskv_inet_ntop(t->ai_addr, listener->local_addr); + priskv_log_info("UCX: <%s> listener starts\n", listener->local_addr); + + g_transport_server.nlisteners++; + ret = 0; + goto out_free_res; + +err_close_listenfd: + ucs_close_fd(&listenfd); +out_free_res: + freeaddrinfo(res); +out: + return ret; +} + +static int priskv_ucx_listen(char **addrs, int naddrs, int port, void *kv, + priskv_transport_conn_cap *cap) +{ + priskv_transport_conn *listener; + int efd; + + for (int i = 0; i < naddrs; i++) { + int ret = priskv_ucx_listen_one(addrs[i], port, kv, cap); + if (ret) { + return ret; + } + } + + g_transport_server.kv = kv; + + g_transport_server.epollfd = epoll_create(g_transport_server.nlisteners); + if (g_transport_server.epollfd == -1) { + priskv_log_error("UCX: failed to create epoll fd %m\n"); + return -1; + } + + for (int i = 0; i < g_transport_server.nlisteners; i++) { + listener = &g_transport_server.listeners[i]; + + priskv_set_fd_handler(listener->listenfd, priskv_ucx_handle_cm, NULL, listener); + if (priskv_add_event_fd(g_transport_server.epollfd, listener->listenfd)) { + priskv_log_error("UCX: failed to add listenfd into epoll fd %m\n"); + return -1; + } + + priskv_set_fd_handler(listener->efd, NULL, NULL, NULL); + if (priskv_add_event_fd(g_transport_server.epollfd, listener->efd)) { + priskv_log_error("UCX: failed to add eventfd into epoll fd %m\n"); + return -1; + } + + priskv_log_notice("UCX: <%s> ready\n", listener->local_addr); + } + + return 0; +} + +static int priskv_ucx_get_fd(void) +{ + return g_transport_server.epollfd; +} + +static void *priskv_ucx_get_kv(void) +{ + return g_transport_server.kv; +} + +static int priskv_ucx_mem_new(priskv_transport_conn *conn, priskv_transport_mem *rmem, + const char *name, uint32_t size) +{ + bool guard = true; /* always enable memory guard */ + uint8_t *buf; + int ret; + + buf = priskv_mem_malloc(size, guard); + if (!buf) { + priskv_log_error("UCX: failed to allocate %s buffer: %m\n", name); + ret = -ENOMEM; + goto error; + } + + rmem->memh.ucx_memh = + priskv_ucx_mmap(g_transport_server.context, buf, size, UCS_MEMORY_TYPE_HOST); + if (!rmem->memh.ucx_memh) { + priskv_log_error("UCX: failed to reg MR for %s buffer: %m\n", name); + ret = -errno; + goto free_mem; + } + + strncpy(rmem->name, name, PRISKV_TRANSPORT_MEM_NAME_LEN - 1); + rmem->buf = buf; + rmem->buf_size = size; + + priskv_log_info("UCX: new rmem %s, size %d\n", name, size); + priskv_log_debug("UCX: new rmem %s, buf %p\n", name, buf); + return 0; + +free_mem: + priskv_mem_free(rmem->buf, rmem->buf_size, guard); + +error: + memset(rmem, 0x00, sizeof(priskv_transport_mem)); + + return ret; +} + +static inline void priskv_ucx_mem_free(priskv_transport_conn *conn, priskv_transport_mem *rmem) +{ + if (rmem->memh.ucx_memh) { + priskv_ucx_munmap(rmem->memh.ucx_memh); + rmem->memh.ucx_memh = NULL; + } + + if (rmem->buf) { + priskv_log_debug("UCX: free rmem %s, buf %p\n", rmem->name, rmem->buf); + priskv_mem_free(rmem->buf, rmem->buf_size, true); + } + + priskv_log_info("UCX: free rmem %s, size %d\n", rmem->name, rmem->buf_size); + memset(rmem, 0x00, sizeof(priskv_transport_mem)); +} + +static int priskv_ucx_new_ctrl_buffer(priskv_transport_conn *conn) +{ + uint16_t size; + uint32_t buf_size; + + /* #step 1, prepare buffer & MR for request from client */ + size = + priskv_ucx_max_request_size_aligned(conn->conn_cap.max_sgl, conn->conn_cap.max_key_length); + buf_size = (uint32_t)size * priskv_ucx_wr_size(conn); + if (priskv_ucx_mem_new(conn, &conn->rmem[PRISKV_TRANSPORT_MEM_REQ], "Request", buf_size)) { + goto error; + } + + /* #step 2, prepare buffer & MR for response to client */ + size = sizeof(priskv_response); + buf_size = size * priskv_ucx_wr_size(conn); + if (priskv_ucx_mem_new(conn, &conn->rmem[PRISKV_TRANSPORT_MEM_RESP], "Response", buf_size)) { + goto error; + } + + for (uint16_t i = 0; i < priskv_ucx_wr_size(conn); i++) { + priskv_response *resp = + (priskv_response *)(conn->rmem[PRISKV_TRANSPORT_MEM_RESP].buf + i * size); + priskv_ucx_response_free(resp); + } + + return 0; + +error: + priskv_ucx_free_ctrl_buffer(conn); + return -ENOMEM; +} + +static inline void priskv_ucx_free_ctrl_buffer(priskv_transport_conn *conn) +{ + for (int i = 0; i < PRISKV_TRANSPORT_MEM_MAX; i++) { + priskv_transport_mem *rmem = &conn->rmem[i]; + + priskv_ucx_mem_free(conn, rmem); + } +} + +static inline void priskv_ucx_cancel_client_requests(priskv_transport_conn *client) +{ + priskv_ucx_request *req, *tmp; + + HASH_ITER(hh, client->inflight_reqs, req, tmp) + { + HASH_DEL(client->inflight_reqs, req); + if (!req->handle) { + continue; + } + priskv_ucx_request_cancel(req); + } +} + +static void priskv_ucx_close_client(priskv_transport_conn *client) +{ + priskv_log_notice("UCX: <%s - %s> close. name %s. Requests GET %ld, SET %ld, TEST %ld, " + "DELETE %ld, Responses %ld\n", + client->local_addr, client->peer_addr, client->ep->name, + client->c.stats[PRISKV_COMMAND_GET].ops, + client->c.stats[PRISKV_COMMAND_SET].ops, + client->c.stats[PRISKV_COMMAND_TEST].ops, + client->c.stats[PRISKV_COMMAND_DELETE].ops, client->c.resps); + + if ((client->worker) && (client->c.thread != NULL)) { + priskv_thread_call_function(client->c.thread, priskv_ucx_cancel_client_requests, client); + priskv_thread_del_event_handler(client->c.thread, client->worker->efd); + priskv_thread_del_event_handler(client->c.thread, client->connfd); + priskv_set_fd_handler(client->worker->efd, NULL, NULL, NULL); /* clear fd handler */ + priskv_set_fd_handler(client->connfd, NULL, NULL, NULL); /* clear fd handler */ + client->c.thread = NULL; + } + + priskv_ucx_free_ctrl_buffer(client); + + if (client->ep) { + priskv_ucx_ep_destroy(client->ep); + client->ep = NULL; + } + + if (client->worker) { + priskv_ucx_worker_destroy(client->worker); + client->worker = NULL; + } + + free(client); +} + +static void priskv_ucx_get_clients(priskv_transport_conn *listener, + priskv_transport_client **clients, int *nclients) +{ + priskv_transport_conn *client; + *nclients = 0; + + pthread_spin_lock(&listener->lock); + *clients = calloc(listener->s.nclients, sizeof(priskv_transport_client)); + list_for_each (&listener->s.head, client, c.node) { + const char *peer_addr = client->peer_addr; + memcpy((*clients)[*nclients].address, peer_addr, strlen(peer_addr) + 1); + memcpy((*clients)[*nclients].stats, client->c.stats, + PRISKV_COMMAND_MAX * sizeof(priskv_transport_stats)); + (*clients)[*nclients].resps = client->c.resps; + (*clients)[*nclients].closing = client->c.closing; + (*nclients)++; + + if (*nclients == listener->s.nclients) { + break; + } + } + pthread_spin_unlock(&listener->lock); +} + +static priskv_transport_listener *priskv_ucx_get_listeners(int *nlisteners) +{ + priskv_transport_listener *listeners; + + *nlisteners = g_transport_server.nlisteners; + listeners = calloc(*nlisteners, sizeof(priskv_transport_listener)); + + for (int i = 0; i < *nlisteners; i++) { + const char *local_addr = g_transport_server.listeners[i].local_addr; + memcpy(listeners[i].address, local_addr, strlen(local_addr) + 1); + priskv_ucx_get_clients(&g_transport_server.listeners[i], &listeners[i].clients, + &listeners[i].nclients); + } + + return listeners; +} + +void priskv_ucx_free_listeners(priskv_transport_listener *listeners, int nlisteners) +{ + for (int i = 0; i < nlisteners; i++) { + free(listeners[i].clients); + } + free(listeners); +} + +static priskv_response *priskv_ucx_unused_response(priskv_transport_conn *conn) +{ + uint16_t resp_buf_size = sizeof(priskv_response); + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_RESP]; + + for (uint16_t i = 0; i < priskv_ucx_wr_size(conn); i++) { + priskv_response *resp = (priskv_response *)(rmem->buf + i * resp_buf_size); + if (resp->status == PRISKV_UCX_RESPONSE_FREE_STATUS) { + priskv_log_debug("UCX: use response %d\n", i); + resp->status = PRISKV_RESP_STATUS_OK; + return resp; + } + } + + priskv_log_error("UCX: <%s - %s> inflight response exceeds %d\n", conn->local_addr, + conn->peer_addr, priskv_ucx_wr_size(conn)); + return NULL; +} + +static void priskv_ucx_send_response_cb(ucs_status_t status, void *arg) +{ + if (ucs_unlikely(arg == NULL)) { + priskv_log_error("UCX: priskv_ucx_send_response_cb, arg is NULL\n"); + return; + } + + priskv_ucx_conn_aux *conn_resp = arg; + priskv_transport_conn *conn = conn_resp->conn; + priskv_response *resp = conn_resp->aux; + priskv_ucx_request *req; + free(conn_resp); + + HASH_FIND_PTR(conn->inflight_reqs, &resp, req); + if (req) { + priskv_log_debug("UCX: remove request %p from inflight_reqs\n", req); + HASH_DEL(conn->inflight_reqs, req); + } + + if (status != UCS_OK) { + priskv_log_error( + "UCX: priskv_ucx_send_response_cb, status: %s, response: %p, request_id: %lu\n", + ucs_status_string(status), resp, resp->request_id); + priskv_ucx_mark_client_closed(conn); + return; + } + priskv_ucx_response_free(resp); +} + +static int priskv_ucx_send_response(priskv_transport_conn *conn, uint64_t request_id, + priskv_resp_status status, uint32_t length) +{ + priskv_transport_mem *rmem = &conn->rmem[PRISKV_TRANSPORT_MEM_RESP]; + priskv_response *resp; + + resp = priskv_ucx_unused_response(conn); + if (!resp) { + return -EPROTO; + } + + assert(((uint8_t *)resp >= rmem->buf) && ((uint8_t *)resp < rmem->buf + rmem->buf_size)); + + resp->request_id = request_id; /* be64 */ + resp->status = htobe16(status); + resp->length = htobe32(length); + + priskv_ucx_conn_aux *conn_resp = malloc(sizeof(priskv_ucx_conn_aux)); + if (ucs_unlikely(conn_resp == NULL)) { + priskv_log_error("UCX: priskv_ucx_send_response, malloc conn_resp failed\n"); + return -ENOMEM; + } + + conn_resp->conn = conn; + conn_resp->aux = resp; + ucs_status_ptr_t handle = + priskv_ucx_ep_post_tag_send(conn->ep, resp, sizeof(priskv_response), PRISKV_PROTO_TAG_CTRL, + priskv_ucx_send_response_cb, conn_resp); + if (UCS_PTR_IS_ERR(handle)) { + ucs_status_t ucs_status = UCS_PTR_STATUS(handle); + priskv_log_error("UCX: <%s - %s> priskv_ucx_ep_post_tag_send response failed: addr 0x%lx, " + "length 0x%x status %s\n", + conn->local_addr, conn->peer_addr, (uint64_t)resp, sizeof(priskv_response), + ucs_status_string(ucs_status)); + return -EIO; + } else if (UCS_PTR_IS_PTR(handle)) { + // still in progress + priskv_ucx_request *req = handle; + if (req->status == UCS_INPROGRESS) { + req->key = conn; + HASH_ADD_PTR(conn->inflight_reqs, key, req); + } + } else { + // Operation completed immediately + } + + conn->c.resps++; + return 0; +} + +static void priskv_ucx_rw_req_cb(ucs_status_t status, void *arg) +{ + struct timeval server_data_recv_time; + + priskv_transport_rw_work *work = arg; + priskv_request *req = (priskv_request *)work->req; + priskv_transport_conn *conn = work->conn; + + priskv_ucx_request *handle = NULL; + HASH_FIND_PTR(conn->inflight_reqs, &work, handle); + if (handle) { + priskv_log_debug("UCX: remove request %p from inflight_reqs\n", handle); + HASH_DEL(conn->inflight_reqs, handle); + } + + gettimeofday(&server_data_recv_time, NULL); + req->runtime.server_data_recv_time = server_data_recv_time; + + if (work->ucx_rkey) { + priskv_ucx_rkey_destroy(work->ucx_rkey); + } + + if (priskv_transport_handle_rw(conn, work)) { + priskv_ucx_mark_client_closed(conn); + } +} + +static int priskv_ucx_rw_req(priskv_transport_conn *conn, priskv_request *req, + priskv_transport_memh *memh, uint8_t *val, uint32_t valuelen, bool set, + void (*cb)(void *), void *cbarg, bool defer_resp, + priskv_transport_rw_work **work_out) +{ + priskv_transport_rw_work *work; + uint32_t offset = 0; + uint16_t nsgl = be16toh(req->nsgl); + const char *cmdstr = set ? "READ" : "WRITE"; + + if (work_out) { + *work_out = NULL; + } + + work = calloc(1, sizeof(priskv_transport_rw_work)); + if (!work) { + priskv_log_error("UCX: failed to allocate memory for %s request\n", cmdstr); + return -ENOMEM; + } + + const char *local_addr = conn->local_addr; + const char *peer_addr = conn->peer_addr; + + work->conn = conn; + work->req = req; + work->memh.ucx_memh = memh->ucx_memh; + work->ucx_rkey = NULL; + work->request_id = req->request_id; /* be64 */ + work->valuelen = valuelen; + work->completed = 0; + work->cb = cb; + work->cbarg = cbarg; + work->defer_resp = defer_resp; + + uint8_t *keyed_sgl_base = (uint8_t *)req->sgls; + for (uint16_t i = 0; i < nsgl; i++) { + priskv_keyed_sgl *sgl = (priskv_keyed_sgl *)keyed_sgl_base; + + uint64_t remote_base = be64toh(sgl->addr); + uint32_t sgl_length = be32toh(sgl->length); + uint32_t packed_rkey_len = be32toh(sgl->packed_rkey_len); + uint32_t runtime_offset = 0; + keyed_sgl_base += sizeof(priskv_keyed_sgl) + packed_rkey_len; + + if (priskv_get_log_level() >= priskv_log_debug) { + char rkey_hex[priskv_ucx_max_rkey_len * 2 + 1]; + priskv_ucx_to_hex(rkey_hex, sgl->packed_rkey, packed_rkey_len); + priskv_log_debug("UCX: got rkey (%d) %s\n", packed_rkey_len, rkey_hex); + } + + work->ucx_rkey = priskv_ucx_rkey_create(conn->ep, sgl->packed_rkey); + if (ucs_unlikely(!work->ucx_rkey)) { + priskv_log_error("UCX: <%s - %s> priskv_ucx_rkey_create failed: %m\n", local_addr, + peer_addr); + free(work); + return -EPROTO; + } + + do { + uint64_t local_base = (uint64_t)val + offset + runtime_offset; + uint32_t length = priskv_min_u32(sgl_length - runtime_offset, valuelen); + length = priskv_min_u32(length, priskv_ucx_max_rw_size); + // need to go before post_get/put in case of immediate completion, + // which will check nsgl + work->nsgl++; + + ucs_status_ptr_t handle = NULL; + if (set) { + handle = priskv_ucx_ep_post_get(conn->ep, local_base, length, work->ucx_rkey, + remote_base + runtime_offset, priskv_ucx_rw_req_cb, + work); + } else { + handle = priskv_ucx_ep_post_put(conn->ep, local_base, length, work->ucx_rkey, + remote_base + runtime_offset, priskv_ucx_rw_req_cb, + work); + } + + if (UCS_PTR_IS_ERR(handle)) { + ucs_status_t status = UCS_PTR_STATUS(handle); + priskv_log_error("UCX: <%s - %s> priskv_ucx_ep_post_put/get failed: addr 0x%lx, " + "length 0x%x status %s\n", + local_addr, peer_addr, local_base, length, + ucs_status_string(status)); + // ucx will return UCS_ERR_ALREADY_EXISTS if the same buffer is posted again + // to avoid memory corruption, we only log the error and consider it as completed + if (status != UCS_ERR_ALREADY_EXISTS) { + return -EIO; + } + } else if (UCS_PTR_IS_PTR(handle)) { + // still in progress + priskv_ucx_request *request = (priskv_ucx_request *)handle; + if (request->status == UCS_INPROGRESS) { + request->key = work; + HASH_ADD_PTR(conn->inflight_reqs, key, request); + } + } else { + // Operation completed immediately + } + + priskv_log_debug("UCX: %s [%d/%d]:[%d/%d], val %p, length 0x%x\n", cmdstr, i, nsgl, + runtime_offset, sgl_length, val + offset + runtime_offset, length); + + runtime_offset += length; + } while (runtime_offset < priskv_min_u32(sgl_length, valuelen)); + + offset += sgl_length; + valuelen -= sgl_length; + } + + if (work_out) { + *work_out = work; + } + + return 0; +} + +priskv_transport_driver priskv_transport_driver_ucx = { + .name = "ucx", + .init = priskv_ucx_init, + .listen = priskv_ucx_listen, + .get_fd = priskv_ucx_get_fd, + .get_kv = priskv_ucx_get_kv, + .get_listeners = priskv_ucx_get_listeners, + .free_listeners = priskv_ucx_free_listeners, + .send_response = priskv_ucx_send_response, + .rw_req = priskv_ucx_rw_req, + .recv_req = priskv_ucx_recv_req, + .mem_new = priskv_ucx_mem_new, + .mem_free = priskv_ucx_mem_free, + .request_key_off = priskv_ucx_request_key_off, + .request_key = priskv_ucx_request_key, + .close_client = priskv_ucx_close_client, +};