@@ -262,7 +262,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
262
262
_rename_arg (op , in_var .name , out_var .name )
263
263
264
264
for attr_name in ['in_dtype' , 'out_dtype' , 'dtype' ]:
265
- if op .has_attr (attr_name ) and is_float_dtype ( op .attr (attr_name )) :
265
+ if op .has_attr (attr_name ) and op .attr (attr_name ) in FLOAT_TYPES :
266
266
op ._set_attr (attr_name , dest_dtype )
267
267
268
268
return num_cast_ops
@@ -405,13 +405,18 @@ def fp16_guard():
405
405
yield
406
406
407
407
408
- def is_float_dtype (dtype ):
409
- return (
410
- dtype == core .VarDesc .VarType .FP32
411
- or dtype == core .VarDesc .VarType .FP16
412
- or dtype == core .VarDesc .VarType .BF16
413
- or dtype == core .VarDesc .VarType .FP64
414
- )
408
+ FLOAT_TYPES = {
409
+ core .VarDesc .VarType .FP32 ,
410
+ core .VarDesc .VarType .FP16 ,
411
+ core .VarDesc .VarType .BF16 ,
412
+ core .VarDesc .VarType .FP64 ,
413
+ }
414
+
415
+ SUPPORT_FLOAT_TYPES = {
416
+ core .VarDesc .VarType .FP32 ,
417
+ core .VarDesc .VarType .FP16 ,
418
+ core .VarDesc .VarType .BF16 ,
419
+ }
415
420
416
421
417
422
def set_var_dst_dtype (
@@ -433,7 +438,7 @@ def set_var_dst_dtype(
433
438
if var is None or var .type not in _valid_types :
434
439
continue
435
440
436
- if is_float_dtype ( var .dtype ) :
441
+ if var .dtype in FLOAT_TYPES :
437
442
low_precison_var_names .add (var_name )
438
443
if need_set_dtype :
439
444
var .desc .set_dtype (dtype )
@@ -700,6 +705,25 @@ def cast_model_to_fp16(
700
705
701
706
def need_process (op ):
702
707
need_process = True
708
+
709
+ def is_support_type (name ):
710
+ if not op .block ._find_var_recursive (
711
+ name
712
+ ): # a special case for lod_tensor_blocking_queue_0
713
+ return True
714
+ if (
715
+ op .block ._var_recursive (name ).type
716
+ != core .VarDesc .VarType .LOD_TENSOR
717
+ ):
718
+ return False
719
+ return op .block ._var_recursive (name ).dtype in SUPPORT_FLOAT_TYPES
720
+
721
+ if len (op .input_arg_names ) > 0 and all (
722
+ not is_support_type (name ) for name in op .input_arg_names
723
+ ):
724
+ return False
725
+
726
+ # if input type of op is fp64, we just skip it.
703
727
if op .type in ["set_value" ]:
704
728
# NOTE(zoooo0820): OP set_value has attribute "dtype", but its output type is
705
729
# determined by the input.dtype instead of attribute. So, here we still process it.
@@ -711,8 +735,7 @@ def need_process(op):
711
735
# output type of some operators such as fill_constant will be determined by the attribute value.
712
736
#
713
737
if not op .has_attr ('in_dtype' ) and (
714
- op .has_attr (attr_name )
715
- and is_float_dtype (op .attr (attr_name ))
738
+ op .has_attr (attr_name ) and op .attr (attr_name ) in FLOAT_TYPES
716
739
):
717
740
need_process = False
718
741
0 commit comments