Skip to content
Open
39 changes: 39 additions & 0 deletions prov/efa/src/efa_domain.c
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,33 @@ int efa_domain_open(struct fid_fabric *fabric_fid, struct fi_info *info,
} else {
assert(efa_domain->info_type == EFA_INFO_DIRECT || efa_domain->info_type == EFA_INFO_DGRAM);
efa_domain->util_domain.domain_fid.ops = &efa_domain_ops;

/* Allocate and register bounce buffer for 0-byte inject (efa-direct only) */
if (efa_domain->info_type == EFA_INFO_DIRECT) {
struct iovec iov;
struct fid_mr *mr_fid;

efa_domain->zero_byte_bounce_buf = malloc(4096);
if (!efa_domain->zero_byte_bounce_buf) {
EFA_WARN(FI_LOG_DOMAIN, "Failed to allocate zero-byte bounce buffer\n");
err = -FI_ENOMEM;
goto err_free;
}

iov.iov_base = efa_domain->zero_byte_bounce_buf;
iov.iov_len = 4096;
err = efa_mr_internal_regv(&efa_domain->util_domain.domain_fid,
&iov, 1,
FI_SEND | FI_RECV | FI_READ | FI_WRITE | FI_REMOTE_READ | FI_REMOTE_WRITE,
0, 0, 0, &mr_fid, NULL);
if (err) {
EFA_WARN(FI_LOG_DOMAIN, "Failed to register zero-byte bounce buffer: %d\n", err);
free(efa_domain->zero_byte_bounce_buf);
efa_domain->zero_byte_bounce_buf = NULL;
goto err_free;
}
efa_domain->zero_byte_bounce_buf_mr = container_of(mr_fid, struct efa_mr, mr_fid);
}
}

#ifndef _WIN32
Expand Down Expand Up @@ -370,6 +397,18 @@ static int efa_domain_close(fid_t fid)
}
ofi_genlock_unlock(&efa_domain->util_domain.lock);

if (efa_domain->zero_byte_bounce_buf_mr) {
ret = fi_close(&efa_domain->zero_byte_bounce_buf_mr->mr_fid.fid);
if (ret)
EFA_WARN(FI_LOG_DOMAIN, "Failed to close zero-byte bounce buffer MR: %d\n", ret);
efa_domain->zero_byte_bounce_buf_mr = NULL;
}

if (efa_domain->zero_byte_bounce_buf) {
free(efa_domain->zero_byte_bounce_buf);
efa_domain->zero_byte_bounce_buf = NULL;
}

if (efa_domain->ibv_pd) {
ret = ibv_dealloc_pd(efa_domain->ibv_pd);
if (ret)
Expand Down
3 changes: 3 additions & 0 deletions prov/efa/src/efa_domain.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ struct efa_domain {
uint64_t access, uint64_t offset,
uint64_t requested_key, uint64_t flags,
struct fid_mr **mr_fid, void *context);
/* Bounce buffer for 0-byte inject operations (efa-direct only) */
void *zero_byte_bounce_buf;
struct efa_mr *zero_byte_bounce_buf_mr;
};

