@@ -71,128 +71,6 @@ def callable_list(x, y):
71
71
self .assertEqual (paddle .jit .to_static (callable_list )(1 , 2 ), 3 )
72
72
73
73
74
- class TestConvertShapeCompare (Dy2StTestBase ):
75
- @test_legacy_and_pt_and_pir
76
- def test_non_variable (self ):
77
- self .assertEqual (
78
- paddle .jit .dy2static .convert_shape_compare (1 , "<" , 2 ), True
79
- )
80
- self .assertEqual (
81
- paddle .jit .dy2static .convert_shape_compare (1 , "<" , 2 , "<=" , 3 ), True
82
- )
83
- self .assertEqual (
84
- paddle .jit .dy2static .convert_shape_compare (1 , ">" , 2 , "<=" , 3 ),
85
- False ,
86
- )
87
-
88
- def error_func ():
89
- """
90
- Function used to test that comparison doesn't run after first False
91
- """
92
- raise ValueError ("Used for test" )
93
-
94
- self .assertEqual (
95
- paddle .jit .dy2static .convert_shape_compare (
96
- 1 , ">" , 2 , "<=" , lambda : error_func ()
97
- ),
98
- False ,
99
- )
100
-
101
- self .assertEqual (
102
- paddle .jit .dy2static .convert_shape_compare (
103
- 1 , "<" , 2 , "in" , [1 , 2 , 3 ]
104
- ),
105
- True ,
106
- )
107
- self .assertEqual (
108
- paddle .jit .dy2static .convert_shape_compare (
109
- 1 , "<" , 2 , "not in" , [1 , 2 , 3 ]
110
- ),
111
- False ,
112
- )
113
- self .assertEqual (
114
- paddle .jit .dy2static .convert_shape_compare (1 , "<" , 2 , "is" , 3 ),
115
- False ,
116
- )
117
- self .assertEqual (
118
- paddle .jit .dy2static .convert_shape_compare (
119
- 1 , "<" , 2 , "is not" , [1 , 2 , 3 ]
120
- ),
121
- True ,
122
- )
123
-
124
- self .assertEqual (
125
- paddle .jit .dy2static .convert_shape_compare (
126
- [1 , 2 ], "==" , [1 , 2 ], "!=" , [1 , 2 , 3 ]
127
- ),
128
- True ,
129
- )
130
- self .assertEqual (
131
- paddle .jit .dy2static .convert_shape_compare (
132
- [1 , 2 ], "!=" , [1 , 2 , 3 ], "==" , [1 , 2 ]
133
- ),
134
- False ,
135
- )
136
-
137
- def test_variable (self ):
138
- paddle .enable_static ()
139
- main_program = paddle .static .Program ()
140
- startup_program = paddle .static .Program ()
141
- with paddle .static .program_guard (main_program , startup_program ):
142
- x = paddle .static .data (name = 'x' , shape = [3 , 2 ], dtype = 'float32' )
143
- y = paddle .static .data (name = 'y' , shape = [3 , 2 ], dtype = 'float32' )
144
- self .assertEqual (
145
- paddle .jit .dy2static .convert_shape_compare (
146
- x , "is" , x , "is not" , y
147
- ),
148
- True ,
149
- )
150
- self .assertEqual (
151
- paddle .jit .dy2static .convert_shape_compare (
152
- x , "is not" , x , "is not" , y
153
- ),
154
- False ,
155
- )
156
- self .assertEqual (
157
- paddle .jit .dy2static .convert_shape_compare (x , "is" , x , "is" , y ),
158
- False ,
159
- )
160
-
161
- eq_out = paddle .jit .dy2static .convert_shape_compare (x , "==" , y )
162
- not_eq_out = paddle .jit .dy2static .convert_shape_compare (x , "!=" , y )
163
- long_eq_out = paddle .jit .dy2static .convert_shape_compare (
164
- x , "==" , x , "!=" , y
165
- )
166
-
167
- place = (
168
- paddle .CUDAPlace (0 )
169
- if paddle .is_compiled_with_cuda ()
170
- else paddle .CPUPlace ()
171
- )
172
- exe = paddle .static .Executor (place )
173
- x_y_eq_out = exe .run (
174
- feed = {
175
- "x" : np .ones ([3 , 2 ]).astype (np .float32 ),
176
- "y" : np .ones ([3 , 2 ]).astype (np .float32 ),
177
- },
178
- fetch_list = [eq_out , not_eq_out , long_eq_out ],
179
- )
180
- np .testing .assert_array_equal (
181
- np .array (x_y_eq_out ), np .array ([True , False , False ])
182
- )
183
-
184
- set_a_zero = np .ones ([3 , 2 ]).astype (np .float32 )
185
- set_a_zero [0 ][0 ] = 0.0
186
- x_y_not_eq_out = exe .run (
187
- feed = {"x" : np .ones ([3 , 2 ]).astype (np .float32 ), "y" : set_a_zero },
188
- fetch_list = [eq_out , not_eq_out , long_eq_out ],
189
- )
190
- np .testing .assert_array_equal (
191
- np .array (x_y_not_eq_out ), np .array ([False , True , True ])
192
- )
193
- paddle .disable_static ()
194
-
195
-
196
74
class ShapeLayer (paddle .nn .Layer ):
197
75
def __init__ (self ):
198
76
super ().__init__ ()
0 commit comments