Skip to content

Commit 94e69d8

Browse files
committed
add spatial_batching_pooling_layer
1 parent 48deb77 commit 94e69d8

File tree

3 files changed

+524
-0
lines changed

3 files changed

+524
-0
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#ifndef CAFFE_SPATIAL_BATCHING_POOLING_LAYER_HPP_
2+
#define CAFFE_SPATIAL_BATCHING_POOLING_LAYER_HPP_
3+
4+
#include <vector>
5+
6+
#include "caffe/blob.hpp"
7+
#include "caffe/layer.hpp"
8+
#include "caffe/proto/caffe.pb.h"
9+
10+
namespace caffe {
11+
12+
/**
13+
* @brief Pools the input image by taking the max, average, etc. within regions.
14+
*
15+
* TODO(dox): thorough documentation for Forward, Backward, and proto params.
16+
*/
17+
template <typename Dtype>
18+
class SpatialBatchingPoolingLayer : public Layer<Dtype> {
19+
public:
20+
explicit SpatialBatchingPoolingLayer(const LayerParameter &param)
21+
: Layer<Dtype>(param) {}
22+
virtual void LayerSetUp(const vector<Blob<Dtype> *> &bottom,
23+
const vector<Blob<Dtype> *> &top);
24+
virtual void Reshape(const vector<Blob<Dtype> *> &bottom,
25+
const vector<Blob<Dtype> *> &top);
26+
27+
virtual inline const char *type() const { return "SpatialBatchingPooling"; }
28+
virtual inline int ExactNumBottomBlobs() const { return 1; }
29+
virtual inline int MinTopBlobs() const { return 1; }
30+
// MAX POOL layers can output an extra top blob for the mask;
31+
// others can only output the pooled inputs.
32+
virtual inline int MaxTopBlobs() const {
33+
return (this->layer_param_.spatial_batching_pooling_param().pool() ==
34+
SpatialBatchingPoolingParameter_PoolMethod_MAX)
35+
? 2
36+
: 1;
37+
}
38+
39+
protected:
40+
virtual void Forward_cpu(const vector<Blob<Dtype> *> &bottom,
41+
const vector<Blob<Dtype> *> &top);
42+
/// @brief Not implemented
43+
virtual void Backward_cpu(const vector<Blob<Dtype> *> &top,
44+
const vector<bool> &propagate_down,
45+
const vector<Blob<Dtype> *> &bottom) {
46+
NOT_IMPLEMENTED;
47+
}
48+
49+
int kernel_h_, kernel_w_;
50+
int stride_h_, stride_w_;
51+
int pad_h_, pad_w_;
52+
int channels_;
53+
int height_, width_;
54+
int pooled_height_, pooled_width_;
55+
bool global_pooling_;
56+
Blob<Dtype> rand_idx_;
57+
Blob<int> max_idx_;
58+
bool ceil_mode_;
59+
int pad_l_; // CUSTOMIZATION
60+
int pad_r_; // CUSTOMIZATION
61+
int pad_t_; // CUSTOMIZATION
62+
int pad_b_; // CUSTOMIZATION
63+
Dtype saturate_; // CUSTOMIZATION
64+
int spatial_batching_h_;
65+
int spatial_batching_w_;
66+
int skip_h_, skip_w_;
67+
int batch_h_, batch_w_;
68+
int gap_h_, gap_w_;
69+
int pooled_batch_h_;
70+
int pooled_batch_w_;
71+
int pooled_gap_h_;
72+
int pooled_gap_w_;
73+
int s_pooled_height_, s_pooled_width_;
74+
};
75+
76+
} // namespace caffe
77+
78+
#endif // CAFFE_SPATIAL_BATCHING_POOLING_LAYER_HPP_

0 commit comments

Comments
 (0)