Skip to content

Commit 3639d99

Browse files
authored
Fix save and load lookup table/optimizer vars (#14301)
* fix mkdir conflict * fix load/save lookup tables test=develop * add lookup_table_utils * fix load optimize vars on pserver * delete lookup table utils * fix save and load lookup tables * fix load optimizer var * fix load optimizer var, test=develop * fix python 3 style, test=develop * move lookup_table_utils to contrib utils
1 parent 2fc32b1 commit 3639d99

File tree

7 files changed

+342
-22
lines changed

7 files changed

+342
-22
lines changed

paddle/fluid/operators/lookup_sparse_table_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
6767
framework::proto::VarType::FP32,
6868
"The sparse table only support FP32");
6969
w_t->Get(ids_t, out_t, true, is_test);
70+
out_t->set_lod(ids_t.lod());
7071
}
7172
};
7273

paddle/fluid/operators/sum_op.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ class SumKernel : public framework::OpKernel<T> {
127127
math::scatter::MergeAdd<DeviceContext, T> merge_add;
128128
merge_add(context.template device_context<DeviceContext>(), inputs,
129129
out);
130+
131+
out->SyncIndex();
132+
130133
} else {
131134
// no data, just set a empty out tensor.
132135
out->mutable_value()->mutable_data<T>(framework::make_ddim({0}),

python/paddle/fluid/contrib/utils/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# limitations under the License.
1414

1515
from __future__ import print_function
16-
16+
from . import lookup_table_utils
17+
from .lookup_table_utils import *
1718
from . import hdfs_utils
1819
from .hdfs_utils import *
1920

21+
__all__ = lookup_table_utils.__all__
2022
__all__ = hdfs_utils.__all__
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import os
18+
import time
19+
import logging
20+
21+
import paddle
22+
import paddle.fluid as fluid
23+
from paddle.fluid import core
24+
from paddle.fluid import io
25+
from paddle.fluid import Program
26+
27+
__all__ = [
28+
"load_inference_model", "load_persistable_vars",
29+
"convert_dist_to_sparse_program"
30+
]
31+
32+
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
33+
_logger = logging.getLogger("lookup_table_utils")
34+
_logger.setLevel(logging.INFO)
35+
36+
model_filename = "__model__"
37+
lookup_table_dir = "__lookup_table__"
38+
39+
40+
def __insert_lookup_sparse_table_op(main_program, idx, ids, w, out):
41+
main_program.global_block()._insert_op(
42+
index=idx,
43+
type="lookup_sparse_table",
44+
inputs={"Ids": [ids],
45+
"W": [w]},
46+
outputs={"Out": [out]},
47+
attrs={
48+
"is_distributed": False,
49+
"is_sparse": True,
50+
"grad_inplace": False
51+
})
52+
53+
54+
def __get_prefetch_op_tuples(main_program):
55+
# current lookup tables op is split_ids->prefetch->merge_ids
56+
prefetch_op_tuples = None
57+
op_types = [op.type for op in main_program.global_block().ops]
58+
59+
for i in range(len(op_types)):
60+
if op_types[i] == "prefetch":
61+
if op_types[i - 1] == "split_ids" and op_types[i +
62+
1] == "merge_ids":
63+
split_ids_op_id = i - 1
64+
split_ids_inputs = main_program.global_block().ops[i - 1].input(
65+
"Ids")
66+
prefetch_op_inputs = main_program.global_block().ops[i].input(
67+
"X")
68+
prefetch_op_outputs = main_program.global_block().ops[i].output(
69+
"Out")
70+
merge_ids_outputs = main_program.global_block().ops[
71+
i + 1].output("Out")
72+
73+
need_delete_vars = []
74+
need_delete_vars.extend(prefetch_op_inputs)
75+
need_delete_vars.extend(prefetch_op_outputs)
76+
77+
prefetch_op_tuples = (split_ids_op_id, split_ids_inputs,
78+
merge_ids_outputs, need_delete_vars)
79+
break
80+
return prefetch_op_tuples
81+
82+
83+
def convert_dist_to_sparse_program(main_program):
84+
if not main_program._distributed_lookup_table:
85+
_logger.warn(
86+
"There are no distributed lookup tables need to be converted")
87+
return
88+
89+
# create table param and grad var in pserver program
90+
origin_emb_var = "{}.origin".format(main_program._distributed_lookup_table)
91+
emb_var = main_program._distributed_lookup_table
92+
main_program.global_block()._rename_var(emb_var, origin_emb_var)
93+
origin_param_var = main_program.global_block().vars[origin_emb_var]
94+
95+
param_var = main_program.global_block().create_var(
96+
name=emb_var,
97+
shape=origin_param_var.shape,
98+
dtype=origin_param_var.dtype,
99+
type=core.VarDesc.VarType.SELECTED_ROWS,
100+
persistable=True)
101+
# parameter must be selected rows
102+
param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS)
103+
main_program._sync_with_cpp()
104+
105+
prefetch_op_tuples = __get_prefetch_op_tuples(main_program)
106+
107+
split_ids_id = prefetch_op_tuples[0]
108+
109+
for idx in range(split_ids_id + 2, split_ids_id - 1, -1):
110+
main_program.global_block()._remove_op(idx)
111+
main_program.desc.flush()
112+
113+
in_out_pairs = zip(prefetch_op_tuples[1], prefetch_op_tuples[2])
114+
115+
for in_out_pair in in_out_pairs:
116+
idx = split_ids_id
117+
ids = main_program.global_block().vars[in_out_pair[0]]
118+
out = main_program.global_block().vars[in_out_pair[1]]
119+
__insert_lookup_sparse_table_op(main_program, idx, ids, param_var, out)
120+
main_program.desc.flush()
121+
return main_program
122+
123+
124+
def load_persistable_vars(executor, dirname, program, lookup_table_var):
125+
def _is_checkpoint_var(exclude_fluid_vars=None):
126+
"""
127+
the checkpoint will not save or load all the variables.
128+
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
129+
130+
: param var(Variable)
131+
"""
132+
133+
if exclude_fluid_vars is None:
134+
exclude_fluid_vars = []
135+
136+
def is_valid(var):
137+
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
138+
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
139+
var.desc.type() == core.VarDesc.VarType.RAW:
140+
return False
141+
# @GRAD are named for gradient variables, checkpoint will not save it.
142+
if "@GRAD" in var.name:
143+
return False
144+
# .trainer_ are named for distribute train variables, checkpoint will not save it.
145+
if ".trainer_" in var.name:
146+
return False
147+
148+
# .block is named for distribute train variables, checkpoint will not save it.
149+
if ".block" in var.name:
150+
return False
151+
152+
if "tmp_" in var.name:
153+
return False
154+
155+
if var.name in exclude_fluid_vars:
156+
return False
157+
158+
return var.persistable
159+
160+
return is_valid
161+
162+
def _load_lookup_table_vars(executor, dirname, main_program,
163+
lookup_table_vars):
164+
if not os.path.isdir(dirname):
165+
raise ValueError("There is no directory named '%s'", dirname)
166+
167+
lookup_table_dirname = os.path.join(dirname, lookup_table_dir)
168+
169+
emb_var_name = lookup_table_vars[0]
170+
emb_var = main_program.global_block().var(emb_var_name)
171+
172+
emb_files = []
173+
for emb_name in os.listdir(lookup_table_dirname):
174+
if emb_var_name in emb_name:
175+
emb_files.append(emb_name)
176+
177+
convert_program = Program()
178+
global_block = convert_program.global_block()
179+
180+
emb_var = global_block.create_var(
181+
name=emb_var.name,
182+
shape=emb_var.shape,
183+
dtype=emb_var.dtype,
184+
type=core.VarDesc.VarType.SELECTED_ROWS,
185+
persistable=True)
186+
emb_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS)
187+
188+
sums = []
189+
190+
for i, emb_file in enumerate(emb_files):
191+
var_name = "{}_{}".format(emb_var.name, i)
192+
param_var = global_block.create_var(
193+
name=var_name,
194+
shape=emb_var.shape,
195+
dtype=emb_var.dtype,
196+
type=core.VarDesc.VarType.SELECTED_ROWS,
197+
persistable=True)
198+
param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS)
199+
global_block.append_op(
200+
type='load',
201+
inputs={},
202+
outputs={'Out': [param_var]},
203+
attrs={
204+
'file_path': os.path.join(lookup_table_dirname, var_name)
205+
})
206+
sums.append(param_var)
207+
global_block.append_op(
208+
type='sum', inputs={"X": sums}, outputs={'Out': emb_var}, attrs={})
209+
global_block.append_op(type='delete_var', inputs={'X': sums})
210+
executor.run(convert_program)
211+
212+
_logger.info("Start Load Sparse Program With "
213+
"Distributed Lookup Table Vars from {}, time = {}".format(
214+
dirname, time.ctime()))
215+
216+
lookup_table_vars = [lookup_table_var]
217+
218+
io.load_vars(
219+
executor,
220+
dirname=dirname,
221+
main_program=program,
222+
predicate=_is_checkpoint_var(lookup_table_vars),
223+
filename=None)
224+
225+
_load_lookup_table_vars(executor, dirname, program, lookup_table_vars)
226+
227+
_logger.info("Finish Load Sparse Program With "
228+
"Distributed Lookup Table Vars from {}, time = {}".format(
229+
dirname, time.ctime()))
230+
231+
232+
def load_inference_model(dirname, executor, lookup_table_var_name):
233+
if not os.path.isdir(dirname):
234+
raise ValueError("There is no directory named '%s'", dirname)
235+
236+
local_model = os.path.join(dirname, model_filename)
237+
238+
with open(local_model, "rb") as f:
239+
program_desc_str = f.read()
240+
241+
program = Program.parse_from_string(program_desc_str)
242+
243+
if not core._is_program_version_supported(program._version()):
244+
raise ValueError("Unsupported program version: %d\n" %
245+
program._version())
246+
247+
# Binary data also need version.
248+
load_persistable_vars(executor, dirname, program, lookup_table_var_name)
249+
250+
feed_target_names = program.desc.get_feed_target_names()
251+
fetch_target_names = program.desc.get_fetch_target_names()
252+
fetch_targets = [
253+
program.global_block().var(name) for name in fetch_target_names
254+
]
255+
256+
return [program, feed_target_names, fetch_targets]

