Skip to content

Commit 313d623

Browse files
committed
add embedding_lookup_layer
1 parent 4336177 commit 313d623

File tree

3 files changed

+200
-1
lines changed

3 files changed

+200
-1
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#ifndef CAFFE_EMBEDDING_LOOKUP_LAYER_HPP_
2+
#define CAFFE_EMBEDDING_LOOKUP_LAYER_HPP_
3+
4+
#include <string>
5+
#include <vector>
6+
7+
#include "caffe/blob.hpp"
8+
#include "caffe/layer.hpp"
9+
#include "caffe/proto/caffe.pb.h"
10+
11+
namespace caffe {
12+
/*
13+
* Note: implementation of tf.embedding_lookup
14+
* https://www.tensorflow.org/api_docs/python/tf/nn/embedding_lookup
15+
*/
16+
17+
template <typename Dtype> class EmbeddingLookupLayer : public Layer<Dtype> {
18+
public:
19+
explicit EmbeddingLookupLayer(const LayerParameter &param)
20+
: Layer<Dtype>(param) {}
21+
virtual void LayerSetUp(const vector<Blob<Dtype> *> &bottom,
22+
const vector<Blob<Dtype> *> &top);
23+
virtual void Reshape(const vector<Blob<Dtype> *> &bottom,
24+
const vector<Blob<Dtype> *> &top);
25+
26+
virtual inline const char *type() const { return "EmbeddingLookup"; }
27+
virtual inline int MinBottomBlobs() const { return 1; }
28+
virtual inline int ExactNumTopBlobs() const { return 1; }
29+
30+
protected:
31+
virtual void Forward_cpu(const vector<Blob<Dtype> *> &bottom,
32+
const vector<Blob<Dtype> *> &top);
33+
/// @brief Not implemented
34+
virtual void Backward_cpu(const vector<Blob<Dtype> *> &top,
35+
const vector<bool> &propagate_down,
36+
const vector<Blob<Dtype> *> &bottom) {
37+
NOT_IMPLEMENTED;
38+
}
39+
40+
// int n_top;
41+
vector<int> ids;
42+
vector<int> ids_shape;
43+
string p_strategy;
44+
float max_norm;
45+
};
46+
47+
} // namespace caffe
48+
49+
#endif // CAFFE_EMBEDDING_LOOKUP_LAYER_HPP_
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#include <algorithm>
2+
#include <cmath>
3+
#include <vector>
4+
5+
#include "caffe/layers/embedding_lookup_layer.hpp"
6+
#include "caffe/util/math_functions.hpp"
7+
8+
namespace caffe {
9+
10+
template <typename Dtype>
11+
void EmbeddingLookupLayer<Dtype>::LayerSetUp(
12+
const vector<Blob<Dtype> *> &bottom, const vector<Blob<Dtype> *> &top) {
13+
const EmbeddingLookupParameter &embedding_lookup_param =
14+
this->layer_param_.embedding_lookup_param();
15+
ids.clear();
16+
std::copy(embedding_lookup_param.ids().begin(),
17+
embedding_lookup_param.ids().end(), std::back_inserter(ids));
18+
ids_shape.clear();
19+
std::copy(embedding_lookup_param.ids_shape().begin(),
20+
embedding_lookup_param.ids_shape().end(),
21+
std::back_inserter(ids_shape));
22+
p_strategy = embedding_lookup_param.partition_strategy();
23+
24+
// if max_norm = None, define None -> none_value = 9999999999
25+
max_norm = embedding_lookup_param.max_norm();
26+
// CHECK PART
27+
const int num_axes = bottom[0]->num_axes();
28+
vector<int> bottom_shape = bottom[0]->shape();
29+
for (int i = 1; i < bottom.size(); ++i) {
30+
CHECK_EQ(num_axes, bottom[i]->num_axes())
31+
<< "All inputs must have the same #axes.";
32+
for (int j = 1; j < num_axes; ++j) {
33+
CHECK_EQ(bottom_shape[j], bottom[i]->shape(j))
34+
<< "Dimension " << j - 1 << " in both shapes must be equal, but are "
35+
<< bottom_shape[j] << " and " << bottom[i]->shape(j);
36+
}
37+
}
38+
}
39+
40+
template <typename Dtype>
41+
void EmbeddingLookupLayer<Dtype>::Reshape(const vector<Blob<Dtype> *> &bottom,
42+
const vector<Blob<Dtype> *> &top) {
43+
vector<int> top_shape = bottom[0]->shape();
44+
top_shape.erase(top_shape.begin());
45+
top_shape.insert(top_shape.begin(), ids_shape.begin(), ids_shape.end());
46+
top[0]->Reshape(top_shape);
47+
}
48+
49+
template <typename Dtype>
50+
void EmbeddingLookupLayer<Dtype>::Forward_cpu(
51+
const vector<Blob<Dtype> *> &bottom, const vector<Blob<Dtype> *> &top) {
52+
Dtype *top_data = top[0]->mutable_cpu_data();
53+
const int copy_num = bottom[0]->count(1);
54+
// define none_value
55+
const float none_value = 9999999999;
56+
// for one params
57+
if (bottom.size() == 1) {
58+
const Dtype *bottom_data = bottom[0]->cpu_data();
59+
for (int i = 0; i < ids.size(); ++i) {
60+
CHECK_GE(ids[i], 0) << "ids[" << i << "] = " << ids[i] << " is not in "
61+
<< "[0, " << bottom[0]->shape(0) << ").";
62+
CHECK_LT(ids[i], bottom[0]->shape(0))
63+
<< "ids[" << i << "] = " << ids[i] << " is not in "
64+
<< "[0, " << bottom[0]->shape(0) << ").";
65+
const int t_offset = i * copy_num;
66+
const int b_offset = ids[i] * copy_num;
67+
caffe_copy(copy_num, bottom_data + b_offset, top_data + t_offset);
68+
const auto normt = std::sqrt(caffe_cpu_dot(
69+
copy_num, bottom_data + b_offset, bottom_data + b_offset));
70+
if (max_norm != none_value && (normt > max_norm)) {
71+
const auto alpha = max_norm / normt;
72+
caffe_scal(copy_num, alpha, top_data + t_offset);
73+
}
74+
}
75+
}
76+
// for multiple params
77+
else {
78+
for (int i = 0; i < ids.size(); ++i) {
79+
// strategy = mod
80+
if (p_strategy == "mod") {
81+
const int bottom_num = ids[i] % bottom.size();
82+
const int row_num = ids[i] / bottom.size();
83+
CHECK_GE(row_num, 0) << "ids[" << i << "] is not in "
84+
<< "[0, " << bottom[bottom_num]->shape(0)
85+
<< ") for params[" << bottom_num << "]";
86+
CHECK_LT(row_num, bottom[bottom_num]->shape(0))
87+
<< "ids[" << i << "] is not in "
88+
<< "[0, " << bottom[bottom_num]->shape(0) << ") for params["
89+
<< bottom_num << "]";
90+
const Dtype *bottom_data = bottom[bottom_num]->cpu_data();
91+
const int t_offset = i * copy_num;
92+
const int b_offset = row_num * copy_num;
93+
caffe_copy(copy_num, bottom_data + b_offset, top_data + t_offset);
94+
// max_norm part
95+
const auto normt = std::sqrt(caffe_cpu_dot(
96+
copy_num, bottom_data + b_offset, bottom_data + b_offset));
97+
if (max_norm != none_value && (normt > max_norm)) {
98+
const auto alpha = max_norm / normt;
99+
caffe_scal(copy_num, alpha, top_data + t_offset);
100+
}
101+
}
102+
// strategy = div
103+
if (p_strategy == "div") {
104+
int all_idx = 0;
105+
for (int i = 0; i < bottom.size(); ++i) {
106+
all_idx += bottom[i]->shape(0);
107+
}
108+
const int a = all_idx / bottom.size();
109+
const int b = all_idx % bottom.size();
110+
const int bottom_num = (ids[i] < b * (a + 1))
111+
? (ids[i] / (a + 1))
112+
: (b + (ids[i] - b * (a + 1)) / a);
113+
const int row_num = (ids[i] < b * (a + 1))
114+
? (ids[i] % (a + 1))
115+
: ((ids[i] - b * (a + 1)) % a);
116+
CHECK_GE(row_num, 0) << "ids[" << i << "] is not in "
117+
<< "[0, " << bottom[bottom_num]->shape(0)
118+
<< ") for params[" << bottom_num << "]";
119+
CHECK_LT(row_num, bottom[bottom_num]->shape(0))
120+
<< "ids[" << i << "] is not in "
121+
<< "[0, " << bottom[bottom_num]->shape(0) << ") for params["
122+
<< bottom_num << "]";
123+
const Dtype *bottom_data = bottom[bottom_num]->cpu_data();
124+
const int t_offset = i * copy_num;
125+
const int b_offset = row_num * copy_num;
126+
caffe_copy(copy_num, bottom_data + b_offset, top_data + t_offset);
127+
// max_norm part
128+
const auto normt = std::sqrt(caffe_cpu_dot(
129+
copy_num, bottom_data + b_offset, bottom_data + b_offset));
130+
if (max_norm != none_value && (normt > max_norm)) {
131+
const auto alpha = max_norm / normt;
132+
caffe_scal(copy_num, alpha, top_data + t_offset);
133+
}
134+
}
135+
}
136+
}
137+
}
138+
139+
INSTANTIATE_CLASS(EmbeddingLookupLayer);
140+
REGISTER_LAYER_CLASS(EmbeddingLookup);
141+
142+
} // namespace caffe

