Skip to content

Commit 6996029

Browse files
authored
Support set_grad_in_dtype_consistent for pylayer (PaddlePaddle#76537)
* support dtype consistent for pylayer * fix * fix for static check
1 parent 626981e commit 6996029

File tree

7 files changed

+167
-7
lines changed

7 files changed

+167
-7
lines changed

paddle/fluid/eager/backward.cc

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "paddle/fluid/eager/backward.h"
1616

1717
#include "paddle/fluid/eager/general_grad.h"
18+
#include "paddle/fluid/eager/pylayer/py_layer_node.h"
1819
#include "paddle/fluid/eager/utils.h"
1920
#include "paddle/fluid/inference/analysis/dot.h"
2021
#include "paddle/phi/core/memory/stats.h"
@@ -321,8 +322,15 @@ std::vector<paddle::Tensor> RunBackward(
321322
VLOG(4) << "RunBackward: Create Value for grad input tensor " << i
322323
<< " of grad node: " << grad_node->name() << "(" << grad_node
323324
<< ")";
324-
node_input_buffers_dict[grad_node] =
325-
std::make_unique<GradTensorHolder>(grad_node->InputMeta());
325+
326+
if (typeid(*grad_node) == typeid(GradNodePyLayer)) {
327+
auto pylayer_gradnode = dynamic_cast<GradNodePyLayer*>(grad_node);
328+
node_input_buffers_dict[grad_node] = std::make_unique<GradTensorHolder>(
329+
grad_node->InputMeta(), pylayer_gradnode->GradInDtypeConsistent());
330+
} else {
331+
node_input_buffers_dict[grad_node] =
332+
std::make_unique<GradTensorHolder>(grad_node->InputMeta());
333+
}
326334
}
327335

328336
// copy grad tensor since we should totally run grad without affect forward
@@ -589,11 +597,23 @@ std::vector<paddle::Tensor> RunBackward(
589597

590598
if (!node_input_buffers_dict.count(next_node)) {
591599
const auto& input_meta = next_node->InputMeta();
592-
auto grad_tensor_holder =
593-
std::make_unique<GradTensorHolder>(input_meta);
600+
594601
VLOG(6) << "RunBackward: Construct GradTensorHolder for grad node: "
595602
<< next_node->name() << "(" << next_node << ") ";
596-
node_input_buffers_dict[next_node] = std::move(grad_tensor_holder);
603+
604+
if (typeid(*next_node) == typeid(GradNodePyLayer)) {
605+
auto pylayer_gradnode = dynamic_cast<GradNodePyLayer*>(next_node);
606+
auto grad_tensor_holder = std::make_unique<GradTensorHolder>(
607+
next_node->InputMeta(),
608+
pylayer_gradnode->GradInDtypeConsistent());
609+
node_input_buffers_dict[next_node] =
610+
std::move(grad_tensor_holder);
611+
} else {
612+
auto grad_tensor_holder =
613+
std::make_unique<GradTensorHolder>(input_meta);
614+
node_input_buffers_dict[next_node] =
615+
std::move(grad_tensor_holder);
616+
}
597617
}
598618

599619
VLOG(7) << "RunBackward: Sum or Move grad inputs for edge slot: "

paddle/fluid/eager/grad_tensor_holder.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ class GradTensorHolder {
2929
public:
3030
explicit GradTensorHolder(
3131
const paddle::small_vector<std::vector<GradSlotMeta>,
32-
kSlotSmallVectorSize>& metas) {
32+
kSlotSmallVectorSize>& metas,
33+
bool record_input_dtypes = true) {
3334
VLOG(7) << "Init GradTensorHolder with meta size: " << metas.size();
3435
buffer_.resize(metas.size());
3536
input_dtypes_.resize(metas.size());
@@ -41,12 +42,13 @@ class GradTensorHolder {
4142
// Extract only dtype information from metas
4243
for (size_t j = 0; j < metas[i].size(); j++) {
4344
const auto& meta = metas[i][j];
44-
if (meta.HasTensorMeta()) {
45+
if (meta.HasTensorMeta() && record_input_dtypes) {
4546
const auto& tensor_meta = meta.GetTensorMeta();
4647
input_dtypes_[i][j] = tensor_meta.dtype;
4748
VLOG(7) << "Init GradTensorHolder with dtype: "
4849
<< phi::DataTypeToString(tensor_meta.dtype);
4950
} else {
51+
VLOG(7) << "Init GradTensorHolder with UNDEFINED";
5052
input_dtypes_[i][j] = phi::DataType::UNDEFINED;
5153
}
5254
}

paddle/fluid/eager/pylayer/py_layer_node.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ class GradNodePyLayer : public GradNodeBase {
106106
std::shared_ptr<GradNodePyLayer>(new GradNodePyLayer(*this));
107107
return copied_node;
108108
}
109+
bool GradInDtypeConsistent() { return grad_in_dtype_consistent_; }
110+
void SetGradInDtypeConsistent(bool value) {
111+
grad_in_dtype_consistent_ = value;
112+
}
109113

110114
private:
111115
PyObject* ctx_{nullptr};
@@ -116,6 +120,7 @@ class GradNodePyLayer : public GradNodeBase {
116120
forward_outputs_dist_attr_;
117121
std::vector<std::vector<phi::DDim>> forward_outputs_global_dims_;
118122
std::vector<std::vector<bool>> forward_outputs_is_dist_meta_;
123+
bool grad_in_dtype_consistent_;
119124
};
120125

121126
} // namespace egr

paddle/fluid/pybind/eager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ typedef struct {
3232
PyObject* non_differentiable;
3333
PyObject* not_inplace_tensors;
3434
bool materialize_grads;
35+
bool grad_in_dtype_consistent;
3536
std::vector<bool> forward_input_tensor_is_duplicable;
3637
std::vector<bool> forward_output_tensor_is_duplicable;
3738
std::weak_ptr<egr::GradNodePyLayer> grad_node;

paddle/fluid/pybind/eager_py_layer.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ PyObject* PyLayerNew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
8585
v->container = nullptr;
8686
v->materialize_grads = true;
8787
v->container_be_packed = false;
88+
v->grad_in_dtype_consistent = true;
8889
new (&v->grad_node) std::weak_ptr<egr::GradNodePyLayer>();
8990
new (&v->forward_input_tensor_is_duplicable) std::vector<bool>();
9091
new (&v->forward_output_tensor_is_duplicable) std::vector<bool>();
@@ -575,6 +576,7 @@ PyObject* pylayer_method_apply(PyObject* cls,
575576
if (ctx->materialize_grads) {
576577
grad_node->SaveForwardOutputsMeta(outputs_tensor);
577578
}
579+
grad_node->SetGradInDtypeConsistent(ctx->grad_in_dtype_consistent);
578580

579581
for (size_t i = 0; i < inputs_autograd_meta.size(); i++) {
580582
if (ctx->forward_input_tensor_is_duplicable[i]) {
@@ -858,6 +860,14 @@ int tensor_properties_set_materialize_grads(PyLayerObject* self,
858860
return 0;
859861
EAGER_CATCH_AND_THROW_RETURN_NEG
860862
}
863+
int tensor_properties_set_grad_in_dtype_consistent(PyLayerObject* self,
864+
PyObject* value,
865+
void* closure) {
866+
EAGER_TRY
867+
self->grad_in_dtype_consistent = CastPyArg2AttrBoolean(value, 0);
868+
return 0;
869+
EAGER_CATCH_AND_THROW_RETURN_NEG
870+
}
861871

862872
PyMethodDef pylayer_methods[] = {{"name", // NOLINT
863873
(PyCFunction)(void (*)())pylayer_method_name,
@@ -890,6 +900,11 @@ struct PyGetSetDef pylayer_properties[] { // NOLINT
890900
(setter)tensor_properties_set_materialize_grads,
891901
nullptr,
892902
nullptr},
903+
{"grad_in_dtype_consistent",
904+
nullptr,
905+
(setter)tensor_properties_set_grad_in_dtype_consistent,
906+
nullptr,
907+
nullptr},
893908
{
894909
nullptr, nullptr, nullptr, nullptr, nullptr
895910
}

python/paddle/autograd/py_layer.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,76 @@ class PyLayerContext:
6363
not_inplace_tensors: tuple[Tensor, ...]
6464
non_differentiable: tuple[Tensor, ...]
6565
materialize_grads: bool
66+
grad_in_dtype_consistent: bool
67+
68+
def set_grad_in_dtype_consistent(self, flag: bool) -> None:
69+
"""
70+
Set whether to maintain gradient input dtype consistency between forward output and backward input.
71+
72+
Note:
73+
This API should be called only inside `forward`.
74+
By default, backward input gradients are automatically cast to match the dtype of forward outputs.
75+
Set this to `False` to disable automatic casting and maintain original gradient dtypes in backward.
76+
77+
Args:
78+
flag (bool): Whether to enable automatic dtype conversion in backward.
79+
- `True`: Cast backward input gradient to match forward output dtype (default behavior)
80+
- `False`: Preserve original dtype of backward input gradient
81+
82+
Returns:
83+
None
84+
85+
Examples:
86+
.. code-block:: python
87+
88+
>>> import paddle
89+
>>> from paddle.autograd import PyLayer
90+
>>> paddle.seed(2025)
91+
>>> class cus_tanh(PyLayer):
92+
... @staticmethod
93+
... def forward(ctx, x):
94+
... y = paddle.tanh(x)
95+
... # Pass tensors to backward.
96+
... ctx.save_for_backward(y)
97+
... # The gradient input in the backward process
98+
... # will not be automatically cast to the dtype of the forward output.
99+
... ctx.set_grad_in_dtype_consistent(False)
100+
... return y
101+
...
102+
... @staticmethod
103+
... def backward(ctx, dy):
104+
...
105+
... # Get the tensors passed by forward.
106+
... y, = ctx.saved_tensor()
107+
... grad = dy * (1 - paddle.square(y))
108+
... return grad
109+
...
110+
>>> class cus_tanh_cast_grad(PyLayer):
111+
... @staticmethod
112+
... def forward(ctx, x):
113+
... y = paddle.tanh(x)
114+
... # Pass tensors to backward.
115+
... ctx.save_for_backward(y)
116+
... return y
117+
...
118+
... @staticmethod
119+
... def backward(ctx, dy):
120+
... # Get the tensors passed by forward.
121+
... y, = ctx.saved_tensor()
122+
... grad = dy * (1 - paddle.square(y))
123+
... # The gradient input in cus_tanh be cast to bfloat16 manually,
124+
... # and cus_tanh will not cast the gradient to the dtype of the forward output.
125+
... grad = paddle.cast(grad,paddle.float16)
126+
... return grad
127+
...
128+
>>> x = paddle.randn([3,3]).astype("float32")
129+
>>> x.stop_gradient = False
130+
>>> y = cus_tanh.apply(x)
131+
>>> z = cus_tanh_cast_grad.apply(y)
132+
>>> z.sum().backward()
133+
134+
"""
135+
self.grad_in_dtype_consistent = flag
66136

67137
def save_for_backward(self, *tensors: Tensor) -> None:
68138
"""

test/legacy_test/test_pylayer_op.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,53 @@ def test_nest_backward_error(self):
732732
expect_msg, err_msg, expect_msg + " should in error message "
733733
)
734734

735+
def test_set_grad_in_dtype_consistent(self):
736+
paddle.seed(2025)
737+
cus_tanh_backward_input = paddle.empty([])
738+
739+
class cus_tanh(PyLayer):
740+
@staticmethod
741+
def forward(ctx, x):
742+
y = paddle.tanh(x)
743+
# Pass tensors to backward.
744+
ctx.save_for_backward(y)
745+
# The gradient input in the backward process
746+
# will not be automatically cast to the dtype of the forward output.
747+
ctx.set_grad_in_dtype_consistent(False)
748+
return y
749+
750+
@staticmethod
751+
def backward(ctx, dy):
752+
nonlocal cus_tanh_backward_input
753+
cus_tanh_backward_input = dy
754+
# Get the tensors passed by forward.
755+
(y,) = ctx.saved_tensor()
756+
grad = dy * (1 - paddle.square(y))
757+
return grad
758+
759+
class cus_tanh_cast_grad(PyLayer):
760+
@staticmethod
761+
def forward(ctx, x):
762+
y = paddle.tanh(x)
763+
# Pass tensors to backward.
764+
ctx.save_for_backward(y)
765+
return y
766+
767+
@staticmethod
768+
def backward(ctx, dy):
769+
# Get the tensors passed by forward.
770+
(y,) = ctx.saved_tensor()
771+
grad = dy * (1 - paddle.square(y))
772+
grad = paddle.cast(grad, paddle.float16)
773+
return grad
774+
775+
x = paddle.randn([3, 3]).astype("float32")
776+
x.stop_gradient = False
777+
y = cus_tanh.apply(x)
778+
z = cus_tanh_cast_grad.apply(y)
779+
z.backward()
780+
self.assertEqual(cus_tanh_backward_input.dtype, paddle.float16)
781+
735782

736783
if __name__ == '__main__':
737784
unittest.main()

0 commit comments

Comments
 (0)