@@ -660,6 +660,7 @@ class {} : public egr::GradNodeBase {{
660
660
#include "paddle/fluid/framework/op_registry.h"
661
661
#include "paddle/utils/test_macros.h"
662
662
#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
663
+ #include "paddle/utils/optional.h"
663
664
using CPUPlace = phi::CPUPlace;
664
665
{}
665
666
{}
@@ -1496,7 +1497,7 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False):
1496
1497
1497
1498
self .grad_node_out_list = grad_node_out_list
1498
1499
1499
- def run (self ):
1500
+ def run (self , append_input_out = False ):
1500
1501
# Basic Validation Check
1501
1502
self .DygraphYamlValidationCheck ()
1502
1503
@@ -1684,7 +1685,9 @@ def GenerateForwardLayoutAutotune(
1684
1685
1685
1686
return layout_logic_str
1686
1687
1687
- def GenerateForwardDefinitionAndDeclaration (self , is_inplaced , grad_flag ):
1688
+ def GenerateForwardDefinitionAndDeclaration (
1689
+ self , is_inplaced , grad_flag , append_input_out
1690
+ ):
1688
1691
namespace = self .namespace
1689
1692
if self .forward_api_name [- 1 ] == '_' and not is_inplaced :
1690
1693
return
@@ -1881,6 +1884,24 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced, grad_flag):
1881
1884
1882
1885
inputs_args_declaration_str = ", " .join (inputs_args_declaration_list )
1883
1886
inputs_args_definition_str = ", " .join (inputs_args_definition_list )
1887
+ if (
1888
+ append_input_out
1889
+ and not grad_flag
1890
+ and not is_inplaced
1891
+ and len (self .forward_outputs_position_map ) == 1
1892
+ and next (iter (self .forward_outputs_position_map .values ()))[0 ]
1893
+ == "Tensor"
1894
+ and forward_api_name != "empty_like"
1895
+ ):
1896
+ inputs_args_declaration_str = (
1897
+ inputs_args_declaration_str
1898
+ + ", paddle::optional<paddle::Tensor*> input_out = paddle::none"
1899
+ )
1900
+ inputs_args_definition_str = (
1901
+ inputs_args_definition_str
1902
+ + ", paddle::optional<paddle::Tensor*> input_out"
1903
+ )
1904
+ inputs_call_list .append ("input_out" )
1884
1905
inputs_call_args_str = ", " .join (inputs_call_list )
1885
1906
self .inputs_call_list = inputs_call_list
1886
1907
@@ -2135,6 +2156,16 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced, grad_flag):
2135
2156
+ " " .join (amp_autocast_optional_list )
2136
2157
)
2137
2158
amp_inputs_call_args_str = ", " .join (amp_inputs_call_list )
2159
+ if (
2160
+ append_input_out
2161
+ and not grad_flag
2162
+ and not is_inplaced
2163
+ and len (self .forward_outputs_position_map ) == 1
2164
+ and next (iter (self .forward_outputs_position_map .values ()))[0 ]
2165
+ == "Tensor"
2166
+ and forward_api_name != "empty_like"
2167
+ ):
2168
+ amp_inputs_call_args_str = amp_inputs_call_args_str + ", input_out"
2138
2169
amp_call_str = (
2139
2170
f"return { forward_ad_function_name } ({ amp_inputs_call_args_str } );"
2140
2171
)
@@ -2158,6 +2189,18 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced, grad_flag):
2158
2189
type_promote_inputs_call_args_str = ", " .join (
2159
2190
type_promote_inputs_call_list
2160
2191
)
2192
+ if (
2193
+ append_input_out
2194
+ and not grad_flag
2195
+ and not is_inplaced
2196
+ and len (self .forward_outputs_position_map ) == 1
2197
+ and next (iter (self .forward_outputs_position_map .values ()))[0 ]
2198
+ == "Tensor"
2199
+ and forward_api_name != "empty_like"
2200
+ ):
2201
+ type_promote_inputs_call_args_str = (
2202
+ type_promote_inputs_call_args_str + ", input_out"
2203
+ )
2161
2204
type_promote_call_list = f"return { forward_ad_function_name } ({ type_promote_inputs_call_args_str } );"
2162
2205
2163
2206
x_cast = (
@@ -2180,6 +2223,19 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced, grad_flag):
2180
2223
type_promote_inputs_call_args_str = ", " .join (
2181
2224
type_promote_inputs_call_list
2182
2225
)
2226
+ if (
2227
+ append_input_out
2228
+ and not grad_flag
2229
+ and not is_inplaced
2230
+ and len (self .forward_outputs_position_map ) == 1
2231
+ and next (iter (self .forward_outputs_position_map .values ()))[0 ]
2232
+ == "Tensor"
2233
+ and forward_api_name != "empty_like"
2234
+ ):
2235
+ type_promote_inputs_call_args_str = (
2236
+ type_promote_inputs_call_args_str + ", input_out"
2237
+ )
2238
+
2183
2239
type_promote_call_list = f"return { forward_ad_function_name } ({ type_promote_inputs_call_args_str } );"
2184
2240
2185
2241
x_cast = (
@@ -2323,15 +2379,19 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced, grad_flag):
2323
2379
2324
2380
self .forward_declaration_str += f"TEST_API { returns_type_str } { forward_ad_function_name } ({ inputs_args_declaration_str } );\n "
2325
2381
2326
- def GenerateInplacedForwardDygraphFunctions (self , grad_flag ):
2382
+ def GenerateInplacedForwardDygraphFunctions (
2383
+ self , grad_flag , append_input_out
2384
+ ):
2327
2385
# Inplaced Version Dygraph Function Generation
2328
2386
forward_api_name = self .forward_api_name
2329
2387
forward_api_contents = self .forward_api_contents
2330
2388
2331
2389
if forward_api_name != "sum" and "inplace" in forward_api_contents :
2332
2390
# Function Definition and Declaration Generation
2333
2391
self .GenerateForwardDefinitionAndDeclaration (
2334
- is_inplaced = True , grad_flag = grad_flag
2392
+ is_inplaced = True ,
2393
+ grad_flag = grad_flag ,
2394
+ append_input_out = append_input_out ,
2335
2395
)
2336
2396
self .UpdateCoreOpsInformation (is_inplaced = True )
2337
2397
@@ -2367,21 +2427,25 @@ def UpdateCoreOpsInformation(self, is_inplaced):
2367
2427
for name , (ttype , pos ) in forward_outputs_position_map .items ():
2368
2428
core_ops_returns_info [fwd_api_name ][pos ] = name
2369
2429
2370
- def run (self , grad_flag = False ):
2371
- super ().run ()
2430
+ def run (self , grad_flag = False , append_input_out = False ):
2431
+ super ().run (append_input_out = append_input_out )
2372
2432
2373
2433
###################
2374
2434
# Code Generation #
2375
2435
###################
2376
2436
2377
2437
# Definition And Declaration
2378
2438
self .GenerateForwardDefinitionAndDeclaration (
2379
- is_inplaced = False , grad_flag = grad_flag
2439
+ is_inplaced = False ,
2440
+ grad_flag = grad_flag ,
2441
+ append_input_out = append_input_out ,
2380
2442
)
2381
2443
2382
2444
self .UpdateCoreOpsInformation (is_inplaced = False )
2383
2445
2384
- self .GenerateInplacedForwardDygraphFunctions (grad_flag )
2446
+ self .GenerateInplacedForwardDygraphFunctions (
2447
+ grad_flag , append_input_out = append_input_out
2448
+ )
2385
2449
2386
2450
2387
2451
class DygraphNodeGenerator (DygraphFunctionGeneratorBase ):
@@ -3214,8 +3278,8 @@ def _gen_api_call_code_block(
3214
3278
returns_str ,
3215
3279
)
3216
3280
3217
- def run (self ):
3218
- super ().run ()
3281
+ def run (self , append_input_out = False ):
3282
+ super ().run (append_input_out = append_input_out )
3219
3283
3220
3284
self .ResetOptionalInputs ()
3221
3285
@@ -3299,7 +3363,7 @@ def GetBackwardAPIContents(self, forward_api_contents):
3299
3363
3300
3364
return backward_api_contents
3301
3365
3302
- def GenerateCode (self , grad_flag = False ):
3366
+ def GenerateCode (self , grad_flag = False , append_input_out = True ):
3303
3367
if grad_flag :
3304
3368
op_string = 'backward_op'
3305
3369
else :
@@ -3347,7 +3411,9 @@ def GenerateCode(self, grad_flag=False):
3347
3411
forward_apis_dict ,
3348
3412
namespace ,
3349
3413
)
3350
- function_generator .run (grad_flag )
3414
+ function_generator .run (
3415
+ grad_flag , append_input_out = append_input_out
3416
+ )
3351
3417
3352
3418
self .forward_definition_str += (
3353
3419
function_generator .forward_definition_str + "\n "
@@ -3372,7 +3438,7 @@ def GenerateCode(self, grad_flag=False):
3372
3438
namespace ,
3373
3439
next_grad_api_contents ,
3374
3440
)
3375
- node_generator .run ()
3441
+ node_generator .run (append_input_out = append_input_out )
3376
3442
self .node_declaration_str += (
3377
3443
node_generator .node_declaration_str + "\n "
3378
3444
)
@@ -3407,12 +3473,12 @@ def GenerateCode(self, grad_flag=False):
3407
3473
namespace , self .node_definition_str
3408
3474
)
3409
3475
3410
- def run (self , grad_flag = False ):
3476
+ def run (self , grad_flag = False , append_input_out = False ):
3411
3477
self .ParseYamlContents ()
3412
3478
3413
3479
self .InferNameSpace ()
3414
3480
3415
- self .GenerateCode (grad_flag )
3481
+ self .GenerateCode (grad_flag , append_input_out = append_input_out )
3416
3482
3417
3483
3418
3484
################
@@ -3521,7 +3587,10 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str, grad_flag):
3521
3587
generator = DygraphForwardAndNodesGenerator (
3522
3588
api_yaml_path , backward_yaml_path
3523
3589
)
3524
- generator .run ()
3590
+ append_input_out = (
3591
+ "string" not in api_yaml_path and "sparse" not in api_yaml_path
3592
+ )
3593
+ generator .run (append_input_out = append_input_out )
3525
3594
3526
3595
node_declaration_str += generator .node_declaration_str + "\n "
3527
3596
node_definition_str += generator .node_definition_str + "\n "
@@ -3556,7 +3625,7 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str, grad_flag):
3556
3625
backward_yaml_path , backward_yaml_path
3557
3626
)
3558
3627
3559
- generator_grad .run (True )
3628
+ generator_grad .run (True , append_input_out = False )
3560
3629
3561
3630
backward_declaration_str += (
3562
3631
generator_grad .forward_declaration_str + "\n "
0 commit comments