python/paddle/fluid/framework.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,6 +1698,7 @@ def clone(self, for_test=False):
16981698

16991699
p._copy_param_info_from(self)
17001700
p._copy_data_info_from(self)
1701+
p._copy_dist_param_info_from(self)
17011702
return p
17021703

17031704
def _prune(self, targets):
@@ -1938,6 +1939,25 @@ def _copy_param_info_from(self, other):
19381939
"program, with represent the same topology")
19391940
self.global_block()._copy_param_info_from(other.global_block())
19401941

1942+
def _copy_dist_param_info_from(self, other):
1943+
"""
1944+
Copy the information of distributed information from other program.
1945+
1946+
Args:
1947+
other(Program): Other program
1948+
1949+
Returns:
1950+
None
1951+
"""
1952+
if not isinstance(other, Program):
1953+
raise TypeError("_copy_dist_param_info_from should be invoked with "
1954+
"Program")
1955+
self._is_distributed = other._is_distributed
1956+
self._is_chief = other._is_chief
1957+
self._slice_vars_and_attrs = other._slice_vars_and_attrs
1958+
self._endpoints = other._endpoints
1959+
self._distributed_lookup_table = other._distributed_lookup_table
1960+
19411961
def _copy_data_info_from(self, other):
19421962
"""
19431963
Copy the information of data variables from other program.

0 commit comments

Comments
 (0)