src/caffe/proto/caffe.proto

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ message ParamSpec {
460460
// NOTE
461461
// Update the next available ID when you add a new LayerParameter field.
462462
//
463-
// LayerParameter next available layer-specific ID: 247 (last added: sparse_to_dense_param)
463+
// LayerParameter next available layer-specific ID: 250 (last added: embedding_lookup_param)
464464
message LayerParameter {
465465
optional string name = 1; // the layer name
466466
optional string type = 2; // the layer type
@@ -650,6 +650,7 @@ message LayerParameter {
650650
optional PieceParameter piece_param = 244;
651651
optional RangeParameter range_param = 245;
652652
optional SparseToDenseParameter sparse_to_dense_param = 246;
653+
optional EmbeddingLookupParameter embedding_lookup_param = 249;
653654
}
654655

655656
message AccumParameter {
@@ -3115,3 +3116,10 @@ message SparseToDenseParameter {
31153116
optional float default_value = 5;
31163117
optional bool validate_indices = 6;
31173118
}
3119+
3120+
message EmbeddingLookupParameter {
3121+
repeated uint32 ids = 1;
3122+
repeated uint32 ids_shape = 2;
3123+
optional string partition_strategy = 3 [default = "mod"];
3124+
optional float max_norm = 4 [default = 9999999999];
3125+
}

0 commit comments

Comments
 (0)