Skip to content

Commit 39afdac

Browse files
rongoutrivialfis
andauthored
Better error message when world size and rank are set as strings (dmlc#8316)
Co-authored-by: jiamingy <[email protected]>
1 parent 210915c commit 39afdac

File tree

6 files changed

+79
-31
lines changed

6 files changed

+79
-31
lines changed

plugin/federated/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ find_package(Threads)
77
add_library(federated_proto federated.proto)
88
target_link_libraries(federated_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
99
target_include_directories(federated_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
10-
set_property(TARGET federated_proto PROPERTY POSITION_INDEPENDENT_CODE ON)
10+
xgboost_target_properties(federated_proto)
1111

1212
get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION)
1313
protobuf_generate(TARGET federated_proto LANGUAGE cpp)

plugin/federated/federated_communicator.h

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#pragma once
55
#include <xgboost/json.h>
66

7+
#include "../../src/c_api/c_api_utils.h"
78
#include "../../src/collective/communicator.h"
89
#include "../../src/common/io.h"
910
#include "federated_client.h"
@@ -89,31 +90,14 @@ class FederatedCommunicator : public Communicator {
8990
client_cert = value;
9091
}
9192

92-
// Runtime configuration overrides.
93-
auto const &j_server_address = config["federated_server_address"];
94-
if (IsA<String const>(j_server_address)) {
95-
server_address = get<String const>(j_server_address);
96-
}
97-
auto const &j_world_size = config["federated_world_size"];
98-
if (IsA<Integer const>(j_world_size)) {
99-
world_size = static_cast<int>(get<Integer const>(j_world_size));
100-
}
101-
auto const &j_rank = config["federated_rank"];
102-
if (IsA<Integer const>(j_rank)) {
103-
rank = static_cast<int>(get<Integer const>(j_rank));
104-
}
105-
auto const &j_server_cert = config["federated_server_cert"];
106-
if (IsA<String const>(j_server_cert)) {
107-
server_cert = get<String const>(j_server_cert);
108-
}
109-
auto const &j_client_key = config["federated_client_key"];
110-
if (IsA<String const>(j_client_key)) {
111-
client_key = get<String const>(j_client_key);
112-
}
113-
auto const &j_client_cert = config["federated_client_cert"];
114-
if (IsA<String const>(j_client_cert)) {
115-
client_cert = get<String const>(j_client_cert);
116-
}
93+
// Runtime configuration overrides, optional as users can specify them as env vars.
94+
server_address = OptionalArg<String>(config, "federated_server_address", server_address);
95+
world_size =
96+
OptionalArg<Integer>(config, "federated_world_size", static_cast<Integer::Int>(world_size));
97+
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));
98+
server_cert = OptionalArg<String>(config, "federated_server_cert", server_cert);
99+
client_key = OptionalArg<String>(config, "federated_client_key", client_key);
100+
client_cert = OptionalArg<String>(config, "federated_client_cert", client_cert);
117101

