Skip to content

Commit f729aff

Browse files
[Bug Fix] Fix view grad when out_grad is not contiguous (PaddlePaddle#76679)
* Fix view grad when out_grad is not contiguous * feat(phi): add no_need_buffer for view_shape_grad in backward.yaml Add `no_need_buffer: input` to the view_shape_grad backward operator configuration. This indicates that the gradient computation for view_shape does not require the input tensor buffer, optimizing memory usage during backward passes. * feat(test): add gradient test for view and transpose operations Add a new test case `TestViewGrad.test_dygraph` to verify gradient correctness when using `view` followed by `transpose` in dynamic graph mode. The test ensures that the computed gradients match expected values, improving coverage for tensor manipulation operations.
1 parent b6ebea4 commit f729aff

File tree

3 files changed

+46
-2
lines changed

3 files changed

+46
-2
lines changed

paddle/phi/kernels/stride/view_grad_kernel.cc

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "paddle/common/flags.h"
1616
#include "paddle/phi/backends/all_context.h"
1717
#include "paddle/phi/core/kernel_registry.h"
18+
#include "paddle/phi/kernels/funcs/strided_reshape_utils.h"
19+
#include "paddle/phi/kernels/funcs/strided_utils.h"
1820
#include "paddle/phi/kernels/view_kernel.h"
1921

2022
COMMON_DECLARE_bool(use_stride_kernel);
@@ -32,8 +34,34 @@ void ViewShapeGradKernel(const Context& dev_ctx,
3234
"FLAGS_use_stride_kernel is closed. Strided kernel "
3335
"be called, something wrong has happened!"));
3436
}
35-
ViewShapeStridedKernel<Context>(
36-
dev_ctx, out_grad, common::vectorize<int64_t>(input.dims()), input_grad);
37+
DDim target_dims = input.dims();
38+
DDim target_stride;
39+
40+
if (ReshapeStride(
41+
out_grad.dims(), out_grad.strides(), target_dims, target_stride)) {
42+
input_grad->set_meta(out_grad.meta());
43+
input_grad->Resize(target_dims);
44+
input_grad->set_strides(target_stride);
45+
input_grad->set_offset(out_grad.offset());
46+
input_grad->ResetHolder(out_grad.Holder());
47+
input_grad->ShareInplaceVersionCounterWith(out_grad);
48+
} else {
49+
DenseTensor contiguous_tmp;
50+
DenseTensor tmp_out_grad = out_grad;
51+
52+
contiguous_tmp.set_meta(tmp_out_grad.meta());
53+
54+
PD_VISIT_ALL_TYPES(out_grad.dtype(), "ViewShapeGradKernel", ([&] {
55+
phi::StridedTensorContiguous<data_t>(tmp_out_grad,
56+
&contiguous_tmp);
57+
}));
58+
59+
input_grad->set_meta(contiguous_tmp.meta());
60+
input_grad->Resize(target_dims);
61+
input_grad->set_strides(DenseTensorMeta::calc_strides(target_dims));
62+
input_grad->set_offset(0);
63+
input_grad->ResetHolder(contiguous_tmp.Holder());
64+
}
3765
}
3866

3967
template <typename Context>

paddle/phi/ops/yaml/backward.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4044,6 +4044,7 @@
40444044
kernel :
40454045
func : view_shape_grad
40464046
backward : view_shape_double_grad
4047+
no_need_buffer: input
40474048

40484049
- backward_op : warpctc_grad
40494050
forward : warpctc (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank = 0, bool norm_by_times = false) -> Tensor(loss), Tensor(warpctcgrad)

test/legacy_test/test_stride.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,5 +1092,20 @@ def func2():
10921092
func2()
10931093

10941094

1095+
class TestViewGrad(unittest.TestCase):
1096+
def test_dygraph(self):
1097+
paddle.disable_static()
1098+
x = paddle.randn(2, 12, requires_grad=True)
1099+
1100+
y = x.view(2, 3, 4)
1101+
z = y.transpose(1, 2)
1102+
1103+
loss = z.sum()
1104+
loss.backward()
1105+
1106+
x_grad_expected = paddle.full_like(x, 1.0)
1107+
self.assertEqual((x.grad == x_grad_expected).all(), True)
1108+
1109+
10951110
if __name__ == '__main__':
10961111
unittest.main()

0 commit comments

Comments
 (0)