diff --git a/src/ucp/core/ucp_context.c b/src/ucp/core/ucp_context.c index 778e0945a35..a184527f2e4 100644 --- a/src/ucp/core/ucp_context.c +++ b/src/ucp/core/ucp_context.c @@ -384,6 +384,17 @@ static ucs_config_field_t ucp_context_config_table[] = { "even if invalidation workflow isn't supported", ucs_offsetof(ucp_context_config_t, rndv_errh_ppln_enable), UCS_CONFIG_TYPE_BOOL}, + {"RNDV_MTYPE_WORKER_FC_ENABLE", "n", + "Enable worker-level flow control to limit total concurrent mtype fragments\n" + "across all requests, preventing memory exhaustion", + ucs_offsetof(ucp_context_config_t, rndv_mtype_worker_fc_enable), UCS_CONFIG_TYPE_BOOL}, + + {"RNDV_MTYPE_WORKER_MAX_MEM", "4g", + "Maximum memory for concurrent mtype fragments per worker.\n" + "This value is translated to a fragment count based on RNDV_FRAG_SIZE\n" + "for each memory type (only applies when RNDV_MTYPE_WORKER_FC_ENABLE=y)", + ucs_offsetof(ucp_context_config_t, rndv_mtype_worker_max_mem), UCS_CONFIG_TYPE_MEMUNITS}, + {"FLUSH_WORKER_EPS", "y", "Enable flushing the worker by flushing its endpoints. Allows completing\n" "the flush operation in a bounded time even if there are new requests on\n" diff --git a/src/ucp/core/ucp_context.h b/src/ucp/core/ucp_context.h index c68a80ef263..b8274d8407a 100644 --- a/src/ucp/core/ucp_context.h +++ b/src/ucp/core/ucp_context.h @@ -98,6 +98,10 @@ typedef struct ucp_context_config { int rndv_shm_ppln_enable; /** Enable error handling for rndv pipeline protocol */ int rndv_errh_ppln_enable; + /** Enable flow control for rndv mtype fragments at worker level */ + int rndv_mtype_worker_fc_enable; + /** Maximum memory for concurrent rndv mtype fragments per worker (bytes) */ + size_t rndv_mtype_worker_max_mem; /** Threshold for using tag matching offload capabilities. Smaller buffers * will not be posted to the transport. */ size_t tm_thresh; diff --git a/src/ucp/core/ucp_request.h b/src/ucp/core/ucp_request.h index 8787b747133..d157a861597 100644 --- a/src/ucp/core/ucp_request.h +++ b/src/ucp/core/ucp_request.h @@ -318,7 +318,10 @@ struct ucp_request { /* Used by rndv/send/ppln and rndv/recv/ppln */ struct { /* Size to send in ack message */ - ssize_t ack_data_size; + ssize_t ack_data_size; + /* Element in worker-level pending queue + * for throttled ppln requests */ + ucs_queue_elem_t queue_elem; } ppln; /* Used by rndv/rkey_ptr */ diff --git a/src/ucp/core/ucp_worker.c b/src/ucp/core/ucp_worker.c index d0c31f151e6..02d3d8ad250 100644 --- a/src/ucp/core/ucp_worker.c +++ b/src/ucp/core/ucp_worker.c @@ -87,18 +87,20 @@ static ucs_stats_class_t ucp_worker_stats_class = { .num_counters = UCP_WORKER_STAT_LAST, .class_id = UCS_STATS_CLASS_ID_INVALID, .counter_names = { - [UCP_WORKER_STAT_TAG_RX_EAGER_MSG] = "tag_rx_eager_msg", - [UCP_WORKER_STAT_TAG_RX_EAGER_SYNC_MSG] = "tag_rx_sync_msg", - [UCP_WORKER_STAT_TAG_RX_EAGER_CHUNK_EXP] = "tag_rx_eager_chunk_exp", - [UCP_WORKER_STAT_TAG_RX_EAGER_CHUNK_UNEXP] = "tag_rx_eager_chunk_unexp", - [UCP_WORKER_STAT_RNDV_RX_EXP] = "rndv_rx_exp", - [UCP_WORKER_STAT_RNDV_RX_UNEXP] = "rndv_rx_unexp", - [UCP_WORKER_STAT_RNDV_PUT_ZCOPY] = "rndv_put_zcopy", - [UCP_WORKER_STAT_RNDV_PUT_MTYPE_ZCOPY] = "rndv_put_mtype_zcopy", - [UCP_WORKER_STAT_RNDV_GET_ZCOPY] = "rndv_get_zcopy", - [UCP_WORKER_STAT_RNDV_RTR] = "rndv_rtr", - [UCP_WORKER_STAT_RNDV_RTR_MTYPE] = "rndv_rtr_mtype", - [UCP_WORKER_STAT_RNDV_RKEY_PTR] = "rndv_rkey_ptr" + [UCP_WORKER_STAT_TAG_RX_EAGER_MSG] = "tag_rx_eager_msg", + [UCP_WORKER_STAT_TAG_RX_EAGER_SYNC_MSG] = "tag_rx_sync_msg", + [UCP_WORKER_STAT_TAG_RX_EAGER_CHUNK_EXP] = "tag_rx_eager_chunk_exp", + [UCP_WORKER_STAT_TAG_RX_EAGER_CHUNK_UNEXP] = "tag_rx_eager_chunk_unexp", + [UCP_WORKER_STAT_RNDV_RX_EXP] = "rndv_rx_exp", + [UCP_WORKER_STAT_RNDV_RX_UNEXP] = "rndv_rx_unexp", + [UCP_WORKER_STAT_RNDV_PUT_ZCOPY] = "rndv_put_zcopy", + [UCP_WORKER_STAT_RNDV_PUT_MTYPE_ZCOPY] = "rndv_put_mtype_zcopy", + [UCP_WORKER_STAT_RNDV_GET_ZCOPY] = "rndv_get_zcopy", + [UCP_WORKER_STAT_RNDV_RTR] = "rndv_rtr", + [UCP_WORKER_STAT_RNDV_RTR_MTYPE] = "rndv_rtr_mtype", + [UCP_WORKER_STAT_RNDV_RKEY_PTR] = "rndv_rkey_ptr", + [UCP_WORKER_STAT_RNDV_MTYPE_FC_THROTTLED] = "rndv_mtype_fc_throttled", + [UCP_WORKER_STAT_RNDV_MTYPE_FC_INCREMENTED] = "rndv_mtype_fc_incremented" } }; #endif @@ -2524,6 +2526,12 @@ ucs_status_t ucp_worker_create(ucp_context_h context, worker->counters.ep_closures = 0; worker->counters.ep_failures = 0; + /* Initialize RNDV mtype flow control */ + worker->rndv_mtype_fc.active_frags = 0; + ucs_queue_head_init(&worker->rndv_mtype_fc.put_pending_q); + ucs_queue_head_init(&worker->rndv_mtype_fc.get_pending_q); + ucs_queue_head_init(&worker->rndv_mtype_fc.rtr_pending_q); + /* Copy user flags, and mask-out unsupported flags for compatibility */ worker->flags = UCP_PARAM_VALUE(WORKER, params, flags, FLAGS, 0) & UCS_MASK(UCP_WORKER_INTERNAL_FLAGS_SHIFT); diff --git a/src/ucp/core/ucp_worker.h b/src/ucp/core/ucp_worker.h index 561f0ec8043..c78a7bdbcb0 100644 --- a/src/ucp/core/ucp_worker.h +++ b/src/ucp/core/ucp_worker.h @@ -154,6 +154,8 @@ enum { UCP_WORKER_STAT_RNDV_RTR, UCP_WORKER_STAT_RNDV_RTR_MTYPE, UCP_WORKER_STAT_RNDV_RKEY_PTR, + UCP_WORKER_STAT_RNDV_MTYPE_FC_THROTTLED, + UCP_WORKER_STAT_RNDV_MTYPE_FC_INCREMENTED, UCP_WORKER_STAT_LAST }; @@ -393,6 +395,15 @@ typedef struct ucp_worker { uint64_t ep_failures; } counters; + struct { + /* Worker-level mtype fragment flow control */ + size_t active_frags; /* Current active fragments */ + /* Separate pending queues by priority (PUT > GET > RTR) */ + ucs_queue_head_t put_pending_q; /* Throttled PUT (RNDV_SEND) requests */ + ucs_queue_head_t get_pending_q; /* Throttled GET requests */ + ucs_queue_head_t rtr_pending_q; /* Throttled RTR requests */ + } rndv_mtype_fc; + struct { /* Usage tracker handle */ ucs_usage_tracker_h handle; diff --git a/src/ucp/rndv/proto_rndv.c b/src/ucp/rndv/proto_rndv.c index dc516b23c80..c3d1261eea5 100644 --- a/src/ucp/rndv/proto_rndv.c +++ b/src/ucp/rndv/proto_rndv.c @@ -9,6 +9,7 @@ #endif #include "proto_rndv.inl" +#include "rndv_mtype.inl" #include #include @@ -650,6 +651,8 @@ ucp_proto_rndv_bulk_init(const ucp_proto_multi_init_params_t *init_params, rpriv->frag_mem_type = init_params->super.reg_mem_info.type; rpriv->frag_sys_dev = init_params->super.reg_mem_info.sys_dev; + rpriv->fc_max_frags = ucp_proto_rndv_mtype_fc_max_frags( + context, rpriv->frag_mem_type); if (rpriv->super.lane == UCP_NULL_LANE) { /* Add perf without ACK in case of pipeline */ diff --git a/src/ucp/rndv/proto_rndv.h b/src/ucp/rndv/proto_rndv.h index 8299b740491..ca18507b76b 100644 --- a/src/ucp/rndv/proto_rndv.h +++ b/src/ucp/rndv/proto_rndv.h @@ -75,6 +75,9 @@ typedef struct { ucs_memory_type_t frag_mem_type; ucs_sys_device_t frag_sys_dev; + /* max fragments for flow control */ + size_t fc_max_frags; + /* Multi-lane common part. Must be the last field, see @ref ucp_proto_multi_priv_t */ ucp_proto_multi_priv_t mpriv; diff --git a/src/ucp/rndv/rndv_get.c b/src/ucp/rndv/rndv_get.c index 9a8fb8c28f4..b75db7f9490 100644 --- a/src/ucp/rndv/rndv_get.c +++ b/src/ucp/rndv/rndv_get.c @@ -262,6 +262,8 @@ ucp_proto_rndv_get_mtype_unpack_completion(uct_completion_t *uct_comp) send.state.uct_comp); ucs_mpool_put_inline(req->send.rndv.mdesc); + ucp_proto_rndv_mtype_fc_decrement(req); + if (ucp_proto_rndv_request_is_ppln_frag(req)) { ucp_proto_rndv_ppln_recv_frag_complete(req, 1, 0); } else { @@ -287,11 +289,24 @@ ucp_proto_rndv_get_mtype_fetch_progress(uct_pending_req_t *uct_req) ucp_request_t *req = ucs_container_of(uct_req, ucp_request_t, send.uct); const ucp_proto_rndv_bulk_priv_t *rpriv; ucs_status_t status; + size_t max_frags; + ucs_queue_head_t *pending_q; /* coverity[tainted_data_downcast] */ rpriv = req->send.proto_config->priv; if (!(req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED)) { + /* GET priority: 80% of total fragments */ + max_frags = rpriv->fc_max_frags / 5 * 4; + + /* Check throttling limit. If no resource at the moment, queue the + * request in GET pending queue and return UCS_OK. */ + pending_q = &req->send.ep->worker->rndv_mtype_fc.get_pending_q; + if (ucp_proto_rndv_mtype_fc_check(req, max_frags, pending_q) == + UCS_ERR_NO_RESOURCE) { + return UCS_OK; + } + status = ucp_proto_rndv_mtype_request_init(req, rpriv->frag_mem_type, rpriv->frag_sys_dev); if (status != UCS_OK) { @@ -299,6 +314,7 @@ ucp_proto_rndv_get_mtype_fetch_progress(uct_pending_req_t *uct_req) return UCS_OK; } + ucp_proto_rndv_mtype_fc_increment(req); ucp_proto_rndv_get_common_request_init(req); ucp_proto_completion_init(&req->send.state.uct_comp, ucp_proto_rndv_get_mtype_fetch_completion); @@ -364,6 +380,8 @@ static ucs_status_t ucp_proto_rndv_get_mtype_reset(ucp_request_t *req) req->send.rndv.mdesc = NULL; req->flags &= ~UCP_REQUEST_FLAG_PROTO_INITIALIZED; + ucp_proto_rndv_mtype_fc_decrement(req); + if ((req->send.proto_stage != UCP_PROTO_RNDV_GET_STAGE_FETCH) && (req->send.proto_stage != UCP_PROTO_RNDV_GET_STAGE_ATS)) { ucp_proto_fatal_invalid_stage(req, "reset"); diff --git a/src/ucp/rndv/rndv_mtype.inl b/src/ucp/rndv/rndv_mtype.inl index 93a958be7ac..90c05c74916 100644 --- a/src/ucp/rndv/rndv_mtype.inl +++ b/src/ucp/rndv/rndv_mtype.inl @@ -169,6 +169,144 @@ static UCS_F_ALWAYS_INLINE ucs_status_t ucp_proto_rndv_mtype_copy( return status; } +/* Reschedule callback for throttled mtype requests */ +static unsigned ucp_proto_rndv_mtype_fc_reschedule_cb(void *arg) +{ + ucp_request_t *req = arg; + ucp_request_send(req); + return 1; +} + +/** + * Compute maximum number of fragments allowed based on configured max memory + * and fragment size for the given memory type. The result is rounded down to + * the allocation chunk granularity (rndv_num_frags). + * + * @param context The UCP context. + * @param frag_mem_type Memory type used for fragments. + * + * @return Maximum number of fragments that fit within the configured memory + * limit, aligned to allocation chunk size. + */ +static UCS_F_ALWAYS_INLINE size_t +ucp_proto_rndv_mtype_fc_max_frags(ucp_context_h context, + ucs_memory_type_t frag_mem_type) +{ + size_t max_mem = context->config.ext.rndv_mtype_worker_max_mem; + size_t frag_size = context->config.ext.rndv_frag_size[frag_mem_type]; + size_t frags_in_chunk = context->config.ext.rndv_num_frags[frag_mem_type]; + size_t max_frags; + + ucs_assert(frag_size > 0); + + /* Compute max fragments and round down to allocation chunk granularity */ + max_frags = max_mem / frag_size; + max_frags = (max_frags / frags_in_chunk) * frags_in_chunk; + + if (max_frags == 0) { + ucs_warn("RNDV_MTYPE_WORKER_MAX_MEM (%zu) is too low for %s " + "(frag_size=%zu, frags_per_alloc=%zu), using minimum %zu " + "frags", max_mem, ucs_memory_type_names[frag_mem_type], + frag_size, frags_in_chunk, frags_in_chunk); + return frags_in_chunk; + } + + return max_frags; +} + +/** + * Check if request should be throttled due to flow control limit. + * If throttled, the request is queued to the appropriate priority queue. + * + * @param req The request to check. + * @param max_frags The maximum number of fragments allowed. + * @param pending_q The queue to add the request to if it is throttled. + * + * @return UCS_OK if not throttled, UCS_ERR_NO_RESOURCE if throttled and queued. + */ +static UCS_F_ALWAYS_INLINE ucs_status_t +ucp_proto_rndv_mtype_fc_check(ucp_request_t *req, size_t max_frags, + ucs_queue_head_t *pending_q) +{ + ucp_worker_h worker = req->send.ep->worker; + ucp_context_h context = worker->context; + + if (!context->config.ext.rndv_mtype_worker_fc_enable) { + return UCS_OK; + } + + if (worker->rndv_mtype_fc.active_frags >= max_frags) { + ucs_trace_req("mtype_fc: fragments throttle limit reached (%zu/%zu)", + worker->rndv_mtype_fc.active_frags, max_frags); + UCS_STATS_UPDATE_COUNTER(worker->stats, + UCP_WORKER_STAT_RNDV_MTYPE_FC_THROTTLED, 1); + ucs_queue_push(pending_q, &req->send.rndv.ppln.queue_elem); + return UCS_ERR_NO_RESOURCE; + } + + return UCS_OK; +} + +/** + * Increment active_frags counter after successful mtype allocation. + */ +static UCS_F_ALWAYS_INLINE void +ucp_proto_rndv_mtype_fc_increment(ucp_request_t *req) +{ + ucp_worker_h worker = req->send.ep->worker; + + if (worker->context->config.ext.rndv_mtype_worker_fc_enable) { + worker->rndv_mtype_fc.active_frags++; + UCS_STATS_UPDATE_COUNTER(worker->stats, + UCP_WORKER_STAT_RNDV_MTYPE_FC_INCREMENTED, 1); + } +} + +/** + * Decrement active_frags counter and reschedule pending request. + * Dequeue priority: PUT > GET > RTR + * + * Priority rationale: + * PUT - Remote is blocked waiting for our data. Scheduling PUT unblocks remote + * as well. + * GET - Self-contained fetch operation. Completes without causing remote + * allocations, but scheduling it doesn't unblock another buffer. + * RTR - Scheduling RTR triggers a remote PUT allocation, increasing total + * memory pressure. + */ +static UCS_F_ALWAYS_INLINE void +ucp_proto_rndv_mtype_fc_decrement(ucp_request_t *req) +{ + ucp_worker_h worker = req->send.ep->worker; + ucp_context_h context = worker->context; + ucs_queue_elem_t *elem = NULL; + ucp_request_t *pending_req; + + if (!context->config.ext.rndv_mtype_worker_fc_enable) { + return; + } + + ucs_assert(worker->rndv_mtype_fc.active_frags > 0); + worker->rndv_mtype_fc.active_frags--; + + /* Dequeue with priority: PUT > GET > RTR */ + if (!ucs_queue_is_empty(&worker->rndv_mtype_fc.put_pending_q)) { + elem = ucs_queue_pull(&worker->rndv_mtype_fc.put_pending_q); + } else if (!ucs_queue_is_empty(&worker->rndv_mtype_fc.get_pending_q)) { + elem = ucs_queue_pull(&worker->rndv_mtype_fc.get_pending_q); + } else if (!ucs_queue_is_empty(&worker->rndv_mtype_fc.rtr_pending_q)) { + elem = ucs_queue_pull(&worker->rndv_mtype_fc.rtr_pending_q); + } + + if (elem != NULL) { + pending_req = ucs_container_of(elem, ucp_request_t, + send.rndv.ppln.queue_elem); + ucs_callbackq_add_oneshot(&worker->uct->progress_q, pending_req, + ucp_proto_rndv_mtype_fc_reschedule_cb, + pending_req); + } +} + static UCS_F_ALWAYS_INLINE ucs_status_t ucp_proto_rndv_mdesc_mtype_copy(ucp_request_t *req, uct_ep_put_zcopy_func_t copy_func, diff --git a/src/ucp/rndv/rndv_ppln.c b/src/ucp/rndv/rndv_ppln.c index 17fc9e8b16b..26282fae8a0 100644 --- a/src/ucp/rndv/rndv_ppln.c +++ b/src/ucp/rndv/rndv_ppln.c @@ -13,6 +13,7 @@ #include #include #include +#include enum { diff --git a/src/ucp/rndv/rndv_put.c b/src/ucp/rndv/rndv_put.c index 28948386ffb..b49b825fc6b 100644 --- a/src/ucp/rndv/rndv_put.c +++ b/src/ucp/rndv/rndv_put.c @@ -517,14 +517,24 @@ static UCS_F_ALWAYS_INLINE ucs_status_t ucp_proto_rndv_put_mtype_send_func( static ucs_status_t ucp_proto_rndv_put_mtype_copy_progress(uct_pending_req_t *uct_req) { - ucp_request_t *req = ucs_container_of(uct_req, - ucp_request_t, - send.uct); - const ucp_proto_rndv_put_priv_t *rpriv = req->send.proto_config->priv; + ucp_request_t *req = ucs_container_of(uct_req, ucp_request_t, send.uct); + const ucp_proto_rndv_put_priv_t *rpriv; ucs_status_t status; + size_t max_frags; + ucs_queue_head_t *pending_q; ucs_assert(!(req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED)); + rpriv = req->send.proto_config->priv; + + /* Check throttling limit. If no resource at the moment, queue the request + * in PUT pending queue and return UCS_OK. */ + max_frags = rpriv->bulk.fc_max_frags; + pending_q = &req->send.ep->worker->rndv_mtype_fc.put_pending_q; + if (ucp_proto_rndv_mtype_fc_check(req, max_frags, pending_q) == + UCS_ERR_NO_RESOURCE) { + return UCS_OK; + } status = ucp_proto_rndv_mtype_request_init(req, rpriv->bulk.frag_mem_type, rpriv->bulk.frag_sys_dev); if (status != UCS_OK) { @@ -532,6 +542,7 @@ ucp_proto_rndv_put_mtype_copy_progress(uct_pending_req_t *uct_req) return UCS_OK; } + ucp_proto_rndv_mtype_fc_increment(req); ucp_proto_rndv_put_common_request_init(req); req->flags |= UCP_REQUEST_FLAG_PROTO_INITIALIZED; ucp_proto_rndv_mdesc_mtype_copy(req, uct_ep_get_zcopy, @@ -563,6 +574,7 @@ static void ucp_proto_rndv_put_mtype_completion(uct_completion_t *uct_comp) ucp_trace_req(req, "rndv_put_mtype_completion"); ucs_mpool_put(req->send.rndv.mdesc); + ucp_proto_rndv_mtype_fc_decrement(req); ucp_proto_rndv_put_common_complete(req); } @@ -573,6 +585,7 @@ static void ucp_proto_rndv_put_mtype_frag_completion(uct_completion_t *uct_comp) ucp_trace_req(req, "rndv_put_mtype_frag_completion"); ucs_mpool_put(req->send.rndv.mdesc); + ucp_proto_rndv_mtype_fc_decrement(req); ucp_proto_rndv_ppln_send_frag_complete(req, 1); } diff --git a/src/ucp/rndv/rndv_rtr.c b/src/ucp/rndv/rndv_rtr.c index 11493a511c9..bbc5e1f100f 100644 --- a/src/ucp/rndv/rndv_rtr.c +++ b/src/ucp/rndv/rndv_rtr.c @@ -36,6 +36,8 @@ typedef struct { ucp_proto_rndv_rtr_priv_t super; ucs_memory_type_t frag_mem_type; ucs_sys_device_t frag_sys_dev; + /* max fragments for flow control */ + size_t fc_max_frags; } ucp_proto_rndv_rtr_mtype_priv_t; static UCS_F_ALWAYS_INLINE void @@ -286,6 +288,9 @@ ucp_proto_rndv_rtr_mtype_complete(ucp_request_t *req, int abort) if (!abort || (req->send.rndv.mdesc != NULL)) { ucs_mpool_put_inline(req->send.rndv.mdesc); } + + ucp_proto_rndv_mtype_fc_decrement(req); + if (ucp_proto_rndv_request_is_ppln_frag(req)) { ucp_proto_rndv_ppln_recv_frag_complete(req, 0, abort); } else { @@ -316,6 +321,7 @@ static ucs_status_t ucp_proto_rndv_rtr_mtype_reset(ucp_request_t *req) if (req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED) { ucs_mpool_put_inline(req->send.rndv.mdesc); req->send.rndv.mdesc = NULL; + ucp_proto_rndv_mtype_fc_decrement(req); } return ucp_proto_request_zcopy_id_reset(req); @@ -348,10 +354,24 @@ ucp_proto_rndv_rtr_mtype_data_received(ucp_request_t *req, int in_buffer) static ucs_status_t ucp_proto_rndv_rtr_mtype_progress(uct_pending_req_t *self) { ucp_request_t *req = ucs_container_of(self, ucp_request_t, send.uct); - const ucp_proto_rndv_rtr_mtype_priv_t *rpriv = req->send.proto_config->priv; + const ucp_proto_rndv_rtr_mtype_priv_t *rpriv; + size_t max_frags; + ucs_queue_head_t *pending_q; ucs_status_t status; if (!(req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED)) { + rpriv = req->send.proto_config->priv; + + /* RTR priority: 60% of total fragments */ + max_frags = rpriv->fc_max_frags / 5 * 3; + + /* Check throttling limit. If no resource at the moment, queue the + * request in RTR pending queue and return UCS_OK. */ + pending_q = &req->send.ep->worker->rndv_mtype_fc.rtr_pending_q; + if (ucp_proto_rndv_mtype_fc_check(req, max_frags, pending_q) == + UCS_ERR_NO_RESOURCE) { + return UCS_OK; + } status = ucp_proto_rndv_mtype_request_init(req, rpriv->frag_mem_type, rpriv->frag_sys_dev); if (status != UCS_OK) { @@ -359,6 +379,7 @@ static ucs_status_t ucp_proto_rndv_rtr_mtype_progress(uct_pending_req_t *self) return UCS_OK; } + ucp_proto_rndv_mtype_fc_increment(req); ucp_proto_rtr_common_request_init(req); req->flags |= UCP_REQUEST_FLAG_PROTO_INITIALIZED; } @@ -448,6 +469,8 @@ ucp_proto_rndv_rtr_mtype_probe(const ucp_proto_init_params_t *init_params) rpriv.super.data_received = ucp_proto_rndv_rtr_mtype_data_received; rpriv.frag_mem_type = frag_mem_type; rpriv.frag_sys_dev = params.super.reg_mem_info.sys_dev; + rpriv.fc_max_frags = ucp_proto_rndv_mtype_fc_max_frags( + context, frag_mem_type); ucp_proto_rndv_ctrl_probe(¶ms, &rpriv, sizeof(rpriv)); out_unpack_perf_destroy: diff --git a/src/ucs/stats/libstats.h b/src/ucs/stats/libstats.h index 3b72c5fb7eb..63e5df1a533 100644 --- a/src/ucs/stats/libstats.h +++ b/src/ucs/stats/libstats.h @@ -56,7 +56,9 @@ struct ucs_stats_class { const char *name; unsigned num_counters; unsigned class_id; - const char* counter_names[]; + const char* counter_names[14]; /* Need to maintain the number of + counters once it gets bigger + than 14 in one of the classes */ }; /* diff --git a/test/gtest/ucp/test_ucp_am.cc b/test/gtest/ucp/test_ucp_am.cc index 45dc9edd2dc..4852cc075f3 100644 --- a/test/gtest/ucp/test_ucp_am.cc +++ b/test/gtest/ucp/test_ucp_am.cc @@ -2187,4 +2187,146 @@ UCS_TEST_P(test_ucp_am_nbx_rndv_ppln, cuda_managed_buff, UCP_INSTANTIATE_TEST_CASE_GPU_AWARE(test_ucp_am_nbx_rndv_ppln); + +class test_ucp_am_nbx_rndv_mtype_fc : public test_ucp_am_nbx_rndv_ppln { +protected: + void check_stats_ge(entity &e, uint64_t cntr, uint64_t min_value) + { + auto stats_node = e.worker()->stats; + auto value = UCS_STATS_GET_COUNTER(stats_node, cntr); + + EXPECT_GE(value, min_value) << "counter " + << stats_node->cls->counter_names[cntr] + << " expected >= " << min_value + << " but got " << value; + } + + void check_active_frags_zero(entity &e) + { + EXPECT_EQ(0u, e.worker()->rndv_mtype_fc.active_frags) + << "active_frags should be 0 after completion"; + } + + void check_pending_queues_empty(entity &e) + { + ucp_worker_h worker = e.worker(); + EXPECT_TRUE(ucs_queue_is_empty(&worker->rndv_mtype_fc.put_pending_q)) + << "put_pending_q should be empty"; + EXPECT_TRUE(ucs_queue_is_empty(&worker->rndv_mtype_fc.get_pending_q)) + << "get_pending_q should be empty"; + EXPECT_TRUE(ucs_queue_is_empty(&worker->rndv_mtype_fc.rtr_pending_q)) + << "rtr_pending_q should be empty"; + } + + void send_message(size_t num_frags) + { + set_mem_type(UCS_MEMORY_TYPE_CUDA_MANAGED); + test_am_send_recv(get_rndv_frag_size(UCS_MEMORY_TYPE_CUDA) * num_frags); + } + + void verify_clean_fc_state() + { + check_active_frags_zero(sender()); + check_active_frags_zero(receiver()); + check_pending_queues_empty(sender()); + check_pending_queues_empty(receiver()); + } +}; + +UCS_TEST_P(test_ucp_am_nbx_rndv_mtype_fc, fc_enabled_under_cap, + "RNDV_MTYPE_WORKER_FC_ENABLE=y", + "RNDV_MTYPE_WORKER_MAX_MEM=1g", + "RNDV_FRAG_MEM_TYPE=cuda") +{ + if (!sender().is_rndv_put_ppln_supported()) { + UCS_TEST_SKIP_R("RNDV is not supported"); + } + + send_message(8); + + check_stats_ge(sender(), UCP_WORKER_STAT_RNDV_PUT_MTYPE_ZCOPY, 1); + check_stats_ge(receiver(), UCP_WORKER_STAT_RNDV_RTR_MTYPE, 1); + + /* Throttling should NOT have occurred */ + auto sender_throttled = UCS_STATS_GET_COUNTER(sender().worker()->stats, + UCP_WORKER_STAT_RNDV_MTYPE_FC_THROTTLED); + auto receiver_throttled = UCS_STATS_GET_COUNTER(receiver().worker()->stats, + UCP_WORKER_STAT_RNDV_MTYPE_FC_THROTTLED); + EXPECT_EQ(0u, sender_throttled) << "sender should not be throttled"; + EXPECT_EQ(0u, receiver_throttled) << "receiver should not be throttled"; + + /* FC should have been active (incremented) even though no throttling */ + auto sender_incremented = UCS_STATS_GET_COUNTER( + sender().worker()->stats, + UCP_WORKER_STAT_RNDV_MTYPE_FC_INCREMENTED); + auto receiver_incremented = UCS_STATS_GET_COUNTER( + receiver().worker()->stats, + UCP_WORKER_STAT_RNDV_MTYPE_FC_INCREMENTED); + EXPECT_GT(sender_incremented + receiver_incremented, 0u) + << "FC should be active and tracking fragments"; + + verify_clean_fc_state(); +} + +UCS_TEST_P(test_ucp_am_nbx_rndv_mtype_fc, fc_enabled_cap_reached, + "RNDV_MTYPE_WORKER_FC_ENABLE=y", + "RNDV_MTYPE_WORKER_MAX_MEM=600mb", + "RNDV_FRAG_MEM_TYPE=cuda") +{ + if (!sender().is_rndv_put_ppln_supported()) { + UCS_TEST_SKIP_R("RNDV is not supported"); + } + + send_message(200); + + /* Verify mtype protocols were used */ + check_stats_ge(sender(), UCP_WORKER_STAT_RNDV_PUT_MTYPE_ZCOPY, 1); + check_stats_ge(receiver(), UCP_WORKER_STAT_RNDV_RTR_MTYPE, 1); + + /* Throttling SHOULD have occurred (at least once on sender or receiver) */ + auto sender_throttled = UCS_STATS_GET_COUNTER(sender().worker()->stats, + UCP_WORKER_STAT_RNDV_MTYPE_FC_THROTTLED); + auto receiver_throttled = UCS_STATS_GET_COUNTER(receiver().worker()->stats, + UCP_WORKER_STAT_RNDV_MTYPE_FC_THROTTLED); + EXPECT_GT(sender_throttled + receiver_throttled, 0u) + << "throttling should have occurred with MAX_MEM=600mb"; + + verify_clean_fc_state(); +} + +UCS_TEST_P(test_ucp_am_nbx_rndv_mtype_fc, fc_disabled, + "RNDV_MTYPE_WORKER_FC_ENABLE=n", + "RNDV_FRAG_MEM_TYPE=cuda") +{ + if (!sender().is_rndv_put_ppln_supported()) { + UCS_TEST_SKIP_R("RNDV is not supported"); + } + + send_message(8); + + check_stats_ge(sender(), UCP_WORKER_STAT_RNDV_PUT_MTYPE_ZCOPY, 1); + check_stats_ge(receiver(), UCP_WORKER_STAT_RNDV_RTR_MTYPE, 1); + + /* No throttling should have occurred (FC is disabled) */ + auto sender_throttled = UCS_STATS_GET_COUNTER(sender().worker()->stats, + UCP_WORKER_STAT_RNDV_MTYPE_FC_THROTTLED); + auto receiver_throttled = UCS_STATS_GET_COUNTER(receiver().worker()->stats, + UCP_WORKER_STAT_RNDV_MTYPE_FC_THROTTLED); + EXPECT_EQ(0u, sender_throttled) << "FC disabled - no throttling expected"; + EXPECT_EQ(0u, receiver_throttled) << "FC disabled - no throttling expected"; + + /* With FC disabled, increment should NOT be called */ + auto sender_incremented = UCS_STATS_GET_COUNTER(sender().worker()->stats, + UCP_WORKER_STAT_RNDV_MTYPE_FC_INCREMENTED); + auto receiver_incremented = UCS_STATS_GET_COUNTER(receiver().worker()->stats, + UCP_WORKER_STAT_RNDV_MTYPE_FC_INCREMENTED); + EXPECT_EQ(0u, sender_incremented) << "FC disabled - no increment expected"; + EXPECT_EQ(0u, receiver_incremented) << "FC disabled - no increment expected"; + + verify_clean_fc_state(); +} + + +UCP_INSTANTIATE_TEST_CASE_GPU_AWARE(test_ucp_am_nbx_rndv_mtype_fc); + #endif