Skip to content

Commit 9ae1523

Browse files
authored
Merge pull request #7719 from guoshengCS/enhance-lookup_table_op-padidx
Enhance lookup_table_op to support padding_idx.
2 parents 76429f4 + d512044 commit 9ae1523

File tree

11 files changed

+126
-21
lines changed

11 files changed

+126
-21
lines changed

paddle/framework/attribute.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
6161
}
6262
return val;
6363
}
64+
case proto::AttrType::LONG: {
65+
return attr_desc.l();
66+
}
6467
default:
6568
PADDLE_THROW("Unsupport attr type %d", attr_desc.type());
6669
}

paddle/framework/attribute.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,32 @@ struct ExtractAttribute<bool> {
168168
const std::string& attr_name_;
169169
};
170170

171+
template <>
172+
struct ExtractAttribute<int64_t> {
173+
explicit ExtractAttribute(const std::string& attr_name)
174+
: attr_name_(attr_name) {}
175+
176+
int64_t* operator()(Attribute& attr) const {
177+
if (attr.type() == typeid(int)) { // NOLINT
178+
int val = boost::get<int>(attr);
179+
attr = static_cast<int64_t>(val);
180+
} else if (attr.type() == typeid(float)) { // NOLINT
181+
int val = boost::get<float>(attr);
182+
attr = static_cast<int64_t>(val);
183+
}
184+
int64_t* attr_value = nullptr;
185+
try {
186+
attr_value = &boost::get<int64_t>(attr);
187+
} catch (boost::bad_get& bad_get) {
188+
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
189+
attr_name_, attr.type().name());
190+
}
191+
return attr_value;
192+
}
193+
194+
const std::string& attr_name_;
195+
};
196+
171197
// check whether a certain attribute fit its limits
172198
// an attribute can have more than one limits
173199
template <typename T>

paddle/framework/framework.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ enum AttrType {
2626
BOOLEAN = 6;
2727
BOOLEANS = 7;
2828
BLOCK = 8;
29+
LONG = 9;
2930
}
3031

3132
// OpDesc describes an instance of a C++ framework::OperatorBase
@@ -44,6 +45,7 @@ message OpDesc {
4445
optional bool b = 10;
4546
repeated bool bools = 11;
4647
optional int32 block_idx = 12;
48+
optional int64 l = 13;
4749
};
4850

4951
message Var {

paddle/framework/op_desc.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
283283
VectorToRepeated(v, attr_->mutable_bools());
284284
}
285285
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }
286+
void operator()(int64_t v) const { attr_->set_l(v); }
286287
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
287288
};
288289

paddle/framework/type_defs.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
3535
using Attribute =
3636
boost::variant<boost::blank, int, float, std::string, std::vector<int>,
3737
std::vector<float>, std::vector<std::string>, bool,
38-
std::vector<bool>, BlockDesc*>;
38+
std::vector<bool>, BlockDesc*, int64_t>;
3939

4040
using AttributeMap = std::unordered_map<std::string, Attribute>;
4141

paddle/operators/lookup_table_op.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
6666
"(boolean, default false) "
6767
"Sparse update")
6868
.SetDefault(false);
69+
AddAttr<int64_t>("padding_idx",
70+
"(int64, default -1) "
71+
"If the value is -1, it makes no effect to lookup. "
72+
"Otherwise the given value indicates padding the output "
73+
"with zeros whenever lookup encounters it in Ids.")
74+
.SetDefault(-1);
6975
AddComment(R"DOC(
7076
Lookup Table Operator.
7177

paddle/operators/lookup_table_op.cu

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@ limitations under the License. */
2121
namespace paddle {
2222
namespace operators {
2323

24-
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
24+
template <typename T, int BlockDimX, int BlockDimY, int GridDimX,
25+
bool PaddingFlag>
2526
__global__ void LookupTable(T* output, const T* table, const int64_t* ids,
26-
const int64_t N, const int64_t K, const int64_t D) {
27+
const int64_t N, const int64_t K, const int64_t D,
28+
const int64_t padding_idx) {
2729
int idx = threadIdx.x;
2830
int idy = blockIdx.x + threadIdx.y * GridDimX;
2931

@@ -34,7 +36,14 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
3436
T* out = output + idy * D;
3537
const T* tab = table + id * D;
3638
for (int i = idx; i < D; i += BlockDimX) {
37-
out[i] = tab[i];
39+
if (PaddingFlag) {
40+
if (id == padding_idx)
41+
out[i] = static_cast<T>(0);
42+
else
43+
out[i] = tab[i];
44+
} else {
45+
out[i] = tab[i];
46+
}
3847
}
3948
idy += BlockDimY * GridDimX;
4049
}
@@ -67,6 +76,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
6776
auto* table_t = context.Input<LoDTensor>("W");
6877
auto* ids_t = context.Input<LoDTensor>("Ids");
6978
auto* output_t = context.Output<LoDTensor>("Out");
79+
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
7080

7181
size_t N = table_t->dims()[0];
7282
size_t D = table_t->dims()[1];
@@ -77,10 +87,17 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
7787

7888
dim3 threads(128, 8);
7989
dim3 grids(8, 1);
80-
LookupTable<
81-
T, 128, 8,
82-
8><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
83-
output, table, ids, N, K, D);
90+
91+
if (padding_idx == -1)
92+
LookupTable<
93+
T, 128, 8, 8,
94+
false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
95+
output, table, ids, N, K, D, padding_idx);
96+
else
97+
LookupTable<
98+
T, 128, 8, 8,
99+
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
100+
output, table, ids, N, K, D, padding_idx);
84101
}
85102
};
86103

