Skip to content

Commit 7776754

Browse files
authored
[CherryPick][Auto Parallel] fix loss scale in xpu (#71698) (#71765)
1 parent 0aa35fa commit 7776754

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

python/paddle/distributed/auto_parallel/api.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1933,6 +1933,17 @@ def unscale_method(self, optimizer):
19331933
temp_param_grads_half,
19341934
temp_scale,
19351935
)
1936+
1937+
# AllReduce for "bool" is not supported on XPU
1938+
if "xpu" in paddle.device.get_device():
1939+
temp_param_grads_half = paddle.cast(
1940+
temp_param_grads_half, "int32"
1941+
)
1942+
temp_param_grads_half = paddle.sum(temp_param_grads_half)
1943+
temp_param_grads_half = paddle.cast(
1944+
temp_param_grads_half, "bool"
1945+
)
1946+
19361947
temp_found_inf = _C_ops.bitwise_or(
19371948
temp_found_inf, temp_found_inf_half
19381949
)
@@ -1941,6 +1952,17 @@ def unscale_method(self, optimizer):
19411952
temp_param_grads_fp32,
19421953
temp_scale,
19431954
)
1955+
1956+
# AllReduce for "bool" is not supported on XPU
1957+
if "xpu" in paddle.device.get_device():
1958+
temp_found_inf_fp32 = paddle.cast(
1959+
temp_found_inf_fp32, "int32"
1960+
)
1961+
temp_found_inf_fp32 = paddle.sum(temp_found_inf_fp32)
1962+
temp_found_inf_fp32 = paddle.cast(
1963+
temp_found_inf_fp32, "bool"
1964+
)
1965+
19441966
temp_found_inf = _C_ops.bitwise_or(
19451967
temp_found_inf, temp_found_inf_fp32
19461968
)

python/paddle/optimizer/optimizer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,6 +1336,12 @@ def _create_optimization_pass(
13361336

13371337
if framework.in_dygraph_mode():
13381338
found_inf = self._get_auxiliary_var('found_inf')
1339+
if (
1340+
"xpu" in paddle.device.get_device()
1341+
and found_inf is not None
1342+
and found_inf.is_dist()
1343+
):
1344+
found_inf = found_inf._local_value()
13391345
if found_inf:
13401346
if isinstance(found_inf, core.eager.Tensor):
13411347
self._set_auxiliary_var('found_inf', True)

0 commit comments

Comments
 (0)