@@ -71,128 +71,6 @@ def callable_list(x, y):
7171 self .assertEqual (paddle .jit .to_static (callable_list )(1 , 2 ), 3 )
7272
7373
74- class TestConvertShapeCompare (Dy2StTestBase ):
75- @test_legacy_and_pt_and_pir
76- def test_non_variable (self ):
77- self .assertEqual (
78- paddle .jit .dy2static .convert_shape_compare (1 , "<" , 2 ), True
79- )
80- self .assertEqual (
81- paddle .jit .dy2static .convert_shape_compare (1 , "<" , 2 , "<=" , 3 ), True
82- )
83- self .assertEqual (
84- paddle .jit .dy2static .convert_shape_compare (1 , ">" , 2 , "<=" , 3 ),
85- False ,
86- )
87-
88- def error_func ():
89- """
90- Function used to test that comparison doesn't run after first False
91- """
92- raise ValueError ("Used for test" )
93-
94- self .assertEqual (
95- paddle .jit .dy2static .convert_shape_compare (
96- 1 , ">" , 2 , "<=" , lambda : error_func ()
97- ),
98- False ,
99- )
100-
101- self .assertEqual (
102- paddle .jit .dy2static .convert_shape_compare (
103- 1 , "<" , 2 , "in" , [1 , 2 , 3 ]
104- ),
105- True ,
106- )
107- self .assertEqual (
108- paddle .jit .dy2static .convert_shape_compare (
109- 1 , "<" , 2 , "not in" , [1 , 2 , 3 ]
110- ),
111- False ,
112- )
113- self .assertEqual (
114- paddle .jit .dy2static .convert_shape_compare (1 , "<" , 2 , "is" , 3 ),
115- False ,
116- )
117- self .assertEqual (
118- paddle .jit .dy2static .convert_shape_compare (
119- 1 , "<" , 2 , "is not" , [1 , 2 , 3 ]
120- ),
121- True ,
122- )
123-
124- self .assertEqual (
125- paddle .jit .dy2static .convert_shape_compare (
126- [1 , 2 ], "==" , [1 , 2 ], "!=" , [1 , 2 , 3 ]
127- ),
128- True ,
129- )
130- self .assertEqual (
131- paddle .jit .dy2static .convert_shape_compare (
132- [1 , 2 ], "!=" , [1 , 2 , 3 ], "==" , [1 , 2 ]
133- ),
134- False ,
135- )
136-
137- def test_variable (self ):
138- paddle .enable_static ()
139- main_program = paddle .static .Program ()
140- startup_program = paddle .static .Program ()
141- with paddle .static .program_guard (main_program , startup_program ):
142- x = paddle .static .data (name = 'x' , shape = [3 , 2 ], dtype = 'float32' )
143- y = paddle .static .data (name = 'y' , shape = [3 , 2 ], dtype = 'float32' )
144- self .assertEqual (
145- paddle .jit .dy2static .convert_shape_compare (
146- x , "is" , x , "is not" , y
147- ),
148- True ,
149- )
150- self .assertEqual (
151- paddle .jit .dy2static .convert_shape_compare (
152- x , "is not" , x , "is not" , y
153- ),
154- False ,
155- )
156- self .assertEqual (
157- paddle .jit .dy2static .convert_shape_compare (x , "is" , x , "is" , y ),
158- False ,
159- )
160-
161- eq_out = paddle .jit .dy2static .convert_shape_compare (x , "==" , y )
162- not_eq_out = paddle .jit .dy2static .convert_shape_compare (x , "!=" , y )
163- long_eq_out = paddle .jit .dy2static .convert_shape_compare (
164- x , "==" , x , "!=" , y
165- )
166-
167- place = (
168- paddle .CUDAPlace (0 )
169- if paddle .is_compiled_with_cuda ()
170- else paddle .CPUPlace ()
171- )
172- exe = paddle .static .Executor (place )
173- x_y_eq_out = exe .run (
174- feed = {
175- "x" : np .ones ([3 , 2 ]).astype (np .float32 ),
176- "y" : np .ones ([3 , 2 ]).astype (np .float32 ),
177- },
178- fetch_list = [eq_out , not_eq_out , long_eq_out ],
179- )
180- np .testing .assert_array_equal (
181- np .array (x_y_eq_out ), np .array ([True , False , False ])
182- )
183-
184- set_a_zero = np .ones ([3 , 2 ]).astype (np .float32 )
185- set_a_zero [0 ][0 ] = 0.0
186- x_y_not_eq_out = exe .run (
187- feed = {"x" : np .ones ([3 , 2 ]).astype (np .float32 ), "y" : set_a_zero },
188- fetch_list = [eq_out , not_eq_out , long_eq_out ],
189- )
190- np .testing .assert_array_equal (
191- np .array (x_y_not_eq_out ), np .array ([False , True , True ])
192- )
193- paddle .disable_static ()
194-
195-
19674class ShapeLayer (paddle .nn .Layer ):
19775 def __init__ (self ):
19876 super ().__init__ ()
0 commit comments