Skip to content

Commit cf64aa0

Browse files
Change Static AMP List (#54135) (#54591)
AMP 动静态图黑名单统一
1 parent 7487592 commit cf64aa0

File tree

1 file changed

+11
-42
lines changed

1 file changed

+11
-42
lines changed

python/paddle/static/amp/fp16_lists.py

Lines changed: 11 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,21 @@
1515
import copy
1616
import logging
1717

18+
from paddle.amp.amp_lists import (
19+
FP16_BLACK_LIST,
20+
FP16_EXTRA_BLACK_LIST,
21+
FP16_WHITE_LIST,
22+
)
1823
from paddle.fluid import core
1924
from paddle.fluid.log_helper import get_logger
2025

2126
_logger = get_logger(
2227
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
2328
)
2429

25-
# lookup_table fp16 is slower than fp32, though fp16 is supported.
26-
_extra_black_list = {
27-
'lookup_table',
28-
'lookup_table_v2',
29-
'scatter',
30-
'linear_interp_v2',
31-
'nearest_interp_v2',
32-
'bilinear_interp_v2',
33-
'bicubic_interp_v2',
34-
'trilinear_interp_v2',
35-
}
30+
black_list = FP16_BLACK_LIST
31+
_extra_black_list = FP16_EXTRA_BLACK_LIST
32+
white_list = FP16_WHITE_LIST
3633

3734

3835
def check_amp_dtype(dtype):
@@ -131,45 +128,17 @@ def _get_unsupported_list(dtype):
131128

132129
_only_supported_fp16_list = {'resnet_unit', 'fused_bn_add_activation'}
133130

134-
white_list = {
135-
'conv2d',
136-
'einsum',
137-
'matmul',
138-
'matmul_v2',
139-
'mul',
140-
}
141-
142131

143132
def _get_white_list(dtype):
144-
white_list_for_dtype = copy.copy(white_list)
133+
white_list_for_dtype = copy.copy(FP16_WHITE_LIST)
145134
if dtype == 'float16':
146135
white_list_for_dtype = white_list_for_dtype | _only_supported_fp16_list
147136
return white_list_for_dtype
148137

149138

150-
# The set of ops that support fp16 calculation and are considered numerically-
151-
# dangerous and whose effects may also be observed in downstream ops.
152-
black_list = {
153-
'exp',
154-
'square',
155-
'log',
156-
'mean',
157-
'sum',
158-
'cos_sim',
159-
'softmax',
160-
'softmax_with_cross_entropy',
161-
'sigmoid_cross_entropy_with_logits',
162-
'c_softmax_with_cross_entropy',
163-
'cross_entropy',
164-
'cross_entropy2',
165-
# default fp32 can avoid return inf when the sum value large than 65504
166-
'reduce_sum',
167-
}
168-
169-
170139
def _get_black_list():
171-
_black_list = copy.copy(black_list)
172-
_black_list = _black_list | _extra_black_list
140+
_black_list = copy.copy(FP16_BLACK_LIST)
141+
_black_list = _black_list | FP16_EXTRA_BLACK_LIST
173142
return _black_list
174143

175144

0 commit comments

Comments
 (0)