Skip to content

Commit a20ce3e

Browse files
Aurelius84liym27chenwhql
authored
[cherry-pick][Dy2stat] training with @declarative decorator and save_inference_model (#24557)
* [Dy2Stat] Add test for ptb model. (#24076) * [Dy2Stat] Add test for ptb model. test=develop * Simplify code for gast.If in is_control_flow_to_transform. test=develop * Move IsControlFlowVisitor to file utils. test=develop * Don't use convert_call for build-in func in CallTransformer. test=develop * Optimize api is_control_flow_to_transform. test=develop * Polish the document of IsControlFlowVisitor. test=develop * Use declarative instead of dygraph_to_static_func. test=develop * [dy2static] Add print transformer and unify print format (#24068) * add print transformer & unify print format, test=develop * remove using of dygraph_to_static_func, test=develop * remove python stdout capture, test=develop * fix compatibility problems for PY2, test=develop * fix detail error, test=develop * fix type analysis bug, test=develop * fix print tuple compatible error in PY2, test=develop * replace get_func to declarative, test=develop * fix detail bug, test=develop * fix some detail problems, test=develop * change visit_call in print transformer, test=develop * [dy2static] Support for static graph training with @declarative decorator (#24259) * support to train in static * support to independent decorator * remove in_dygraph_mode condition in ProgramTranslator * fix import param_guard and add train/eval test=develop * Modify into ShareVarsFromScope and rm __all__ in partial_program test=develop * [Dy2Stat] Optimize loop cond (#24049) * Simplify code for gast.If in is_control_flow_to_transform. * Move IsControlFlowVisitor to file utils. * Don't use convert_call for build-in func in CallTransformer. * Optimize api is_control_flow_to_transform. * Polish the document of IsControlFlowVisitor. * revert modification from #24259 * [dy2stat]Support save_inference_model in program_translator (#24353) * support save_inference_model in program_translator test=develop * fix compatibility with OrderedDict.values() in python3 test=develop * synchronized random_seed test=develop * Polish Error Message test=develop * Fix bug with `if Tensor` in is_control_flow (#24433) * fix bug with `if Tensor` in is_control_flow test=develop * remove continue test=develop * Revert "[dy2static] Add print transformer and unify print format (#24068)" This reverts commit 09dd019. * Revert "[dy2static] Add print transformer and unify print format (#24068)" This reverts commit 09dd019. * fix sample code in sava_inference_model test=develop Co-authored-by: liym27 <[email protected]> Co-authored-by: Chen Weihang <[email protected]>
1 parent bc1e17e commit a20ce3e

38 files changed

+1710
-1108
lines changed

paddle/fluid/operators/run_program_op.h

Lines changed: 16 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -102,74 +102,50 @@ static void CheckOutputVarStatus(const Variable &src_var,
102102
}
103103

104104
static void VariableShare(const Variable &src_var, Variable *dst_var) {
105-
// The previous check ensures that the variable type can only be LoDTensor
106-
auto *lod_tensor = dst_var->GetMutable<LoDTensor>();
107-
lod_tensor->ShareDataWith(src_var.Get<LoDTensor>());
108-
lod_tensor->set_lod(src_var.Get<LoDTensor>().lod());
109-
}
110-
111-
static void ShareVarsIntoScope(const std::vector<Variable *> &vars,
112-
const std::vector<std::string> &var_names,
113-
framework::Scope *scope) {
114-
for (size_t i = 0; i < vars.size(); ++i) {
115-
auto *var = scope->Var(var_names[i]);
116-
CheckInputVarStatus(*vars[i], var_names[i]);
117-
VariableShare(*vars[i], var);
118-
}
119-
}
120-
121-
static void VariableCopy(const Variable &src_var,
122-
const platform::Place &dst_place, Variable *dst_var) {
123105
// The previous check ensures that the variable type can only be LoDTensor or
124-
// SelectedRows
106+
// SelectedRows.
125107
if (src_var.IsType<LoDTensor>()) {
126108
auto *lod_tensor = dst_var->GetMutable<LoDTensor>();
127-
TensorCopySync(src_var.Get<LoDTensor>(), dst_place, lod_tensor);
109+
lod_tensor->ShareDataWith(src_var.Get<LoDTensor>());
128110
lod_tensor->set_lod(src_var.Get<LoDTensor>().lod());
129111
} else if (src_var.IsType<SelectedRows>()) {
130112
auto *selected_rows = dst_var->GetMutable<SelectedRows>();
131-
TensorCopySync(src_var.Get<SelectedRows>().value(), dst_place,
132-
selected_rows->mutable_value());
113+
selected_rows->mutable_value()->ShareDataWith(
114+
src_var.Get<SelectedRows>().value());
133115
selected_rows->set_rows(src_var.Get<SelectedRows>().rows());
134116
selected_rows->set_height(src_var.Get<SelectedRows>().height());
135117
}
136118
}
137119

138-
static void ShareVarsFromScope(const std::vector<Variable *> &vars,
120+
static void ShareVarsIntoScope(const std::vector<Variable *> &vars,
139121
const std::vector<std::string> &var_names,
140122
framework::Scope *scope) {
141123
for (size_t i = 0; i < vars.size(); ++i) {
142-
auto *var = scope->FindVar(var_names[i]);
143-
PADDLE_ENFORCE_NOT_NULL(
144-
var, platform::errors::NotFound("The output variable %s is not in "
145-
"RunProgram(Grad)Op(StaticModelRunner)'"
146-
"s internal scope.",
147-
var_names[i]));
148-
CheckOutputVarStatus(*var, *vars[i], var_names[i]);
149-
VariableShare(*var, vars[i]);
124+
auto *var = scope->Var(var_names[i]);
125+
CheckInputVarStatus(*vars[i], var_names[i]);
126+
VariableShare(*vars[i], var);
150127
}
151128
}
152129

153-
static void CopyVarsFromScope(const std::vector<Variable *> &vars,
154-
const std::vector<std::string> &var_names,
155-
const platform::Place &dst_place,
156-
framework::Scope *scope) {
130+
static void ShareVarsFromScope(const std::vector<Variable *> &vars,
131+
const std::vector<std::string> &var_names,
132+
framework::Scope *scope) {
157133
for (size_t i = 0; i < vars.size(); ++i) {
158134
if (var_names[i] == framework::kEmptyVarName) {
159135
VLOG(2) << "find variable name is " << framework::kEmptyVarName
160136
<< ", skip it!";
161137
continue;
162138
}
163-
auto *var = scope->FindVar(var_names[i]);
164139
// NOTE: Here skip not found var is dangerous, if a bug is caused here,
165140
// the result is grad calculation error, which will be very hidden!
141+
auto *var = scope->FindVar(var_names[i]);
166142
PADDLE_ENFORCE_NOT_NULL(
167143
var, platform::errors::NotFound("The output variable %s is not in "
168144
"RunProgram(Grad)Op(StaticModelRunner)'"
169145
"s internal scope.",
170146
var_names[i]));
171147
CheckOutputVarStatus(*var, *vars[i], var_names[i]);
172-
VariableCopy(*var, dst_place, vars[i]);
148+
VariableShare(*var, vars[i]);
173149
}
174150
}
175151

@@ -306,11 +282,9 @@ class RunProgramGradOpKernel : public framework::OpKernel<T> {
306282
end_op_index, /*create_local_scope=*/false,
307283
/*create_vars=*/true, /*keep_kids=*/false);
308284

309-
// Step 4. copy outputs
310-
details::CopyVarsFromScope(input_grad_vars, input_grad_var_names,
311-
ctx.GetPlace(), &scope);
312-
details::CopyVarsFromScope(param_grad_vars, param_grad_names,
313-
ctx.GetPlace(), &scope);
285+
// Step 4. get outputs
286+
details::ShareVarsFromScope(input_grad_vars, input_grad_var_names, &scope);
287+
details::ShareVarsFromScope(param_grad_vars, param_grad_names, &scope);
314288
}
315289
};
316290

python/paddle/fluid/data_feeder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def check_variable_and_dtype(input,
7676
expected_dtype,
7777
op_name,
7878
extra_message=''):
79-
check_type(input, input_name, Variable, op_name, extra_message)
79+
check_type(input, input_name, (Variable, core.VarBase), op_name,
80+
extra_message)
8081
check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message)
8182

8283

python/paddle/fluid/dygraph/base.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,26 @@ def program_desc_tracing_guard(enable):
6161
_functional_dygraph_context_manager = None
6262

6363

64+
@signature_safe_contextmanager
65+
def param_guard(parameters):
66+
# Note: parameters is a reference of self._parameters
67+
if not framework.in_dygraph_mode() and parameters:
68+
origin_parameters = parameters.copy()
69+
for name, var_base in parameters.items():
70+
if isinstance(var_base, core.VarBase):
71+
new_var = framework.Parameter(
72+
var_base.block,
73+
var_base.shape,
74+
var_base.dtype,
75+
var_base.type,
76+
name=var_base.name)
77+
parameters[name] = new_var
78+
yield
79+
parameters.update(origin_parameters)
80+
else:
81+
yield
82+
83+
6484
def enabled():
6585
"""
6686
This function checks whether the program runs in dynamic graph mode or not.

python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
from __future__ import print_function
1616

17-
import astor
18-
import copy
1917
# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST).
2018
# It provides a compatibility layer between the AST of various Python versions,
2119
# as produced by ast.parse from the standard ast module.
@@ -24,8 +22,6 @@
2422
import inspect
2523
import textwrap
2624

27-
from paddle.fluid import unique_name
28-
2925
from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer
3026
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer
3127
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer
@@ -34,14 +30,9 @@
3430
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer
3531
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer
3632

37-
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
38-
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
3933
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
4034
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
41-
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api
4235
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
43-
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api, is_dygraph_api, is_to_variable
44-
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
4536

4637
__all__ = ['DygraphToStaticAst', 'convert_to_static']
4738

@@ -142,6 +133,9 @@ def convert_to_static(dyfunc):
142133
Converts dygraph function into static function.
143134
"""
144135
# Get AST from dygraph function
136+
# Note: In Python2, it will raise OSError when inspect function
137+
# with decorator directly and dyfunc.__wrapped__ holds the actual function.
138+
dyfunc = getattr(dyfunc, '__wrapped__', dyfunc)
145139
raw_code = inspect.getsource(dyfunc)
146140
code = textwrap.dedent(raw_code)
147141
root = gast.parse(code)

python/paddle/fluid/dygraph/dygraph_to_static/break_continue_transformer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@
1717
import gast
1818

1919
from paddle.fluid import unique_name
20-
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import NodeTestTransformer
21-
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
22-
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
2320
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
2421
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
2522
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node

python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,26 @@ def __init__(self, wrapper_root):
3232
self.wrapper_root = wrapper_root
3333
self.root = wrapper_root.node
3434

35+
def _is_builtin_call(self, node):
36+
assert isinstance(node, gast.Call)
37+
func_str = ast_to_source_code(node.func).strip()
38+
try:
39+
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin
40+
return eval("is_builtin({})".format(func_str))
41+
except Exception:
42+
return False
43+
3544
def transform(self):
3645
self.visit(self.root)
3746

3847
def visit_Call(self, node):
3948
self.generic_visit(node)
4049
if is_paddle_api(node):
4150
return node
51+
52+
if self._is_builtin_call(node):
53+
return node
54+
4255
func_str = ast_to_source_code(node.func).strip()
4356
new_func_str = "fluid.dygraph.dygraph_to_static.convert_call({})".format(
4457
func_str)

python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -103,63 +103,34 @@ def dyfunc(x):
103103
return func
104104
try:
105105
if func in func.__globals__.values():
106-
if six.PY3:
107-
source_code = inspect.getsource(func)
108-
if any(decorator in source_code
109-
for decorator in DECORATOR_NAMES):
110-
converted_call = None
111-
else:
112-
converted_call = to_static_func(func)
113-
func_self = getattr(func, '__self__', None)
114-
else:
115-
converted_call = to_static_func(func)
116-
func_self = getattr(func, '__self__', None)
106+
converted_call = to_static_func(func)
107+
func_self = getattr(func, '__self__', None)
117108
except AttributeError:
118109
# NOTE:
119110
# If func is not in __globals__, it does not need to be transformed
120111
# because it has been transformed before.
121112
converted_call = None
122113
except (IOError, OSError):
123114
# NOTE:
124-
# If func has beed decorated, its source code can not be get
115+
# If func has been decorated, its source code can not be get
125116
# so that it can not be transformed to static function.
126117
converted_call = None
127118
elif inspect.ismethod(func):
128119
try:
129-
if six.PY3:
130-
source_code = inspect.getsource(func)
131-
if any(decorator in source_code
132-
for decorator in DECORATOR_NAMES):
133-
converted_call = None
134-
else:
135-
converted_call = to_static_func(func)
136-
func_self = getattr(func, '__self__', None)
137-
else:
138-
converted_call = to_static_func(func)
139-
func_self = getattr(func, '__self__', None)
120+
converted_call = to_static_func(func)
121+
func_self = getattr(func, '__self__', None)
140122
except (IOError, OSError):
141-
# NOTE: func may have beed decorated.
123+
# NOTE: func may have been decorated.
142124
converted_call = None
143125

144126
elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'):
145127
if hasattr(func, 'forward') and isinstance(func, Layer):
146128
try:
147-
if six.PY3:
148-
source_code = inspect.getsource(func.forward)
149-
if any(decorator in source_code
150-
for decorator in DECORATOR_NAMES):
151-
converted_call = None
152-
else:
153-
forward_func = to_static_func(func.forward)
154-
setattr(func, 'forward', forward_func)
155-
func_self = func
156-
else:
157-
forward_func = to_static_func(func.forward)
158-
setattr(func, 'forward', forward_func)
159-
func_self = func
160-
129+
forward_func = to_static_func(func.forward)
130+
setattr(func, 'forward', forward_func)
131+
func_self = func
161132
except Exception:
162-
# NOTE: func.forward may have beed decorated.
133+
# NOTE: func.forward may have been decorated.
163134
func_self = None if func_self else func_self
164135
converted_call = func
165136
else:

0 commit comments

Comments
 (0)