@@ -81,6 +81,30 @@ struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
81
81
82
82
template struct FindRangeAbsMaxFunctor <platform::CPUDeviceContext, float >;
83
83
84
+ template <typename T>
85
+ struct FindMovingAverageAbsMaxFunctor <platform::CPUDeviceContext, T> {
86
+ void operator ()(const platform::CPUDeviceContext& ctx,
87
+ const framework::Tensor& in_accum,
88
+ const framework::Tensor& in_state, const T* cur_scale,
89
+ const float rate, framework::Tensor* out_state,
90
+ framework::Tensor* out_accum, framework::Tensor* out_scale) {
91
+ T accum = in_accum.data <T>()[0 ];
92
+ T state = in_state.data <T>()[0 ];
93
+ T scale = cur_scale[0 ];
94
+
95
+ state = rate * state + 1 ;
96
+ accum = rate * accum + scale;
97
+ scale = accum / state;
98
+
99
+ out_state->mutable_data <T>(ctx.GetPlace ())[0 ] = state;
100
+ out_accum->mutable_data <T>(ctx.GetPlace ())[0 ] = accum;
101
+ out_scale->mutable_data <T>(ctx.GetPlace ())[0 ] = scale;
102
+ }
103
+ };
104
+
105
+ template struct FindMovingAverageAbsMaxFunctor <platform::CPUDeviceContext,
106
+ float >;
107
+
84
108
class FakeQuantizeAbsMaxOp : public framework ::OperatorWithKernel {
85
109
public:
86
110
FakeQuantizeAbsMaxOp (const std::string& type,
@@ -255,6 +279,78 @@ FakeQuantize operator is used in static quantization.
255
279
}
256
280
};
257
281
282
+ class FakeQuantizeMovingAverageAbsMaxOp : public framework ::OperatorWithKernel {
283
+ public:
284
+ FakeQuantizeMovingAverageAbsMaxOp (const std::string& type,
285
+ const framework::VariableNameMap& inputs,
286
+ const framework::VariableNameMap& outputs,
287
+ const framework::AttributeMap& attrs)
288
+ : OperatorWithKernel(type, inputs, outputs, attrs) {}
289
+
290
+ void InferShape (framework::InferShapeContext* ctx) const override {
291
+ PADDLE_ENFORCE (
292
+ ctx->HasInput (" X" ),
293
+ " Input(X) of FakeQuantizeMovingAverageAbsMaxOp should not be null." );
294
+ PADDLE_ENFORCE (
295
+ ctx->HasOutput (" Out" ),
296
+ " Output(Out) of FakeQuantizeMovingAverageAbsMaxOp should not be null." );
297
+ PADDLE_ENFORCE (ctx->HasOutput (" OutScale" ),
298
+ " Output(OutScale) of FakeQuantizeMovingAverageAbsMaxOp "
299
+ " should not be null" );
300
+ if (ctx->HasOutput (" OutState" )) {
301
+ ctx->SetOutputDim (" OutState" , {1 });
302
+ }
303
+ if (ctx->HasOutput (" OutAccum" )) {
304
+ ctx->SetOutputDim (" OutAccum" , {1 });
305
+ }
306
+ ctx->SetOutputDim (" Out" , ctx->GetInputDim (" X" ));
307
+ ctx->SetOutputDim (" OutScale" , {1 });
308
+ ctx->ShareLoD (" X" , /* ->*/ " Out" );
309
+ }
310
+
311
+ protected:
312
+ framework::OpKernelType GetExpectedKernelType (
313
+ const framework::ExecutionContext& ctx) const override {
314
+ return framework::OpKernelType (ctx.Input <framework::LoDTensor>(" X" )->type (),
315
+ ctx.device_context ());
316
+ }
317
+ };
318
+
319
+ class FakeQuantizeMovingAverageAbsMaxOpMaker
320
+ : public framework::OpProtoAndCheckerMaker {
321
+ public:
322
+ void Make () override {
323
+ AddInput (" X" , " (Tensor) Input is float data type." );
324
+ AddInput (" InScale" , " Last scale." );
325
+ AddInput (" InAccum" , " Last accum." ).AsDispensable ();
326
+ AddInput (" InState" , " Last state." ).AsDispensable ();
327
+ AddOutput (" Out" , " (Tensor) Output of quantized low level tensor." );
328
+ AddOutput (" OutScale" , " Current scale" );
329
+ AddOutput (" OutState" , " (Tensor) state buffer." ).AsDispensable ();
330
+ AddOutput (" OutAccum" , " (Tensor) accum buffer." ).AsDispensable ();
331
+ AddAttr<float >(" moving_rate" , " (float, default 0.9) moving rate." )
332
+ .SetDefault (0.9 );
333
+ AddAttr<int >(" bit_length" , " (int, default 8), quantization bit number." )
334
+ .SetDefault (8 )
335
+ .AddCustomChecker ([](const int & bit_length) {
336
+ PADDLE_ENFORCE (bit_length >= 1 && bit_length <= 16 ,
337
+ " 'bit_length' should be between 1 and 16." );
338
+ });
339
+ AddAttr<bool >(" is_test" ,
340
+ " (bool, default false) Set to true for inference only, false "
341
+ " for training. Some layers may run faster when this is true." )
342
+ .SetDefault (false );
343
+ AddComment (R"DOC(
344
+ FakeQuantize operator is used in static quantization.
345
+
346
+ $$scale = (0.9*max(abs(x))+accum)/(0.9*state+1)$$
347
+ $$range = 2^{bit_length - 1} - 1$$
348
+ $$Out = round(X/scale * range)$$
349
+
350
+ )DOC" );
351
+ }
352
+ };
353
+
258
354
} // namespace operators
259
355
} // namespace paddle
260
356
@@ -273,6 +369,12 @@ REGISTER_OPERATOR(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp,
273
369
REGISTER_OP_CPU_KERNEL (fake_quantize_range_abs_max,
274
370
ops::FakeQuantizeRangeAbsMaxKernel<CPU, float >);
275
371
372
+ REGISTER_OPERATOR (fake_quantize_moving_average_abs_max,
373
+ ops::FakeQuantizeMovingAverageAbsMaxOp,
374
+ ops::FakeQuantizeMovingAverageAbsMaxOpMaker,
375
+ paddle::framework::EmptyGradOpMaker);
376
+ REGISTER_OP_CPU_KERNEL (fake_quantize_moving_average_abs_max,
377
+ ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float >);
276
378
REGISTER_OPERATOR (fake_channel_wise_quantize_abs_max,
277
379
ops::FakeChannelWiseQuantizeAbsMaxOp,
278
380
ops::FakeChannelWiseQuantizeAbsMaxOpMaker,
0 commit comments