Skip to content

Commit caa46d4

Browse files
committed
evquantize: pool signed/unsigned saturate setting
1 parent 83bfbf4 commit caa46d4

File tree

3 files changed

+65
-27
lines changed

3 files changed

+65
-27
lines changed

include/caffe/layers/pooling_layer.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class PoolingLayer : public Layer<Dtype> {
6060
int pad_t_; //CUSTOMIZATION
6161
int pad_b_; //CUSTOMIZATION
6262
int output_shift_instead_division_; //CUSTOMIZATION
63-
bool saturate_; //CUSTOMIZATION
63+
Dtype saturate_; //CUSTOMIZATION
6464
};
6565

6666
} // namespace caffe

src/caffe/layers/pooling_layer.cu

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
#include "caffe/layers/pooling_layer.hpp"
66
#include "caffe/util/math_functions.hpp"
77

8-
#define SATURATE_MAX 4095
9-
#define SATURATE_MIN -4096
8+
#define SIGNED_SATURATE_MAX 2047
9+
#define SIGNED_SATURATE_MIN -2048
10+
#define UNSIGNED_SATURATE_MAX 4095
1011

1112
namespace caffe {
1213

@@ -57,7 +58,7 @@ __global__ void AvePoolForward(const int nthreads,
5758
const int stride_h, const int stride_w,
5859
//const int pad_h, const int pad_w,
5960
const int pad_top, const int pad_left, const int pad_bottom, const int pad_right, //CUSTOMIZATION
60-
Dtype* const top_data, const int output_shift_instead_division, const bool saturate) {
61+
Dtype* const top_data, const int output_shift_instead_division, const Dtype saturate) {
6162
CUDA_KERNEL_LOOP(index, nthreads) {
6263
const int pw = index % pooled_width;
6364
const int ph = (index / pooled_width) % pooled_height;
@@ -89,22 +90,37 @@ __global__ void AvePoolForward(const int nthreads,
8990
if (output_shift_instead_division != Dtype(0)) {
9091
top_data[index] = aveval / output_shift_instead_division;
9192
top_data[index] = rint(top_data[index]);
92-
if(saturate)
93+
if(saturate == PoolingParameter_SaturateMethod_Unsigned)
9394
{
94-
if(top_data[index] > SATURATE_MAX)
95-
top_data[index] = SATURATE_MAX;
96-
if(top_data[index] < SATURATE_MIN)
97-
top_data[index] = SATURATE_MIN;
95+
if(top_data[index] > UNSIGNED_SATURATE_MAX)
96+
top_data[index] = UNSIGNED_SATURATE_MAX;
97+
if(top_data[index] < 0)
98+
top_data[index] = 0;
99+
}
100+
if(saturate == PoolingParameter_SaturateMethod_Signed)
101+
{
102+
if(top_data[index] > SIGNED_SATURATE_MAX)
103+
top_data[index] = SIGNED_SATURATE_MAX;
104+
if(top_data[index] < SIGNED_SATURATE_MIN)
105+
top_data[index] = SIGNED_SATURATE_MIN;
98106
}
99107
}
100108
else{
101-
if(saturate)
109+
if(saturate == PoolingParameter_SaturateMethod_Unsigned)
110+
{
111+
top_data[index] = aveval;
112+
if(top_data[index] > UNSIGNED_SATURATE_MAX)
113+
top_data[index] = UNSIGNED_SATURATE_MAX;
114+
if(top_data[index] < 0)
115+
top_data[index] = 0;
116+
}
117+
else if(saturate == PoolingParameter_SaturateMethod_Signed)
102118
{
103-
top_data[index] = aveval;
104-
if(top_data[index] > SATURATE_MAX)
105-
top_data[index] = SATURATE_MAX;
106-
if(top_data[index] < SATURATE_MIN)
107-
top_data[index] = SATURATE_MIN;
119+
top_data[index] = aveval;
120+
if(top_data[index] > SIGNED_SATURATE_MAX)
121+
top_data[index] = SIGNED_SATURATE_MAX;
122+
if(top_data[index] < SIGNED_SATURATE_MIN)
123+
top_data[index] = SIGNED_SATURATE_MIN;
108124
}
109125
else //original implementation
110126
top_data[index] = aveval / pool_size;
@@ -121,7 +137,7 @@ __global__ void AvePoolForward_TF(const int nthreads,
121137
const int stride_h, const int stride_w,
122138
//const int pad_h, const int pad_w,
123139
const int pad_top, const int pad_left, const int pad_bottom, const int pad_right, //CUSTOMI
124-
Dtype* const top_data, const int output_shift_instead_division, const bool saturate) {
140+
Dtype* const top_data, const int output_shift_instead_division, const Dtype saturate) {
125141
CUDA_KERNEL_LOOP(index, nthreads) {
126142
const int pw = index % pooled_width;
127143
const int ph = (index / pooled_width) % pooled_height;
@@ -225,23 +241,38 @@ __global__ void AvePoolForward_TF(const int nthreads,
225241
top_data[index] = aveval / output_shift_instead_division * full_pool_size / pool_size;
226242
}
227243
top_data[index] = rint(top_data[index]);
228-
if(saturate)
244+
if(saturate == PoolingParameter_SaturateMethod_Unsigned)
229245
{
230-
if(top_data[index] > SATURATE_MAX)
231-
top_data[index] = SATURATE_MAX;
232-
if(top_data[index] < SATURATE_MIN)
233-
top_data[index] = SATURATE_MIN;
246+
if(top_data[index] > UNSIGNED_SATURATE_MAX)
247+
top_data[index] = UNSIGNED_SATURATE_MAX;
248+
if(top_data[index] < 0)
249+
top_data[index] = 0;
250+
}
251+
if(saturate == PoolingParameter_SaturateMethod_Signed)
252+
{
253+
if(top_data[index] > SIGNED_SATURATE_MAX)
254+
top_data[index] = SIGNED_SATURATE_MAX;
255+
if(top_data[index] < SIGNED_SATURATE_MIN)
256+
top_data[index] = SIGNED_SATURATE_MIN;
234257
}
235258
}
236259

237260
else{
238-
if(saturate)
261+
if(saturate == PoolingParameter_SaturateMethod_Unsigned)
239262
{
240263
top_data[index] = aveval;
241-
if(top_data[index] > SATURATE_MAX)
242-
top_data[index] = SATURATE_MAX;
243-
if(top_data[index] < SATURATE_MIN)
244-
top_data[index] = SATURATE_MIN;
264+
if(top_data[index] > UNSIGNED_SATURATE_MAX)
265+
top_data[index] = UNSIGNED_SATURATE_MAX;
266+
if(top_data[index] < 0)
267+
top_data[index] = 0;
268+
}
269+
else if(saturate == PoolingParameter_SaturateMethod_Signed)
270+
{
271+
top_data[index] = aveval;
272+
if(top_data[index] > SIGNED_SATURATE_MAX)
273+
top_data[index] = SIGNED_SATURATE_MAX;
274+
if(top_data[index] < SIGNED_SATURATE_MIN)
275+
top_data[index] = SIGNED_SATURATE_MIN;
245276
}
246277
else //original implementation
247278
top_data[index] = aveval / pool_size;

src/caffe/proto/caffe.proto

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1979,7 +1979,14 @@ message PoolingParameter {
19791979
optional uint32 stride_h = 7; // The stride height
19801980
optional uint32 stride_w = 8; // The stride width
19811981
optional uint32 output_shift_instead_division = 21 [default = 0]; //CUSTOMIZATION, only valid for Average pooling
1982-
optional bool saturate = 22 [default = false]; //CUSTOMIZATION, control the output in range [-4096, 4095]
1982+
//<--CUSTOMIZATION
1983+
enum SaturateMethod {
1984+
None = 0;
1985+
Signed = 1;
1986+
Unsigned = 2;
1987+
}
1988+
optional SaturateMethod saturate = 22 [default = None]; //control the output in certain range
1989+
//CUSTOMIZATION-->
19831990
enum Engine {
19841991
DEFAULT = 0;
19851992
CAFFE = 1;

0 commit comments

Comments
 (0)