@@ -78,11 +78,11 @@ inline bool VarNodeHasDtype(Node* var_node) {
78
78
(type == VarType::VOCAB);
79
79
}
80
80
81
- inline bool IsFloatType (VarType::Type type) {
81
+ inline bool IsFP32AndFP64 (VarType::Type type) {
82
82
return (type == VarType::FP64) || (type == VarType::FP32);
83
83
}
84
84
85
- inline bool IsHalfType (VarType::Type type) {
85
+ inline bool IsFP16AndBFP16 (VarType::Type type) {
86
86
return (type == VarType::FP16) || (type == VarType::BF16);
87
87
}
88
88
@@ -159,26 +159,16 @@ bool OpSupportPrecision(const std::string& op_type,
159
159
// The set of ops that support fp16 calculation and are considered
160
160
// numerically-dangerous, slower and whose effects may also be observed in
161
161
// downstream ops.
162
+ // ref to python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
162
163
void AutoMixedPrecisionPass::SetDefaultBlacklist () const {
163
164
black_list_.insert ({
164
165
// numerically-dangerous
165
- " acos" ,
166
- " asin" ,
167
- " cosh" ,
168
- " tan" ,
169
166
" exp" ,
170
- " expm1" ,
171
167
" square" ,
172
168
" log" ,
173
- " log2" ,
174
- " log10" ,
175
- " log1p" ,
176
- " logsumexp" ,
177
169
" mean" ,
178
- " rsqrt" ,
179
170
" sum" ,
180
171
" cos_sim" ,
181
- " softmax" ,
182
172
" softmax_with_cross_entropy" ,
183
173
" sigmoid_cross_entropy_with_logits" ,
184
174
" c_softmax_with_cross_entropy" ,
@@ -272,6 +262,9 @@ void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const {
272
262
VLOG (4 ) << " InsertCastOp done" ;
273
263
RestoreOpOriginType ();
274
264
VLOG (4 ) << " RestoreOpOriginType done" ;
265
+ LOG (INFO) << " The number of ops run at low precision ["
266
+ << op_run_low_precision_.size () << " /" << op_original_type_.size ()
267
+ << " ]" ;
275
268
}
276
269
277
270
void AutoMixedPrecisionPass::SetOpUniqueType () const {
@@ -315,22 +308,36 @@ void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const {
315
308
for (const auto & nodes : all_op_nodes_) {
316
309
for (auto * op_node : nodes) {
317
310
auto op_type = op_node->Op ()->Type ();
311
+
312
+ if (op_node->Op ()->HasAttr (" in_dtype" )) {
313
+ auto * var_node = op_node->inputs [0 ];
314
+ auto * real_var_node = real_vars_[var_node->Var ()->Name ()];
315
+ if (IsFP16AndBFP16 (real_var_node->Var ()->GetDataType ())) {
316
+ op_node->Op ()->SetAttr (
317
+ " in_dtype" ,
318
+ static_cast <int >(framework::TransToProtoVarType (low_precision_)));
319
+ op_node->Op ()->Flush ();
320
+ VLOG (4 ) << " process op with in_dtype attr: " << op_type << " ( "
321
+ << static_cast <int >(real_var_node->Var ()->GetDataType ())
322
+ << " --->" << static_cast <int >(low_precision_) << " )" ;
323
+ }
324
+ }
325
+
318
326
if (op_run_low_precision_.count (op_type) == 0 ) continue ;
319
327
320
328
if (op_node->Op ()->HasAttr (" dtype" )) {
321
329
auto dtype = op_node->Op ()->GetAttrIfExists <int >(" dtype" );
322
- if (IsFloatType (static_cast <VarType::Type>(dtype))) {
330
+ if (IsFP32AndFP64 (static_cast <VarType::Type>(dtype))) {
323
331
op_node->Op ()->SetAttr (
324
332
" dtype" ,
325
333
static_cast <int >(framework::TransToProtoVarType (low_precision_)));
326
334
op_node->Op ()->Flush ();
327
335
VLOG (4 ) << " process op with dtype attr: " << op_type << " ( " << dtype
328
336
<< " --->" << static_cast <int >(low_precision_) << " )" ;
329
337
}
330
- }
331
- if (op_node->Op ()->HasAttr (" out_dtype" )) {
338
+ } else if (op_node->Op ()->HasAttr (" out_dtype" )) {
332
339
auto out_dtype = op_node->Op ()->GetAttrIfExists <int >(" out_dtype" );
333
- if (IsFloatType (static_cast <VarType::Type>(out_dtype))) {
340
+ if (IsFP32AndFP64 (static_cast <VarType::Type>(out_dtype))) {
334
341
op_node->Op ()->SetAttr (
335
342
" out_dtype" ,
336
343
static_cast <int >(framework::TransToProtoVarType (low_precision_)));
@@ -359,37 +366,55 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
359
366
360
367
if (op_node->Op ()->HasAttr (" dtype" )) {
361
368
auto dtype = op_node->Op ()->GetAttrIfExists <int >(" dtype" );
362
- support_low_precision = support_low_precision &&
363
- IsFloatType (static_cast <VarType::Type>(dtype));
369
+ support_low_precision =
370
+ support_low_precision &&
371
+ IsFP32AndFP64 (static_cast <VarType::Type>(dtype));
364
372
} else if (op_node->Op ()->HasAttr (" out_dtype" )) {
365
373
auto out_dtype = op_node->Op ()->GetAttrIfExists <int >(" out_dtype" );
366
374
support_low_precision =
367
375
support_low_precision &&
368
- IsFloatType (static_cast <VarType::Type>(out_dtype));
369
- } else {
370
- // if op's input var and output var is not dense tensor, the op should
371
- // not run at low precision.
372
- for (auto * in_var_node : op_node->inputs ) {
373
- CHECK_EQ (in_var_node->IsVar (), true );
374
- auto * real_in_var_node = real_vars_[in_var_node->Var ()->Name ()];
375
- if (real_in_var_node->Var ()->Persistable ()) continue ;
376
+ IsFP32AndFP64 (static_cast <VarType::Type>(out_dtype));
377
+ }
376
378
379
+ // If scale op's "scale" and "bias" attr value exceed the range of fp16
380
+ // and bf16, it cannot run at low precision.
381
+ if (GetOpOriginalType (op_node->Op ()->Type ()) == " scale" ) {
382
+ auto scale = op_node->Op ()->GetAttrIfExists <float >(" scale" );
383
+ auto bias = op_node->Op ()->GetAttrIfExists <float >(" bias" );
384
+ if (low_precision_ == phi::DataType::FLOAT16) {
377
385
support_low_precision =
378
386
support_low_precision &&
379
- (real_in_var_node->Var ()->GetType () == VarType::LOD_TENSOR);
380
- }
381
-
382
- for (auto * out_var_node : op_node->outputs ) {
383
- CHECK_EQ (out_var_node->IsVar (), true );
384
- auto * real_out_var_node = real_vars_[out_var_node->Var ()->Name ()];
385
- if (real_out_var_node->Var ()->Persistable ()) continue ;
386
-
387
+ phi::dtype::isfinite (static_cast <phi::dtype::float16>(scale)) &&
388
+ phi::dtype::isfinite (static_cast <phi::dtype::float16>(bias));
389
+ } else if (low_precision_ == phi::DataType::BFLOAT16) {
387
390
support_low_precision =
388
391
support_low_precision &&
389
- (real_out_var_node->Var ()->GetType () == VarType::LOD_TENSOR);
392
+ phi::dtype::isfinite (static_cast <phi::dtype::bfloat16>(scale)) &&
393
+ phi::dtype::isfinite (static_cast <phi::dtype::bfloat16>(bias));
390
394
}
391
395
}
392
396
397
+ // if op's input var and output var is not dense tensor, the op should
398
+ // not run at low precision.
399
+ for (auto * in_var_node : op_node->inputs ) {
400
+ CHECK_EQ (in_var_node->IsVar (), true );
401
+ auto * real_in_var_node = real_vars_[in_var_node->Var ()->Name ()];
402
+ if (real_in_var_node->Var ()->Persistable ()) continue ;
403
+
404
+ support_low_precision =
405
+ support_low_precision &&
406
+ (real_in_var_node->Var ()->GetType () == VarType::LOD_TENSOR);
407
+ }
408
+ for (auto * out_var_node : op_node->outputs ) {
409
+ CHECK_EQ (out_var_node->IsVar (), true );
410
+ auto * real_out_var_node = real_vars_[out_var_node->Var ()->Name ()];
411
+ if (real_out_var_node->Var ()->Persistable ()) continue ;
412
+
413
+ support_low_precision =
414
+ support_low_precision &&
415
+ (real_out_var_node->Var ()->GetType () == VarType::LOD_TENSOR);
416
+ }
417
+
393
418
if (support_low_precision) {
394
419
op_run_low_precision_.insert (op_type);
395
420
VLOG (4 ) << " support precision: " << op_type << " run at low precision" ;
@@ -439,7 +464,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
439
464
}
440
465
441
466
// when op_1 only support cpu kernel. if op_2's intput var is op_1's
442
- // output var, then op_2 should not run half .
467
+ // output var, then op_2 should not run at low precision .
443
468
if (GetOpOriginalType (op_type) != " feed" &&
444
469
!GpuKernelSupportPrecision (GetOpOriginalType (op_type),
445
470
phi::DataType::FLOAT32)) {
@@ -597,7 +622,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const {
597
622
auto * real_in_var_node = real_vars_[in_var_node->Var ()->Name ()];
598
623
auto in_var_name = real_in_var_node->Var ()->Name ();
599
624
600
- if (!IsFloatType (real_in_var_node->Var ()->GetDataType ())) continue ;
625
+ if (!IsFP32AndFP64 (real_in_var_node->Var ()->GetDataType ())) continue ;
601
626
if (!VarNodeHasDtype (real_in_var_node)) continue ;
602
627
if (InputVarsNotConvert (op_node, in_var_name)) continue ;
603
628
@@ -616,7 +641,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const {
616
641
auto * real_out_var_node = real_vars_[out_var_node->Var ()->Name ()];
617
642
auto out_var_name = real_out_var_node->Var ()->Name ();
618
643
619
- if (!IsFloatType (real_out_var_node->Var ()->GetDataType ())) continue ;
644
+ if (!IsFP32AndFP64 (real_out_var_node->Var ()->GetDataType ())) continue ;
620
645
if (!VarNodeHasDtype (real_out_var_node)) continue ;
621
646
if (OutputVarsNotConvert (op_node, out_var_name)) continue ;
622
647
@@ -656,7 +681,7 @@ void AutoMixedPrecisionPass::ConvertWeightsData() const {
656
681
auto var_names = scope->LocalVarNames ();
657
682
for (const auto & var_name : var_names) {
658
683
if (vars_convert_to_low_precision_.count (var_name)) {
659
- VLOG (4 ) << var_name << " 's data type was convert to half " ;
684
+ VLOG (4 ) << var_name << " 's data type was convert to low precision " ;
660
685
661
686
auto * var = scope->FindLocalVar (var_name);
662
687
CHECK_EQ (var->IsType <phi::DenseTensor>(), true );
@@ -683,16 +708,18 @@ void AutoMixedPrecisionPass::ConvertWeightsData() const {
683
708
}
684
709
}
685
710
} else if (low_precision_ == phi::DataType::BFLOAT16) {
686
- auto * half_data =
711
+ auto * low_precision_data =
687
712
low_precision_tensor.mutable_data <phi::dtype::bfloat16>(
688
713
phi::CPUPlace{});
689
714
for (int64_t i = 0 ; i < origin_tensor->numel (); i++) {
690
715
if (origin_tensor->dtype () == phi::DataType::FLOAT64) {
691
716
auto * origin_data = origin_tensor->data <double >();
692
- half_data[i] = static_cast <phi::dtype::bfloat16>(origin_data[i]);
717
+ low_precision_data[i] =
718
+ static_cast <phi::dtype::bfloat16>(origin_data[i]);
693
719
} else if (origin_tensor->dtype () == phi::DataType::FLOAT32) {
694
720
auto * origin_data = origin_tensor->data <float >();
695
- half_data[i] = static_cast <phi::dtype::bfloat16>(origin_data[i]);
721
+ low_precision_data[i] =
722
+ static_cast <phi::dtype::bfloat16>(origin_data[i]);
696
723
}
697
724
}
698
725
}
@@ -732,25 +759,44 @@ void AutoMixedPrecisionPass::InsertCastOp() const {
732
759
VLOG (4 ) << " process var: " << real_in_var_node->Var ()->Name ()
733
760
<< " with type " << in_var_type;
734
761
735
- if (IsFloatType (in_var_type) && op_run_low_precision_.count (op_type)) {
736
- DoInsertCastOp (subgraphes_[i],
737
- in_var_node,
738
- op_node,
739
- in_var_type,
740
- framework::TransToProtoVarType (low_precision_),
741
- block_desc,
742
- &suffix,
743
- &cache);
744
- } else if (IsHalfType (in_var_type) &&
762
+ if (IsFP32AndFP64 (in_var_type) &&
763
+ op_run_low_precision_.count (op_type)) {
764
+ auto to_type = framework::TransToProtoVarType (low_precision_);
765
+ auto * prev_op =
766
+ in_var_node->inputs .empty () ? nullptr : in_var_node->inputs [0 ];
767
+ if (prev_op && GetOpOriginalType (prev_op->Op ()->Type ()) == " cast" ) {
768
+ in_var_node->Var ()->SetDataType (to_type);
769
+ prev_op->Op ()->SetAttr (" out_dtype" , static_cast <int >(to_type));
770
+ prev_op->Op ()->Flush ();
771
+ } else {
772
+ DoInsertCastOp (subgraphes_[i],
773
+ in_var_node,
774
+ op_node,
775
+ in_var_type,
776
+ to_type,
777
+ block_desc,
778
+ &suffix,
779
+ &cache);
780
+ }
781
+ } else if (IsFP16AndBFP16 (in_var_type) &&
745
782
op_run_low_precision_.count (op_type) == 0 ) {
746
- DoInsertCastOp (subgraphes_[i],
747
- in_var_node,
748
- op_node,
749
- in_var_type,
750
- VarType::FP32,
751
- block_desc,
752
- &suffix,
753
- &cache);
783
+ auto to_type = VarType::FP32;
784
+ auto * prev_op =
785
+ in_var_node->inputs .empty () ? nullptr : in_var_node->inputs [0 ];
786
+ if (prev_op && GetOpOriginalType (prev_op->Op ()->Type ()) == " cast" ) {
787
+ in_var_node->Var ()->SetDataType (to_type);
788
+ prev_op->Op ()->SetAttr (" out_dtype" , static_cast <int >(to_type));
789
+ prev_op->Op ()->Flush ();
790
+ } else {
791
+ DoInsertCastOp (subgraphes_[i],
792
+ in_var_node,
793
+ op_node,
794
+ in_var_type,
795
+ to_type,
796
+ block_desc,
797
+ &suffix,
798
+ &cache);
799
+ }
754
800
}
755
801
}
756
802
0 commit comments