@@ -87,7 +87,7 @@ def FindParsingFunctionFromAttributeType(atype):
87
87
88
88
89
89
PYTHON_C_FUNCTION_TEMPLATE = """
90
- static PyObject * eager_api_{}(PyObject *self, PyObject *args, PyObject *kwargs) {{
90
+ PyObject * eager_api_{}(PyObject *self, PyObject *args, PyObject *kwargs) {{
91
91
{}
92
92
PyThreadState *tstate = nullptr;
93
93
try {{
@@ -173,6 +173,7 @@ def FindParsingFunctionFromAttributeType(atype):
173
173
#include "paddle/fluid/pybind/eager.h"
174
174
#include "paddle/fluid/eager/amp_utils.h"
175
175
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
176
+ #include "paddle/fluid/pybind/eager_op_function.h"
176
177
namespace paddle {{
177
178
namespace pybind {{
178
179
@@ -253,6 +254,29 @@ def FindParsingFunctionFromAttributeType(atype):
253
254
}}
254
255
"""
255
256
257
+ PYTHON_C_H_TEMPLATE = """
258
+ #pragma once
259
+
260
+ #include <Python.h>
261
+
262
+ // Avoid a problem with copysign defined in pyconfig.h on Windows.
263
+ #ifdef copysign
264
+ #undef copysign
265
+ #endif
266
+
267
+ namespace paddle {{
268
+ namespace pybind {{
269
+
270
+ {body}
271
+
272
+ }} // namespace pybind
273
+ }} // namespace paddle
274
+ """
275
+
276
+ PYTHON_C_FUNCTION_DECLARE_TEMPLATE = """
277
+ PyObject *eager_api_{name}(PyObject *self, PyObject *args, PyObject *kwargs);
278
+ """
279
+
256
280
257
281
#####################
258
282
# Generator Classes #
@@ -279,6 +303,7 @@ def __init__(self, forward_api_contents, namespace):
279
303
# Generated Results
280
304
self .python_c_function_str = ""
281
305
self .python_c_function_reg_str = ""
306
+ self .python_c_funcion_declare_str = ""
282
307
283
308
def CollectIsForwardOnly (self ):
284
309
forward_api_contents = self .forward_api_contents
@@ -428,6 +453,9 @@ def GeneratePythonCFunction(self):
428
453
noamp_dygraph_function_str ,
429
454
return_str ,
430
455
)
456
+ self .python_c_funcion_declare_str = (
457
+ PYTHON_C_FUNCTION_DECLARE_TEMPLATE .format (name = forward_api_name )
458
+ )
431
459
432
460
# Set prefix of forward_api_name to avoid conflicts
433
461
prefix = self .namespace .strip ("::" )
@@ -483,6 +511,12 @@ def GeneratePythonCFunction(self):
483
511
return_str ,
484
512
)
485
513
514
+ python_c_funcion_declare_str = (
515
+ PYTHON_C_FUNCTION_DECLARE_TEMPLATE .format (
516
+ name = inplaced_forward_api_name
517
+ )
518
+ )
519
+
486
520
python_c_inplace_func_reg_str = (
487
521
PYTHON_C_FUNCTION_REG_TEMPLATE .format (
488
522
forward_api_name_prefix ,
@@ -496,10 +530,14 @@ def GeneratePythonCFunction(self):
496
530
# self.forward_api_name ending with '_' means it only has inplace api
497
531
if self .forward_api_name [- 1 ] == '_' :
498
532
self .python_c_function_str = python_c_inplace_func_str
533
+ self .python_c_funcion_declare_str = python_c_funcion_declare_str
499
534
# Generate Python-C Function Registration
500
535
self .python_c_function_reg_str = python_c_inplace_func_reg_str
501
536
else :
502
537
self .python_c_function_str += python_c_inplace_func_str
538
+ self .python_c_funcion_declare_str += (
539
+ python_c_funcion_declare_str
540
+ )
503
541
# Generate Python-C Function Registration
504
542
self .python_c_function_reg_str += python_c_inplace_func_reg_str
505
543
@@ -541,6 +579,7 @@ def __init__(self, path):
541
579
# Generated Result
542
580
self .python_c_functions_str = ""
543
581
self .python_c_functions_reg_str = ""
582
+ self .python_c_funcion_declare_str = ""
544
583
545
584
def GeneratePythonCFunctions (self ):
546
585
namespace = self .namespace
@@ -559,6 +598,9 @@ def GeneratePythonCFunctions(self):
559
598
self .python_c_functions_reg_str += (
560
599
f_generator .python_c_function_reg_str
561
600
)
601
+ self .python_c_funcion_declare_str += (
602
+ f_generator .python_c_funcion_declare_str
603
+ )
562
604
563
605
def AttachNamespace (self ):
564
606
namespace = self .namespace
@@ -570,6 +612,11 @@ def AttachNamespace(self):
570
612
self .python_c_functions_str = NAMESPACE_WRAPPER_TEMPLATE .format (
571
613
namespace , python_c_functions_str
572
614
)
615
+ self .python_c_funcion_declare_str = (
616
+ NAMESPACE_WRAPPER_TEMPLATE .format (
617
+ namespace , self .python_c_funcion_declare_str
618
+ )
619
+ )
573
620
574
621
def run (self ):
575
622
# Infer namespace from yaml_path
@@ -593,7 +640,8 @@ def ParseArguments():
593
640
description = 'Eager Code Generator Args Parser'
594
641
)
595
642
parser .add_argument ('--api_yaml_path' , type = str )
596
- parser .add_argument ('--output_path' , type = str )
643
+ parser .add_argument ('--source_path' , type = str )
644
+ parser .add_argument ('--header_path' , type = str )
597
645
598
646
args = parser .parse_args ()
599
647
return args
@@ -631,6 +679,7 @@ def GeneratePythonCFile(filepath, python_c_str):
631
679
632
680
generated_python_c_functions = ""
633
681
generated_python_c_registration = ""
682
+ generated_python_c_functions_header = ""
634
683
for i in range (len (api_yaml_paths )):
635
684
api_yaml_path = api_yaml_paths [i ]
636
685
@@ -643,14 +692,22 @@ def GeneratePythonCFile(filepath, python_c_str):
643
692
generated_python_c_registration += (
644
693
py_c_generator .python_c_functions_reg_str
645
694
)
695
+ generated_python_c_functions_header += (
696
+ py_c_generator .python_c_funcion_declare_str
697
+ )
646
698
647
699
python_c_str = GeneratePythonCWrappers (
648
700
generated_python_c_functions , generated_python_c_registration
649
701
)
650
702
651
- output_path = args .output_path
652
- for path in [output_path ]:
703
+ soucre_path = args .source_path
704
+ header_path = args .header_path
705
+ for path in [soucre_path , header_path ]:
653
706
if os .path .exists (path ):
654
707
os .remove (path )
655
708
656
- GeneratePythonCFile (output_path , python_c_str )
709
+ GeneratePythonCFile (soucre_path , python_c_str )
710
+ GeneratePythonCFile (
711
+ header_path ,
712
+ PYTHON_C_H_TEMPLATE .format (body = generated_python_c_functions_header ),
713
+ )
0 commit comments