Skip to content

Commit 11a526a

Browse files
authored
[NewIR]Call _C_ops.xx in both dygraph and static mode (#56809)
1 parent 179d426 commit 11a526a

File tree

15 files changed

+203
-28
lines changed

15 files changed

+203
-28
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ tools/nvcc_lazy
7272

7373
# This file is automatically generated.
7474
# TODO(zhiqiang) Move this file to build directory.
75-
paddle/fluid/pybind/eager_op_function.cc
75+
paddle/fluid/pybind/eager_op_function.*
7676
tools/nvcc_lazy
7777
paddle/phi/kernels/sparse/gpu/cutlass_generator/all_gemm_operations.h
7878
paddle/phi/kernels/sparse/gpu/cutlass_generator/configurations.h

paddle/fluid/eager/api/utils/global_utils.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,7 @@
1818
namespace egr {
1919

2020
Controller* Controller::controller_ = new Controller();
21+
thread_local std::shared_ptr<paddle::imperative::Tracer> Controller::tracer_ =
22+
std::make_shared<paddle::imperative::Tracer>();
2123

2224
} // namespace egr

paddle/fluid/eager/api/utils/global_utils.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,7 @@ class Controller {
145145
private:
146146
Controller() = default;
147147
static Controller* controller_;
148-
std::shared_ptr<paddle::imperative::Tracer> tracer_{
149-
new paddle::imperative::Tracer()};
148+
static thread_local std::shared_ptr<paddle::imperative::Tracer> tracer_;
150149
std::unordered_map<std::string, std::vector<paddle::OpMetaInfo>>
151150
op_meta_info_map_;
152151
/* op_type : {{{grad_outputs}, {grad_inputs}, {input}, {output}, {attrs}},

paddle/fluid/eager/auto_code_generator/generator/CMakeLists.txt

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,25 @@ add_custom_target(
5353
${nodes_h_path}
5454
VERBATIM)
5555

56-
set(tmp_python_c_output_path
56+
set(tmp_python_c_source_path
5757
"${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_op_function.cc.tmp")
58-
set(python_c_output_path
58+
set(python_c_source_path
5959
"${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_op_function.cc")
60+
set(tmp_python_c_header_path
61+
"${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_op_function.h.tmp")
62+
set(python_c_header_path
63+
"${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_op_function.h")
6064

6165
add_custom_target(
6266
eager_python_c_codegen
6367
COMMAND
6468
"${PYTHON_EXECUTABLE}"
6569
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py"
6670
"--api_yaml_path=${api_yaml_path},${fwd_api_yaml_path}"
67-
"--output_path=${tmp_python_c_output_path}"
68-
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_python_c_output_path}
69-
${python_c_output_path}
71+
"--source_path=${tmp_python_c_source_path}"
72+
"--header_path=${tmp_python_c_header_path}"
73+
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_python_c_source_path}
74+
${python_c_source_path}
75+
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_python_c_header_path}
76+
${python_c_header_path}
7077
VERBATIM)

paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def FindParsingFunctionFromAttributeType(atype):
8787

8888

8989
PYTHON_C_FUNCTION_TEMPLATE = """
90-
static PyObject * eager_api_{}(PyObject *self, PyObject *args, PyObject *kwargs) {{
90+
PyObject * eager_api_{}(PyObject *self, PyObject *args, PyObject *kwargs) {{
9191
{}
9292
PyThreadState *tstate = nullptr;
9393
try {{
@@ -173,6 +173,7 @@ def FindParsingFunctionFromAttributeType(atype):
173173
#include "paddle/fluid/pybind/eager.h"
174174
#include "paddle/fluid/eager/amp_utils.h"
175175
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
176+
#include "paddle/fluid/pybind/eager_op_function.h"
176177
namespace paddle {{
177178
namespace pybind {{
178179
@@ -253,6 +254,29 @@ def FindParsingFunctionFromAttributeType(atype):
253254
}}
254255
"""
255256

257+
PYTHON_C_H_TEMPLATE = """
258+
#pragma once
259+
260+
#include <Python.h>
261+
262+
// Avoid a problem with copysign defined in pyconfig.h on Windows.
263+
#ifdef copysign
264+
#undef copysign
265+
#endif
266+
267+
namespace paddle {{
268+
namespace pybind {{
269+
270+
{body}
271+
272+
}} // namespace pybind
273+
}} // namespace paddle
274+
"""
275+
276+
PYTHON_C_FUNCTION_DECLARE_TEMPLATE = """
277+
PyObject *eager_api_{name}(PyObject *self, PyObject *args, PyObject *kwargs);
278+
"""
279+
256280

257281
#####################
258282
# Generator Classes #
@@ -279,6 +303,7 @@ def __init__(self, forward_api_contents, namespace):
279303
# Generated Results
280304
self.python_c_function_str = ""
281305
self.python_c_function_reg_str = ""
306+
self.python_c_funcion_declare_str = ""
282307

283308
def CollectIsForwardOnly(self):
284309
forward_api_contents = self.forward_api_contents
@@ -428,6 +453,9 @@ def GeneratePythonCFunction(self):
428453
noamp_dygraph_function_str,
429454
return_str,
430455
)
456+
self.python_c_funcion_declare_str = (
457+
PYTHON_C_FUNCTION_DECLARE_TEMPLATE.format(name=forward_api_name)
458+
)
431459

432460
# Set prefix of forward_api_name to avoid conflicts
433461
prefix = self.namespace.strip("::")
@@ -483,6 +511,12 @@ def GeneratePythonCFunction(self):
483511
return_str,
484512
)
485513

514+
python_c_funcion_declare_str = (
515+
PYTHON_C_FUNCTION_DECLARE_TEMPLATE.format(
516+
name=inplaced_forward_api_name
517+
)
518+
)
519+
486520
python_c_inplace_func_reg_str = (
487521
PYTHON_C_FUNCTION_REG_TEMPLATE.format(
488522
forward_api_name_prefix,
@@ -496,10 +530,14 @@ def GeneratePythonCFunction(self):
496530
# self.forward_api_name ending with '_' means it only has inplace api
497531
if self.forward_api_name[-1] == '_':
498532
self.python_c_function_str = python_c_inplace_func_str
533+
self.python_c_funcion_declare_str = python_c_funcion_declare_str
499534
# Generate Python-C Function Registration
500535
self.python_c_function_reg_str = python_c_inplace_func_reg_str
501536
else:
502537
self.python_c_function_str += python_c_inplace_func_str
538+
self.python_c_funcion_declare_str += (
539+
python_c_funcion_declare_str
540+
)
503541
# Generate Python-C Function Registration
504542
self.python_c_function_reg_str += python_c_inplace_func_reg_str
505543

@@ -541,6 +579,7 @@ def __init__(self, path):
541579
# Generated Result
542580
self.python_c_functions_str = ""
543581
self.python_c_functions_reg_str = ""
582+
self.python_c_funcion_declare_str = ""
544583

545584
def GeneratePythonCFunctions(self):
546585
namespace = self.namespace
@@ -559,6 +598,9 @@ def GeneratePythonCFunctions(self):
559598
self.python_c_functions_reg_str += (
560599
f_generator.python_c_function_reg_str
561600
)
601+
self.python_c_funcion_declare_str += (
602+
f_generator.python_c_funcion_declare_str
603+
)
562604

563605
def AttachNamespace(self):
564606
namespace = self.namespace
@@ -570,6 +612,11 @@ def AttachNamespace(self):
570612
self.python_c_functions_str = NAMESPACE_WRAPPER_TEMPLATE.format(
571613
namespace, python_c_functions_str
572614
)
615+
self.python_c_funcion_declare_str = (
616+
NAMESPACE_WRAPPER_TEMPLATE.format(
617+
namespace, self.python_c_funcion_declare_str
618+
)
619+
)
573620

574621
def run(self):
575622
# Infer namespace from yaml_path
@@ -593,7 +640,8 @@ def ParseArguments():
593640
description='Eager Code Generator Args Parser'
594641
)
595642
parser.add_argument('--api_yaml_path', type=str)
596-
parser.add_argument('--output_path', type=str)
643+
parser.add_argument('--source_path', type=str)
644+
parser.add_argument('--header_path', type=str)
597645

598646
args = parser.parse_args()
599647
return args
@@ -631,6 +679,7 @@ def GeneratePythonCFile(filepath, python_c_str):
631679

632680
generated_python_c_functions = ""
633681
generated_python_c_registration = ""
682+
generated_python_c_functions_header = ""
634683
for i in range(len(api_yaml_paths)):
635684
api_yaml_path = api_yaml_paths[i]
636685

@@ -643,14 +692,22 @@ def GeneratePythonCFile(filepath, python_c_str):
643692
generated_python_c_registration += (
644693
py_c_generator.python_c_functions_reg_str
645694
)
695+
generated_python_c_functions_header += (
696+
py_c_generator.python_c_funcion_declare_str
697+
)
646698

647699
python_c_str = GeneratePythonCWrappers(
648700
generated_python_c_functions, generated_python_c_registration
649701
)
650702

651-
output_path = args.output_path
652-
for path in [output_path]:
703+
soucre_path = args.source_path
704+
header_path = args.header_path
705+
for path in [soucre_path, header_path]:
653706
if os.path.exists(path):
654707
os.remove(path)
655708

656-
GeneratePythonCFile(output_path, python_c_str)
709+
GeneratePythonCFile(soucre_path, python_c_str)
710+
GeneratePythonCFile(
711+
header_path,
712+
PYTHON_C_H_TEMPLATE.format(body=generated_python_c_functions_header),
713+
)

paddle/fluid/imperative/tracer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ thread_local AmpLevel Tracer::amp_level_ = AmpLevel::O0;
5656

5757
thread_local phi::DataType Tracer::amp_dtype_ = phi::DataType::FLOAT32;
5858

59-
static std::shared_ptr<Tracer> g_current_tracer(nullptr);
59+
static thread_local std::shared_ptr<Tracer> g_current_tracer(nullptr);
6060

6161
const std::shared_ptr<Tracer>& GetCurrentTracer() { return g_current_tracer; }
6262

paddle/fluid/ir/dialect/op_generator/ops_api_gen.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
#include <pybind11/pybind11.h>
2222
2323
#include "paddle/fluid/pybind/static_op_function.h"
24+
#include "paddle/fluid/pybind/eager_op_function.h"
2425
#include "paddle/phi/core/enforce.h"
26+
#include "paddle/fluid/eager/api/utils/global_utils.h"
2527
2628
{body}
2729
@@ -44,19 +46,63 @@
4446

4547
FUNCTION_IMPL_TEMPLATE = """
4648
static PyObject *{name}(PyObject *self, PyObject *args, PyObject *kwargs) {{
49+
if (egr::Controller::Instance().GetCurrentTracer() == nullptr) {{
50+
VLOG(6) << "Call static_api_{name}";
51+
return static_api_{name}(self, args, kwargs);
52+
}} else {{
53+
VLOG(6) << "Call eager_api_{name}";
54+
return eager_api_{name}(self, args, kwargs);
55+
}}
56+
}}"""
57+
58+
NO_DY_FUNCTION_IMPL_TEMPLATE = """
59+
static PyObject *{name}(PyObject *self, PyObject *args, PyObject *kwargs) {{
60+
VLOG(6) << "Call static_api_{name}";
4761
return static_api_{name}(self, args, kwargs);
4862
}}"""
4963

5064
OPS_API_TEMPLATE = """
5165
{{"{name}", (PyCFunction)(void (*)(void)){name}, METH_VARARGS | METH_KEYWORDS, "C++ interface function for {name}."}},"""
5266

67+
SPECIAL_STATIC_ONLY_APIS = [
68+
'fetch',
69+
'set_value_with_tensor',
70+
'set_value_with_tensor_',
71+
'fused_bn_add_activation_',
72+
'fused_batch_norm_act_',
73+
'add_n_',
74+
'set_value',
75+
'assign_value',
76+
'set_value_',
77+
'embedding_grad_sparse',
78+
'add_n_with_kernel',
79+
'print',
80+
'send_v2',
81+
'shadow_feed',
82+
'recv_v2',
83+
'rnn_',
84+
'fused_scale_bias_relu_conv_bnstats',
85+
'batch_norm_',
86+
'c_allreduce_sum',
87+
'c_embedding',
88+
'c_identity',
89+
]
90+
5391

5492
class OpsAPIGen(CodeGen):
5593
def __init__(self) -> None:
5694
super().__init__()
5795

5896
def _gen_one_function_impl(self, name):
59-
return FUNCTION_IMPL_TEMPLATE.format(name=name)
97+
if (
98+
name.endswith('grad')
99+
or name.endswith('grad_')
100+
or name.endswith('xpu')
101+
or name in SPECIAL_STATIC_ONLY_APIS
102+
):
103+
return NO_DY_FUNCTION_IMPL_TEMPLATE.format(name=name)
104+
else:
105+
return FUNCTION_IMPL_TEMPLATE.format(name=name)
60106

61107
def _gen_one_ops_api(self, name):
62108
return OPS_API_TEMPLATE.format(name=name)

paddle/fluid/pybind/.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
pybind.h
2-
eager_op_function.cc
2+
eager_op_function.*
33
eager_legacy_op_function.cc

python/paddle/_C_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@
1616

1717
__all__ = []
1818

19+
UNIFIED_APIS = ['mean']
20+
1921
for name in dir(core.eager.ops):
2022
globals()[name] = getattr(core.eager.ops, name)
2123
__all__.append(name)
24+
25+
for name in dir(core.ir.ops):
26+
if name in UNIFIED_APIS:
27+
globals()[name] = getattr(core.ir.ops, name)
28+
if name not in __all__:
29+
__all__.append(name)

python/paddle/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,7 @@
459459
from . import linalg # noqa: F401
460460
from . import fft # noqa: F401
461461
from . import signal # noqa: F401
462+
from . import _ir_ops # noqa: F401
462463

463464
import paddle.text # noqa: F401
464465
import paddle.vision # noqa: F401

0 commit comments

Comments
 (0)