|
15 | 15 | import copy
|
16 | 16 | import logging
|
17 | 17 |
|
| 18 | +from paddle.amp.amp_lists import ( |
| 19 | + FP16_BLACK_LIST, |
| 20 | + FP16_EXTRA_BLACK_LIST, |
| 21 | + FP16_WHITE_LIST, |
| 22 | +) |
18 | 23 | from paddle.fluid import core
|
19 | 24 | from paddle.fluid.log_helper import get_logger
|
20 | 25 |
|
21 | 26 | _logger = get_logger(
|
22 | 27 | __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
|
23 | 28 | )
|
24 | 29 |
|
25 |
| -# lookup_table fp16 is slower than fp32, though fp16 is supported. |
26 |
| -_extra_black_list = { |
27 |
| - 'lookup_table', |
28 |
| - 'lookup_table_v2', |
29 |
| - 'scatter', |
30 |
| - 'linear_interp_v2', |
31 |
| - 'nearest_interp_v2', |
32 |
| - 'bilinear_interp_v2', |
33 |
| - 'bicubic_interp_v2', |
34 |
| - 'trilinear_interp_v2', |
35 |
| -} |
| 30 | +black_list = FP16_BLACK_LIST |
| 31 | +_extra_black_list = FP16_EXTRA_BLACK_LIST |
| 32 | +white_list = FP16_WHITE_LIST |
36 | 33 |
|
37 | 34 |
|
38 | 35 | def check_amp_dtype(dtype):
|
@@ -131,45 +128,17 @@ def _get_unsupported_list(dtype):
|
131 | 128 |
|
132 | 129 | _only_supported_fp16_list = {'resnet_unit', 'fused_bn_add_activation'}
|
133 | 130 |
|
134 |
| -white_list = { |
135 |
| - 'conv2d', |
136 |
| - 'einsum', |
137 |
| - 'matmul', |
138 |
| - 'matmul_v2', |
139 |
| - 'mul', |
140 |
| -} |
141 |
| - |
142 | 131 |
|
143 | 132 | def _get_white_list(dtype):
|
144 |
| - white_list_for_dtype = copy.copy(white_list) |
| 133 | + white_list_for_dtype = copy.copy(FP16_WHITE_LIST) |
145 | 134 | if dtype == 'float16':
|
146 | 135 | white_list_for_dtype = white_list_for_dtype | _only_supported_fp16_list
|
147 | 136 | return white_list_for_dtype
|
148 | 137 |
|
149 | 138 |
|
150 |
| -# The set of ops that support fp16 calculation and are considered numerically- |
151 |
| -# dangerous and whose effects may also be observed in downstream ops. |
152 |
| -black_list = { |
153 |
| - 'exp', |
154 |
| - 'square', |
155 |
| - 'log', |
156 |
| - 'mean', |
157 |
| - 'sum', |
158 |
| - 'cos_sim', |
159 |
| - 'softmax', |
160 |
| - 'softmax_with_cross_entropy', |
161 |
| - 'sigmoid_cross_entropy_with_logits', |
162 |
| - 'c_softmax_with_cross_entropy', |
163 |
| - 'cross_entropy', |
164 |
| - 'cross_entropy2', |
165 |
| - # default fp32 can avoid return inf when the sum value large than 65504 |
166 |
| - 'reduce_sum', |
167 |
| -} |
168 |
| - |
169 |
| - |
170 | 139 | def _get_black_list():
|
171 |
| - _black_list = copy.copy(black_list) |
172 |
| - _black_list = _black_list | _extra_black_list |
| 140 | + _black_list = copy.copy(FP16_BLACK_LIST) |
| 141 | + _black_list = _black_list | FP16_EXTRA_BLACK_LIST |
173 | 142 | return _black_list
|
174 | 143 |
|
175 | 144 |
|
|
0 commit comments