Skip to content

Commit 7fa6687

Browse files
wanghuancodermaxiaolong001
authored andcommitted
dygraph support input a out Tensor (PaddlePaddle#74484)
* dygraph support input a out Tensor * refine * refine * refine * refine * refine * refine * refine * refine
1 parent a292656 commit 7fa6687

23 files changed

+504
-122
lines changed

paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,29 @@
1919
#include "paddle/phi/core/distributed/auto_parallel/placement_types.h"
2020
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
2121

22-
paddle::Tensor add_n_ad_func(const std::vector<paddle::Tensor>& x);
22+
paddle::Tensor add_n_ad_func(
23+
const std::vector<paddle::Tensor>& x,
24+
paddle::optional<paddle::Tensor*> input_out = paddle::none);
2325

24-
paddle::Tensor conv2d_ad_func(const paddle::Tensor& input,
25-
const paddle::Tensor& filter,
26-
std::vector<int> strides,
27-
std::vector<int> paddings,
28-
std::string padding_algorithm,
29-
std::vector<int> dilations,
30-
int groups,
31-
std::string data_format);
26+
paddle::Tensor conv2d_ad_func(
27+
const paddle::Tensor& input,
28+
const paddle::Tensor& filter,
29+
std::vector<int> strides,
30+
std::vector<int> paddings,
31+
std::string padding_algorithm,
32+
std::vector<int> dilations,
33+
int groups,
34+
std::string data_format,
35+
paddle::optional<paddle::Tensor*> input_out = paddle::none);
3236

33-
paddle::Tensor multiply_ad_func(const paddle::Tensor& x,
34-
const paddle::Tensor& y);
35-
paddle::Tensor& multiply__ad_func(paddle::Tensor& x, // NOLINT
36-
const paddle::Tensor& y);
37+
paddle::Tensor multiply_ad_func(
38+
const paddle::Tensor& x,
39+
const paddle::Tensor& y,
40+
paddle::optional<paddle::Tensor*> input_out = paddle::none);
41+
paddle::Tensor& multiply__ad_func(
42+
paddle::Tensor& x, // NOLINT
43+
const paddle::Tensor& y,
44+
paddle::optional<paddle::Tensor*> input_out = paddle::none);
3745

3846
std::tuple<paddle::Tensor,
3947
paddle::Tensor&,
@@ -55,17 +63,20 @@ sync_batch_norm__ad_func(const paddle::Tensor& x,
5563

5664
paddle::Tensor reshard_ad_function(
5765
const paddle::Tensor& tensor,
58-
const phi::distributed::TensorDistAttr dist_attr);
66+
const phi::distributed::TensorDistAttr dist_attr,
67+
paddle::optional<paddle::Tensor*> input_out = paddle::none);
5968

6069
paddle::Tensor dtensor_to_local_ad_function(
6170
const paddle::Tensor& input,
6271
const phi::distributed::ProcessMesh& processmesh,
63-
const phi::distributed::Placements& placements);
72+
const phi::distributed::Placements& placements,
73+
paddle::optional<paddle::Tensor*> input_out = paddle::none);
6474

6575
paddle::Tensor dtensor_from_local_ad_function(
6676
const paddle::Tensor& input,
6777
const phi::distributed::ProcessMesh& processmesh,
68-
const phi::distributed::Placements& placements);
78+
const phi::distributed::Placements& placements,
79+
paddle::optional<paddle::Tensor*> input_out = paddle::none);
6980

