Skip to content

Commit de4013f

Browse files
ZiyueXu77nvidianztrivialfis
authored
Implement secure horizontal scheme for federated learning (dmlc#10231)
* Add additional data split mode to cover the secure vertical pipeline * Add IsSecure info and update corresponding functions * Modify evaluate_splits to block non-label owners to perform hist compute under secure scenario * Continue using Allgather for best split sync for secure vertical, equvalent to broadcast * Modify histogram sync scheme for secure vertical case, can identify global best split, but need to further apply split correctly * Sync cut informaiton across clients, full pipeline works for testing case * Code cleanup, phase 1 of alternative vertical pipeline finished * Code clean * change kColS to kColSecure to avoid confusion with kCols * Replace allreduce with allgather, functional but inefficient version * Update AllGather behavior from individual pair to bulk by adopting new histogram transmission data structure of a flat vector * comment out the record printing * fix pointer bug for histsync with allgather * identify the HE adding locations * revise and simplify template code * revise and simplify template code * prepare aggregator for gh broadcast * prepare histogram for histindex and row index for allgather * fix conflicts * fix conflicts * fix format * fix allgather logic and update unit test * fix linting * fix linting and other unit test issues * fix linting and other unit test issues * integration with interface initial attempt * integration with interface initial attempt * integration with interface initial attempt * functional integration with interface * remove debugging prints * remove processor from another PR * Update the processor functions according to new processor implementation * Move processor interface init from learner to communicator * Move processor interface init from learner to communicator functional * switch to allgatherV for encrypted message with varying lenghts * consolidate with processor interface PR * remove prints and fix format * fix linting over reference pass * fix undefined symbol issue * fix processor test * secure vertical relies on processor, move the unit test * type correction * type correction * extra linting from last change * Added Windows support * fix for cstdint types * fix for cstdint types * Added support for horizontal secure XGBoost * update with mock plugin * secure horizontal fully functional with mock plugin * linting fix * linting fix * linting fix * fix type * change loader and proc params input pattern to align with std map * update with secure vertical incorporation * Update mock_processor to enable nvflare usage * [backport] Fix compiling with the latest CTX. (dmlc#10263) * fix secure horizontal inference * initialized aggr context only once * Added support for multiple plugins in a single lib * remove redundant condition * Added support for boolean in proc_params * free buffer * CUDA. * Fix clean build. * Fix include. * tidy. * lint. * nolint. * disable. * disable sanitizer. --------- Co-authored-by: Zhihong Zhang <zhihongz@nvidia.com> Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
1 parent 09bc2c7 commit de4013f

File tree

15 files changed

+832
-79
lines changed

15 files changed

+832
-79
lines changed

cmake/Utils.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ macro(xgboost_target_link_libraries target)
227227
else()
228228
target_link_libraries(${target} PRIVATE Threads::Threads ${CMAKE_THREAD_LIBS_INIT})
229229
endif()
230+
target_link_libraries(${target} PRIVATE ${CMAKE_DL_LIBS})
230231

231232
if(USE_OPENMP)
232233
if(BUILD_STATIC_LIB)

include/xgboost/data.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ enum class DataType : uint8_t {
4040

4141
enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 };
4242

43-
enum class DataSplitMode : int { kRow = 0, kCol = 1, kColSecure = 2 };
43+
enum class DataSplitMode : int { kRow = 0, kCol = 1, kColSecure = 2, kRowSecure = 3 };
4444

4545
/*!
4646
* \brief Meta information about dataset, always sit in memory.
@@ -181,16 +181,16 @@ class MetaInfo {
181181
void SynchronizeNumberOfColumns(Context const* ctx);
182182

183183
/*! \brief Whether the data is split row-wise. */
184-
bool IsRowSplit() const {
185-
return data_split_mode == DataSplitMode::kRow;
186-
}
184+
bool IsRowSplit() const { return (data_split_mode == DataSplitMode::kRow)
185+
|| (data_split_mode == DataSplitMode::kRowSecure); }
187186

188187
/** @brief Whether the data is split column-wise. */
189188
bool IsColumnSplit() const { return (data_split_mode == DataSplitMode::kCol)
190189
|| (data_split_mode == DataSplitMode::kColSecure); }
191190

192191
/** @brief Whether the data is split column-wise with secure computation. */
193-
bool IsSecure() const { return data_split_mode == DataSplitMode::kColSecure; }
192+
bool IsSecure() const { return (data_split_mode == DataSplitMode::kColSecure)
193+
|| (data_split_mode == DataSplitMode::kRowSecure); }
194194

195195
/** @brief Whether this is a learning to rank data. */
196196
bool IsRanking() const { return !group_ptr_.empty(); }

src/collective/aggregator.h

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "communicator-inl.h"
1515
#include "xgboost/collective/result.h" // for Result
1616
#include "xgboost/data.h" // for MetaINfo
17+
#include "../processing/processor.h" // for Processor
1718

1819
namespace xgboost::collective {
1920

@@ -69,7 +70,7 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::si
6970
* @param result The HostDeviceVector storing the results.
7071
* @param function The function used to calculate the results.
7172
*/
72-
template <typename T, typename Function>
73+
template <bool is_gpair, typename T, typename Function>
7374
void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>* result,
7475
Function&& function) {
7576
if (info.IsVerticalFederated()) {
@@ -96,8 +97,49 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>*
9697
}
9798
collective::Broadcast(&size, sizeof(std::size_t), 0);
9899

99-
result->Resize(size);
100-
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
100+
if (info.IsSecure() && is_gpair) {
101+
// Under secure mode, gpairs will be processed to vector and encrypt
102+
// information only available on rank 0
103+
std::size_t buffer_size{};
104+
std::int8_t *buffer;
105+
if (collective::GetRank() == 0) {
106+
std::vector<double> vector_gh;
107+
for (std::size_t i = 0; i < size; i++) {
108+
auto gpair = result->HostVector()[i];
109+
// cast from GradientPair to float pointer
110+
auto gpair_ptr = reinterpret_cast<float*>(&gpair);
111+
// save to vector
112+
vector_gh.push_back(gpair_ptr[0]);
113+
vector_gh.push_back(gpair_ptr[1]);
114+
}
115+
// provide the vectors to the processor interface
116+
size_t size;
117+
auto buf = processor_instance->ProcessGHPairs(&size, vector_gh);
118+
buffer_size = size;
119+
buffer = reinterpret_cast<std::int8_t *>(buf);
120+
}
121+
122+
// broadcast the buffer size for other ranks to prepare
123+
collective::Broadcast(&buffer_size, sizeof(std::size_t), 0);
124+
// prepare buffer on passive parties for satisfying broadcast mpi call
125+
if (collective::GetRank() != 0) {
126+
buffer = reinterpret_cast<std::int8_t *>(malloc(buffer_size));
127+
}
128+
129+
// broadcast the data buffer holding processed gpairs
130+
collective::Broadcast(buffer, buffer_size, 0);
131+
132+
// call HandleGHPairs
133+
size_t size;
134+
processor_instance->HandleGHPairs(&size, buffer, buffer_size);
135+
136+
// free the buffer
137+
free(buffer);
138+
} else {
139+
// clear text mode, broadcast the data directly
140+
result->Resize(size);
141+
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
142+
}
101143
} else {
102144
std::forward<Function>(function)();
103145
}

src/collective/communicator.cc

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/*!
22
* Copyright 2022 XGBoost contributors
33
*/
4+
#include <map>
45
#include "communicator.h"
56

67
#include "comm.h"
@@ -9,14 +10,41 @@
910
#include "rabit_communicator.h"
1011

1112
#if defined(XGBOOST_USE_FEDERATED)
12-
#include "../../plugin/federated/federated_communicator.h"
13+
#include "../../plugin/federated/federated_communicator.h"
1314
#endif
1415

16+
#include "../processing/processor.h"
17+
processing::Processor *processor_instance;
18+
1519
namespace xgboost::collective {
1620
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
1721
thread_local CommunicatorType Communicator::type_{};
1822
thread_local std::string Communicator::nccl_path_{};
1923

24+
std::map<std::string, std::string> JsonToMap(xgboost::Json const& config, std::string key) {
25+
auto json_map = xgboost::OptionalArg<xgboost::Object>(config, key, xgboost::JsonObject::Map{});
26+
std::map<std::string, std::string> params{};
27+
for (auto entry : json_map) {
28+
std::string text;
29+
xgboost::Value* value = &(entry.second.GetValue());
30+
if (value->Type() == xgboost::Value::ValueKind::kString) {
31+
text = reinterpret_cast<xgboost::String *>(value)->GetString();
32+
} else if (value->Type() == xgboost::Value::ValueKind::kInteger) {
33+
auto num = reinterpret_cast<xgboost::Integer *>(value)->GetInteger();
34+
text = std::to_string(num);
35+
} else if (value->Type() == xgboost::Value::ValueKind::kNumber) {
36+
auto num = reinterpret_cast<xgboost::Number *>(value)->GetNumber();
37+
text = std::to_string(num);
38+
} else if (value->Type() == xgboost::Value::ValueKind::kBoolean) {
39+
text = reinterpret_cast<xgboost::Boolean *>(value)->GetBoolean() ? "true" : "false";
40+
} else {
41+
text = "Unsupported type";
42+
}
43+
params[entry.first] = text;
44+
}
45+
return params;
46+
}
47+
2048
void Communicator::Init(Json const& config) {
2149
auto nccl = OptionalArg<String>(config, "dmlc_nccl_path", std::string{DefaultNcclName()});
2250
nccl_path_ = nccl;
@@ -38,26 +66,46 @@ void Communicator::Init(Json const& config) {
3866
}
3967
case CommunicatorType::kFederated: {
4068
#if defined(XGBOOST_USE_FEDERATED)
41-
communicator_.reset(FederatedCommunicator::Create(config));
69+
communicator_.reset(FederatedCommunicator::Create(config));
70+
// Get processor configs
71+
std::string plugin_name{};
72+
std::string loader_params_key{};
73+
std::string loader_params_map{};
74+
std::string proc_params_key{};
75+
std::string proc_params_map{};
76+
plugin_name = OptionalArg<String>(config, "plugin_name", plugin_name);
77+
// Initialize processor if plugin_name is provided
78+
if (!plugin_name.empty()) {
79+
std::map<std::string, std::string> loader_params = JsonToMap(config, "loader_params");
80+
std::map<std::string, std::string> proc_params = JsonToMap(config, "proc_params");
81+
processing::ProcessorLoader loader(loader_params);
82+
processor_instance = loader.Load(plugin_name);
83+
processor_instance->Initialize(collective::GetRank() == 0, proc_params);
84+
}
4285
#else
43-
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
86+
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
4487
#endif
45-
break;
46-
}
47-
case CommunicatorType::kInMemory:
48-
case CommunicatorType::kInMemoryNccl: {
49-
communicator_.reset(InMemoryCommunicator::Create(config));
50-
break;
51-
}
52-
case CommunicatorType::kUnknown:
53-
LOG(FATAL) << "Unknown communicator type.";
88+
break;
89+
}
90+
91+
case CommunicatorType::kInMemory:
92+
case CommunicatorType::kInMemoryNccl: {
93+
communicator_.reset(InMemoryCommunicator::Create(config));
94+
break;
95+
}
96+
case CommunicatorType::kUnknown:
97+
LOG(FATAL) << "Unknown communicator type.";
5498
}
5599
}
56100

57101
#ifndef XGBOOST_USE_CUDA
58102
void Communicator::Finalize() {
59103
communicator_->Shutdown();
60104
communicator_.reset(new NoOpCommunicator());
105+
if (processor_instance != nullptr) {
106+
processor_instance->Shutdown();
107+
processor_instance = nullptr;
108+
}
61109
}
62110
#endif
63111
} // namespace xgboost::collective

