Skip to content

Commit 14c2dae

Browse files
committed
add support for scalar input in add layer
1 parent 77f6d41 commit 14c2dae

File tree

3 files changed

+110
-73
lines changed

3 files changed

+110
-73
lines changed

include/caffe/layers/add_layer.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class AddLayer : public Layer<Dtype> {
4040

4141
int dim_diff_;
4242
int dim_;
43+
bool is_scalar_;
4344
};
4445

4546
} // namespace caffe

include/caffe/layers/tile_nd_layer.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@ class TileNDLayer : public Layer<Dtype> {
3838
// const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
3939

4040
vector<int> multiples_;
41-
vector<int> outer_dim_, inner_dim_;
42-
vector<int> top_inner_dim_;
43-
Blob<Dtype> top_temp_;
4441
};
4542

4643
} // namespace caffe

src/caffe/layers/add_layer.cpp

Lines changed: 109 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -10,31 +10,53 @@ template <typename Dtype>
1010
void AddLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
1111
const vector<Blob<Dtype>*>& top) {
1212
CHECK_NE(top[0], bottom[0]) << this->type() << " Layer does not allow in-place computation.";
13+
is_scalar_=false;
14+
if(bottom[0]->num_axes()==0 || bottom[1]->num_axes()==0)
15+
is_scalar_=true;
1316
dim_diff_ = bottom[0]->num_axes() - bottom[1]->num_axes();
1417
dim_ = bottom[0]->num_axes() >= bottom[1]->num_axes() ? bottom[0]->num_axes() : bottom[1]->num_axes();
1518
vector<int> top_shape(dim_, 1);
1619
if(dim_diff_ == 0)
1720
{
18-
for(int i=0;i<dim_;i++)
21+
if(!is_scalar_)
1922
{
20-
CHECK(bottom[0]->shape(i)==bottom[1]->shape(i) || bottom[0]->shape(i)==1 || bottom[1]->shape(i)==1)
21-
<< "Dimensions must be equal or 1 in the bottoms!";
22-
top_shape[i] = bottom[0]->shape(i) >= bottom[1]->shape(i) ? bottom[0]->shape(i): bottom[1]->shape(i);
23+
for(int i=0;i<dim_;i++)
24+
{
25+
CHECK(bottom[0]->shape(i)==bottom[1]->shape(i) || bottom[0]->shape(i)==1 || bottom[1]->shape(i)==1)
26+
<< "Dimensions must be equal or 1 in the bottoms!";
27+
top_shape[i] = bottom[0]->shape(i) >= bottom[1]->shape(i) ? bottom[0]->shape(i): bottom[1]->shape(i);
28+
}
2329
}
2430
}
2531
else if(dim_diff_ > 0) //bottom0 has more axes than bottom1
2632
{
27-
for(int i=0;i<dim_diff_;i++)
28-
top_shape[i] = bottom[0]->shape(i);
29-
for(int i=dim_diff_; i<dim_; i++)
30-
top_shape[i] = bottom[0]->shape(i) >= bottom[1]->shape(i-dim_diff_) ? bottom[0]->shape(i): bottom[1]->shape(i-dim_diff_);
33+
if(!is_scalar_)
34+
{
35+
for(int i=0;i<dim_diff_;i++)
36+
top_shape[i] = bottom[0]->shape(i);
37+
for(int i=dim_diff_; i<dim_; i++)
38+
top_shape[i] = bottom[0]->shape(i) >= bottom[1]->shape(i-dim_diff_) ? bottom[0]->shape(i): bottom[1]->shape(i-dim_diff_);
39+
}
40+
else //bottom1 is a scalar
41+
{
42+
for(int i=0;i<dim_;i++)
43+
top_shape[i] = bottom[0]->shape(i);
44+
}
3145
}
3246
else //dim_diff_<0, bottom1 has more axes than bottom0
3347
{
34-
for(int i=0;i<-dim_diff_;i++)
35-
top_shape[i] = bottom[1]->shape(i);
36-
for(int i=-dim_diff_; i<dim_; i++)
37-
top_shape[i] = bottom[0]->shape(i+dim_diff_) >= bottom[1]->shape(i) ? bottom[0]->shape(i+dim_diff_): bottom[1]->shape(i);
48+
if(!is_scalar_)
49+
{
50+
for(int i=0;i<-dim_diff_;i++)
51+
top_shape[i] = bottom[1]->shape(i);
52+
for(int i=-dim_diff_; i<dim_; i++)
53+
top_shape[i] = bottom[0]->shape(i+dim_diff_) >= bottom[1]->shape(i) ? bottom[0]->shape(i+dim_diff_): bottom[1]->shape(i);
54+
}
55+
else //bottom0 is a scalar
56+
{
57+
for(int i=0;i<dim_;i++)
58+
top_shape[i] = bottom[1]->shape(i);
59+
}
3860
}
3961
top[0]->Reshape(top_shape);
4062
}
@@ -45,8 +67,8 @@ void AddLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
4567
const Dtype* bottom0_data = bottom[0]->cpu_data();
4668
const Dtype* bottom1_data = bottom[1]->cpu_data();
4769
Dtype* top_data = top[0]->mutable_cpu_data();
48-
4970
int count = top[0]->count();
71+
5072
// Assume top index (x,y,z) with top shape (A, B, C)
5173
// top offset d = xBC + yC + z
5274
// So to count the bottom index, should first figure out x, y, z
@@ -55,75 +77,92 @@ void AddLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
5577
// z = d % C
5678
// Then consider bottom shape (A', B', C'), while A = A' or 1
5779
// So bottom offset = x'B'C' + y'C' + z', while x' = x or 0
58-
for(int d=0; d<count; d++)
80+
if(!is_scalar_)
5981
{
60-
int offset0 = 0;
61-
int offset1 = 0;
62-
63-
if(dim_diff_ == 0)
82+
for(int d=0; d<count; d++)
6483
{
65-
for(int i=0;i<dim_-1;i++)
84+
int offset0 = 0;
85+
int offset1 = 0;
86+
87+
if(dim_diff_ == 0)
6688
{
67-
int num = (d % top[0]->count(i)) / top[0]->count(i+1);
68-
int n0 = 1 == bottom[0]->shape(i) ? 0 : num;
69-
int n1 = 1 == bottom[1]->shape(i) ? 0 : num;
70-
offset0 += n0 * bottom[0]->count(i+1);
71-
offset1 += n1 * bottom[1]->count(i+1);
89+
for(int i=0;i<dim_-1;i++)
90+
{
91+
int num = (d % top[0]->count(i)) / top[0]->count(i+1);
92+
int n0 = 1 == bottom[0]->shape(i) ? 0 : num;
93+
int n1 = 1 == bottom[1]->shape(i) ? 0 : num;
94+
offset0 += n0 * bottom[0]->count(i+1);
95+
offset1 += n1 * bottom[1]->count(i+1);
96+
}
97+
int z = d % top[0]->shape(dim_-1);
98+
int z0 = 1 == bottom[0]->shape(dim_-1) ? 0 : z;
99+
int z1 = 1 == bottom[1]->shape(dim_-1) ? 0 : z;
100+
offset0 += z0;
101+
offset1 += z1;
72102
}
73-
int z = d % top[0]->shape(dim_-1);
74-
int z0 = 1 == bottom[0]->shape(dim_-1) ? 0 : z;
75-
int z1 = 1 == bottom[1]->shape(dim_-1) ? 0 : z;
76-
offset0 += z0;
77-
offset1 += z1;
78-
}
79-
else if(dim_diff_ > 0) //bottom0 has more axes than bottom1
80-
{
81-
for(int i=0;i<dim_diff_;i++)
103+
else if(dim_diff_ > 0) //bottom0 has more axes than bottom1
82104
{
83-
int num = (d % top[0]->count(i)) / top[0]->count(i+1);
84-
int n0 = 1 == bottom[0]->shape(i) ? 0 : num;
85-
offset0 += n0 * bottom[0]->count(i+1);
105+
for(int i=0;i<dim_diff_;i++)
106+
{
107+
int num = (d % top[0]->count(i)) / top[0]->count(i+1);
108+
int n0 = 1 == bottom[0]->shape(i) ? 0 : num;
109+
offset0 += n0 * bottom[0]->count(i+1);
110+
}
111+
for(int i=dim_diff_;i<dim_-1;i++)
112+
{
113+
int num = (d % top[0]->count(i)) / top[0]->count(i+1);
114+
int n0 = 1 == bottom[0]->shape(i) ? 0 : num;
115+
int n1 = 1 == bottom[1]->shape(i-dim_diff_) ? 0 : num;
116+
offset0 += n0 * bottom[0]->count(i+1);
117+
offset1 += n1 * bottom[1]->count(i-dim_diff_+1);
118+
}
119+
int z = d % top[0]->shape(dim_-1);
120+
int z0 = 1 == bottom[0]->shape(dim_-1) ? 0 : z;
121+
int z1 = 1 == bottom[1]->shape(dim_-dim_diff_-1) ? 0 : z;
122+
offset0 += z0;
123+
offset1 += z1;
86124
}
87-
for(int i=dim_diff_;i<dim_-1;i++)
125+
else //dim_diff_<0, bottom1 has more axes than bottom0
88126
{
89-
int num = (d % top[0]->count(i)) / top[0]->count(i+1);
90-
int n0 = 1 == bottom[0]->shape(i) ? 0 : num;
91-
int n1 = 1 == bottom[1]->shape(i-dim_diff_) ? 0 : num;
92-
offset0 += n0 * bottom[0]->count(i+1);
93-
offset1 += n1 * bottom[1]->count(i-dim_diff_+1);
127+
for(int i=0;i<-dim_diff_;i++)
128+
{
129+
int num = (d % top[0]->count(i)) / top[0]->count(i+1);
130+
int n1 = 1 == bottom[1]->shape(i) ? 0 : num;
131+
offset1 += n1 * bottom[1]->count(i+1);
132+
}
133+
for(int i=-dim_diff_;i<dim_-1;i++)
134+
{
135+
int num = (d % top[0]->count(i)) / top[0]->count(i+1);
136+
int n0 = 1 == bottom[0]->shape(i+dim_diff_) ? 0 : num;
137+
int n1 = 1 == bottom[1]->shape(i) ? 0 : num;
138+
offset0 += n0 * bottom[0]->count(i+dim_diff_+1);
139+
offset1 += n1 * bottom[1]->count(i+1);
140+
}
141+
int z = d % top[0]->shape(dim_-1);
142+
int z0 = 1 == bottom[0]->shape(dim_+dim_diff_-1) ? 0 : z;
143+
int z1 = 1 == bottom[1]->shape(dim_-1) ? 0 : z;
144+
offset0 += z0;
145+
offset1 += z1;
94146
}
95-
int z = d % top[0]->shape(dim_-1);
96-
int z0 = 1 == bottom[0]->shape(dim_-1) ? 0 : z;
97-
int z1 = 1 == bottom[1]->shape(dim_-dim_diff_-1) ? 0 : z;
98-
offset0 += z0;
99-
offset1 += z1;
147+
148+
top_data[d] = bottom0_data[offset0] + bottom1_data[offset1];
100149
}
101-
else //dim_diff_<0, bottom1 has more axes than bottom0
150+
}
151+
else //has scalar
152+
{
153+
if(bottom[1]->num_axes()==0) //bottom1 is a scalar
102154
{
103-
for(int i=0;i<-dim_diff_;i++)
104-
{
105-
int num = (d % top[0]->count(i)) / top[0]->count(i+1);
106-
int n1 = 1 == bottom[1]->shape(i) ? 0 : num;
107-
offset1 += n1 * bottom[1]->count(i+1);
108-
}
109-
for(int i=-dim_diff_;i<dim_-1;i++)
110-
{
111-
int num = (d % top[0]->count(i)) / top[0]->count(i+1);
112-
int n0 = 1 == bottom[0]->shape(i+dim_diff_) ? 0 : num;
113-
int n1 = 1 == bottom[1]->shape(i) ? 0 : num;
114-
offset0 += n0 * bottom[0]->count(i+dim_diff_+1);
115-
offset1 += n1 * bottom[1]->count(i+1);
116-
}
117-
int z = d % top[0]->shape(dim_-1);
118-
int z0 = 1 == bottom[0]->shape(dim_+dim_diff_-1) ? 0 : z;
119-
int z1 = 1 == bottom[1]->shape(dim_-1) ? 0 : z;
120-
offset0 += z0;
121-
offset1 += z1;
155+
caffe_copy(count, bottom0_data, top_data);
156+
Dtype scalar = bottom1_data[0];
157+
caffe_add_scalar(count, scalar, top_data);
158+
}
159+
else //bottom0 is a scalar
160+
{
161+
caffe_copy(count, bottom1_data, top_data);
162+
Dtype scalar = bottom0_data[0];
163+
caffe_add_scalar(count, scalar, top_data);
122164
}
123-
124-
top_data[d] = bottom0_data[offset0] + bottom1_data[offset1];
125165
}
126-
127166
}
128167

129168
REGISTER_LAYER_CLASS(Add);

0 commit comments

Comments
 (0)