Skip to content

Commit f3a1ae1

Browse files
UCP/CORE: Allow specifying mlx devices without a port in UCX_NET_DEVICES (#11142)
1 parent 6b25652 commit f3a1ae1

File tree

2 files changed

+236
-3
lines changed

2 files changed

+236
-3
lines changed

src/ucp/core/ucp_context.c

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,19 +1031,51 @@ ucp_config_is_tl_name_present(const ucs_config_names_array_t *tl_array,
10311031
tl_cfg_mask));
10321032
}
10331033

1034+
static void
1035+
ucp_get_dev_basename(const char *dev_name, char *dev_basename_p, size_t max)
1036+
{
1037+
const char *delimiter = strchr(dev_name, ':');
1038+
size_t basename_len;
1039+
1040+
if (delimiter != NULL) {
1041+
basename_len = UCS_PTR_BYTE_DIFF(dev_name, delimiter);
1042+
ucs_assertv(basename_len < max, "basename_len=%zu max=%zu",
1043+
basename_len, max);
1044+
1045+
ucs_strncpy_zero(dev_basename_p, dev_name, basename_len + 1);
1046+
} else {
1047+
dev_basename_p[0] = '\0';
1048+
}
1049+
}
1050+
1051+
/* go over the device list from the user and check (against the available resources)
1052+
* which can be satisfied */
10341053
static int ucp_is_resource_in_device_list(const uct_tl_resource_desc_t *resource,
10351054
const ucs_config_names_array_t *devices,
10361055
uint64_t *dev_cfg_mask,
10371056
uct_device_type_t dev_type)
10381057
{
1058+
char dev_basename[UCT_DEVICE_NAME_MAX];
10391059
uint64_t mask, exclusive_mask;
10401060

1041-
/* go over the device list from the user and check (against the available resources)
1042-
* which can be satisfied */
10431061
ucs_assert_always(devices[dev_type].count <= 64); /* Using uint64_t bitmap */
1062+
1063+
/* search for the full device name */
10441064
mask = ucp_str_array_search((const char**)devices[dev_type].names,
10451065
devices[dev_type].count, resource->dev_name,
10461066
NULL);
1067+
1068+
/* for network devices, also search for the base name (before the delimiter) */
1069+
if (dev_type == UCT_DEVICE_TYPE_NET) {
1070+
ucp_get_dev_basename(resource->dev_name, dev_basename,
1071+
sizeof(dev_basename));
1072+
if (!ucs_string_is_empty(dev_basename)) {
1073+
mask |= ucp_str_array_search((const char**)devices[dev_type].names,
1074+
devices[dev_type].count, dev_basename,
1075+
NULL);
1076+
}
1077+
}
1078+
10471079
if (!mask) {
10481080
/* if the user's list is 'all', use all the available resources */
10491081
mask = ucp_str_array_search((const char**)devices[dev_type].names,
@@ -1213,7 +1245,6 @@ static int ucp_is_resource_enabled(const uct_tl_resource_desc_t *resource,
12131245
resource, config->devices, &dev_cfg_masks[resource->dev_type],
12141246
resource->dev_type);
12151247

1216-
12171248
/* Find the enabled UCTs */
12181249
*rsc_flags = 0;
12191250
tl_enabled = ucp_is_resource_in_transports_list(resource->tl_name,

test/gtest/ucp/test_ucp_context.cc

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
*/
66

77
#include "ucp_test.h"
8+
9+
#include <set>
10+
811
extern "C" {
12+
#include <ucp/core/ucp_context.h>
913
#include <ucs/sys/sys.h>
1014
}
1115

@@ -118,3 +122,201 @@ UCS_TEST_P(test_ucp_version, version_string) {
118122
}
119123

120124
UCP_INSTANTIATE_TEST_CASE_TLS(test_ucp_version, all, "all")
125+
126+
class test_ucp_net_devices_config : public ucp_test {
127+
public:
128+
static void get_test_variants(std::vector<ucp_test_variant> &variants) {
129+
add_variant(variants, UCP_FEATURE_TAG);
130+
}
131+
132+
protected:
133+
static const char DELIMITER = ':';
134+
135+
/* Iterate over all network devices and apply action to each */
136+
template<typename Action>
137+
static void for_each_net_device(const entity &e, Action action) {
138+
ucp_context_h ctx = e.ucph();
139+
for (ucp_rsc_index_t i = 0; i < ctx->num_tls; ++i) {
140+
const uct_tl_resource_desc_t *rsc = &ctx->tl_rscs[i].tl_rsc;
141+
if (rsc->dev_type == UCT_DEVICE_TYPE_NET) {
142+
action(rsc);
143+
}
144+
}
145+
}
146+
147+
/* Check if a specific device name exists in the set */
148+
static bool has_device(const std::set<std::string> &devices,
149+
const std::string &dev_name)
150+
{
151+
return devices.find(dev_name) != devices.end();
152+
}
153+
154+
155+
/* Get all network device names from the context */
156+
static std::set<std::string> get_net_device_names(const entity &e)
157+
{
158+
std::set<std::string> device_names;
159+
for_each_net_device(e, [&](const uct_tl_resource_desc_t *rsc) {
160+
device_names.insert(rsc->dev_name);
161+
});
162+
return device_names;
163+
}
164+
165+
/* Get all network device names from the context with delimiter */
166+
static std::set<std::string>
167+
get_net_device_names_with_delimiter(const entity &e)
168+
{
169+
std::set<std::string> device_names;
170+
for_each_net_device(e, [&](const uct_tl_resource_desc_t *rsc) {
171+
std::string dev_name(rsc->dev_name);
172+
size_t delimiter_pos = dev_name.find(DELIMITER);
173+
if (delimiter_pos != std::string::npos) {
174+
device_names.insert(dev_name);
175+
}
176+
});
177+
return device_names;
178+
}
179+
180+
static std::set<std::string>
181+
get_device_base_names(const std::set<std::string> &dev_names)
182+
{
183+
std::set<std::string> base_names;
184+
185+
for (const std::string &dev_name : dev_names) {
186+
size_t delimiter_pos = dev_name.find(DELIMITER);
187+
if (delimiter_pos != std::string::npos) {
188+
base_names.insert(dev_name.substr(0, delimiter_pos));
189+
} else {
190+
base_names.insert(dev_name);
191+
}
192+
}
193+
194+
return base_names;
195+
}
196+
197+
/* Join strings with a delimiter */
198+
static std::string
199+
join(const std::set<std::string> &strings, const std::string &delimiter)
200+
{
201+
std::string result;
202+
for (auto it = strings.begin(); it != strings.end(); ++it) {
203+
if (it != strings.begin()) {
204+
result += delimiter;
205+
}
206+
result += *it;
207+
}
208+
return result;
209+
}
210+
211+
/* Test that net device selection works correctly */
212+
void
213+
test_net_device_selection(const std::set<std::string> &test_net_devices,
214+
const std::set<std::string> &expected_net_devices)
215+
{
216+
std::string net_devices_config = join(test_net_devices, ",");
217+
modify_config("NET_DEVICES", net_devices_config.c_str());
218+
entity *e = create_entity();
219+
220+
std::set<std::string> selected_devices = get_net_device_names(*e);
221+
EXPECT_EQ(selected_devices.size(), expected_net_devices.size());
222+
223+
for (const std::string &net_device : expected_net_devices) {
224+
EXPECT_TRUE(has_device(selected_devices, net_device))
225+
<< "Device '" << net_device << "' should be selected when "
226+
<< "UCX_NET_DEVICES=" << net_devices_config;
227+
}
228+
}
229+
230+
/* Test that a device config triggers duplicate device warning */
231+
void test_duplicate_device_warning(const std::string &required_dev_name,
232+
const std::string &devices_config,
233+
const std::string &duplicate_dev_name)
234+
{
235+
entity *e = create_entity();
236+
237+
std::set<std::string> net_devices = get_net_device_names(*e);
238+
ASSERT_FALSE(net_devices.empty());
239+
240+
if (!has_device(net_devices, required_dev_name)) {
241+
UCS_TEST_SKIP_R(required_dev_name + " device not available");
242+
}
243+
244+
m_entities.clear();
245+
246+
modify_config("NET_DEVICES", devices_config.c_str());
247+
248+
size_t warn_count;
249+
{
250+
scoped_log_handler slh(hide_warns_logger);
251+
warn_count = m_warnings.size();
252+
create_entity();
253+
}
254+
255+
EXPECT_EQ(m_warnings.size() - warn_count, 1)
256+
<< "Expected exactly one warning";
257+
258+
/* Check that the warning about duplicate device was printed */
259+
std::string expected_warn = "device '" + duplicate_dev_name +
260+
"' is specified multiple times";
261+
EXPECT_NE(m_warnings[warn_count].find(expected_warn), std::string::npos)
262+
<< "Expected warning about duplicate device '"
263+
<< duplicate_dev_name << "' with config '" << devices_config
264+
<< "'";
265+
}
266+
};
267+
268+
/*
269+
* Test that when UCX_NET_DEVICES is set to a base name (e.g., "mlx5_0"),
270+
* devices with the same base name are selected (e.g., "mlx5_0:1").
271+
*/
272+
UCS_TEST_P(test_ucp_net_devices_config, base_name_selects_device)
273+
{
274+
entity *e = create_entity();
275+
276+
std::set<std::string> net_devices = get_net_device_names_with_delimiter(*e);
277+
if (net_devices.empty()) {
278+
UCS_TEST_SKIP_R("No network devices available with delimiter");
279+
}
280+
281+
m_entities.clear();
282+
283+
std::set<std::string> base_names = get_device_base_names(net_devices);
284+
test_net_device_selection(base_names, net_devices);
285+
}
286+
287+
/*
288+
* Test that explicit suffix specification works correctly.
289+
*/
290+
UCS_TEST_P(test_ucp_net_devices_config, explicit_suffix)
291+
{
292+
entity *e = create_entity();
293+
294+
std::set<std::string> net_devices = get_net_device_names_with_delimiter(*e);
295+
if (net_devices.empty()) {
296+
UCS_TEST_SKIP_R("No network devices available with delimiter");
297+
}
298+
299+
m_entities.clear();
300+
301+
test_net_device_selection(net_devices, net_devices);
302+
}
303+
304+
/*
305+
* Test that specifying a device multiple times produces a warning
306+
*/
307+
UCS_TEST_P(test_ucp_net_devices_config, duplicate_device_warning_simple)
308+
{
309+
test_duplicate_device_warning("mlx5_0:1", "mlx5_0:1,mlx5_0:1", "mlx5_0:1");
310+
}
311+
312+
UCS_TEST_P(test_ucp_net_devices_config, duplicate_device_warning_base_name)
313+
{
314+
test_duplicate_device_warning("mlx5_0:1", "mlx5_0:1,mlx5_0", "mlx5_0");
315+
}
316+
317+
UCS_TEST_P(test_ucp_net_devices_config, duplicate_device_warning_two_base_name)
318+
{
319+
test_duplicate_device_warning("mlx5_0:1", "mlx5_0,mlx5_0", "mlx5_0");
320+
}
321+
322+
UCP_INSTANTIATE_TEST_CASE_TLS(test_ucp_net_devices_config, all, "all")

0 commit comments

Comments
 (0)