Skip to content

Commit 90349ee

Browse files
committed
invoke saturate when Concat/ResizeBilinear/Pooling does forward_cpu
1 parent 4b174d8 commit 90349ee

File tree

6 files changed

+44
-3
lines changed

6 files changed

+44
-3
lines changed

include/caffe/layers/concat_layer.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class ConcatLayer : public Layer<Dtype> {
8484
vector<int> input_zero_point_; //CUSTOMIZATION
8585
double output_scale_; //CUSTOMIZATION
8686
int output_zero_point_; //CUSTOMIZATION
87-
//Dtype saturate_; //CUSTOMIZATION
87+
Dtype saturate_; //CUSTOMIZATION
8888
};
8989

9090
} // namespace caffe

include/caffe/layers/resize_bilinear_layer.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class ResizeBilinearLayer : public Layer<Dtype> {
5555
bool pytorch_half_pixel_;
5656
double output_scale_; //CUSTOMIZATION
5757
int output_zero_point_; //CUSTOMIZATION
58+
Dtype saturate_; //CUSTOMIZATION
5859

5960
// Compute the interpolation indices only once.
6061
struct CachedInterpolation {

src/caffe/layers/concat_layer.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ void ConcatLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
2828
}
2929
output_scale_ = concat_param.output_scale();
3030
output_zero_point_ = concat_param.output_zero_point();
31+
saturate_ = concat_param.saturate();
3132
}
3233

3334
template <typename Dtype>
@@ -97,8 +98,12 @@ void ConcatLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
9798
}
9899
offset_concat_axis += bottom_concat_axis;
99100
}
100-
if (is_quant) // CUSTOMIZATION
101-
caffe_cpu_quantize<Dtype>(top[0]->count(), top[0]->mutable_cpu_data(), output_scale_, output_zero_point_);
101+
if (is_quant){ // CUSTOMIZATION
102+
const int count_t = top[0]->count();
103+
top_data = top[0]->mutable_cpu_data();
104+
caffe_cpu_quantize(count_t, top_data, output_scale_, output_zero_point_);
105+
caffe_cpu_saturate(count_t, top_data, saturate_); // if None nothing happens
106+
}
102107

103108
if (is_quant) {
104109
for (int i = 0; i < bottom.size(); ++i) {

src/caffe/layers/pooling_layer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@ void PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
498498
default:
499499
LOG(FATAL) << "Unknown pooling method.";
500500
}
501+
caffe_cpu_saturate(top[0]->count(), top[0]->mutable_cpu_data(), saturate_); // if None nothing happens
501502
}
502503

503504
template <typename Dtype>

src/caffe/layers/resize_bilinear_layer.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ void ResizeBilinearLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
2626
"If pytorch_half_pixel_ is True, half_pixel_centers_ must be False.";
2727
output_scale_ = this->layer_param_.resize_bilinear_param().output_scale(); //CUSTOMIZATION
2828
output_zero_point_ = this->layer_param_.resize_bilinear_param().output_zero_point(); //CUSTOMIZATION
29+
saturate_ = this->layer_param_.resize_bilinear_param().saturate(); //CUSTOMIZATION
2930
}
3031

3132
template <typename Dtype>
@@ -327,6 +328,7 @@ void ResizeBilinearLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
327328
}
328329
}
329330
}
331+
caffe_cpu_saturate(top[0]->count(), top[0]->mutable_cpu_data(), saturate_); // if None nothing happens
330332
}
331333

332334
INSTANTIATE_CLASS(ResizeBilinearLayer);

src/caffe/proto/caffe.proto

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,10 +1348,26 @@ message ConcatParameter {
13481348
// DEPRECATED: alias for "axis" -- does not support negative indexing.
13491349
optional uint32 concat_dim = 1 [default = 1];
13501350

1351+
//<--CUSTOMIZATION
13511352
repeated double input_scale = 3; // CUSTOMIZATION, blob-wise for quantization
13521353
optional double output_scale = 4 [default = 1]; //CUSTOMIZATION
13531354
repeated int32 input_zero_point = 5; //CUSTOMIZATION, blob-wise for quantization
13541355
optional int32 output_zero_point = 6 [default = 0]; //CUSTOMIZATION
1356+
enum SaturateMethod {
1357+
None = 0;
1358+
Signed = 1;
1359+
Unsigned = 2;
1360+
Signed_8bit = 3;
1361+
Unsigned_8bit = 4;
1362+
}
1363+
optional SaturateMethod saturate = 7 [default = None]; //control the output in certain range
1364+
enum QuantizeMethod {
1365+
tflite = 0;
1366+
ONNX = 1;
1367+
Caffe2 = 2;
1368+
}
1369+
optional QuantizeMethod quantize_method = 8 [default = tflite];
1370+
//CUSTOMIZATION-->
13551371
}
13561372

13571373
message ContrastiveLossParameter {
@@ -3286,10 +3302,26 @@ message ResizeBilinearParameter {
32863302
optional bool pytorch_half_pixel = 8 [default = false];
32873303
optional int32 output_depth = 9;
32883304
optional float scale_depth = 10 [default = 1];
3305+
//<--CUSTOMIZATION
32893306
optional double input_scale = 20 [default = 1]; //CUSTOMIZATION
32903307
optional double output_scale = 21 [default = 1]; //CUSTOMIZATION
32913308
optional int32 input_zero_point = 22 [default = 0]; //CUSTOMIZATION
32923309
optional int32 output_zero_point = 23 [default = 0]; //CUSTOMIZATION
3310+
enum SaturateMethod {
3311+
None = 0;
3312+
Signed = 1;
3313+
Unsigned = 2;
3314+
Signed_8bit = 3;
3315+
Unsigned_8bit = 4;
3316+
}
3317+
optional SaturateMethod saturate = 24 [default = None]; //control the output in certain range
3318+
enum QuantizeMethod {
3319+
tflite = 0;
3320+
ONNX = 1;
3321+
Caffe2 = 2;
3322+
}
3323+
optional QuantizeMethod quantize_method = 25 [default = tflite];
3324+
//CUSTOMIZATION-->
32933325
}
32943326

32953327
message ReduceSumParameter {

0 commit comments

Comments
 (0)