diff --git a/src/ucp/core/ucp_am.c b/src/ucp/core/ucp_am.c index 46d739b0ffa..c1f70eedf51 100644 --- a/src/ucp/core/ucp_am.c +++ b/src/ucp/core/ucp_am.c @@ -915,7 +915,7 @@ ucp_am_try_send_short(ucp_ep_h ep, uint16_t id, uint32_t flags, } if (ucp_proto_is_inline(ep, max_eager_short, - header_length + length, param)) { + header_length + length, buffer, param)) { return ucp_am_send_short(ep, id, flags, header, header_length, buffer, length, flags & UCP_AM_SEND_FLAG_REPLY); } diff --git a/src/ucp/proto/proto_am.inl b/src/ucp/proto/proto_am.inl index 5315cfceb5a..7b12c2bfbfd 100644 --- a/src/ucp/proto/proto_am.inl +++ b/src/ucp/proto/proto_am.inl @@ -649,13 +649,34 @@ ucp_proto_get_short_max(const ucp_request_t *req, static UCS_F_ALWAYS_INLINE int ucp_proto_is_inline(ucp_ep_h ep, const ucp_memtype_thresh_t *max_eager_short, - ssize_t length, const ucp_request_param_t *param) + ssize_t length, const void *buffer, + const ucp_request_param_t *param) { - return (ucs_likely(length <= max_eager_short->memtype_off) || - ((length <= max_eager_short->memtype_on) && - (ucs_memtype_cache_is_empty() || - ((param->op_attr_mask & UCP_OP_ATTR_FIELD_MEMORY_TYPE) && - (param->memory_type == UCS_MEMORY_TYPE_HOST))))); + ucs_memory_info_t mem_info; + ucs_status_t status; + + if (ucs_likely(length <= max_eager_short->memtype_off)) { + return 1; + } + + if (length > max_eager_short->memtype_on) { + return 0; + } + + if (ucs_memtype_cache_is_empty()) { + return 1; + } + + if ((param->op_attr_mask & UCP_OP_ATTR_FIELD_MEMORY_TYPE) && + (param->memory_type == UCS_MEMORY_TYPE_HOST)) { + return 1; + } + + /* Look up the buffer in the memory type cache to determine if it is host + * memory. If the address is not found in the cache, it is host memory. */ + status = ucs_memtype_cache_lookup(buffer, length, &mem_info); + return (status == UCS_ERR_NO_ELEM) || + ((status == UCS_OK) && (mem_info.type == UCS_MEMORY_TYPE_HOST)); } static UCS_F_ALWAYS_INLINE ucp_request_t* diff --git a/src/ucp/tag/tag_send.c b/src/ucp/tag/tag_send.c index 5237f2d0ce2..aaaab13d101 100644 --- a/src/ucp/tag/tag_send.c +++ b/src/ucp/tag/tag_send.c @@ -154,14 +154,14 @@ ucp_tag_send_inline(ucp_ep_h ep, const void *buffer, size_t length, ucs_status_t status; if (ucp_proto_is_inline(ep, &ucp_ep_config(ep)->tag.max_eager_short, - length, param)) { + length, buffer, param)) { UCS_STATIC_ASSERT(sizeof(ucp_tag_t) == sizeof(ucp_eager_hdr_t)); UCS_STATIC_ASSERT(sizeof(ucp_tag_t) == sizeof(uint64_t)); status = uct_ep_am_short(ucp_ep_get_am_uct_ep(ep), UCP_AM_ID_EAGER_ONLY, tag, buffer, length); } else if (ucp_proto_is_inline(ep, &ucp_ep_config(ep)->tag.offload.max_eager_short, - length, param)) { + length, buffer, param)) { UCS_STATIC_ASSERT(sizeof(ucp_tag_t) == sizeof(uct_tag_t)); status = uct_ep_tag_eager_short(ucp_ep_get_tag_uct_ep(ep), tag, buffer, length);