Skip to content

Commit fbba94f

Browse files
authored
[Dy2St]Fix BUG with Potential security vulnerabilities (#60100)
1 parent e8ee704 commit fbba94f

File tree

3 files changed

+0
-194
lines changed

3 files changed

+0
-194
lines changed

python/paddle/jit/dy2static/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
convert_logical_or as Or,
2727
convert_pop as Pop,
2828
convert_shape as Shape,
29-
convert_shape_compare,
3029
convert_var_dtype as AsDtype,
3130
convert_while_loop as While,
3231
indexable as Indexable,

python/paddle/jit/dy2static/convert_operators.py

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -693,77 +693,6 @@ def has_negative(list_shape):
693693
return x.shape
694694

695695

696-
def convert_shape_compare(left, *args):
697-
"""
698-
A function handles comparison difference between Paddle and Python.
699-
For example, if x and y are Tensors, x.shape == y.shape will return single
700-
boolean Value (True/False). However, paddle.shape(x) == paddle.shape(y) is
701-
an element-wise comparison. The difference can cause dy2stat error. So we
702-
create this function to handle the difference.
703-
704-
Args:
705-
left: variable
706-
*args: compare_op(str), variable, compare_op(str), variable, where
707-
compare_op means "<", ">", "==", "!=", etc.
708-
Returns:
709-
If the variables to compare are NOT Paddle Variables, we will return as
710-
Python like "a op1 b and b op2 c and ... ".
711-
If the variables to compare are Paddle Variables, we will do elementwise
712-
comparsion first and then reduce to a boolean whose numel is 1.
713-
714-
"""
715-
args_len = len(args)
716-
assert (
717-
args_len >= 2
718-
), "convert_shape_compare needs at least one right compare variable"
719-
assert (
720-
args_len % 2 == 0
721-
), "Illegal input for convert_shape_compare, *args should be op(str), var, op(str), var ..."
722-
num_cmp = args_len // 2
723-
if isinstance(left, (Variable, Value)):
724-
725-
def reduce_compare(x, op_str, y):
726-
element_wise_result = eval("x " + op_str + " y")
727-
if op_str == "!=":
728-
return paddle.any(element_wise_result)
729-
elif (
730-
op_str == "is"
731-
or op_str == "is not"
732-
or op_str == "in"
733-
or op_str == "not in"
734-
):
735-
return element_wise_result
736-
else:
737-
return paddle.all(element_wise_result)
738-
739-
final_result = reduce_compare(left, args[0], args[1])
740-
for i in range(1, num_cmp):
741-
cmp_left = args[i * 2 - 1]
742-
cmp_op = args[i * 2]
743-
cmp_right = args[i * 2 + 1]
744-
cur_result = reduce_compare(cmp_left, cmp_op, cmp_right)
745-
final_result = convert_logical_and(
746-
lambda: final_result, lambda: cur_result
747-
)
748-
return final_result
749-
else:
750-
cmp_left = left
751-
final_result = None
752-
for i in range(num_cmp):
753-
cmp_op = args[i * 2]
754-
cmp_right = args[i * 2 + 1]
755-
cur_result = eval("cmp_left " + cmp_op + " cmp_right")
756-
if final_result is None:
757-
final_result = cur_result
758-
else:
759-
final_result = final_result and cur_result
760-
761-
if final_result is False:
762-
return False
763-
cmp_left = cmp_right
764-
return final_result
765-
766-
767696
def cast_bool_if_necessary(var):
768697
assert isinstance(var, (Variable, Value))
769698
if convert_dtype(var.dtype) not in ['bool']:

test/dygraph_to_static/test_convert_operators.py

Lines changed: 0 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -71,128 +71,6 @@ def callable_list(x, y):
7171
self.assertEqual(paddle.jit.to_static(callable_list)(1, 2), 3)
7272

7373

74-
class TestConvertShapeCompare(Dy2StTestBase):
75-
@test_legacy_and_pt_and_pir
76-
def test_non_variable(self):
77-
self.assertEqual(
78-
paddle.jit.dy2static.convert_shape_compare(1, "<", 2), True
79-
)
80-
self.assertEqual(
81-
paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "<=", 3), True
82-
)
83-
self.assertEqual(
84-
paddle.jit.dy2static.convert_shape_compare(1, ">", 2, "<=", 3),
85-
False,
86-
)
87-
88-
def error_func():
89-
"""
90-
Function used to test that comparison doesn't run after first False
91-
"""
92-
raise ValueError("Used for test")
93-
94-
self.assertEqual(
95-
paddle.jit.dy2static.convert_shape_compare(
96-
1, ">", 2, "<=", lambda: error_func()
97-
),
98-
False,
99-
)
100-
101-
self.assertEqual(
102-
paddle.jit.dy2static.convert_shape_compare(
103-
1, "<", 2, "in", [1, 2, 3]
104-
),
105-
True,
106-
)
107-
self.assertEqual(
108-
paddle.jit.dy2static.convert_shape_compare(
109-
1, "<", 2, "not in", [1, 2, 3]
110-
),
111-
False,
112-
)
113-
self.assertEqual(
114-
paddle.jit.dy2static.convert_shape_compare(1, "<", 2, "is", 3),
115-
False,
116-
)
117-
self.assertEqual(
118-
paddle.jit.dy2static.convert_shape_compare(
119-
1, "<", 2, "is not", [1, 2, 3]
120-
),
121-
True,
122-
)
123-
124-
self.assertEqual(
125-
paddle.jit.dy2static.convert_shape_compare(
126-
[1, 2], "==", [1, 2], "!=", [1, 2, 3]
127-
),
128-
True,
129-
)
130-
self.assertEqual(
131-
paddle.jit.dy2static.convert_shape_compare(
132-
[1, 2], "!=", [1, 2, 3], "==", [1, 2]
133-
),
134-
False,
135-
)
136-
137-
def test_variable(self):
138-
paddle.enable_static()
139-
main_program = paddle.static.Program()
140-
startup_program = paddle.static.Program()
141-
with paddle.static.program_guard(main_program, startup_program):
142-
x = paddle.static.data(name='x', shape=[3, 2], dtype='float32')
143-
y = paddle.static.data(name='y', shape=[3, 2], dtype='float32')
144-
self.assertEqual(
145-
paddle.jit.dy2static.convert_shape_compare(
146-
x, "is", x, "is not", y
147-
),
148-
True,
149-
)
150-
self.assertEqual(
151-
paddle.jit.dy2static.convert_shape_compare(
152-
x, "is not", x, "is not", y
153-
),
154-
False,
155-
)
156-
self.assertEqual(
157-
paddle.jit.dy2static.convert_shape_compare(x, "is", x, "is", y),
158-
False,
159-
)
160-
161-
eq_out = paddle.jit.dy2static.convert_shape_compare(x, "==", y)
162-
not_eq_out = paddle.jit.dy2static.convert_shape_compare(x, "!=", y)
163-
long_eq_out = paddle.jit.dy2static.convert_shape_compare(
164-
x, "==", x, "!=", y
165-
)
166-
167-
place = (
168-
paddle.CUDAPlace(0)
169-
if paddle.is_compiled_with_cuda()
170-
else paddle.CPUPlace()
171-
)
172-
exe = paddle.static.Executor(place)
173-
x_y_eq_out = exe.run(
174-
feed={
175-
"x": np.ones([3, 2]).astype(np.float32),
176-
"y": np.ones([3, 2]).astype(np.float32),
177-
},
178-
fetch_list=[eq_out, not_eq_out, long_eq_out],
179-
)
180-
np.testing.assert_array_equal(
181-
np.array(x_y_eq_out), np.array([True, False, False])
182-
)
183-
184-
set_a_zero = np.ones([3, 2]).astype(np.float32)
185-
set_a_zero[0][0] = 0.0
186-
x_y_not_eq_out = exe.run(
187-
feed={"x": np.ones([3, 2]).astype(np.float32), "y": set_a_zero},
188-
fetch_list=[eq_out, not_eq_out, long_eq_out],
189-
)
190-
np.testing.assert_array_equal(
191-
np.array(x_y_not_eq_out), np.array([False, True, True])
192-
)
193-
paddle.disable_static()
194-
195-
19674
class ShapeLayer(paddle.nn.Layer):
19775
def __init__(self):
19876
super().__init__()

0 commit comments

Comments
 (0)