Skip to content

Commit 20219bd

Browse files
committed
ch4/ucx: Use UCX datatypes in AM path
The nbx active message interfaces can support noncontig data, so use it. We already do the same in the tagged send and recv path.
1 parent d7ee40e commit 20219bd

File tree

2 files changed

+40
-43
lines changed

2 files changed

+40
-43
lines changed

src/mpid/ch4/netmod/ucx/ucx_am.c

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,18 +78,6 @@ int MPIDI_UCX_do_am_recv(MPIR_Request * rreq)
7878
MPI_Aint data_sz, in_data_sz;
7979
int vci = MPIDI_Request_get_vci(rreq);
8080

81-
MPIDIG_get_recv_buffer(&recv_buf, &data_sz, &is_contig, &in_data_sz, rreq);
82-
if (!is_contig || in_data_sz > data_sz) {
83-
/* non-contig datatype, need receive into pack buffer */
84-
/* ucx will error out if buffer size is less than the promised data size,
85-
* also use a pack buffer in this case */
86-
recv_buf = MPL_malloc(in_data_sz, MPL_MEM_OTHER);
87-
MPIR_Assert(recv_buf);
88-
MPIDI_UCX_AM_RECV_REQUEST(rreq, pack_buffer) = recv_buf;
89-
} else {
90-
MPIDI_UCX_AM_RECV_REQUEST(rreq, pack_buffer) = NULL;
91-
}
92-
9381
MPIDI_UCX_ucp_request_t *ucp_request;
9482
size_t received_length;
9583
ucp_request_param_t param = {
@@ -99,6 +87,26 @@ int MPIDI_UCX_do_am_recv(MPIR_Request * rreq)
9987
.recv_info.length = &received_length,
10088
.user_data = rreq,
10189
};
90+
91+
MPIDIG_get_recv_buffer(&recv_buf, &data_sz, &is_contig, &in_data_sz, rreq);
92+
if (in_data_sz > data_sz) {
93+
/* ucx will error out if buffer size is less than the promised data size,
94+
* use a pack buffer in this case */
95+
/* FIXME: what? how does UCX know the buffer size differs? */
96+
recv_buf = MPL_malloc(in_data_sz, MPL_MEM_OTHER);
97+
MPIR_Assert(recv_buf);
98+
MPIDI_UCX_AM_RECV_REQUEST(rreq, pack_buffer) = recv_buf;
99+
} else {
100+
MPIDI_UCX_AM_RECV_REQUEST(rreq, pack_buffer) = NULL;
101+
if (!is_contig) {
102+
MPIR_Datatype *dt_ptr;
103+
MPIR_Datatype_get_ptr(MPIDIG_REQUEST(rreq, datatype), dt_ptr);
104+
param.op_attr_mask |= UCP_OP_ATTR_FIELD_DATATYPE;
105+
param.datatype = dt_ptr->dev.netmod.ucx.ucp_datatype;
106+
MPIR_Datatype_ptr_add_ref(dt_ptr);
107+
}
108+
}
109+
102110
void *data_desc = MPIDI_UCX_AM_RECV_REQUEST(rreq, data_desc);
103111
/* note: use in_data_sz to match promised data size */
104112
ucp_request = ucp_am_recv_data_nbx(MPIDI_UCX_global.ctx[vci].worker,

src/mpid/ch4/netmod/ucx/ucx_am.h

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -42,52 +42,41 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_am_isend(int rank,
4242

4343
#ifdef HAVE_UCP_AM_NBX
4444
size_t header_size = sizeof(ucx_hdr) + am_hdr_sz;
45-
void *send_buf, *header, *data_ptr;
46-
/* note: since we are not copying large contig gpu data, it is less useful
47-
* to use MPIR_gpu_malloc_host */
48-
if (dt_contig) {
49-
/* only need copy headers */
50-
send_buf = MPL_malloc(header_size, MPL_MEM_OTHER);
51-
MPIR_Assert(send_buf);
52-
header = send_buf;
53-
54-
MPIR_Memcpy(header, &ucx_hdr, sizeof(ucx_hdr));
55-
MPIR_Memcpy((char *) header + sizeof(ucx_hdr), am_hdr, am_hdr_sz);
56-
57-
data_ptr = (char *) data + dt_true_lb;
58-
} else {
59-
/* need copy headers and pack data */
60-
send_buf = MPL_malloc(header_size + data_sz, MPL_MEM_OTHER);
61-
MPIR_Assert(send_buf);
62-
header = send_buf;
63-
data_ptr = (char *) send_buf + header_size;
64-
65-
MPIR_Memcpy(header, &ucx_hdr, sizeof(ucx_hdr));
66-
MPIR_Memcpy((char *) header + sizeof(ucx_hdr), am_hdr, am_hdr_sz);
67-
68-
MPI_Aint actual_pack_bytes;
69-
mpi_errno = MPIR_Typerep_pack(data, count, datatype, 0, data_ptr, data_sz,
70-
&actual_pack_bytes, MPIR_TYPEREP_FLAG_NONE);
71-
MPIR_ERR_CHECK(mpi_errno);
72-
MPIR_Assert(actual_pack_bytes == data_sz);
73-
}
45+
void *header;
46+
const void *data_ptr;
7447
ucp_request_param_t param = {
7548
.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA,
7649
.cb.send = &MPIDI_UCX_am_isend_callback_nbx,
7750
.user_data = sreq,
7851
};
52+
53+
header = MPL_malloc(header_size, MPL_MEM_OTHER);
54+
MPIR_Assert(header);
55+
56+
MPIR_Memcpy(header, &ucx_hdr, sizeof(ucx_hdr));
57+
MPIR_Memcpy((char *) header + sizeof(ucx_hdr), am_hdr, am_hdr_sz);
58+
59+
if (dt_contig) {
60+
data_ptr = (char *) data + dt_true_lb;
61+
} else {
62+
param.op_attr_mask |= UCP_OP_ATTR_FIELD_DATATYPE;
63+
param.datatype = dt_ptr->dev.netmod.ucx.ucp_datatype;
64+
MPIR_Datatype_ptr_add_ref(dt_ptr);
65+
data_ptr = data;
66+
data_sz = count;
67+
}
7968
ucp_request = (MPIDI_UCX_ucp_request_t *) ucp_am_send_nbx(ep, MPIDI_UCX_AM_NBX_HANDLER_ID,
8069
header, header_size,
8170
data_ptr, data_sz, &param);
8271
MPIDI_UCX_CHK_REQUEST(ucp_request);
8372
/* if send is done, free all resources and complete the request */
8473
if (ucp_request == NULL) {
85-
MPL_free(send_buf);
74+
MPL_free(header);
8675
MPIDIG_global.origin_cbs[handler_id] (sreq);
8776
goto fn_exit;
8877
}
8978

90-
MPIDI_UCX_AM_SEND_REQUEST(sreq, pack_buffer) = send_buf;
79+
MPIDI_UCX_AM_SEND_REQUEST(sreq, pack_buffer) = header;
9180
MPIDI_UCX_AM_SEND_REQUEST(sreq, handler_id) = handler_id;
9281
ucp_request_release(ucp_request);
9382

0 commit comments

Comments
 (0)