118102
if (server_address.empty()) {
119103
LOG(FATAL) << "Federated server address must be set.";

src/c_api/c_api_utils.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,21 +248,32 @@ inline void GenerateFeatureMap(Learner const *learner,
248248

249249
void XGBBuildInfoDevice(Json* p_info);
250250

251+
template <typename JT>
252+
void TypeCheck(Json const &value, StringView name) {
253+
using T = std::remove_const_t<JT> const;
254+
if (!IsA<T>(value)) {
255+
LOG(FATAL) << "Incorrect type for: `" << name << "`, expecting: `" << T{}.TypeStr()
256+
<< "`, got: `" << value.GetValue().TypeStr() << "`.";
257+
}
258+
}
259+
251260
template <typename JT>
252261
auto const &RequiredArg(Json const &in, std::string const &key, StringView func) {
253262
auto const &obj = get<Object const>(in);
254263
auto it = obj.find(key);
255264
if (it == obj.cend() || IsA<Null>(it->second)) {
256-
LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`";
265+
LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`.";
257266
}
267+
TypeCheck<JT>(it->second, StringView{key});
258268
return get<std::remove_const_t<JT> const>(it->second);
259269
}
260270

261271
template <typename JT, typename T>
262272
auto const &OptionalArg(Json const &in, std::string const &key, T const &dft) {
263273
auto const &obj = get<Object const>(in);
264274
auto it = obj.find(key);
265-
if (it != obj.cend()) {
275+
if (it != obj.cend() && !IsA<Null>(it->second)) {
276+
TypeCheck<JT>(it->second, StringView{key});
266277
return get<std::remove_const_t<JT> const>(it->second);
267278
}
268279
return dft;

src/collective/noop_communicator.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ class NoOpCommunicator : public Communicator {
1717
NoOpCommunicator() : Communicator(1, 0) {}
1818
bool IsDistributed() const override { return false; }
1919
bool IsFederated() const override { return false; }
20-
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
21-
Operation op) override {}
22-
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {}
20+
void AllReduce(void *, std::size_t, DataType, Operation) override {}
21+
void Broadcast(void *, std::size_t, int) override {}
2322
std::string GetProcessorName() override { return ""; }
2423
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
2524

tests/cpp/c_api/test_c_api.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,4 +324,36 @@ TEST(CAPI, NullPtr) {
324324
ASSERT_NE(pos, std::string::npos);
325325
XGBAPISetLastError("");
326326
}
327+
328+
TEST(CAPI, JArgs) {
329+
{
330+
Json args{Object{}};
331+
args["key"] = String{"value"};
332+
args["null"] = Null{};
333+
auto value = OptionalArg<String>(args, "key", std::string{"foo"});
334+
ASSERT_EQ(value, "value");
335+
value = OptionalArg<String const>(args, "key", std::string{"foo"});
336+
ASSERT_EQ(value, "value");
337+
338+
ASSERT_THROW({ OptionalArg<Number>(args, "key", 0.0f); }, dmlc::Error);
339+
value = OptionalArg<String const>(args, "bar", std::string{"foo"});
340+
ASSERT_EQ(value, "foo");
341+
value = OptionalArg<String const>(args, "null", std::string{"foo"});
342+
ASSERT_EQ(value, "foo");
343+
}
344+
345+
{
346+
Json args{Object{}};
347+
args["key"] = String{"value"};
348+
args["null"] = Null{};
349+
auto value = RequiredArg<String>(args, "key", __func__);
350+
ASSERT_EQ(value, "value");
351+
value = RequiredArg<String const>(args, "key", __func__);
352+
ASSERT_EQ(value, "value");
353+
354+
ASSERT_THROW({ RequiredArg<Integer>(args, "key", __func__); }, dmlc::Error);
355+
ASSERT_THROW({ RequiredArg<String const>(args, "foo", __func__); }, dmlc::Error);
356+
ASSERT_THROW({ RequiredArg<String>(args, "null", __func__); }, dmlc::Error);
357+
}
358+
}
327359
} // namespace xgboost

tests/cpp/plugin/test_federated_communicator.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,28 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooBig) {
8585
EXPECT_THROW(construct(), dmlc::Error);
8686
}
8787

88+
TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) {
89+
auto construct = []() {
90+
Json config{JsonObject()};
91+
config["federated_server_address"] = kServerAddress;
92+
config["federated_world_size"] = std::string("1");
93+
config["federated_rank"] = Integer(0);
94+
auto *comm = FederatedCommunicator::Create(config);
95+
};
96+
EXPECT_THROW(construct(), dmlc::Error);
97+
}
98+
99+
TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) {
100+
auto construct = []() {
101+
Json config{JsonObject()};
102+
config["federated_server_address"] = kServerAddress;
103+
config["federated_world_size"] = 1;
104+
config["federated_rank"] = std::string("0");
105+
auto *comm = FederatedCommunicator::Create(config);
106+
};
107+
EXPECT_THROW(construct(), dmlc::Error);
108+
}
109+
88110
TEST(FederatedCommunicatorSimpleTest, GetWorldSizeAndRank) {
89111
FederatedCommunicator comm{6, 3, kServerAddress};
90112
EXPECT_EQ(comm.GetWorldSize(), 6);

0 commit comments

Comments
 (0)