@@ -10,31 +10,53 @@ template <typename Dtype>
1010void AddLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
1111 const vector<Blob<Dtype>*>& top) {
1212 CHECK_NE (top[0 ], bottom[0 ]) << this ->type () << " Layer does not allow in-place computation." ;
13+ is_scalar_=false ;
14+ if (bottom[0 ]->num_axes ()==0 || bottom[1 ]->num_axes ()==0 )
15+ is_scalar_=true ;
1316 dim_diff_ = bottom[0 ]->num_axes () - bottom[1 ]->num_axes ();
1417 dim_ = bottom[0 ]->num_axes () >= bottom[1 ]->num_axes () ? bottom[0 ]->num_axes () : bottom[1 ]->num_axes ();
1518 vector<int > top_shape (dim_, 1 );
1619 if (dim_diff_ == 0 )
1720 {
18- for ( int i= 0 ;i<dim_;i++ )
21+ if (!is_scalar_ )
1922 {
20- CHECK (bottom[0 ]->shape (i)==bottom[1 ]->shape (i) || bottom[0 ]->shape (i)==1 || bottom[1 ]->shape (i)==1 )
21- << " Dimensions must be equal or 1 in the bottoms!" ;
22- top_shape[i] = bottom[0 ]->shape (i) >= bottom[1 ]->shape (i) ? bottom[0 ]->shape (i): bottom[1 ]->shape (i);
23+ for (int i=0 ;i<dim_;i++)
24+ {
25+ CHECK (bottom[0 ]->shape (i)==bottom[1 ]->shape (i) || bottom[0 ]->shape (i)==1 || bottom[1 ]->shape (i)==1 )
26+ << " Dimensions must be equal or 1 in the bottoms!" ;
27+ top_shape[i] = bottom[0 ]->shape (i) >= bottom[1 ]->shape (i) ? bottom[0 ]->shape (i): bottom[1 ]->shape (i);
28+ }
2329 }
2430 }
2531 else if (dim_diff_ > 0 ) // bottom0 has more axes than bottom1
2632 {
27- for (int i=0 ;i<dim_diff_;i++)
28- top_shape[i] = bottom[0 ]->shape (i);
29- for (int i=dim_diff_; i<dim_; i++)
30- top_shape[i] = bottom[0 ]->shape (i) >= bottom[1 ]->shape (i-dim_diff_) ? bottom[0 ]->shape (i): bottom[1 ]->shape (i-dim_diff_);
33+ if (!is_scalar_)
34+ {
35+ for (int i=0 ;i<dim_diff_;i++)
36+ top_shape[i] = bottom[0 ]->shape (i);
37+ for (int i=dim_diff_; i<dim_; i++)
38+ top_shape[i] = bottom[0 ]->shape (i) >= bottom[1 ]->shape (i-dim_diff_) ? bottom[0 ]->shape (i): bottom[1 ]->shape (i-dim_diff_);
39+ }
40+ else // bottom1 is a scalar
41+ {
42+ for (int i=0 ;i<dim_;i++)
43+ top_shape[i] = bottom[0 ]->shape (i);
44+ }
3145 }
3246 else // dim_diff_<0, bottom1 has more axes than bottom0
3347 {
34- for (int i=0 ;i<-dim_diff_;i++)
35- top_shape[i] = bottom[1 ]->shape (i);
36- for (int i=-dim_diff_; i<dim_; i++)
37- top_shape[i] = bottom[0 ]->shape (i+dim_diff_) >= bottom[1 ]->shape (i) ? bottom[0 ]->shape (i+dim_diff_): bottom[1 ]->shape (i);
48+ if (!is_scalar_)
49+ {
50+ for (int i=0 ;i<-dim_diff_;i++)
51+ top_shape[i] = bottom[1 ]->shape (i);
52+ for (int i=-dim_diff_; i<dim_; i++)
53+ top_shape[i] = bottom[0 ]->shape (i+dim_diff_) >= bottom[1 ]->shape (i) ? bottom[0 ]->shape (i+dim_diff_): bottom[1 ]->shape (i);
54+ }
55+ else // bottom0 is a scalar
56+ {
57+ for (int i=0 ;i<dim_;i++)
58+ top_shape[i] = bottom[1 ]->shape (i);
59+ }
3860 }
3961 top[0 ]->Reshape (top_shape);
4062}
@@ -45,8 +67,8 @@ void AddLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
4567 const Dtype* bottom0_data = bottom[0 ]->cpu_data ();
4668 const Dtype* bottom1_data = bottom[1 ]->cpu_data ();
4769 Dtype* top_data = top[0 ]->mutable_cpu_data ();
48-
4970 int count = top[0 ]->count ();
71+
5072 // Assume top index (x,y,z) with top shape (A, B, C)
5173 // top offset d = xBC + yC + z
5274 // So to count the bottom index, should first figure out x, y, z
@@ -55,75 +77,92 @@ void AddLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
5577 // z = d % C
5678 // Then consider bottom shape (A', B', C'), while A = A' or 1
5779 // So bottom offset = x'B'C' + y'C' + z', while x' = x or 0
58- for ( int d= 0 ; d<count; d++ )
80+ if (!is_scalar_ )
5981 {
60- int offset0 = 0 ;
61- int offset1 = 0 ;
62-
63- if (dim_diff_ == 0 )
82+ for (int d=0 ; d<count; d++)
6483 {
65- for (int i=0 ;i<dim_-1 ;i++)
84+ int offset0 = 0 ;
85+ int offset1 = 0 ;
86+
87+ if (dim_diff_ == 0 )
6688 {
67- int num = (d % top[0 ]->count (i)) / top[0 ]->count (i+1 );
68- int n0 = 1 == bottom[0 ]->shape (i) ? 0 : num;
69- int n1 = 1 == bottom[1 ]->shape (i) ? 0 : num;
70- offset0 += n0 * bottom[0 ]->count (i+1 );
71- offset1 += n1 * bottom[1 ]->count (i+1 );
89+ for (int i=0 ;i<dim_-1 ;i++)
90+ {
91+ int num = (d % top[0 ]->count (i)) / top[0 ]->count (i+1 );
92+ int n0 = 1 == bottom[0 ]->shape (i) ? 0 : num;
93+ int n1 = 1 == bottom[1 ]->shape (i) ? 0 : num;
94+ offset0 += n0 * bottom[0 ]->count (i+1 );
95+ offset1 += n1 * bottom[1 ]->count (i+1 );
96+ }
97+ int z = d % top[0 ]->shape (dim_-1 );
98+ int z0 = 1 == bottom[0 ]->shape (dim_-1 ) ? 0 : z;
99+ int z1 = 1 == bottom[1 ]->shape (dim_-1 ) ? 0 : z;
100+ offset0 += z0;
101+ offset1 += z1;
72102 }
73- int z = d % top[0 ]->shape (dim_-1 );
74- int z0 = 1 == bottom[0 ]->shape (dim_-1 ) ? 0 : z;
75- int z1 = 1 == bottom[1 ]->shape (dim_-1 ) ? 0 : z;
76- offset0 += z0;
77- offset1 += z1;
78- }
79- else if (dim_diff_ > 0 ) // bottom0 has more axes than bottom1
80- {
81- for (int i=0 ;i<dim_diff_;i++)
103+ else if (dim_diff_ > 0 ) // bottom0 has more axes than bottom1
82104 {
83- int num = (d % top[0 ]->count (i)) / top[0 ]->count (i+1 );
84- int n0 = 1 == bottom[0 ]->shape (i) ? 0 : num;
85- offset0 += n0 * bottom[0 ]->count (i+1 );
105+ for (int i=0 ;i<dim_diff_;i++)
106+ {
107+ int num = (d % top[0 ]->count (i)) / top[0 ]->count (i+1 );
108+ int n0 = 1 == bottom[0 ]->shape (i) ? 0 : num;
109+ offset0 += n0 * bottom[0 ]->count (i+1 );
110+ }
111+ for (int i=dim_diff_;i<dim_-1 ;i++)
112+ {
113+ int num = (d % top[0 ]->count (i)) / top[0 ]->count (i+1 );
114+ int n0 = 1 == bottom[0 ]->shape (i) ? 0 : num;
115+ int n1 = 1 == bottom[1 ]->shape (i-dim_diff_) ? 0 : num;
116+ offset0 += n0 * bottom[0 ]->count (i+1 );
117+ offset1 += n1 * bottom[1 ]->count (i-dim_diff_+1 );
118+ }
119+ int z = d % top[0 ]->shape (dim_-1 );
120+ int z0 = 1 == bottom[0 ]->shape (dim_-1 ) ? 0 : z;
121+ int z1 = 1 == bottom[1 ]->shape (dim_-dim_diff_-1 ) ? 0 : z;
122+ offset0 += z0;
123+ offset1 += z1;
86124 }
87- for ( int i= dim_diff_;i<dim_- 1 ;i++)
125+ else // dim_diff_<0, bottom1 has more axes than bottom0
88126 {
89- int num = (d % top[0 ]->count (i)) / top[0 ]->count (i+1 );
90- int n0 = 1 == bottom[0 ]->shape (i) ? 0 : num;
91- int n1 = 1 == bottom[1 ]->shape (i-dim_diff_) ? 0 : num;
92- offset0 += n0 * bottom[0 ]->count (i+1 );
93- offset1 += n1 * bottom[1 ]->count (i-dim_diff_+1 );
127+ for (int i=0 ;i<-dim_diff_;i++)
128+ {
129+ int num = (d % top[0 ]->count (i)) / top[0 ]->count (i+1 );
130+ int n1 = 1 == bottom[1 ]->shape (i) ? 0 : num;
131+ offset1 += n1 * bottom[1 ]->count (i+1 );
132+ }
133+ for (int i=-dim_diff_;i<dim_-1 ;i++)
134+ {
135+ int num = (d % top[0 ]->count (i)) / top[0 ]->count (i+1 );
136+ int n0 = 1 == bottom[0 ]->shape (i+dim_diff_) ? 0 : num;
137+ int n1 = 1 == bottom[1 ]->shape (i) ? 0 : num;
138+ offset0 += n0 * bottom[0 ]->count (i+dim_diff_+1 );
139+ offset1 += n1 * bottom[1 ]->count (i+1 );
140+ }
141+ int z = d % top[0 ]->shape (dim_-1 );
142+ int z0 = 1 == bottom[0 ]->shape (dim_+dim_diff_-1 ) ? 0 : z;
143+ int z1 = 1 == bottom[1 ]->shape (dim_-1 ) ? 0 : z;
144+ offset0 += z0;
145+ offset1 += z1;
94146 }
95- int z = d % top[0 ]->shape (dim_-1 );
96- int z0 = 1 == bottom[0 ]->shape (dim_-1 ) ? 0 : z;
97- int z1 = 1 == bottom[1 ]->shape (dim_-dim_diff_-1 ) ? 0 : z;
98- offset0 += z0;
99- offset1 += z1;
147+
148+ top_data[d] = bottom0_data[offset0] + bottom1_data[offset1];
100149 }
101- else // dim_diff_<0, bottom1 has more axes than bottom0
150+ }
151+ else // has scalar
152+ {
153+ if (bottom[1 ]->num_axes ()==0 ) // bottom1 is a scalar
102154 {
103- for (int i=0 ;i<-dim_diff_;i++)
104- {
105- int num = (d % top[0 ]->count (i)) / top[0 ]->count (i+1 );
106- int n1 = 1 == bottom[1 ]->shape (i) ? 0 : num;
107- offset1 += n1 * bottom[1 ]->count (i+1 );
108- }
109- for (int i=-dim_diff_;i<dim_-1 ;i++)
110- {
111- int num = (d % top[0 ]->count (i)) / top[0 ]->count (i+1 );
112- int n0 = 1 == bottom[0 ]->shape (i+dim_diff_) ? 0 : num;
113- int n1 = 1 == bottom[1 ]->shape (i) ? 0 : num;
114- offset0 += n0 * bottom[0 ]->count (i+dim_diff_+1 );
115- offset1 += n1 * bottom[1 ]->count (i+1 );
116- }
117- int z = d % top[0 ]->shape (dim_-1 );
118- int z0 = 1 == bottom[0 ]->shape (dim_+dim_diff_-1 ) ? 0 : z;
119- int z1 = 1 == bottom[1 ]->shape (dim_-1 ) ? 0 : z;
120- offset0 += z0;
121- offset1 += z1;
155+ caffe_copy (count, bottom0_data, top_data);
156+ Dtype scalar = bottom1_data[0 ];
157+ caffe_add_scalar (count, scalar, top_data);
158+ }
159+ else // bottom0 is a scalar
160+ {
161+ caffe_copy (count, bottom1_data, top_data);
162+ Dtype scalar = bottom0_data[0 ];
163+ caffe_add_scalar (count, scalar, top_data);
122164 }
123-
124- top_data[d] = bottom0_data[offset0] + bottom1_data[offset1];
125165 }
126-
127166}
128167
129168REGISTER_LAYER_CLASS (Add);
0 commit comments