Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions src/ucp/core/ucp_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -1031,19 +1031,51 @@ ucp_config_is_tl_name_present(const ucs_config_names_array_t *tl_array,
tl_cfg_mask));
}

static void
ucp_get_dev_basename(const char *dev_name, char *dev_basename_p, size_t max)
{
const char *delimiter = strchr(dev_name, ':');
size_t basename_len;

if (delimiter != NULL) {
basename_len = UCS_PTR_BYTE_DIFF(dev_name, delimiter);
ucs_assertv(basename_len < max, "basename_len=%zu max=%zu",
basename_len, max);

ucs_strncpy_zero(dev_basename_p, dev_name, basename_len + 1);
} else {
dev_basename_p[0] = '\0';
}
}

/* go over the device list from the user and check (against the available resources)
* which can be satisfied */
static int ucp_is_resource_in_device_list(const uct_tl_resource_desc_t *resource,
const ucs_config_names_array_t *devices,
uint64_t *dev_cfg_mask,
uct_device_type_t dev_type)
{
char dev_basename[UCT_DEVICE_NAME_MAX];
uint64_t mask, exclusive_mask;

/* go over the device list from the user and check (against the available resources)
* which can be satisfied */
ucs_assert_always(devices[dev_type].count <= 64); /* Using uint64_t bitmap */

/* search for the full device name */
mask = ucp_str_array_search((const char**)devices[dev_type].names,
devices[dev_type].count, resource->dev_name,
NULL);

/* for network devices, also search for the base name (before the delimiter) */
if (dev_type == UCT_DEVICE_TYPE_NET) {
ucp_get_dev_basename(resource->dev_name, dev_basename,
sizeof(dev_basename));
if (!ucs_string_is_empty(dev_basename)) {
mask |= ucp_str_array_search((const char**)devices[dev_type].names,
devices[dev_type].count, dev_basename,
NULL);
}
}

if (!mask) {
/* if the user's list is 'all', use all the available resources */
mask = ucp_str_array_search((const char**)devices[dev_type].names,
Expand Down Expand Up @@ -1213,7 +1245,6 @@ static int ucp_is_resource_enabled(const uct_tl_resource_desc_t *resource,
resource, config->devices, &dev_cfg_masks[resource->dev_type],
resource->dev_type);


/* Find the enabled UCTs */
*rsc_flags = 0;
tl_enabled = ucp_is_resource_in_transports_list(resource->tl_name,
Expand Down
202 changes: 202 additions & 0 deletions test/gtest/ucp/test_ucp_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
*/

#include "ucp_test.h"

#include <set>

extern "C" {
#include <ucp/core/ucp_context.h>
#include <ucs/sys/sys.h>
}

Expand Down Expand Up @@ -118,3 +122,201 @@ UCS_TEST_P(test_ucp_version, version_string) {
}

UCP_INSTANTIATE_TEST_CASE_TLS(test_ucp_version, all, "all")

class test_ucp_net_devices_config : public ucp_test {
public:
static void get_test_variants(std::vector<ucp_test_variant> &variants) {
add_variant(variants, UCP_FEATURE_TAG);
}

protected:
static const char DELIMITER = ':';

/* Iterate over all network devices and apply action to each */
template<typename Action>
static void for_each_net_device(const entity &e, Action action) {
ucp_context_h ctx = e.ucph();
for (ucp_rsc_index_t i = 0; i < ctx->num_tls; ++i) {
const uct_tl_resource_desc_t *rsc = &ctx->tl_rscs[i].tl_rsc;
if (rsc->dev_type == UCT_DEVICE_TYPE_NET) {
action(rsc);
}
}
}

/* Check if a specific device name exists in the set */
static bool has_device(const std::set<std::string> &devices,
const std::string &dev_name)
{
return devices.find(dev_name) != devices.end();
}


/* Get all network device names from the context */
static std::set<std::string> get_net_device_names(const entity &e)
{
std::set<std::string> device_names;
for_each_net_device(e, [&](const uct_tl_resource_desc_t *rsc) {
device_names.insert(rsc->dev_name);
});
return device_names;
}

/* Get all network device names from the context with delimiter */
static std::set<std::string>
get_net_device_names_with_delimiter(const entity &e)
{
std::set<std::string> device_names;
for_each_net_device(e, [&](const uct_tl_resource_desc_t *rsc) {
std::string dev_name(rsc->dev_name);
size_t delimiter_pos = dev_name.find(DELIMITER);
if (delimiter_pos != std::string::npos) {
device_names.insert(dev_name);
}
});
return device_names;
}

static std::set<std::string>
get_device_base_names(const std::set<std::string> &dev_names)
{
std::set<std::string> base_names;

for (const std::string &dev_name : dev_names) {
size_t delimiter_pos = dev_name.find(DELIMITER);
if (delimiter_pos != std::string::npos) {
base_names.insert(dev_name.substr(0, delimiter_pos));
} else {
base_names.insert(dev_name);
}
}

return base_names;
}

/* Join strings with a delimiter */
static std::string
join(const std::set<std::string> &strings, const std::string &delimiter)
{
std::string result;
for (auto it = strings.begin(); it != strings.end(); ++it) {
if (it != strings.begin()) {
result += delimiter;
}
result += *it;
}
return result;
}

/* Test that net device selection works correctly */
void
test_net_device_selection(const std::set<std::string> &test_net_devices,
const std::set<std::string> &expected_net_devices)
{
std::string net_devices_config = join(test_net_devices, ",");
modify_config("NET_DEVICES", net_devices_config.c_str());
entity *e = create_entity();

std::set<std::string> selected_devices = get_net_device_names(*e);
EXPECT_EQ(selected_devices.size(), expected_net_devices.size());

for (const std::string &net_device : expected_net_devices) {
EXPECT_TRUE(has_device(selected_devices, net_device))
<< "Device '" << net_device << "' should be selected when "
<< "UCX_NET_DEVICES=" << net_devices_config;
}
}

/* Test that a device config triggers duplicate device warning */
void test_duplicate_device_warning(const std::string &required_dev_name,
const std::string &devices_config,
const std::string &duplicate_dev_name)
{
entity *e = create_entity();

std::set<std::string> net_devices = get_net_device_names(*e);
ASSERT_FALSE(net_devices.empty());

if (!has_device(net_devices, required_dev_name)) {
UCS_TEST_SKIP_R(required_dev_name + " device not available");
}

m_entities.clear();

modify_config("NET_DEVICES", devices_config.c_str());

size_t warn_count;
{
scoped_log_handler slh(hide_warns_logger);
warn_count = m_warnings.size();
create_entity();
}

EXPECT_EQ(m_warnings.size() - warn_count, 1)
<< "Expected exactly one warning";

/* Check that the warning about duplicate device was printed */
std::string expected_warn = "device '" + duplicate_dev_name +
"' is specified multiple times";
EXPECT_NE(m_warnings[warn_count].find(expected_warn), std::string::npos)
<< "Expected warning about duplicate device '"
<< duplicate_dev_name << "' with config '" << devices_config
<< "'";
}
Comment on lines 255 to 265
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about to add a wrapper for hide_warns_logger and handle all these checks just in time?

Copy link
Contributor Author

@guy-ealey-morag guy-ealey-morag Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to to match the warning string with the device name duplicate_dev_name.
It's not trivial to pass it to a warpper and the solutions I found make it too complicated for what it is.

};

/*
* Test that when UCX_NET_DEVICES is set to a base name (e.g., "mlx5_0"),
* devices with the same base name are selected (e.g., "mlx5_0:1").
*/
UCS_TEST_P(test_ucp_net_devices_config, base_name_selects_device)
{
entity *e = create_entity();

std::set<std::string> net_devices = get_net_device_names_with_delimiter(*e);
if (net_devices.empty()) {
UCS_TEST_SKIP_R("No network devices available with delimiter");
}

m_entities.clear();

std::set<std::string> base_names = get_device_base_names(net_devices);
test_net_device_selection(base_names, net_devices);
}

/*
* Test that explicit suffix specification works correctly.
*/
UCS_TEST_P(test_ucp_net_devices_config, explicit_suffix)
{
entity *e = create_entity();

std::set<std::string> net_devices = get_net_device_names_with_delimiter(*e);
if (net_devices.empty()) {
UCS_TEST_SKIP_R("No network devices available with delimiter");
}

m_entities.clear();

test_net_device_selection(net_devices, net_devices);
}

/*
* Test that specifying a device multiple times produces a warning
*/
UCS_TEST_P(test_ucp_net_devices_config, duplicate_device_warning_simple)
{
test_duplicate_device_warning("mlx5_0:1", "mlx5_0:1,mlx5_0:1", "mlx5_0:1");
}

UCS_TEST_P(test_ucp_net_devices_config, duplicate_device_warning_base_name)
{
test_duplicate_device_warning("mlx5_0:1", "mlx5_0:1,mlx5_0", "mlx5_0");
}

UCS_TEST_P(test_ucp_net_devices_config, duplicate_device_warning_two_base_name)
{
test_duplicate_device_warning("mlx5_0:1", "mlx5_0,mlx5_0", "mlx5_0");
}

UCP_INSTANTIATE_TEST_CASE_TLS(test_ucp_net_devices_config, all, "all")