@@ -336,27 +336,45 @@ void ComputePropagateScalesMkldnnPass::ComputeWeightScales(
336
336
ComputeLstmWeightScales (graph, scope, " WeightX" , " WeightH" , var_quant_scales);
337
337
}
338
338
339
- void ComputePropagateScalesMkldnnPass::UpdateScaleOpInScale (
339
+ void ComputePropagateScalesMkldnnPass::UpdateScaleOpInOutScales (
340
340
Node* op_node,
341
341
const std::string& input_name,
342
342
const std::string& output_name,
343
343
StringPairMap* var_quant_scales) const {
344
- auto iter = var_quant_scales->find (output_name);
345
- if (iter != var_quant_scales->end ()) {
346
- auto pair = iter->second ;
347
- const auto tensor = pair.second ;
348
-
349
- const auto scale = PADDLE_GET_CONST (float , op_node->Op ()->GetAttr (" scale" ));
350
- Tensor tmp_tensor;
351
- tmp_tensor.Resize (tensor.dims ());
352
- auto * data = tmp_tensor.mutable_data <float >(platform::CPUPlace ());
353
- for (int i = 0 ; i < tensor.numel (); i++) {
354
- data[i] = data[i] * scale;
355
- }
344
+ auto out_iter = var_quant_scales->find (output_name);
345
+ auto input_iter = var_quant_scales->find (input_name);
346
+ // All the input and output have scales
347
+ if (out_iter != var_quant_scales->end () &&
348
+ input_iter != var_quant_scales->end ()) {
349
+ return ;
350
+ }
351
+ const auto scale = PADDLE_GET_CONST (float , op_node->Op ()->GetAttr (" scale" ));
352
+ if (std::abs (scale) < 1e-6 && out_iter != var_quant_scales->end ()) {
353
+ return ;
354
+ }
355
+
356
+ std::string name = input_name;
357
+ auto iter = out_iter;
358
+ if (input_iter != var_quant_scales->end ()) {
359
+ iter = input_iter;
360
+ name = output_name;
361
+ }
356
362
357
- auto new_pair = std::make_pair (pair.first , tmp_tensor);
358
- var_quant_scales->insert (std::make_pair (input_name, new_pair));
363
+ phi::DenseTensor tmp_tensor;
364
+ auto pair = iter->second ;
365
+ const auto tensor = pair.second ;
366
+ tmp_tensor.Resize (tensor.dims ());
367
+ auto * data = tmp_tensor.mutable_data <float >(platform::CPUPlace ());
368
+ auto * src_data = tensor.data <float >();
369
+ for (int i = 0 ; i < tensor.numel (); i++) {
370
+ if (out_iter != var_quant_scales->end ()) {
371
+ data[i] = src_data[i] / scale;
372
+ } else {
373
+ data[i] = src_data[i] * scale;
374
+ }
359
375
}
376
+ auto new_pair = std::make_pair (pair.first , tmp_tensor);
377
+ var_quant_scales->insert (std::make_pair (name, new_pair));
360
378
}
361
379
362
380
std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales (
@@ -403,10 +421,12 @@ std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales(
403
421
}
404
422
} else if (op_name == " scale" ) {
405
423
const std::string output_name = op_node->Op ()->Output (" Out" )[0 ];
424
+ const std::string input_name = op_node->Op ()->Input (" X" )[0 ];
406
425
auto out_iter = var_quant_scales->find (output_name);
407
- if (out_iter != var_quant_scales->end ()) {
408
- const std::string input_name = op_node->Op ()->Input (" X" )[0 ];
409
- UpdateScaleOpInScale (
426
+ auto input_iter = var_quant_scales->find (input_name);
427
+ if (out_iter != var_quant_scales->end () ||
428
+ input_iter != var_quant_scales->end ()) {
429
+ UpdateScaleOpInOutScales (
410
430
op_node, input_name, output_name, var_quant_scales);
411
431
}
412
432
}
0 commit comments