Skip to content

Commit 11e7dc0

Browse files
[Comp] Add view_shape_double_grad for eager mode (PaddlePaddle#76667)
* add view_shape_double_grad for eager mode
1 parent 6009ebc commit 11e7dc0

File tree

10 files changed

+48
-23
lines changed

10 files changed

+48
-23
lines changed

paddle/fluid/eager/auto_code_generator/generator/eager_gen.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
"put_along_axis_double_grad",
9595
"masked_fill_double_grad",
9696
"index_elementwise_put_with_tensor_double_grad",
97+
"view_shape_double_grad",
9798
]
9899

99100
# white ops list whose kernel can automatically do type promotion.
@@ -3170,12 +3171,7 @@ def GenerateNodeDefinition(
31703171
)
31713172

31723173
grad_api_args[grad_api_position] = name
3173-
if (
3174-
not is_invoke_forward_api
3175-
or name in self.grad_api_contents['invoke']
3176-
):
3177-
# NOTE: attr 'dims' is not necessary for 'invoke: view_shape(out_grad, input.shape())'
3178-
get_grad_in_args_list.append(get_attr_str)
3174+
get_grad_in_args_list.append(get_attr_str)
31793175

31803176
get_grad_in_args_str = "\n".join(get_grad_in_args_list)
31813177

paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,5 @@
4343
'masked_select_grad',
4444
'index_elementwise_get_grad',
4545
'index_elementwise_put_with_tensor_grad',
46+
'view_shape_grad',
4647
]

paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,5 +1296,16 @@ void index_elementwise_put_with_tensor_double_grad(
12961296
}
12971297
}
12981298

1299+
template <typename T>
1300+
void view_shape_double_grad(const Tensor& grad_input_grad,
1301+
const std::vector<int64_t> dims,
1302+
Tensor* grad_out_grad) {
1303+
if (grad_out_grad) {
1304+
Tensor grad_out_grad_tmp;
1305+
grad_out_grad_tmp = reshape<T>(grad_input_grad, dims);
1306+
set_output<T>(grad_out_grad_tmp, grad_out_grad);
1307+
}
1308+
}
1309+
12991310
} // namespace prim
13001311
} // namespace paddle

paddle/phi/ops/yaml/backward.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4025,6 +4025,26 @@
40254025
data_type : out_grad
40264026
no_need_buffer: input
40274027

4028+
- backward_op : view_shape_double_grad
4029+
forward : view_shape_grad (Tensor input, Tensor grad_out, int64_t[] dims) -> Tensor(grad_input)
4030+
args : (Tensor grad_input_grad, int64_t[] dims)
4031+
output : Tensor(grad_out_grad)
4032+
infer_meta :
4033+
func : StridedUnChangedInferMeta
4034+
param : [grad_input_grad]
4035+
composite: view_shape_double_grad(grad_input_grad, dims, grad_out_grad)
4036+
4037+
- backward_op : view_shape_grad
4038+
forward : view_shape (Tensor input, int64_t[] dims = {}) -> Tensor(out)
4039+
args : (Tensor input, Tensor out_grad, int64_t[] dims = {})
4040+
output : Tensor(input_grad)
4041+
infer_meta :
4042+
func : StridedUnChangedInferMeta
4043+
param : [input]
4044+
kernel :
4045+
func : view_shape_grad
4046+
backward : view_shape_double_grad
4047+
40284048
- backward_op : warpctc_grad
40294049
forward : warpctc (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank = 0, bool norm_by_times = false) -> Tensor(loss), Tensor(warpctcgrad)
40304050
args : (Tensor logits, Tensor logits_length, Tensor warpctcgrad, Tensor loss_grad, int blank, bool norm_by_times)

paddle/phi/ops/yaml/inconsistent/dygraph_backward.yaml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -361,13 +361,6 @@
361361
composite : tile_grad(x, out_grad, repeat_times, x_grad)
362362
backward : tile_double_grad
363363

364-
- backward_op : view_shape_grad
365-
forward : view_shape (Tensor input, int64_t[] dims = {}) -> Tensor(out)
366-
args : (Tensor input, Tensor out_grad, int64_t[] dims = {})
367-
output : Tensor(input_grad)
368-
invoke: view_shape(out_grad, input.shape())
369-
no_need_buffer: input
370-
371364
- backward_op: fused_gemm_epilogue_grad
372365
forward : fused_gemm_epilogue(Tensor x, Tensor y, Tensor bias, bool trans_x, bool trans_y, str activation) -> Tensor(out), Tensor(reserve_space)
373366
args : (Tensor x, Tensor y, Tensor reserve_space, Tensor out_grad, bool trans_x, bool trans_y, str activation)

paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -404,12 +404,3 @@
404404
data_type : x
405405
optional : indices, inverse, counts
406406
traits : paddle::dialect::ForwardOnlyTrait
407-
408-
- op : view_shape
409-
args : (Tensor input, int64_t[] dims = {})
410-
output : Tensor(out)
411-
infer_meta :
412-
func : ViewShapeInferMeta
413-
kernel :
414-
func : view_shape
415-
backward : view_shape_grad

paddle/phi/ops/yaml/op_compat.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4044,6 +4044,9 @@
40444044
get_expected_kernel_type :
40454045
update_loss_scaling_ : GetUpdateLossScalingExpectedKernelType
40464046

4047+
- op : view_shape
4048+
backward : view_shape_grad, view_shape_double_grad
4049+
40474050
- op : viterbi_decode
40484051
inputs :
40494052
{potentials : Input, transition_params : Transition, lengths : Length}

paddle/phi/ops/yaml/ops.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5817,6 +5817,15 @@
58175817
no_need_buffer : input
58185818
interfaces : paddle::dialect::InferSymbolicShapeInterface
58195819

5820+
- op : view_shape
5821+
args : (Tensor input, int64_t[] dims = {})
5822+
output : Tensor(out)
5823+
infer_meta :
5824+
func : ViewShapeInferMeta
5825+
kernel :
5826+
func : view_shape
5827+
backward : view_shape_grad
5828+
58205829
- op : view_slice
58215830
args : (Tensor input, int64_t begin_idx, int64_t end_idx)
58225831
output : Tensor

test/indexing/test_setitem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,7 @@ def test_boolean_mask_tensor_broadcast_v(self):
10241024
np.testing.assert_allclose(res, tensor_np)
10251025

10261026
def test_index_elementwise_put_with_tensor(self):
1027-
with dygraph_guard(), paddle.device("cpu"):
1027+
with dygraph_guard():
10281028
x = paddle.randn(10, 4, requires_grad=True)
10291029
xx = x + 0
10301030

test/legacy_test/test_stride.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,7 @@ def call_stride(self):
971971
self.call_view5()
972972
self.call_view6()
973973
self.call_view7()
974+
self.call_view8()
974975
self.call_view9()
975976
self.call_view10()
976977
self.call_view11()

0 commit comments

Comments
 (0)