7081
namespace sparse {
7182
std::tuple<paddle::Tensor,

paddle/fluid/eager/api/manual/eager_manual/forwards/add_n_fwd_func.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
COMMON_DECLARE_bool(check_nan_inf);
2424
COMMON_DECLARE_bool(check_cuda_error);
2525

26-
paddle::Tensor add_n_ad_func(const std::vector<paddle::Tensor>& x) {
26+
paddle::Tensor add_n_ad_func(const std::vector<paddle::Tensor>& x,
27+
paddle::optional<paddle::Tensor*> input_out) {
2728
VLOG(3) << "Running AD API: "
2829
<< "add_n";
2930
if (FLAGS_check_cuda_error) [[unlikely]] {

paddle/fluid/eager/api/manual/eager_manual/forwards/conv2d_fwd_function.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ paddle::Tensor conv2d_ad_func(const paddle::Tensor& input,
3131
std::string padding_algorithm,
3232
std::vector<int> dilations,
3333
int groups,
34-
std::string data_format) {
34+
std::string data_format,
35+
paddle::optional<paddle::Tensor*> input_out) {
3536
VLOG(3) << "Running AD API: "
3637
<< "conv2d";
3738
if (FLAGS_check_cuda_error) [[unlikely]] {

paddle/fluid/eager/api/manual/eager_manual/forwards/dtensor_from_local_fwd_func.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ COMMON_DECLARE_bool(check_cuda_error);
2525
paddle::Tensor dtensor_from_local_ad_function(
2626
const paddle::Tensor& input,
2727
const phi::distributed::ProcessMesh& process_mesh,
28-
const phi::distributed::Placements& placements) {
28+
const phi::distributed::Placements& placements,
29+
paddle::optional<paddle::Tensor*> input_out) {
2930
#ifdef PADDLE_WITH_DISTRIBUTE
3031
VLOG(3) << "Running AD API: "
3132
<< "dtensor_from_local dygraph";

paddle/fluid/eager/api/manual/eager_manual/forwards/dtensor_to_local_fwd_func.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ COMMON_DECLARE_bool(check_cuda_error);
2323
paddle::Tensor dtensor_to_local_ad_function(
2424
const paddle::Tensor& input,
2525
const phi::distributed::ProcessMesh& process_mesh,
26-
const phi::distributed::Placements& placements) {
26+
const phi::distributed::Placements& placements,
27+
paddle::optional<paddle::Tensor*> input_out) {
2728
#ifdef PADDLE_WITH_DISTRIBUTE
2829
VLOG(3) << "Running AD API: "
2930
<< "dtensor_to_local dygraph";

paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ bool check_if_support_elementwise_mul_mem_opt(const std::string& device_type) {
3838
}
3939

4040
paddle::Tensor multiply_ad_func(const paddle::Tensor& x,
41-
const paddle::Tensor& y) {
41+
const paddle::Tensor& y,
42+
paddle::optional<paddle::Tensor*> input_out) {
4243
FLAGS_tensor_operants_mode = "eager";
4344
VLOG(3) << "Running AD API: "
4445
<< "multiply";
@@ -241,7 +242,8 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x,
241242
}
242243

243244
paddle::Tensor& multiply__ad_func(paddle::Tensor& x, // NOLINT
244-
const paddle::Tensor& y) {
245+
const paddle::Tensor& y,
246+
paddle::optional<paddle::Tensor*> input_out) {
245247
FLAGS_tensor_operants_mode = "eager";
246248
VLOG(3) << "Running AD API: "
247249
<< "multiply_";

paddle/fluid/eager/api/manual/eager_manual/forwards/reshard_fwd_func.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ COMMON_DECLARE_bool(check_cuda_error);
2222

2323
paddle::Tensor reshard_ad_function(
2424
const paddle::Tensor& input,
25-
const phi::distributed::TensorDistAttr dist_attr) {
25+
const phi::distributed::TensorDistAttr dist_attr,
26+
paddle::optional<paddle::Tensor*> input_out) {
2627
#ifdef PADDLE_WITH_DISTRIBUTE
2728
VLOG(3) << "Running AD API: "
2829
<< "reshard dygraph";

paddle/fluid/eager/auto_code_generator/generator/eager_gen.py

Lines changed: 86 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,7 @@ class {} : public egr::GradNodeBase {{
660660
#include "paddle/fluid/framework/op_registry.h"
661661
#include "paddle/utils/test_macros.h"
662662
#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
663+
#include "paddle/utils/optional.h"
663664
using CPUPlace = phi::CPUPlace;
664665
{}
665666
{}
@@ -1496,7 +1497,7 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False):
14961497

14971498
self.grad_node_out_list = grad_node_out_list
14981499

1499-
def run(self):
1500+
def run(self, append_input_out=False):
15001501
# Basic Validation Check
15011502
self.DygraphYamlValidationCheck()
15021503

@@ -1684,7 +1685,9 @@ def GenerateForwardLayoutAutotune(
16841685

16851686
return layout_logic_str
16861687

1687-
def GenerateForwardDefinitionAndDeclaration(self, is_inplaced, grad_flag):
1688+
def GenerateForwardDefinitionAndDeclaration(
1689+
self, is_inplaced, grad_flag, append_input_out
1690+
):
16881691
namespace = self.namespace
16891692
if self.forward_api_name[-1] == '_' and not is_inplaced:
16901693
return
@@ -1881,6 +1884,24 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced, grad_flag):
18811884

18821885
inputs_args_declaration_str = ", ".join(inputs_args_declaration_list)
18831886
inputs_args_definition_str = ", ".join(inputs_args_definition_list)
1887+
if (
1888+
append_input_out
1889+
and not grad_flag
1890+
and not is_inplaced
1891+
and len(self.forward_outputs_position_map) == 1
1892+
and next(iter(self.forward_outputs_position_map.values()))[0]
1893+
== "Tensor"
1894+
and forward_api_name != "empty_like"
1895+
):
1896+
inputs_args_declaration_str = (
1897+
inputs_args_declaration_str
1898+
+ ", paddle::optional<paddle::Tensor*> input_out = paddle::none"
1899+
)
1900+
inputs_args_definition_str = (
1901+
inputs_args_definition_str
1902+
+ ", paddle::optional<paddle::Tensor*> input_out"
1903+
)
1904+
inputs_call_list.append("input_out")
18841905
inputs_call_args_str = ", ".join(inputs_call_list)
18851906
self.inputs_call_list = inputs_call_list
18861907

@@ -2135,6 +2156,16 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced, grad_flag):
21352156
+ " ".join(amp_autocast_optional_list)
21362157
)
21372158
amp_inputs_call_args_str = ", ".join(amp_inputs_call_list)
2159+
if (
2160+
append_input_out
2161+
and not grad_flag
2162+
and not is_inplaced
2163+
and len(self.forward_outputs_position_map) == 1
2164+
and next(iter(self.forward_outputs_position_map.values()))[0]
2165+
== "Tensor"
2166+
and forward_api_name != "empty_like"
2167+
):
2168+
amp_inputs_call_args_str = amp_inputs_call_args_str + ", input_out"
21382169
amp_call_str = (
21392170
f"return {forward_ad_function_name}({amp_inputs_call_args_str});"
21402171
)
@@ -2158,6 +2189,18 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced, grad_flag):
21582189
type_promote_inputs_call_args_str = ", ".join(
21592190
type_promote_inputs_call_list
21602191
)
2192+
if (
2193+
append_input_out
2194+
and not grad_flag
2195+
and not is_inplaced
2196+
and len(self.forward_outputs_position_map) == 1
2197+
and next(iter(self.forward_outputs_position_map.values()))[0]
2198+
== "Tensor"
2199+
and forward_api_name != "empty_like"
2200+
):
2201+
type_promote_inputs_call_args_str = (
2202+
type_promote_inputs_call_args_str + ", input_out"
2203+
)
21612204
type_promote_call_list = f"return {forward_ad_function_name}({type_promote_inputs_call_args_str});"
21622205

21632206
x_cast = (
@@ -2180,6 +2223,19 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced, grad_flag):
21802223
type_promote_inputs_call_args_str = ", ".join(
21812224
type_promote_inputs_call_list
21822225
)
2226+
if (
2227+
append_input_out
2228+
and not grad_flag
2229+
and not is_inplaced
2230+
and len(self.forward_outputs_position_map) == 1
2231+
and next(iter(self.forward_outputs_position_map.values()))[0]
2232+
== "Tensor"
2233+
and forward_api_name != "empty_like"
2234+
):
2235+
type_promote_inputs_call_args_str = (
2236+
type_promote_inputs_call_args_str + ", input_out"
2237+
)
2238+
21832239
type_promote_call_list = f"return {forward_ad_function_name}({type_promote_inputs_call_args_str});"
21842240

21852241
x_cast = (
@@ -2323,15 +2379,19 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced, grad_flag):
23232379

23242380
self.forward_declaration_str += f"TEST_API {returns_type_str} {forward_ad_function_name}({inputs_args_declaration_str});\n"
23252381

2326-
def GenerateInplacedForwardDygraphFunctions(self, grad_flag):
2382+
def GenerateInplacedForwardDygraphFunctions(
2383+
self, grad_flag, append_input_out
2384+
):
23272385
# Inplaced Version Dygraph Function Generation
23282386
forward_api_name = self.forward_api_name
23292387
forward_api_contents = self.forward_api_contents
23302388

23312389
if forward_api_name != "sum" and "inplace" in forward_api_contents:
23322390
# Function Definition and Declaration Generation
23332391
self.GenerateForwardDefinitionAndDeclaration(
2334-
is_inplaced=True, grad_flag=grad_flag
2392+
is_inplaced=True,
2393+
grad_flag=grad_flag,
2394+
append_input_out=append_input_out,
23352395
)
23362396
self.UpdateCoreOpsInformation(is_inplaced=True)
23372397

@@ -2367,21 +2427,25 @@ def UpdateCoreOpsInformation(self, is_inplaced):
23672427
for name, (ttype, pos) in forward_outputs_position_map.items():
23682428
core_ops_returns_info[fwd_api_name][pos] = name
23692429

2370-
def run(self, grad_flag=False):
2371-
super().run()
2430+
def run(self, grad_flag=False, append_input_out=False):
2431+
super().run(append_input_out=append_input_out)
23722432

23732433
###################
23742434
# Code Generation #
23752435
###################
23762436

23772437
# Definition And Declaration
23782438
self.GenerateForwardDefinitionAndDeclaration(
2379-
is_inplaced=False, grad_flag=grad_flag
2439+
is_inplaced=False,
2440+
grad_flag=grad_flag,
2441+
append_input_out=append_input_out,
23802442
)
23812443

23822444
self.UpdateCoreOpsInformation(is_inplaced=False)
23832445

2384-
self.GenerateInplacedForwardDygraphFunctions(grad_flag)
2446+
self.GenerateInplacedForwardDygraphFunctions(
2447+
grad_flag, append_input_out=append_input_out
2448+
)
23852449

23862450

23872451
class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
@@ -3214,8 +3278,8 @@ def _gen_api_call_code_block(
32143278
returns_str,
32153279
)
32163280

3217-
def run(self):
3218-
super().run()
3281+
def run(self, append_input_out=False):
3282+
super().run(append_input_out=append_input_out)
32193283

32203284
self.ResetOptionalInputs()
32213285

@@ -3299,7 +3363,7 @@ def GetBackwardAPIContents(self, forward_api_contents):
32993363

33003364
return backward_api_contents
33013365

3302-
def GenerateCode(self, grad_flag=False):
3366+
def GenerateCode(self, grad_flag=False, append_input_out=True):
33033367
if grad_flag:
33043368
op_string = 'backward_op'
33053369
else:
@@ -3347,7 +3411,9 @@ def GenerateCode(self, grad_flag=False):
33473411
forward_apis_dict,
33483412
namespace,
33493413
)
3350-
function_generator.run(grad_flag)
3414+
function_generator.run(
3415+
grad_flag, append_input_out=append_input_out
3416+
)
33513417

33523418
self.forward_definition_str += (
33533419
function_generator.forward_definition_str + "\n"
@@ -3372,7 +3438,7 @@ def GenerateCode(self, grad_flag=False):
33723438
namespace,
33733439
next_grad_api_contents,
33743440
)
3375-
node_generator.run()
3441+
node_generator.run(append_input_out=append_input_out)
33763442
self.node_declaration_str += (
33773443
node_generator.node_declaration_str + "\n"
33783444
)
@@ -3407,12 +3473,12 @@ def GenerateCode(self, grad_flag=False):
34073473
namespace, self.node_definition_str
34083474
)
34093475

3410-
def run(self, grad_flag=False):
3476+
def run(self, grad_flag=False, append_input_out=False):
34113477
self.ParseYamlContents()
34123478

34133479
self.InferNameSpace()
34143480

3415-
self.GenerateCode(grad_flag)
3481+
self.GenerateCode(grad_flag, append_input_out=append_input_out)
34163482

34173483

34183484
################
@@ -3521,7 +3587,10 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str, grad_flag):
35213587
generator = DygraphForwardAndNodesGenerator(
35223588
api_yaml_path, backward_yaml_path
35233589
)
3524-
generator.run()
3590+
append_input_out = (
3591+
"string" not in api_yaml_path and "sparse" not in api_yaml_path
3592+
)
3593+
generator.run(append_input_out=append_input_out)
35253594

35263595
node_declaration_str += generator.node_declaration_str + "\n"
35273596
node_definition_str += generator.node_definition_str + "\n"
@@ -3556,7 +3625,7 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str, grad_flag):
35563625
backward_yaml_path, backward_yaml_path
35573626
)
35583627

3559-
generator_grad.run(True)
3628+
generator_grad.run(True, append_input_out=False)
35603629

35613630
backward_declaration_str += (
35623631
generator_grad.forward_declaration_str + "\n"

0 commit comments

Comments
 (0)