@@ -74,19 +74,29 @@ def test_predict(get_pytorch_detr):
7474 assert list (result [0 ].keys ()) == ["boxes" , "labels" , "scores" ]
7575
7676 assert result [0 ]["boxes" ].shape == (100 , 4 )
77- expected_detection_boxes = np .asarray ([- 5.9490204e-03 , 1.1947733e+01 , 3.1993944e+01 , 3.1925127e+01 ])
77+ expected_detection_boxes = np .asarray ([- 5.9490204e-03 , 1.1947733e01 , 3.1993944e01 , 3.1925127e01 ])
7878 np .testing .assert_array_almost_equal (result [0 ]["boxes" ][2 , :], expected_detection_boxes , decimal = 3 )
7979
8080 assert result [0 ]["scores" ].shape == (100 ,)
8181 expected_detection_scores = np .asarray (
82- [0.00679839 , 0.0250559 , 0.07205943 , 0.01115368 , 0.03321039 ,
83- 0.10407761 , 0.00113309 , 0.01442852 , 0.00527624 , 0.01240906 ]
82+ [
83+ 0.00679839 ,
84+ 0.0250559 ,
85+ 0.07205943 ,
86+ 0.01115368 ,
87+ 0.03321039 ,
88+ 0.10407761 ,
89+ 0.00113309 ,
90+ 0.01442852 ,
91+ 0.00527624 ,
92+ 0.01240906 ,
93+ ]
8494 )
85- np .testing .assert_array_almost_equal (result [0 ]["scores" ][:10 ], expected_detection_scores , decimal = 6 )
95+ np .testing .assert_array_almost_equal (result [0 ]["scores" ][:10 ], expected_detection_scores , decimal = 5 )
8696
8797 assert result [0 ]["labels" ].shape == (100 ,)
8898 expected_detection_classes = np .asarray ([17 , 17 , 33 , 17 , 17 , 17 , 74 , 17 , 17 , 17 ])
89- np .testing .assert_array_almost_equal (result [0 ]["labels" ][:10 ], expected_detection_classes , decimal = 6 )
99+ np .testing .assert_array_almost_equal (result [0 ]["labels" ][:10 ], expected_detection_classes , decimal = 5 )
90100
91101
92102@pytest .mark .only_with_platform ("pytorch" )
@@ -99,26 +109,79 @@ def test_loss_gradient(get_pytorch_detr):
99109 assert grads .shape == (2 , 3 , 800 , 800 )
100110
101111 expected_gradients1 = np .asarray (
102- [- 0.00061366 , 0.00322502 , - 0.00039866 , - 0.00807413 , - 0.00476555 ,
103- 0.00181204 , 0.01007765 , 0.00415828 , - 0.00073114 , 0.00018387 ,
104- - 0.00146992 , - 0.00119636 , - 0.00098966 , - 0.00295517 , - 0.0024271 ,
105- - 0.00131314 , - 0.00149217 , - 0.00104926 , - 0.00154239 , - 0.00110989 ,
106- 0.00092887 , 0.00049146 , - 0.00292508 , - 0.00124526 , 0.00140347 ,
107- 0.00019833 , 0.00191074 , - 0.00117537 , - 0.00080604 , 0.00057427 ,
108- - 0.00061728 , - 0.00206535 ]
112+ [
113+ - 0.00061366 ,
114+ 0.00322502 ,
115+ - 0.00039866 ,
116+ - 0.00807413 ,
117+ - 0.00476555 ,
118+ 0.00181204 ,
119+ 0.01007765 ,
120+ 0.00415828 ,
121+ - 0.00073114 ,
122+ 0.00018387 ,
123+ - 0.00146992 ,
124+ - 0.00119636 ,
125+ - 0.00098966 ,
126+ - 0.00295517 ,
127+ - 0.0024271 ,
128+ - 0.00131314 ,
129+ - 0.00149217 ,
130+ - 0.00104926 ,
131+ - 0.00154239 ,
132+ - 0.00110989 ,
133+ 0.00092887 ,
134+ 0.00049146 ,
135+ - 0.00292508 ,
136+ - 0.00124526 ,
137+ 0.00140347 ,
138+ 0.00019833 ,
139+ 0.00191074 ,
140+ - 0.00117537 ,
141+ - 0.00080604 ,
142+ 0.00057427 ,
143+ - 0.00061728 ,
144+ - 0.00206535 ,
145+ ]
109146 )
110147
111148 np .testing .assert_array_almost_equal (grads [0 , 0 , 10 , :32 ], expected_gradients1 , decimal = 2 )
112149
113150 expected_gradients2 = np .asarray (
114- [- 1.1787530e-03 , - 2.8500680e-03 , 5.0884970e-03 , 6.4504531e-04 ,
115- - 6.8841036e-05 , 2.8184296e-03 , 3.0257765e-03 , 2.8565727e-04 ,
116- - 1.0701057e-04 , 1.2945699e-03 , 7.3593057e-04 , 1.0177144e-03 ,
117- - 2.4692707e-03 , - 1.3801848e-03 , 6.3182280e-04 , - 4.2305476e-04 ,
118- 4.4307750e-04 , 8.5821096e-04 , - 7.1204413e-04 , - 3.1404425e-03 ,
119- - 1.5964351e-03 , - 1.9222996e-03 , - 5.3157361e-04 , - 9.9202688e-04 ,
120- - 1.5815455e-03 , 2.0060266e-04 , - 2.0584739e-03 , 6.6960667e-04 ,
121- 9.7393827e-04 , - 1.6040013e-03 , - 6.9741381e-04 , 1.4657658e-04 ]
151+ [
152+ - 1.1787530e-03 ,
153+ - 2.8500680e-03 ,
154+ 5.0884970e-03 ,
155+ 6.4504531e-04 ,
156+ - 6.8841036e-05 ,
157+ 2.8184296e-03 ,
158+ 3.0257765e-03 ,
159+ 2.8565727e-04 ,
160+ - 1.0701057e-04 ,
161+ 1.2945699e-03 ,
162+ 7.3593057e-04 ,
163+ 1.0177144e-03 ,
164+ - 2.4692707e-03 ,
165+ - 1.3801848e-03 ,
166+ 6.3182280e-04 ,
167+ - 4.2305476e-04 ,
168+ 4.4307750e-04 ,
169+ 8.5821096e-04 ,
170+ - 7.1204413e-04 ,
171+ - 3.1404425e-03 ,
172+ - 1.5964351e-03 ,
173+ - 1.9222996e-03 ,
174+ - 5.3157361e-04 ,
175+ - 9.9202688e-04 ,
176+ - 1.5815455e-03 ,
177+ 2.0060266e-04 ,
178+ - 2.0584739e-03 ,
179+ 6.6960667e-04 ,
180+ 9.7393827e-04 ,
181+ - 1.6040013e-03 ,
182+ - 6.9741381e-04 ,
183+ 1.4657658e-04 ,
184+ ]
122185 )
123186 np .testing .assert_array_almost_equal (grads [1 , 0 , 10 , :32 ], expected_gradients2 , decimal = 2 )
124187
@@ -236,7 +299,7 @@ def test_pgd(get_pytorch_detr):
236299
237300 imgs = []
238301 for i in x_test :
239- img = Image .fromarray ((i * 255 ).astype (np .uint8 ).transpose (1 ,2 , 0 ))
302+ img = Image .fromarray ((i * 255 ).astype (np .uint8 ).transpose (1 , 2 , 0 ))
240303 img = img .resize (size = (800 , 800 ))
241304 imgs .append (np .array (img ))
242305 x_test = np .array (imgs ).transpose (0 , 3 , 1 , 2 )
0 commit comments