Skip to content

Commit 3eb020a

Browse files
authored
[CherryPick][AutoParallel] Fix inplace op in grad clip (#71565) (#71584)
* [CherryPick][AutoParallel] Fix inplace op in grad clip (#71565) * fix inplace op in grad clip * fix inplace op in grad clip * [CherryPick]fix grad clip (#71607)
1 parent aa72149 commit 3eb020a

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

python/paddle/nn/clip.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,18 @@ def _cast_to_mp_type_if_enabled(x):
240240
return x
241241

242242

243+
def _can_inplace_clip_grad(grad: Tensor, clip_input: Tensor):
244+
if not grad._is_initialized() or not clip_input._is_initialized():
245+
return False
246+
247+
# 1. Inplace ops only support DistTensor and DenseTensor.
248+
# 2. Inplace ops do not support 0-D tensor.
249+
if (grad.is_dist() or grad.is_dense()) and len(grad.shape) != 0:
250+
return True
251+
252+
return False
253+
254+
243255
def _squared_l2_norm(x):
244256
r"""
245257
Return the squared L2 norm of a tensor.
@@ -840,7 +852,8 @@ def async_add_n(var_list):
840852
clip_input = paddle.distributed.reshard(
841853
clip_input, g.process_mesh, clip_input.placements
842854
)
843-
if g.is_dist() or g.is_dense():
855+
856+
if _can_inplace_clip_grad(g, clip_input):
844857
g.multiply_(clip_input)
845858
params_and_grads.append((p, g))
846859
else:

0 commit comments

Comments
 (0)