src/learner.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -846,7 +846,7 @@ class LearnerConfiguration : public Learner {
846846

847847
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
848848
base_score->Reshape(1);
849-
collective::ApplyWithLabels(this->Ctx(), info, base_score->Data(),
849+
collective::ApplyWithLabels<false>(this->Ctx(), info, base_score->Data(),
850850
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
851851
}
852852
};
@@ -1472,8 +1472,9 @@ class LearnerImpl : public LearnerIO {
14721472
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info,
14731473
std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) {
14741474
out_gpair->Reshape(info.num_row_, this->learner_model_param_.OutputLength());
1475-
collective::ApplyWithLabels(&ctx_, info, out_gpair->Data(),
1476-
[&] { obj_->GetGradient(preds, info, iter, out_gpair); });
1475+
// calculate gradient and communicate
1476+
collective::ApplyWithLabels<true>(&ctx_, info, out_gpair->Data(),
1477+
[&] { obj_->GetGradient(preds, info, iter, out_gpair); });
14771478
}
14781479

14791480
/*! \brief random number transformation seed. */

src/objective/adaptive.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
160160
auto t_predt = d_predt.Slice(linalg::All(), group_idx);
161161

162162
HostDeviceVector<float> quantiles;
163-
collective::ApplyWithLabels(ctx, info, &quantiles, [&] {
163+
collective::ApplyWithLabels<false>(ctx, info, &quantiles, [&] {
164164
auto d_labels = info.labels.View(ctx->Device()).Slice(linalg::All(), IdxY(info, group_idx));
165165
auto d_row_index = dh::ToSpan(ridx);
166166
auto seg_beg = nptr.DevicePointer();

0 commit comments

Comments
 (0)