Skip to content

Commit 58a5113

Browse files
authored
cherry pick save/load in the_one_ps (#37461)
* save/load in ps runtime(the_one_ps) (#36097) * add trainer desc config to distributed strategy * code style modified * data_feed set lod * fix bug * code style * fix bug * save load * save load * save unittest * add unittest of the_one_ps * unittest * add todo in communicator sendsparse * fix bug in save_inference_model (#37362)
1 parent d5e73f0 commit 58a5113

File tree

7 files changed

+116
-12
lines changed

7 files changed

+116
-12
lines changed

paddle/fluid/distributed/service/communicator.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,18 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id,
283283
push_g_vec.push_back(tensor->mutable_value()->data<float>() + i * dim);
284284
}
285285

286+
// TODO(wangguanqun): padding_idx is not ignored, this is a bug.
287+
// if padding_idx == padding in datareader, the server will core.
288+
/*
289+
for (size_t i = 0; i < tensor->rows().size(); ++i) {
290+
uint64_t real_id = static_cast<uint64_t>(tensor->rows()[i]);
291+
if (real_id != 0) {
292+
sparse_push_keys.push_back(real_id);
293+
push_g_vec.push_back(tensor->mutable_value()->data<float>() + i * dim);
294+
}
295+
}
296+
*/
297+
286298
++_async_call_num;
287299
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
288300
request_call_num, [this, request_call_num](void *done) {
@@ -353,6 +365,17 @@ void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) {
353365
return;
354366
}
355367

368+
void Communicator::PullDense(const RecvCtxMap &recv_varname_to_ctx) {
369+
for (auto &iter : recv_varname_to_ctx) {
370+
auto &table_id = iter.first;
371+
auto &varnames = iter.second;
372+
RpcRecvDense(varnames, table_id, recv_scope_);
373+
VLOG(1) << "pull dense param to table " << table_id
374+
<< " from 0' trainer done";
375+
}
376+
return;
377+
}
378+
356379
void Communicator::RpcProfilerControl() {
357380
if (trainer_id_ == 0) {
358381
if (!do_server_profiler_ && platform::IsProfileEnabled()) {

paddle/fluid/distributed/service/communicator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ class Communicator {
271271

272272
virtual void InitParams(const RecvCtxMap &recv_varname_to_ctx);
273273

274+
virtual void PullDense(const RecvCtxMap &recv_varname_to_ctx);
275+
274276
virtual void Start() = 0;
275277

276278
virtual void Stop() = 0;

paddle/fluid/distributed/table/common_sparse_table.cc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -279,18 +279,25 @@ int32_t CommonSparseTable::set_global_lr(float* lr) {
279279
return 0;
280280
}
281281

282-
int32_t CommonSparseTable::load(const std::string& path,
282+
int32_t CommonSparseTable::load(const std::string& dirname,
283283
const std::string& param) {
284284
auto begin = GetCurrentUS();
285285
rwlock_->WRLock();
286-
LoadFromText(path, param, _shard_idx, _shard_num, task_pool_size_,
286+
auto varname = _config.common().table_name();
287+
std::string var_store =
288+
string::Sprintf("%s/%s%s", dirname, varname, PSERVER_SAVE_SUFFIX);
289+
std::string shard_var_pre =
290+
string::Sprintf("%s.block%d", varname, _shard_idx);
291+
std::string value_ = string::Sprintf("%s/%s.txt", var_store, shard_var_pre);
292+
std::string meta_ = string::Sprintf("%s/%s.meta", var_store, shard_var_pre);
293+
294+
LoadFromText(value_, meta_, _shard_idx, _shard_num, task_pool_size_,
287295
&shard_values_);
288296
rwlock_->UNLock();
289297
auto end = GetCurrentUS();
290298

291-
auto varname = _config.common().table_name();
292-
VLOG(0) << "load " << varname << " with value: " << path
293-
<< " , meta: " << param
299+
VLOG(0) << "load " << varname << " with value: " << value_
300+
<< " , meta: " << meta_
294301
<< " using: " << std::to_string((end - begin) / 1e+6) << " seconds";
295302

296303
return 0;

paddle/fluid/pybind/fleet_py.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ void BindDistCommunicator(py::module* m) {
158158
.def("start", &Communicator::Start)
159159
.def("push_sparse_param", &Communicator::RpcSendSparseParam)
160160
.def("is_running", &Communicator::IsRunning)
161-
.def("init_params", &Communicator::InitParams);
161+
.def("init_params", &Communicator::InitParams)
162+
.def("pull_dense", &Communicator::PullDense);
162163
// .def("recv", &Communicator::RecvNoBarrier);
163164
}
164165

python/paddle/distributed/fleet/runtime/the_one_ps.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -868,11 +868,11 @@ def _init_server(self, dirname=None, var_names=None, **kwargs):
868868

869869
for var_name in load_varnames:
870870
table_id = sparse_table_maps[var_name]
871-
path = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX,
872-
"{}.block{}.txt".format(var_name, pserver_id))
873-
meta = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX,
874-
"{}.block{}.meta".format(var_name, pserver_id))
875-
self._server.load_sparse(path, meta, table_id)
871+
# path = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX,
872+
# "{}.block{}.txt".format(var_name, pserver_id))
873+
# meta = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX,
874+
# "{}.block{}.meta".format(var_name, pserver_id))
875+
self._server.load_sparse(dirname, "0", table_id)
876876

877877
def _run_server(self):
878878
if self.role_maker._is_heter_worker():
@@ -967,8 +967,12 @@ def _save_distributed_persistables(self,
967967
TheOnePSRuntime.__exclude_vars(saved_varnames),
968968
main_program.list_vars()))
969969

970+
self._communicator.pull_dense(denses)
971+
970972
import paddle
971973
for var in remaining_vars:
974+
# if var.name not in recv_dense_varnames:
975+
# continue
972976
tensor = var.get_value()
973977
paddle.save(
974978
tensor, os.path.join(dirname, var.name), use_binary_format=True)
@@ -1063,8 +1067,64 @@ def _save_inference_model(self, *args, **kwargs):
10631067
def _save_persistables(self, *args, **kwargs):
10641068
self._ps_inference_save_persistables(*args, **kwargs)
10651069

1070+
def _load_sparse_params(self, dirname, context, main_program, mode):
1071+
from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_tablenames
1072+
distributed_varnames = get_sparse_tablenames(
1073+
self.compiled_strategy.origin_main_program, True)
1074+
values = []
1075+
for id, names in context.items():
1076+
if names[0] not in distributed_varnames:
1077+
# TODO: only load sparse param from local
1078+
warnings.warn("varname is not in distributed_varnames, pass")
1079+
# load sparse & distributed param on server
1080+
self._worker.load_one_table(id, dirname, mode)
1081+
values.extend(names)
1082+
return values
1083+
1084+
def _load_distributed_persistables(self, dirname, main_program=None,
1085+
mode=0):
1086+
if main_program is None:
1087+
main_program = self.compiled_strategy.get_origin_ps_main_program()
1088+
1089+
if isinstance(main_program, CompiledProgram):
1090+
raise TypeError(
1091+
"in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed"
1092+
)
1093+
1094+
denses = self.compiled_strategy.get_the_one_recv_context(
1095+
is_dense=True,
1096+
split_dense_table=self.role_maker._is_heter_parameter_server_mode,
1097+
use_origin_program=True)
1098+
sparses = self.compiled_strategy.get_the_one_recv_context(
1099+
is_dense=False,
1100+
split_dense_table=self.role_maker._is_heter_parameter_server_mode,
1101+
use_origin_program=True)
1102+
1103+
sparse_varnames = self._load_sparse_params(dirname, sparses,
1104+
main_program, mode)
1105+
1106+
recv_dense_varnames = []
1107+
for id, names in denses.items():
1108+
recv_dense_varnames.extend(names)
1109+
1110+
loaded_varnames = sparse_varnames
1111+
1112+
remaining_vars = list(
1113+
filter(
1114+
TheOnePSRuntime.__exclude_vars(loaded_varnames),
1115+
main_program.list_vars()))
1116+
1117+
import paddle
1118+
for var in remaining_vars:
1119+
if var.name not in recv_dense_varnames:
1120+
continue
1121+
tensor = paddle.load(os.path.join(dirname, var.name))
1122+
var.set_value(tensor)
1123+
1124+
self._communicator.init_params(denses)
1125+
10661126
def load_model(self, path, mode):
1067-
self._worker.load_model(path, mode)
1127+
self._load_distributed_persistables(path, mode=mode)
10681128

10691129
def _shrink(self, threshold):
10701130
import paddle.distributed.fleet as fleet

python/paddle/fluid/communicator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ def recv(self):
161161
def init_params(self, context):
162162
self.communicator_.init_params(context)
163163

164+
def pull_dense(self, context):
165+
self.communicator_.pull_dense(context)
166+
164167
def push_sparse_param(self, var_name, table_id=-1, scope=global_scope()):
165168
if not self.is_running():
166169
raise ValueError(

python/paddle/fluid/tests/unittests/test_fleet_base_2.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,13 @@ def test_ps_minimize(self):
3636

3737
input_x = paddle.fluid.layers.data(
3838
name="x", shape=[32], dtype='float32')
39+
input_slot = paddle.fluid.layers.data(
40+
name="slot", shape=[1], dtype='int64')
3941
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
4042

43+
emb = paddle.fluid.layers.embedding(
44+
input=input_slot, size=[10, 9], is_sparse=True)
45+
input_x = paddle.concat(x=[input_x, emb], axis=1)
4146
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
4247
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
4348
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
@@ -63,11 +68,14 @@ def test_ps_minimize(self):
6368
compiled_prog = fluid.compiler.CompiledProgram(
6469
fluid.default_main_program())
6570

71+
fleet.init_worker()
6672
fleet.fleet.save(dirname="/tmp", feed=['x', 'y'], fetch=[avg_cost])
6773
fleet.fleet.save(
6874
dirname="/tmp", feed=[input_x, input_y], fetch=[avg_cost])
6975
fleet.fleet.save(dirname="/tmp")
7076

77+
fleet.load_model(path="/tmp", mode=0)
78+
7179
self.assertRaises(
7280
Exception,
7381
fleet.save_inference_model,

0 commit comments

Comments
 (0)