@@ -91,6 +108,8 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
91108
auto& dev_ctx =
92109
context.template device_context<platform::CUDADeviceContext>();
93110
bool is_sparse = context.Attr<bool>("is_sparse");
111+
// Since paddings are not trainable and fixed in forward, the gradient of
112+
// paddings makes no sense and we don't deal with it in backward.
94113
if (is_sparse) {
95114
auto* ids = context.Input<LoDTensor>("Ids");
96115
auto* table = context.Input<LoDTensor>("W");

paddle/operators/lookup_table_op.h

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,30 @@ class LookupTableKernel : public framework::OpKernel<T> {
3232
auto* table_t = context.Input<LoDTensor>("W"); // float tensor
3333
auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor
3434
auto* output_t = context.Output<LoDTensor>("Out"); // float tensor
35+
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
3536

3637
int N = table_t->dims()[0];
3738
int D = table_t->dims()[1];
3839
auto* ids = ids_t->data<int64_t>();
3940
auto* table = table_t->data<T>();
4041
auto* output = output_t->mutable_data<T>(context.GetPlace());
41-
for (int64_t i = 0; i < ids_t->numel(); ++i) {
42-
PADDLE_ENFORCE_LT(ids[i], N);
43-
PADDLE_ENFORCE_GE(ids[i], 0);
44-
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
42+
43+
if (padding_idx == -1) {
44+
for (int64_t i = 0; i < ids_t->numel(); ++i) {
45+
PADDLE_ENFORCE_LT(ids[i], N);
46+
PADDLE_ENFORCE_GE(ids[i], 0);
47+
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
48+
}
49+
} else {
50+
for (int64_t i = 0; i < ids_t->numel(); ++i) {
51+
if (ids[i] == padding_idx) {
52+
memset(output + i * D, 0, D * sizeof(T));
53+
} else {
54+
PADDLE_ENFORCE_LT(ids[i], N);
55+
PADDLE_ENFORCE_GE(ids[i], 0);
56+
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
57+
}
58+
}
4559
}
4660
}
4761
};
@@ -51,6 +65,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
5165
public:
5266
void Compute(const framework::ExecutionContext& context) const override {
5367
bool is_sparse = context.Attr<bool>("is_sparse");
68+
// Since paddings are not trainable and fixed in forward, the gradient of
69+
// paddings makes no sense and we don't deal with it in backward.
5470
if (is_sparse) {
5571
auto* ids = context.Input<LoDTensor>("Ids");
5672
auto* table = context.Input<LoDTensor>("W");

paddle/pybind/print_operators_doc.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ std::string AttrType(paddle::framework::proto::AttrType at) {
6464
return "bool array";
6565
case paddle::framework::proto::BLOCK:
6666
return "block id";
67+
case paddle::framework::proto::LONG:
68+
return "long";
6769
}
6870
return "UNKNOWN"; // not possible
6971
}

python/paddle/v2/fluid/layers/nn.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -185,22 +185,35 @@ def fc(input,
185185
return helper.append_activation(pre_activation)
186186

187187

188-
def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'):
188+
def embedding(input,
189+
size,
190+
is_sparse=False,
191+
padding_idx=None,
192+
param_attr=None,
193+
dtype='float32'):
189194
"""
190195
**Embedding Layer**
191196
192-
This layer is used to lookup a vector of IDs, provided by *input*, in a lookup table.
193-
The result of this lookup is the embedding of each ID in the *input*.
197+
This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in
198+
a lookup table. The result of this lookup is the embedding of each ID in the
199+
:attr:`input`.
194200
195201
All the input variables are passed in as local variables to the LayerHelper
196202
constructor.
197203
198204
Args:
199-
input(Variable): Input to the function
200-
size(tuple|list|None): Shape of the look up table parameter
201-
is_sparse(bool): Boolean flag that specifying whether the input is sparse
202-
param_attr(ParamAttr): Parameters for this layer
203-
dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc
205+
input(Variable): The tensor variable containing the IDs.
206+
size(tuple|list): The shape of the look up table parameter. It should
207+
have two elements which indicate the size of the dictionary of
208+
embeddings and the size of each embedding vector respectively.
209+
is_sparse(bool): The flag indicating whether to use sparse update.
210+
padding_idx(int|long|None): If :attr:`None`, it makes no effect to lookup.
211+
Otherwise the given :attr:`padding_idx` indicates padding the output
212+
with zeros whenever lookup encounters it in :attr:`input`. If
213+
:math:`padding_idx < 0`, the padding_idx to use in lookup is
214+
:math:`size[0] + dim`.
215+
param_attr(ParamAttr): Parameters for this layer
216+
dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc
204217
205218
Returns:
206219
Variable: The tensor variable storing the embeddings of the \
@@ -218,12 +231,15 @@ def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'):
218231
w = helper.create_parameter(
219232
attr=helper.param_attr, shape=size, dtype=dtype, is_bias=False)
220233
tmp = helper.create_tmp_variable(dtype)
234+
padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (
235+
size[0] + padding_idx)
221236
helper.append_op(
222237
type='lookup_table',
223238
inputs={'Ids': input,
224239
'W': w},
225240
outputs={'Out': tmp},
226-
attrs={'is_sparse': is_sparse})
241+
attrs={'is_sparse': is_sparse,
242+
'padding_idx': padding_idx})
227243
return tmp
228244

229245

0 commit comments

Comments
 (0)