Skip to content

Commit 92f1fb7

Browse files
authored
【bugfix】fix amp bugs and remove topo sort in base/backward.py (#60039)
* fix and remove topo order effect (#59996) * 【AMP/SOT/PIR】fix amp bugs in yolo_v5 and add unittest (#59896) * fix amp bugs in yolo_v5 and add unittest * add bf16 * fix-amp-bugs * Update fp16_utils.py
1 parent 1368016 commit 92f1fb7

File tree

3 files changed

+85
-12
lines changed

3 files changed

+85
-12
lines changed

python/paddle/base/backward.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,8 @@ def _addup_repetitive_outputs_(
539539
var_device = collections.defaultdict(str)
540540

541541
def _change_order_by_topo_order(var_name):
542+
if topo_order_for_backward is None:
543+
return
542544
origin_names = renamed_vars[var_name]
543545
origin_names.sort(key=lambda x: topo_order_for_grad_name[x])
544546

@@ -1596,12 +1598,12 @@ def find_op_index(block_desc, cur_op_desc):
15961598
program._appending_grad_times
15971599
]
15981600
# sum parameter's gradients' var given multiple var gradient
1599-
topo_order = _topo_order_map(block, target_vars)
16001601
if os.environ.get("FLAGS_program_topo_reorder", "False") in [
16011602
'True',
16021603
'1',
16031604
'true',
16041605
]:
1606+
topo_order = _topo_order_map(block, target_vars)
16051607
topo_order_for_backward = _topo_bwd_order_map(
16061608
topo_order, get_backward_op_desc
16071609
)

python/paddle/static/amp/fp16_utils.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
262262
_rename_arg(op, in_var.name, out_var.name)
263263

264264
for attr_name in ['in_dtype', 'out_dtype', 'dtype']:
265-
if op.has_attr(attr_name) and is_float_dtype(op.attr(attr_name)):
265+
if op.has_attr(attr_name) and op.attr(attr_name) in FLOAT_TYPES:
266266
op._set_attr(attr_name, dest_dtype)
267267

268268
return num_cast_ops
@@ -405,13 +405,18 @@ def fp16_guard():
405405
yield
406406

407407

408-
def is_float_dtype(dtype):
409-
return (
410-
dtype == core.VarDesc.VarType.FP32
411-
or dtype == core.VarDesc.VarType.FP16
412-
or dtype == core.VarDesc.VarType.BF16
413-
or dtype == core.VarDesc.VarType.FP64
414-
)
408+
FLOAT_TYPES = {
409+
core.VarDesc.VarType.FP32,
410+
core.VarDesc.VarType.FP16,
411+
core.VarDesc.VarType.BF16,
412+
core.VarDesc.VarType.FP64,
413+
}
414+
415+
SUPPORT_FLOAT_TYPES = {
416+
core.VarDesc.VarType.FP32,
417+
core.VarDesc.VarType.FP16,
418+
core.VarDesc.VarType.BF16,
419+
}
415420

416421

417422
def set_var_dst_dtype(
@@ -433,7 +438,7 @@ def set_var_dst_dtype(
433438
if var is None or var.type not in _valid_types:
434439
continue
435440

436-
if is_float_dtype(var.dtype):
441+
if var.dtype in FLOAT_TYPES:
437442
low_precison_var_names.add(var_name)
438443
if need_set_dtype:
439444
var.desc.set_dtype(dtype)
@@ -700,6 +705,25 @@ def cast_model_to_fp16(
700705

701706
def need_process(op):
702707
need_process = True
708+
709+
def is_support_type(name):
710+
if not op.block._find_var_recursive(
711+
name
712+
): # a special case for lod_tensor_blocking_queue_0
713+
return True
714+
if (
715+
op.block._var_recursive(name).type
716+
!= core.VarDesc.VarType.LOD_TENSOR
717+
):
718+
return False
719+
return op.block._var_recursive(name).dtype in SUPPORT_FLOAT_TYPES
720+
721+
if len(op.input_arg_names) > 0 and all(
722+
not is_support_type(name) for name in op.input_arg_names
723+
):
724+
return False
725+
726+
# if input type of op is fp64, we just skip it.
703727
if op.type in ["set_value"]:
704728
# NOTE(zoooo0820): OP set_value has attribute "dtype", but its output type is
705729
# determined by the input.dtype instead of attribute. So, here we still process it.
@@ -711,8 +735,7 @@ def need_process(op):
711735
# output type of some operators such as fill_constant will be determined by the attribute value.
712736
#
713737
if not op.has_attr('in_dtype') and (
714-
op.has_attr(attr_name)
715-
and is_float_dtype(op.attr(attr_name))
738+
op.has_attr(attr_name) and op.attr(attr_name) in FLOAT_TYPES
716739
):
717740
need_process = False
718741

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
from dygraph_to_static_utils import (
19+
Dy2StTestBase,
20+
test_legacy_and_pt,
21+
)
22+
23+
import paddle
24+
25+
np.random.seed(1)
26+
27+
28+
def func(x):
29+
y = x[0:3].astype("float32")
30+
return y
31+
32+
33+
class TestAmp64Case(Dy2StTestBase):
34+
def _run_static(self):
35+
static_func = paddle.jit.to_static(func)
36+
x = paddle.randn((10, 10)).astype("float64")
37+
with paddle.amp.auto_cast(True, level="O2"):
38+
dy_out = func(x)
39+
st_out = static_func(x)
40+
np.testing.assert_allclose(dy_out.numpy(), st_out.numpy())
41+
42+
@test_legacy_and_pt
43+
def test_ast_to_func(self):
44+
self._run_static()
45+
46+
47+
if __name__ == '__main__':
48+
unittest.main()

0 commit comments

Comments
 (0)