Skip to content

Commit 8177ece

Browse files
fix entry (#31079) (#31182)
* fix entry * fix distributed lookup table fuse case * fix entry bug at first time * move entry from paddle.fluid -> paddle.distributed * fix ut with paddle.enable_static() Co-authored-by: malin10 <[email protected]> Co-authored-by: malin10 <[email protected]>
1 parent fe00d32 commit 8177ece

File tree

12 files changed

+242
-43
lines changed

12 files changed

+242
-43
lines changed

paddle/fluid/distributed/ps.proto

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,9 @@ message CommonAccessorParameter {
140140
repeated string params = 4;
141141
repeated uint32 dims = 5;
142142
repeated string initializers = 6;
143-
optional int32 trainer_num = 7;
144-
optional bool sync = 8;
143+
optional string entry = 7;
144+
optional int32 trainer_num = 8;
145+
optional bool sync = 9;
145146
}
146147

147148
message TableAccessorSaveParameter {

paddle/fluid/distributed/table/common_sparse_table.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,13 @@ int32_t CommonSparseTable::initialize() {
238238
int32_t CommonSparseTable::initialize_recorder() { return 0; }
239239

240240
int32_t CommonSparseTable::initialize_value() {
241+
auto common = _config.common();
241242
shard_values_.reserve(task_pool_size_);
242243

243244
for (int x = 0; x < task_pool_size_; ++x) {
244-
auto shard =
245-
std::make_shared<ValueBlock>(value_names_, value_dims_, value_offsets_,
246-
value_idx_, initializer_attrs_, "none");
245+
auto shard = std::make_shared<ValueBlock>(
246+
value_names_, value_dims_, value_offsets_, value_idx_,
247+
initializer_attrs_, common.entry());
247248

248249
shard_values_.emplace_back(shard);
249250
}

paddle/fluid/distributed/table/depends/large_scale_kv.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ inline bool count_entry(std::shared_ptr<VALUE> value, int threshold) {
7171
}
7272

7373
inline bool probility_entry(std::shared_ptr<VALUE> value, float threshold) {
74-
UniformInitializer uniform = UniformInitializer({"0", "0", "1"});
74+
UniformInitializer uniform = UniformInitializer({"uniform", "0", "0", "1"});
7575
return uniform.GetValue() >= threshold;
7676
}
7777

@@ -93,20 +93,20 @@ class ValueBlock {
9393

9494
// for Entry
9595
{
96-
auto slices = string::split_string<std::string>(entry_attr, "&");
96+
auto slices = string::split_string<std::string>(entry_attr, ":");
9797
if (slices[0] == "none") {
9898
entry_func_ = std::bind(&count_entry, std::placeholders::_1, 0);
99-
} else if (slices[0] == "count_filter") {
99+
} else if (slices[0] == "count_filter_entry") {
100100
int threshold = std::stoi(slices[1]);
101101
entry_func_ = std::bind(&count_entry, std::placeholders::_1, threshold);
102-
} else if (slices[0] == "probability") {
102+
} else if (slices[0] == "probability_entry") {
103103
float threshold = std::stof(slices[1]);
104104
entry_func_ =
105105
std::bind(&probility_entry, std::placeholders::_1, threshold);
106106
} else {
107107
PADDLE_THROW(platform::errors::InvalidArgument(
108-
"Not supported Entry Type : %s, Only support [count_filter, "
109-
"probability]",
108+
"Not supported Entry Type : %s, Only support [CountFilterEntry, "
109+
"ProbabilityEntry]",
110110
slices[0]));
111111
}
112112
}
@@ -179,10 +179,12 @@ class ValueBlock {
179179
initializers_[x]->GetValue(value->data_.data() + value_offsets_[x],
180180
value_dims_[x]);
181181
}
182+
value->need_save_ = true;
182183
}
184+
} else {
185+
value->need_save_ = true;
183186
}
184187

185-
value->need_save_ = true;
186188
return;
187189
}
188190

paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ void GetDownpourSparseTableProto(
7878
common_proto->set_table_name("MergedDense");
7979
common_proto->set_trainer_num(1);
8080
common_proto->set_sync(false);
81+
common_proto->set_entry("none");
8182
common_proto->add_params("Param");
8283
common_proto->add_dims(10);
8384
common_proto->add_initializers("uniform_random&0&-1.0&1.0");

python/paddle/distributed/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
from . import collective
2626
from .collective import *
2727

28+
from .entry_attr import ProbabilityEntry
29+
from .entry_attr import CountFilterEntry
30+
2831
# start multiprocess apis
2932
__all__ = ["spawn"]
3033

@@ -38,5 +41,17 @@
3841
"QueueDataset",
3942
]
4043

44+
# dataset reader
45+
__all__ += [
46+
"InMemoryDataset",
47+
"QueueDataset",
48+
]
49+
50+
# entry for embedding
51+
__all__ += [
52+
"ProbabilityEntry",
53+
"CountFilterEntry",
54+
]
55+
4156
# collective apis
4257
__all__ += collective.__all__
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
__all__ = ['ProbabilityEntry', 'CountFilterEntry']
18+
19+
20+
class EntryAttr(object):
21+
"""
22+
Entry Config for paddle.static.nn.sparse_embedding with Parameter Server.
23+
24+
Examples:
25+
.. code-block:: python
26+
27+
import paddle
28+
29+
sparse_feature_dim = 1024
30+
embedding_size = 64
31+
32+
entry = paddle.distributed.ProbabilityEntry(0.1)
33+
34+
input = paddle.static.data(name='ins', shape=[1], dtype='int64')
35+
36+
emb = paddle.static.nn.sparse_embedding((
37+
input=input,
38+
size=[sparse_feature_dim, embedding_size],
39+
is_test=False,
40+
entry=entry,
41+
param_attr=paddle.ParamAttr(name="SparseFeatFactors",
42+
initializer=paddle.nn.initializer.Uniform()))
43+
44+
"""
45+
46+
def __init__(self):
47+
self._name = None
48+
49+
def _to_attr(self):
50+
"""
51+
Returns the attributes of this parameter.
52+
53+
Returns:
54+
Parameter attributes(map): The attributes of this parameter.
55+
"""
56+
raise NotImplementedError("EntryAttr is base class")
57+
58+
59+
class ProbabilityEntry(EntryAttr):
60+
"""
61+
Examples:
62+
.. code-block:: python
63+
64+
import paddle
65+
66+
sparse_feature_dim = 1024
67+
embedding_size = 64
68+
69+
entry = paddle.distributed.ProbabilityEntry(0.1)
70+
71+
input = paddle.static.data(name='ins', shape=[1], dtype='int64')
72+
73+
emb = paddle.static.nn.sparse_embedding((
74+
input=input,
75+
size=[sparse_feature_dim, embedding_size],
76+
is_test=False,
77+
entry=entry,
78+
param_attr=paddle.ParamAttr(name="SparseFeatFactors",
79+
initializer=paddle.nn.initializer.Uniform()))
80+
81+
82+
"""
83+
84+
def __init__(self, probability):
85+
super(EntryAttr, self).__init__()
86+
87+
if not isinstance(probability, float):
88+
raise ValueError("probability must be a float in (0,1)")
89+
90+
if probability <= 0 or probability >= 1:
91+
raise ValueError("probability must be a float in (0,1)")
92+
93+
self._name = "probability_entry"
94+
self._probability = probability
95+
96+
def _to_attr(self):
97+
return ":".join([self._name, str(self._probability)])
98+
99+
100+
class CountFilterEntry(EntryAttr):
101+
"""
102+
Examples:
103+
.. code-block:: python
104+
105+
import paddle
106+
107+
sparse_feature_dim = 1024
108+
embedding_size = 64
109+
110+
entry = paddle.distributed.CountFilterEntry(10)
111+
112+
input = paddle.static.data(name='ins', shape=[1], dtype='int64')
113+
114+
emb = paddle.static.nn.sparse_embedding((
115+
input=input,
116+
size=[sparse_feature_dim, embedding_size],
117+
is_test=False,
118+
entry=entry,
119+
param_attr=paddle.ParamAttr(name="SparseFeatFactors",
120+
initializer=paddle.nn.initializer.Uniform()))
121+
122+
"""
123+
124+
def __init__(self, count_filter):
125+
super(EntryAttr, self).__init__()
126+
127+
if not isinstance(count_filter, int):
128+
raise ValueError(
129+
"count_filter must be a valid integer greater than 0")
130+
131+
if count_filter < 0:
132+
raise ValueError(
133+
"count_filter must be a valid integer greater or equal than 0")
134+
135+
self._name = "count_filter_entry"
136+
self._count_filter = count_filter
137+
138+
def _to_attr(self):
139+
return ":".join([self._name, str(self._count_filter)])

python/paddle/distributed/fleet/runtime/the_one_ps.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class CommonAccessor:
5858
def __init__(self):
5959
self.accessor_class = ""
6060
self.table_name = None
61+
self.entry = None
6162
self.attrs = []
6263
self.params = []
6364
self.dims = []
@@ -93,6 +94,24 @@ def define_optimize_map(self):
9394
self.opt_input_map = opt_input_map
9495
self.opt_init_map = opt_init_map
9596

97+
def parse_entry(self, varname, o_main_program):
98+
from paddle.fluid.incubate.fleet.parameter_server.ir.public import is_distributed_sparse_op
99+
from paddle.fluid.incubate.fleet.parameter_server.ir.public import is_sparse_op
100+
101+
for op in o_main_program.global_block().ops:
102+
if not is_distributed_sparse_op(op) and not is_sparse_op(op):
103+
continue
104+
105+
param_name = op.input("W")[0]
106+
107+
if param_name == varname and op.type == "lookup_table":
108+
self.entry = op.attr('entry')
109+
break
110+
111+
if param_name == varname and op.type == "lookup_table_v2":
112+
self.entry = "none"
113+
break
114+
96115
def get_shard(self, total_dim, shard_num, pserver_id):
97116
# remainder = total_dim % shard_num
98117
blocksize = int(total_dim / shard_num + 1)
@@ -188,6 +207,8 @@ def to_string(self, indent):
188207
if self.table_name:
189208
attrs += "table_name: \"{}\" ".format(self.table_name)
190209

210+
if self.entry:
211+
attrs += "entry: \"{}\" ".format(self.entry)
191212
attrs += "trainer_num: {} ".format(self.trainer_num)
192213
attrs += "sync: {} ".format(self.sync)
193214

@@ -655,36 +676,31 @@ def _get_tables():
655676
use_origin_program=True,
656677
split_dense_table=self.role_maker.
657678
_is_heter_parameter_server_mode)
679+
658680
tables = []
659681
for idx, (name, ctx) in enumerate(send_ctx.items()):
682+
if ctx.is_tensor_table() or len(ctx.origin_varnames()) < 1:
683+
continue
684+
660685
table = Table()
661686
table.id = ctx.table_id()
662-
663-
if ctx.is_tensor_table():
664-
continue
687+
common = CommonAccessor()
665688

666689
if ctx.is_sparse():
667-
if len(ctx.origin_varnames()) < 1:
668-
continue
669690
table.type = "PS_SPARSE_TABLE"
691+
table.shard_num = 256
670692

671693
if self.compiled_strategy.is_geo_mode():
672694
table.table_class = "SparseGeoTable"
673695
else:
674696
table.table_class = "CommonSparseTable"
675-
table.shard_num = 256
676-
else:
677-
if len(ctx.origin_varnames()) < 1:
678-
continue
679-
table.type = "PS_DENSE_TABLE"
680-
table.table_class = "CommonDenseTable"
681-
table.shard_num = 256
682697

683-
common = CommonAccessor()
684-
if ctx.is_sparse():
685698
common.table_name = self.compiled_strategy.grad_name_to_param_name[
686699
ctx.origin_varnames()[0]]
687700
else:
701+
table.type = "PS_DENSE_TABLE"
702+
table.table_class = "CommonDenseTable"
703+
table.shard_num = 256
688704
common.table_name = "MergedDense"
689705

690706
common.parse_by_optimizer(ctx.origin_varnames()[0],
@@ -693,6 +709,10 @@ def _get_tables():
693709
else ctx.sections()[0],
694710
self.compiled_strategy)
695711

712+
if ctx.is_sparse():
713+
common.parse_entry(common.table_name,
714+
self.origin_main_program)
715+
696716
if is_sync:
697717
common.sync = "true"
698718
else:

python/paddle/fluid/contrib/layers/nn.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646

4747
from paddle.fluid import core
4848
from paddle.fluid.param_attr import ParamAttr
49-
from paddle.fluid.entry_attr import ProbabilityEntry, CountFilterEntry
5049

5150
from paddle.fluid.framework import Variable, convert_np_dtype_to_dtype_
5251
from paddle.fluid.layers import slice, reshape
@@ -993,11 +992,13 @@ def sparse_embedding(input,
993992
entry_str = "none"
994993

995994
if entry is not None:
996-
if not isinstance(entry, ProbabilityEntry) and not isinstance(
997-
entry, CountFilterEntry):
995+
if entry.__class__.__name__ not in [
996+
"ProbabilityEntry", "CountFilterEntry"
997+
]:
998998
raise ValueError(
999-
"entry must be instance in [ProbabilityEntry, CountFilterEntry]")
1000-
entry_str = entry.to_attr()
999+
"entry must be instance in [paddle.distributed.ProbabilityEntry, paddle.distributed.CountFilterEntry]"
1000+
)
1001+
entry_str = entry._to_attr()
10011002

10021003
helper.append_op(
10031004
type='lookup_table',

0 commit comments

Comments
 (0)