Skip to content

Commit c8e6052

Browse files
committed
add the broadcast_to layer
1 parent f2addff commit c8e6052

File tree

4 files changed

+136
-1
lines changed

4 files changed

+136
-1
lines changed

FEATURES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ evconvert (TensorFlow/ONNX/... to Caffe Converter) related
88
----------------------------------------------------------
99
atan_layer
1010
batch_to_space_nd_layer
11+
broadcast_to_layer
1112
crop_and_resize
1213
depth_to_space_layer
1314
div_layer
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#ifndef CAFFE_BROADCAST_TO_LAYER_HPP_
2+
#define CAFFE_BROADCAST_TO_LAYER_HPP_
3+
4+
#include <vector>
5+
6+
#include "caffe/blob.hpp"
7+
#include "caffe/layer.hpp"
8+
#include "caffe/proto/caffe.pb.h"
9+
10+
namespace caffe {
11+
12+
/**
13+
* @brief Broadcast an array for a compatible shape.
14+
* Dimensions are right alignment;
15+
* Two corresponding dimension must have the same value, or one of them is equal to 1.
16+
*/
17+
template <typename Dtype>
18+
class BroadcastToLayer : public Layer<Dtype> {
19+
public:
20+
explicit BroadcastToLayer(const LayerParameter& param)
21+
: Layer<Dtype>(param) {}
22+
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
23+
const vector<Blob<Dtype>*>& top);
24+
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
25+
const vector<Blob<Dtype>*>& top);
26+
27+
virtual inline const char* type() const { return "BroadcastTo"; }
28+
virtual inline int ExactNumBottomBlobs() const { return 1; }
29+
virtual inline int ExactNumTopBlobs() const { return 1; }
30+
31+
protected:
32+
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
33+
const vector<Blob<Dtype>*>& top);
34+
//virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
35+
// const vector<Blob<Dtype>*>& top);
36+
37+
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
38+
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
39+
//virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
40+
// const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
41+
42+
vector<int> output_shape_;
43+
};
44+
45+
} // namespace caffe
46+
47+
#endif // CAFFE_BROADCAST_TO_LAYER_HPP_
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#include <vector>
2+
3+
#include "caffe/layers/broadcast_to_layer.hpp"
4+
#include "caffe/util/math_functions.hpp"
5+
6+
namespace caffe {
7+
8+
template <typename Dtype>
9+
void BroadcastToLayer<Dtype>::LayerSetUp(
10+
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
11+
const BroadcastToParameter& broadcast_to_param = this->layer_param_.broadcast_to_param();
12+
output_shape_.clear();
13+
std::copy(broadcast_to_param.shape().begin(),
14+
broadcast_to_param.shape().end(),
15+
std::back_inserter(output_shape_));
16+
17+
CHECK_GE(output_shape_.size(), bottom[0]->num_axes()) << "Output shape should not have less axis than input!";
18+
int dim_diff = output_shape_.size() - bottom[0]->num_axes();
19+
for(int i=output_shape_.size()-1; i>=dim_diff; i--)
20+
{
21+
CHECK_GT(output_shape_[i], 0) << "Values in output shape must be positive!";
22+
CHECK(output_shape_[i]==bottom[0]->shape(i-dim_diff) || bottom[0]->shape(i-dim_diff)==1)
23+
<< "The broadcasting shape is incompatible with the input!";
24+
}
25+
}
26+
27+
28+
template <typename Dtype>
29+
void BroadcastToLayer<Dtype>::Reshape(
30+
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
31+
top[0]->Reshape(output_shape_);
32+
}
33+
34+
template <typename Dtype>
35+
void BroadcastToLayer<Dtype>::Forward_cpu(
36+
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
37+
const Dtype* bottom_data = bottom[0]->cpu_data();
38+
Dtype* top_data = top[0]->mutable_cpu_data();
39+
int count = top[0]->count();
40+
41+
int dim = top[0]->num_axes();
42+
int dim_diff = output_shape_.size() - bottom[0]->num_axes();
43+
// Assume top index (x,y,z) with top shape (A, B, C)
44+
// top offset d = xBC + yC + z
45+
// So to count the bottom index, should first figure out x, y, z
46+
// x = d / BC
47+
// y = (d % BC) / C
48+
// z = d % C
49+
// Then consider bottom shape (A', B', C'), where A' = 1 or A
50+
// So bottom offset = x'B'C' + y'C' + z
51+
for(int d=0; d<count; d++)
52+
{
53+
int offset = 0;
54+
55+
for(int i=dim_diff;i<dim-1;i++)
56+
{
57+
int num = (d % top[0]->count(i)) / top[0]->count(i+1);
58+
int n0 = 1 == bottom[0]->shape(i-dim_diff) ? 0 : num;
59+
offset += n0 * bottom[0]->count(i-dim_diff+1);
60+
}
61+
int z = d % top[0]->shape(dim-1);
62+
int z0 = 1 == bottom[0]->shape(dim-dim_diff-1) ? 0 : z;
63+
offset += z0;
64+
65+
top_data[d] = bottom_data[offset];
66+
}
67+
}
68+
69+
template <typename Dtype>
70+
void BroadcastToLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
71+
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
72+
NOT_IMPLEMENTED;
73+
}
74+
75+
//#ifdef CPU_ONLY
76+
//STUB_GPU(TileLayer);
77+
//#endif
78+
79+
INSTANTIATE_CLASS(BroadcastToLayer);
80+
REGISTER_LAYER_CLASS(BroadcastTo);
81+
82+
} // namespace caffe

src/caffe/proto/caffe.proto

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ message ParamSpec {
460460
// NOTE
461461
// Update the next available ID when you add a new LayerParameter field.
462462
//
463-
// LayerParameter next available layer-specific ID: 242 (last added: squeeze_param)
463+
// LayerParameter next available layer-specific ID: 243 (last added: broadcast_to_param)
464464
message LayerParameter {
465465
optional string name = 1; // the layer name
466466
optional string type = 2; // the layer type
@@ -645,6 +645,7 @@ message LayerParameter {
645645
optional TileNDParameter tile_nd_param = 239;
646646
optional ExpandDimsNDParameter expand_dims_nd_param = 240;
647647
optional SqueezeParameter squeeze_param = 241;
648+
optional BroadcastToParameter broadcast_to_param = 242;
648649
}
649650

650651
message AccumParameter {
@@ -3075,3 +3076,7 @@ message SqueezeParameter {
30753076
repeated int32 axis = 1;
30763077
}
30773078

3079+
message BroadcastToParameter {
3080+
//The shape of the desired output.
3081+
repeated uint32 shape = 1;
3082+
}

0 commit comments

Comments
 (0)