@@ -63,6 +63,76 @@ class PyLayerContext:
6363 not_inplace_tensors : tuple [Tensor , ...]
6464 non_differentiable : tuple [Tensor , ...]
6565 materialize_grads : bool
66+ grad_in_dtype_consistent : bool
67+
68+ def set_grad_in_dtype_consistent (self , flag : bool ) -> None :
69+ """
70+ Set whether to maintain gradient input dtype consistency between forward output and backward input.
71+
72+ Note:
73+ This API should be called only inside `forward`.
74+ By default, backward input gradients are automatically cast to match the dtype of forward outputs.
75+ Set this to `False` to disable automatic casting and maintain original gradient dtypes in backward.
76+
77+ Args:
78+ flag (bool): Whether to enable automatic dtype conversion in backward.
79+ - `True`: Cast backward input gradient to match forward output dtype (default behavior)
80+ - `False`: Preserve original dtype of backward input gradient
81+
82+ Returns:
83+ None
84+
85+ Examples:
86+ .. code-block:: python
87+
88+ >>> import paddle
89+ >>> from paddle.autograd import PyLayer
90+ >>> paddle.seed(2025)
91+ >>> class cus_tanh(PyLayer):
92+ ... @staticmethod
93+ ... def forward(ctx, x):
94+ ... y = paddle.tanh(x)
95+ ... # Pass tensors to backward.
96+ ... ctx.save_for_backward(y)
97+ ... # The gradient input in the backward process
98+ ... # will not be automatically cast to the dtype of the forward output.
99+ ... ctx.set_grad_in_dtype_consistent(False)
100+ ... return y
101+ ...
102+ ... @staticmethod
103+ ... def backward(ctx, dy):
104+ ...
105+ ... # Get the tensors passed by forward.
106+ ... y, = ctx.saved_tensor()
107+ ... grad = dy * (1 - paddle.square(y))
108+ ... return grad
109+ ...
110+ >>> class cus_tanh_cast_grad(PyLayer):
111+ ... @staticmethod
112+ ... def forward(ctx, x):
113+ ... y = paddle.tanh(x)
114+ ... # Pass tensors to backward.
115+ ... ctx.save_for_backward(y)
116+ ... return y
117+ ...
118+ ... @staticmethod
119+ ... def backward(ctx, dy):
120+ ... # Get the tensors passed by forward.
121+ ... y, = ctx.saved_tensor()
122+ ... grad = dy * (1 - paddle.square(y))
123+ ... # The gradient input in cus_tanh be cast to bfloat16 manually,
124+ ... # and cus_tanh will not cast the gradient to the dtype of the forward output.
125+ ... grad = paddle.cast(grad,paddle.float16)
126+ ... return grad
127+ ...
128+ >>> x = paddle.randn([3,3]).astype("float32")
129+ >>> x.stop_gradient = False
130+ >>> y = cus_tanh.apply(x)
131+ >>> z = cus_tanh_cast_grad.apply(y)
132+ >>> z.sum().backward()
133+
134+ """
135+ self .grad_in_dtype_consistent = flag
66136
67137 def save_for_backward (self , * tensors : Tensor ) -> None :
68138 """
0 commit comments