extern struct dlist_entry g_efa_domain_list;
Expand Down
21 changes: 16 additions & 5 deletions prov/efa/src/efa_msg.c
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,24 @@ static ssize_t efa_ep_recvv(struct fid_ep *ep_fid, const struct iovec *iov, void

static inline ssize_t efa_post_send(struct efa_base_ep *base_ep, const struct fi_msg *msg, uint64_t flags)
{
struct efa_domain *domain = base_ep->domain;
struct efa_qp *qp = base_ep->qp;
struct efa_conn *conn;
struct ibv_sge sg_list[2]; /* efa device support up to 2 iov */
struct ibv_data_buf inline_data_list[2];
size_t len, i;
size_t iov_count = msg->iov_count;
bool use_inline;
int ret = 0;
uintptr_t wr_id;

efa_tracepoint(send_begin_msg_context, (size_t) msg->context, (size_t) msg->addr);

len = ofi_total_iov_len(msg->msg_iov, msg->iov_count);

EFA_DBG(FI_LOG_EP_DATA,
"total len: %zu, addr: %lu, context: %lx, flags: %lx\n",
ofi_total_iov_len(msg->msg_iov, msg->iov_count),
msg->addr, (size_t) msg->context, flags);
len, msg->addr, (size_t) msg->context, flags);

dump_msg(msg, "send");

Expand All @@ -214,8 +217,6 @@ static inline ssize_t efa_post_send(struct efa_base_ep *base_ep, const struct fi

assert(msg->iov_count <= base_ep->info->tx_attr->iov_limit);

len = ofi_total_iov_len(msg->msg_iov, msg->iov_count);

if (qp->ibv_qp->qp_type == IBV_QPT_UD) {
assert(msg->msg_iov[0].iov_len >= base_ep->info->ep_attr->msg_prefix_size);
len -= base_ep->info->ep_attr->msg_prefix_size;
Expand All @@ -229,6 +230,15 @@ static inline ssize_t efa_post_send(struct efa_base_ep *base_ep, const struct fi
wr_id = (uintptr_t) efa_fill_context(
msg->context, msg->addr, flags, FI_SEND | FI_MSG);

/* Handle 0-byte send with bounce buffer */
if (len == 0) {
inline_data_list[0].addr = domain->zero_byte_bounce_buf;
inline_data_list[0].length = 0;
iov_count = 1;
use_inline = true;
goto post;
}

/* Determine if we should use inline data */
use_inline = (len <= base_ep->domain->device->efa_attr.inline_buf_size &&
(!msg->desc || !efa_mr_is_hmem(msg->desc[0])));
Expand Down Expand Up @@ -268,8 +278,9 @@ static inline ssize_t efa_post_send(struct efa_base_ep *base_ep, const struct fi
}
}

post:
/* Use consolidated send function */
ret = efa_qp_post_send(qp, sg_list, inline_data_list, msg->iov_count,
ret = efa_qp_post_send(qp, sg_list, inline_data_list, iov_count,
use_inline, wr_id, msg->data, flags,
conn->ah, conn->ep_addr->qpn, conn->ep_addr->qkey);
if (OFI_UNLIKELY(ret))
Expand Down
175 changes: 140 additions & 35 deletions prov/efa/src/efa_rma.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ static inline ssize_t efa_rma_post_read(struct efa_base_ep *base_ep,
const struct fi_msg_rma *msg,
uint64_t flags)
{
struct efa_domain *domain = base_ep->domain;
struct efa_mr *efa_mr;
struct efa_conn *conn;
size_t iov_count = msg->iov_count;
#ifndef _WIN32
struct ibv_sge sge_list[msg->iov_count];
#else
Expand All @@ -48,48 +50,56 @@ static inline ssize_t efa_rma_post_read(struct efa_base_ep *base_ep,
#endif
uintptr_t wr_id;
int i, err = 0;
size_t total_len;

efa_tracepoint(read_begin_msg_context, (size_t) msg->context, (size_t) msg->addr);

total_len = ofi_total_iov_len(msg->msg_iov, msg->iov_count);

EFA_DBG(FI_LOG_EP_DATA,
"total len: %zu, addr: %lu, context: %lx, flags: %lx\n",
ofi_total_iov_len(msg->msg_iov, msg->iov_count),
msg->addr, (size_t) msg->context, flags);
total_len, msg->addr, (size_t) msg->context, flags);

assert(msg->iov_count > 0 &&
msg->iov_count <= base_ep->domain->info->tx_attr->iov_limit);
assert(msg->iov_count <= base_ep->domain->info->tx_attr->iov_limit);
assert(msg->rma_iov_count > 0 &&
msg->rma_iov_count <= base_ep->domain->info->tx_attr->rma_iov_limit);
assert(ofi_total_iov_len(msg->msg_iov, msg->iov_count) <=
base_ep->domain->device->max_rdma_size);
assert(total_len <= base_ep->domain->device->max_rdma_size);

ofi_genlock_lock(&base_ep->util_ep.lock);

/* Prepare work request ID */
wr_id = (uintptr_t) efa_fill_context(
msg->context, msg->addr, flags, FI_RMA | FI_READ);

/* Prepare SGE list */
for (i = 0; i < msg->iov_count; ++i) {
sge_list[i].addr = (uint64_t)msg->msg_iov[i].iov_base;
sge_list[i].length = msg->msg_iov[i].iov_len;
if (OFI_UNLIKELY(!msg->desc || !msg->desc[i])) {
EFA_WARN(FI_LOG_EP_CTRL,
"EFA direct requires FI_MR_LOCAL but "
"application does not provide a valid desc\n");
err = -FI_EINVAL;
goto out_err;
/* Handle 0-byte read with bounce buffer */
if (total_len == 0) {
sge_list[0].addr = (uint64_t)domain->zero_byte_bounce_buf;
sge_list[0].length = 0;
sge_list[0].lkey = domain->zero_byte_bounce_buf_mr->ibv_mr->lkey;
iov_count = 1;
} else {
/* Prepare SGE list */
for (i = 0; i < msg->iov_count; ++i) {
sge_list[i].addr = (uint64_t)msg->msg_iov[i].iov_base;
sge_list[i].length = msg->msg_iov[i].iov_len;
if (OFI_UNLIKELY(!msg->desc || !msg->desc[i])) {
EFA_WARN(FI_LOG_EP_CTRL,
"EFA direct requires FI_MR_LOCAL but "
"application does not provide a valid desc\n");
err = -FI_EINVAL;
goto out_err;
}
efa_mr = (struct efa_mr *)msg->desc[i];
sge_list[i].lkey = efa_mr->ibv_mr->lkey;
}
efa_mr = (struct efa_mr *)msg->desc[i];
sge_list[i].lkey = efa_mr->ibv_mr->lkey;
}

conn = efa_av_addr_to_conn(base_ep->av, msg->addr);
assert(conn && conn->ep_addr);

/* Use consolidated RDMA read function */
/* ep->domain->info->tx_attr->rma_iov_limit is set to 1 */
err = efa_qp_post_read(base_ep->qp, sge_list, msg->iov_count,
err = efa_qp_post_read(base_ep->qp, sge_list, iov_count,
msg->rma_iov[0].key, msg->rma_iov[0].addr,
wr_id, flags,
conn->ah, conn->ep_addr->qpn, conn->ep_addr->qkey);
Expand Down Expand Up @@ -178,7 +188,9 @@ static inline ssize_t efa_rma_post_write(struct efa_base_ep *base_ep,
const struct fi_msg_rma *msg,
uint64_t flags)
{
struct efa_domain *domain = base_ep->domain;
struct efa_conn *conn;
size_t iov_count = msg->iov_count;
#ifndef _WIN32
struct ibv_sge sge_list[msg->iov_count];
#else
Expand All @@ -189,6 +201,7 @@ static inline ssize_t efa_rma_post_write(struct efa_base_ep *base_ep,
#endif
uintptr_t wr_id;
int i, err = 0;
size_t total_len;

if (flags & FI_INJECT) {
EFA_WARN(FI_LOG_EP_DATA,
Expand All @@ -198,36 +211,45 @@ static inline ssize_t efa_rma_post_write(struct efa_base_ep *base_ep,

efa_tracepoint(write_begin_msg_context, (size_t) msg->context, (size_t) msg->addr);

total_len = ofi_total_iov_len(msg->msg_iov, msg->iov_count);

EFA_DBG(FI_LOG_EP_DATA,
"total len: %zu, addr: %lu, context: %lx, flags: %lx\n",
ofi_total_iov_len(msg->msg_iov, msg->iov_count),
msg->addr, (size_t) msg->context, flags);
total_len, msg->addr, (size_t) msg->context, flags);

ofi_genlock_lock(&base_ep->util_ep.lock);

/* Prepare work request ID */
wr_id = (uintptr_t) efa_fill_context(
msg->context, msg->addr, flags, FI_RMA | FI_WRITE);

/* Prepare SGE list */
for (i = 0; i < msg->iov_count; ++i) {
sge_list[i].addr = (uint64_t)msg->msg_iov[i].iov_base;
sge_list[i].length = msg->msg_iov[i].iov_len;
if (OFI_UNLIKELY(!msg->desc || !msg->desc[i])) {
EFA_WARN(FI_LOG_EP_CTRL,
"EFA direct requires FI_MR_LOCAL but "
"application does not provide a valid desc\n");
err = -FI_EINVAL;
goto out_err;
/* Handle 0-byte write with bounce buffer */
if (total_len == 0) {
sge_list[0].addr = (uint64_t)domain->zero_byte_bounce_buf;
sge_list[0].length = 0;
sge_list[0].lkey = domain->zero_byte_bounce_buf_mr->ibv_mr->lkey;
iov_count = 1;
} else {
/* Prepare SGE list */
for (i = 0; i < msg->iov_count; ++i) {
sge_list[i].addr = (uint64_t)msg->msg_iov[i].iov_base;
sge_list[i].length = msg->msg_iov[i].iov_len;
if (OFI_UNLIKELY(!msg->desc || !msg->desc[i])) {
EFA_WARN(FI_LOG_EP_CTRL,
"EFA direct requires FI_MR_LOCAL but "
"application does not provide a valid desc\n");
err = -FI_EINVAL;
goto out_err;
}
sge_list[i].lkey = ((struct efa_mr *)msg->desc[i])->ibv_mr->lkey;
}
sge_list[i].lkey = ((struct efa_mr *)msg->desc[i])->ibv_mr->lkey;
}

conn = efa_av_addr_to_conn(base_ep->av, msg->addr);
assert(conn && conn->ep_addr);

/* Use consolidated RDMA write function */
err = efa_qp_post_write(base_ep->qp, sge_list, msg->iov_count,
err = efa_qp_post_write(base_ep->qp, sge_list, iov_count,
msg->rma_iov[0].key, msg->rma_iov[0].addr,
wr_id, msg->data, flags,
conn->ah, conn->ep_addr->qpn, conn->ep_addr->qkey);
Expand Down Expand Up @@ -324,6 +346,89 @@ ssize_t efa_rma_writedata(struct fid_ep *ep_fid, const void *buf, size_t len,
return efa_rma_post_write(base_ep, &msg, FI_REMOTE_CQ_DATA | efa_tx_flags(base_ep));
}

ssize_t efa_rma_inject_write(struct fid_ep *ep_fid, const void *buf, size_t len,
fi_addr_t dest_addr, uint64_t addr, uint64_t key)
{
struct efa_base_ep *base_ep;
struct efa_domain *domain;
struct ibv_sge sge;
struct efa_conn *conn;
uintptr_t wr_id;
int err;

base_ep = container_of(ep_fid, struct efa_base_ep, util_ep.ep_fid);
domain = base_ep->domain;
err = efa_rma_check_cap(base_ep);
if (err)
return err;

/* Only support 0-byte inject for efa-direct */
if (len != 0)
return -FI_ENOSYS;

ofi_genlock_lock(&base_ep->util_ep.lock);

wr_id = (uintptr_t) efa_fill_context(NULL, dest_addr, FI_INJECT, FI_RMA | FI_WRITE);

sge.addr = (uint64_t)domain->zero_byte_bounce_buf;
sge.length = 0;
sge.lkey = domain->zero_byte_bounce_buf_mr->ibv_mr->lkey;

conn = efa_av_addr_to_conn(base_ep->av, dest_addr);
assert(conn && conn->ep_addr);

err = efa_qp_post_write(base_ep->qp, &sge, 1, key, addr,
wr_id, 0, 0, conn->ah, conn->ep_addr->qpn,
conn->ep_addr->qkey);
if (OFI_UNLIKELY(err))
err = (err == ENOMEM) ? -FI_EAGAIN : -err;

ofi_genlock_unlock(&base_ep->util_ep.lock);
return err;
}

static ssize_t efa_rma_inject_writedata(struct fid_ep *ep, const void *buf, size_t len,
uint64_t data, fi_addr_t dest_addr,
uint64_t addr, uint64_t key)
{
struct efa_base_ep *base_ep;
struct efa_domain *domain;
struct efa_conn *conn;
struct ibv_sge sge;
uintptr_t wr_id;
int err;

base_ep = container_of(ep, struct efa_base_ep, util_ep.ep_fid);
domain = base_ep->domain;
err = efa_rma_check_cap(base_ep);
if (err)
return err;

/* Only support 0-byte inject for efa-direct */
if (len != 0)
return -FI_ENOSYS;

ofi_genlock_lock(&base_ep->util_ep.lock);

wr_id = (uintptr_t) efa_fill_context(NULL, dest_addr, FI_INJECT | FI_REMOTE_CQ_DATA, FI_RMA | FI_WRITE);

sge.addr = (uint64_t)domain->zero_byte_bounce_buf;
sge.length = 0;
sge.lkey = domain->zero_byte_bounce_buf_mr->ibv_mr->lkey;

conn = efa_av_addr_to_conn(base_ep->av, dest_addr);
assert(conn && conn->ep_addr);

err = efa_qp_post_write(base_ep->qp, &sge, 1, key, addr,
wr_id, data, IBV_SEND_INLINE, conn->ah, conn->ep_addr->qpn,
conn->ep_addr->qkey);
if (OFI_UNLIKELY(err))
err = (err == ENOMEM) ? -FI_EAGAIN : -err;

ofi_genlock_unlock(&base_ep->util_ep.lock);
return err;
}

struct fi_ops_rma efa_dgram_ep_rma_ops = {
.size = sizeof(struct fi_ops_rma),
.read = fi_no_rma_read,
Expand All @@ -345,7 +450,7 @@ struct fi_ops_rma efa_rma_ops = {
.write = efa_rma_write,
.writev = efa_rma_writev,
.writemsg = efa_rma_writemsg,
.inject = fi_no_rma_inject,
.inject = efa_rma_inject_write,
.writedata = efa_rma_writedata,
.injectdata = fi_no_rma_injectdata,
.injectdata = efa_rma_inject_writedata,
};
Loading