File tree Expand file tree Collapse file tree 2 files changed +28
-0
lines changed
distributed/auto_parallel Expand file tree Collapse file tree 2 files changed +28
-0
lines changed Original file line number Diff line number Diff line change @@ -1933,6 +1933,17 @@ def unscale_method(self, optimizer):
1933
1933
temp_param_grads_half ,
1934
1934
temp_scale ,
1935
1935
)
1936
+
1937
+ # AllReduce for "bool" is not supported on XPU
1938
+ if "xpu" in paddle .device .get_device ():
1939
+ temp_param_grads_half = paddle .cast (
1940
+ temp_param_grads_half , "int32"
1941
+ )
1942
+ temp_param_grads_half = paddle .sum (temp_param_grads_half )
1943
+ temp_param_grads_half = paddle .cast (
1944
+ temp_param_grads_half , "bool"
1945
+ )
1946
+
1936
1947
temp_found_inf = _C_ops .bitwise_or (
1937
1948
temp_found_inf , temp_found_inf_half
1938
1949
)
@@ -1941,6 +1952,17 @@ def unscale_method(self, optimizer):
1941
1952
temp_param_grads_fp32 ,
1942
1953
temp_scale ,
1943
1954
)
1955
+
1956
+ # AllReduce for "bool" is not supported on XPU
1957
+ if "xpu" in paddle .device .get_device ():
1958
+ temp_found_inf_fp32 = paddle .cast (
1959
+ temp_found_inf_fp32 , "int32"
1960
+ )
1961
+ temp_found_inf_fp32 = paddle .sum (temp_found_inf_fp32 )
1962
+ temp_found_inf_fp32 = paddle .cast (
1963
+ temp_found_inf_fp32 , "bool"
1964
+ )
1965
+
1944
1966
temp_found_inf = _C_ops .bitwise_or (
1945
1967
temp_found_inf , temp_found_inf_fp32
1946
1968
)
Original file line number Diff line number Diff line change @@ -1336,6 +1336,12 @@ def _create_optimization_pass(
1336
1336
1337
1337
if framework .in_dygraph_mode ():
1338
1338
found_inf = self ._get_auxiliary_var ('found_inf' )
1339
+ if (
1340
+ "xpu" in paddle .device .get_device ()
1341
+ and found_inf is not None
1342
+ and found_inf .is_dist ()
1343
+ ):
1344
+ found_inf = found_inf ._local_value ()
1339
1345
if found_inf :
1340
1346
if isinstance (found_inf , core .eager .Tensor ):
1341
1347
self ._set_auxiliary_var ('found_inf' , True )
You can’t perform that action at this time.
0 commit comments