diff --git a/src/ucp/core/ucp_context.c b/src/ucp/core/ucp_context.c index 190476efb88..a5b6cd1ddca 100644 --- a/src/ucp/core/ucp_context.c +++ b/src/ucp/core/ucp_context.c @@ -1031,19 +1031,51 @@ ucp_config_is_tl_name_present(const ucs_config_names_array_t *tl_array, tl_cfg_mask)); } +static void +ucp_get_dev_basename(const char *dev_name, char *dev_basename_p, size_t max) +{ + const char *delimiter = strchr(dev_name, ':'); + size_t basename_len; + + if (delimiter != NULL) { + basename_len = UCS_PTR_BYTE_DIFF(dev_name, delimiter); + ucs_assertv(basename_len < max, "basename_len=%zu max=%zu", + basename_len, max); + + ucs_strncpy_zero(dev_basename_p, dev_name, basename_len + 1); + } else { + dev_basename_p[0] = '\0'; + } +} + +/* go over the device list from the user and check (against the available resources) + * which can be satisfied */ static int ucp_is_resource_in_device_list(const uct_tl_resource_desc_t *resource, const ucs_config_names_array_t *devices, uint64_t *dev_cfg_mask, uct_device_type_t dev_type) { + char dev_basename[UCT_DEVICE_NAME_MAX]; uint64_t mask, exclusive_mask; - /* go over the device list from the user and check (against the available resources) - * which can be satisfied */ ucs_assert_always(devices[dev_type].count <= 64); /* Using uint64_t bitmap */ + + /* search for the full device name */ mask = ucp_str_array_search((const char**)devices[dev_type].names, devices[dev_type].count, resource->dev_name, NULL); + + /* for network devices, also search for the base name (before the delimiter) */ + if (dev_type == UCT_DEVICE_TYPE_NET) { + ucp_get_dev_basename(resource->dev_name, dev_basename, + sizeof(dev_basename)); + if (!ucs_string_is_empty(dev_basename)) { + mask |= ucp_str_array_search((const char**)devices[dev_type].names, + devices[dev_type].count, dev_basename, + NULL); + } + } + if (!mask) { /* if the user's list is 'all', use all the available resources */ mask = ucp_str_array_search((const char**)devices[dev_type].names, @@ -1213,7 +1245,6 @@ static int ucp_is_resource_enabled(const uct_tl_resource_desc_t *resource, resource, config->devices, &dev_cfg_masks[resource->dev_type], resource->dev_type); - /* Find the enabled UCTs */ *rsc_flags = 0; tl_enabled = ucp_is_resource_in_transports_list(resource->tl_name, diff --git a/test/gtest/ucp/test_ucp_context.cc b/test/gtest/ucp/test_ucp_context.cc index a415ef806a2..18bb2faf4bc 100644 --- a/test/gtest/ucp/test_ucp_context.cc +++ b/test/gtest/ucp/test_ucp_context.cc @@ -5,7 +5,11 @@ */ #include "ucp_test.h" + +#include + extern "C" { +#include #include } @@ -118,3 +122,201 @@ UCS_TEST_P(test_ucp_version, version_string) { } UCP_INSTANTIATE_TEST_CASE_TLS(test_ucp_version, all, "all") + +class test_ucp_net_devices_config : public ucp_test { +public: + static void get_test_variants(std::vector &variants) { + add_variant(variants, UCP_FEATURE_TAG); + } + +protected: + static const char DELIMITER = ':'; + + /* Iterate over all network devices and apply action to each */ + template + static void for_each_net_device(const entity &e, Action action) { + ucp_context_h ctx = e.ucph(); + for (ucp_rsc_index_t i = 0; i < ctx->num_tls; ++i) { + const uct_tl_resource_desc_t *rsc = &ctx->tl_rscs[i].tl_rsc; + if (rsc->dev_type == UCT_DEVICE_TYPE_NET) { + action(rsc); + } + } + } + + /* Check if a specific device name exists in the set */ + static bool has_device(const std::set &devices, + const std::string &dev_name) + { + return devices.find(dev_name) != devices.end(); + } + + + /* Get all network device names from the context */ + static std::set get_net_device_names(const entity &e) + { + std::set device_names; + for_each_net_device(e, [&](const uct_tl_resource_desc_t *rsc) { + device_names.insert(rsc->dev_name); + }); + return device_names; + } + + /* Get all network device names from the context with delimiter */ + static std::set + get_net_device_names_with_delimiter(const entity &e) + { + std::set device_names; + for_each_net_device(e, [&](const uct_tl_resource_desc_t *rsc) { + std::string dev_name(rsc->dev_name); + size_t delimiter_pos = dev_name.find(DELIMITER); + if (delimiter_pos != std::string::npos) { + device_names.insert(dev_name); + } + }); + return device_names; + } + + static std::set + get_device_base_names(const std::set &dev_names) + { + std::set base_names; + + for (const std::string &dev_name : dev_names) { + size_t delimiter_pos = dev_name.find(DELIMITER); + if (delimiter_pos != std::string::npos) { + base_names.insert(dev_name.substr(0, delimiter_pos)); + } else { + base_names.insert(dev_name); + } + } + + return base_names; + } + + /* Join strings with a delimiter */ + static std::string + join(const std::set &strings, const std::string &delimiter) + { + std::string result; + for (auto it = strings.begin(); it != strings.end(); ++it) { + if (it != strings.begin()) { + result += delimiter; + } + result += *it; + } + return result; + } + + /* Test that net device selection works correctly */ + void + test_net_device_selection(const std::set &test_net_devices, + const std::set &expected_net_devices) + { + std::string net_devices_config = join(test_net_devices, ","); + modify_config("NET_DEVICES", net_devices_config.c_str()); + entity *e = create_entity(); + + std::set selected_devices = get_net_device_names(*e); + EXPECT_EQ(selected_devices.size(), expected_net_devices.size()); + + for (const std::string &net_device : expected_net_devices) { + EXPECT_TRUE(has_device(selected_devices, net_device)) + << "Device '" << net_device << "' should be selected when " + << "UCX_NET_DEVICES=" << net_devices_config; + } + } + + /* Test that a device config triggers duplicate device warning */ + void test_duplicate_device_warning(const std::string &required_dev_name, + const std::string &devices_config, + const std::string &duplicate_dev_name) + { + entity *e = create_entity(); + + std::set net_devices = get_net_device_names(*e); + ASSERT_FALSE(net_devices.empty()); + + if (!has_device(net_devices, required_dev_name)) { + UCS_TEST_SKIP_R(required_dev_name + " device not available"); + } + + m_entities.clear(); + + modify_config("NET_DEVICES", devices_config.c_str()); + + size_t warn_count; + { + scoped_log_handler slh(hide_warns_logger); + warn_count = m_warnings.size(); + create_entity(); + } + + EXPECT_EQ(m_warnings.size() - warn_count, 1) + << "Expected exactly one warning"; + + /* Check that the warning about duplicate device was printed */ + std::string expected_warn = "device '" + duplicate_dev_name + + "' is specified multiple times"; + EXPECT_NE(m_warnings[warn_count].find(expected_warn), std::string::npos) + << "Expected warning about duplicate device '" + << duplicate_dev_name << "' with config '" << devices_config + << "'"; + } +}; + +/* + * Test that when UCX_NET_DEVICES is set to a base name (e.g., "mlx5_0"), + * devices with the same base name are selected (e.g., "mlx5_0:1"). + */ +UCS_TEST_P(test_ucp_net_devices_config, base_name_selects_device) +{ + entity *e = create_entity(); + + std::set net_devices = get_net_device_names_with_delimiter(*e); + if (net_devices.empty()) { + UCS_TEST_SKIP_R("No network devices available with delimiter"); + } + + m_entities.clear(); + + std::set base_names = get_device_base_names(net_devices); + test_net_device_selection(base_names, net_devices); +} + +/* + * Test that explicit suffix specification works correctly. + */ +UCS_TEST_P(test_ucp_net_devices_config, explicit_suffix) +{ + entity *e = create_entity(); + + std::set net_devices = get_net_device_names_with_delimiter(*e); + if (net_devices.empty()) { + UCS_TEST_SKIP_R("No network devices available with delimiter"); + } + + m_entities.clear(); + + test_net_device_selection(net_devices, net_devices); +} + +/* + * Test that specifying a device multiple times produces a warning + */ +UCS_TEST_P(test_ucp_net_devices_config, duplicate_device_warning_simple) +{ + test_duplicate_device_warning("mlx5_0:1", "mlx5_0:1,mlx5_0:1", "mlx5_0:1"); +} + +UCS_TEST_P(test_ucp_net_devices_config, duplicate_device_warning_base_name) +{ + test_duplicate_device_warning("mlx5_0:1", "mlx5_0:1,mlx5_0", "mlx5_0"); +} + +UCS_TEST_P(test_ucp_net_devices_config, duplicate_device_warning_two_base_name) +{ + test_duplicate_device_warning("mlx5_0:1", "mlx5_0,mlx5_0", "mlx5_0"); +} + +UCP_INSTANTIATE_TEST_CASE_TLS(test_ucp_net_devices_config, all, "all")