Skip to content

Commit 2d0f849

Browse files
authored
[Dy2Stat] Add assert for ProgramTranslator (#24492)
Add assert grammar for ProgramTranslator
1 parent 53e3c53 commit 2d0f849

File tree

4 files changed

+138
-2
lines changed

4 files changed

+138
-2
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import gast
18+
19+
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
20+
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
21+
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
22+
23+
24+
class AssertTransformer(gast.NodeTransformer):
25+
"""
26+
A class transforms python assert to fluid.layers.Assert.
27+
"""
28+
29+
def __init__(self, wrapper_root):
30+
assert isinstance(
31+
wrapper_root, AstNodeWrapper
32+
), "Input non-AstNodeWrapper node for the initialization of AssertTransformer."
33+
self.wrapper_root = wrapper_root
34+
self.root = wrapper_root.node
35+
36+
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
37+
38+
def transform(self):
39+
self.visit(self.root)
40+
41+
def visit_Assert(self, node):
42+
if not self.static_analysis_visitor.is_tensor_node(node.test):
43+
return node
44+
cast_node = gast.Call(
45+
func=gast.parse("fluid.layers.cast").body[0].value,
46+
args=[node.test, gast.Constant(
47+
value="bool", kind=None)],
48+
keywords=[])
49+
assert_node = gast.Call(
50+
func=gast.parse("fluid.layers.Assert").body[0].value,
51+
args=[cast_node],
52+
keywords=[])
53+
return gast.Expr(value=assert_node)

python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,15 @@
2222
import inspect
2323
import textwrap
2424

25+
from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer
26+
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer
2527
from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer
2628
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer
2729
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer
2830
from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer
2931
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer
30-
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer
31-
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer
3232
from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransformer
33+
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer
3334

3435
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
3536
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
@@ -80,6 +81,9 @@ def transfer_from_node_type(self, node_wrapper):
8081
# Transform all if/else statement of Dygraph into Static Graph.
8182
IfElseTransformer(node_wrapper).transform()
8283

84+
# Transform python assert statement
85+
AssertTransformer(node_wrapper).transform()
86+
8387
# Transform all python print statement
8488
PrintTransformer(node_wrapper).transform()
8589

python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,14 @@ def get_node_to_wrapper_map(self):
256256
def get_var_env(self):
257257
return self.var_env
258258

259+
def is_tensor_node(self, node):
260+
tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES}
261+
node_wrapper = self.node_to_wrapper_map.get(node, None)
262+
if node_wrapper is None:
263+
return False
264+
if node_wrapper.node_var_type & tensor_types:
265+
return True
266+
259267
def _get_constant_node_type(self, node):
260268
assert isinstance(node, gast.Constant), \
261269
"Type of input node should be gast.Constant, but received %s" % type(node)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import numpy
18+
import unittest
19+
20+
import paddle.fluid as fluid
21+
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
22+
from paddle.fluid.dygraph.jit import declarative
23+
24+
25+
@declarative
26+
def dyfunc_assert_variable(x):
27+
x_v = fluid.dygraph.to_variable(x)
28+
assert x_v
29+
30+
31+
@declarative
32+
def dyfunc_assert_non_variable(x=True):
33+
assert x
34+
35+
36+
class TestAssertVariable(unittest.TestCase):
37+
def _run(self, func, x, with_exception, to_static):
38+
ProgramTranslator().enable(to_static)
39+
if with_exception:
40+
with self.assertRaises(BaseException):
41+
with fluid.dygraph.guard():
42+
func(x)
43+
else:
44+
with fluid.dygraph.guard():
45+
func(x)
46+
47+
def _run_dy_static(self, func, x, with_exception):
48+
self._run(func, x, with_exception, True)
49+
self._run(func, x, with_exception, False)
50+
51+
def test_non_variable(self):
52+
self._run_dy_static(
53+
dyfunc_assert_non_variable, x=False, with_exception=True)
54+
self._run_dy_static(
55+
dyfunc_assert_non_variable, x=True, with_exception=False)
56+
57+
def test_bool_variable(self):
58+
self._run_dy_static(
59+
dyfunc_assert_variable, x=numpy.array([False]), with_exception=True)
60+
self._run_dy_static(
61+
dyfunc_assert_variable, x=numpy.array([True]), with_exception=False)
62+
63+
def test_int_variable(self):
64+
self._run_dy_static(
65+
dyfunc_assert_variable, x=numpy.array([0]), with_exception=True)
66+
self._run_dy_static(
67+
dyfunc_assert_variable, x=numpy.array([1]), with_exception=False)
68+
69+
70+
if __name__ == '__main__':
71+
unittest.main()

0 commit comments

Comments
 (0)