3030@pytest .mark .only_with_platform ("pytorch" )
3131def test_predict (art_warning , get_pytorch_detr ):
3232 try :
33- object_detector , x_test , _ = get_pytorch_detr
34-
35- result = object_detector .predict (x = x_test )
33+ from art .utils import non_maximum_suppression
3634
37- assert list ( result [ 0 ]. keys ()) == [ "boxes" , "labels" , "scores" ]
35+ object_detector , x_test , _ = get_pytorch_detr
3836
39- assert result [ 0 ][ "boxes" ]. shape == ( 100 , 4 )
40- expected_detection_boxes = np . asarray ([ - 0.12423098 , 361.80136 , 82.385345 , 795.50305 ] )
41- np . testing . assert_array_almost_equal (result [ 0 ][ "boxes" ][ 2 , :], expected_detection_boxes , decimal = 1 )
37+ preds = object_detector . predict ( x_test )
38+ result = non_maximum_suppression ( preds [ 0 ], iou_threshold = 0.4 , confidence_threshold = 0.3 )
39+ assert list (result . keys ()) == [ "boxes" , "labels" , "scores" ]
4240
43- assert result [0 ][ "scores " ].shape == (100 , )
44- expected_detection_scores = np .asarray (
41+ assert result ["boxes " ].shape == (3 , 4 )
42+ expected_detection_boxes = np .asarray (
4543 [
46- 0.00105285 ,
47- 0.00261505 ,
48- 0.00060220 ,
49- 0.00121928 ,
50- 0.00154554 ,
51- 0.00021678 ,
52- 0.00077083 ,
53- 0.00045684 ,
54- 0.00180561 ,
55- 0.00067704 ,
44+ [1.0126123 , 25.658852 , 412.70746 , 379.12537 ],
45+ [- 0.089400 , 272.08664 , 415.90994 , 416.25930 ],
46+ [0.1522941 , 75.882440 , 99.139565 , 335.11273 ],
5647 ]
5748 )
58- np .testing .assert_array_almost_equal (result [0 ]["scores" ][:10 ], expected_detection_scores , decimal = 1 )
49+ np .testing .assert_array_almost_equal (result ["boxes" ], expected_detection_boxes , decimal = 3 )
50+
51+ assert result ["scores" ].shape == (3 ,)
52+ expected_detection_scores = np .asarray ([0.8424455 , 0.7796526 , 0.35387915 ])
53+ np .testing .assert_array_almost_equal (result ["scores" ], expected_detection_scores , decimal = 3 )
5954
60- assert result [0 ][ "labels" ].shape == (100 ,)
61- expected_detection_classes = np .asarray ([1 , 23 , 23 , 1 , 1 , 23 , 23 , 23 , 1 , 1 ])
62- np .testing .assert_array_almost_equal (result [0 ][ "labels" ][: 10 ] , expected_detection_classes , decimal = 1 )
55+ assert result ["labels" ].shape == (3 ,)
56+ expected_detection_classes = np .asarray ([17 , 65 , 17 ])
57+ np .testing .assert_array_equal (result ["labels" ], expected_detection_classes )
6358
6459 except ARTTestException as e :
6560 art_warning (e )
@@ -68,15 +63,8 @@ def test_predict(art_warning, get_pytorch_detr):
6863@pytest .mark .only_with_platform ("pytorch" )
6964def test_fit (art_warning , get_pytorch_detr ):
7065 try :
71- import torch
72-
7366 object_detector , x_test , y_test = get_pytorch_detr
7467
75- # Create optimizer
76- params = [p for p in object_detector .model .parameters () if p .requires_grad ]
77- optimizer = torch .optim .SGD (params , lr = 0.01 )
78- object_detector .set_params (optimizer = optimizer )
79-
8068 # Compute loss before training
8169 loss1 = object_detector .compute_loss (x = x_test , y = y_test )
8270
@@ -99,84 +87,83 @@ def test_loss_gradient(art_warning, get_pytorch_detr):
9987
10088 grads = object_detector .loss_gradient (x = x_test , y = y_test )
10189
102- assert grads .shape == (2 , 3 , 800 , 800 )
90+ assert grads .shape == (1 , 3 , 416 , 416 )
10391
10492 expected_gradients1 = np .asarray (
10593 [
106- - 0.00757495 ,
107- - 0.00101332 ,
108- 0.00368362 ,
109- 0.00283334 ,
110- - 0.00096027 ,
111- 0.00873749 ,
112- 0.00546095 ,
113- - 0.00823532 ,
114- - 0.00710872 ,
115- 0.00389713 ,
116- - 0.00966289 ,
117- 0.00448294 ,
118- 0.00754991 ,
119- - 0.00934104 ,
120- - 0.00350194 ,
121- - 0.00541577 ,
122- - 0.00395624 ,
123- 0.00147651 ,
124- 0.0105616 ,
125- 0.01231265 ,
126- - 0.00148831 ,
127- - 0.0043609 ,
128- 0.00093031 ,
129- 0.00884939 ,
130- - 0.00356749 ,
131- 0.00093475 ,
132- - 0.00353712 ,
133- - 0.0060132 ,
134- - 0.00067899 ,
135- - 0.00886974 ,
136- 0.00108483 ,
137- - 0.00052412 ,
94+ 0.02891439 ,
95+ 0.0055933 ,
96+ - 0.00687808 ,
97+ 0.0095074 ,
98+ 0.00247894 ,
99+ 0.00122704 ,
100+ - 0.00482378 ,
101+ - 0.00924361 ,
102+ - 0.02870164 ,
103+ - 0.00683936 ,
104+ 0.00904205 ,
105+ - 0.01315971 ,
106+ - 0.0151937 ,
107+ - 0.00156442 ,
108+ 0.00775309 ,
109+ 0.01946152 ,
110+ 0.00523211 ,
111+ - 0.01682214 ,
112+ 0.00079588 ,
113+ 0.01627164 ,
114+ - 0.01347653 ,
115+ - 0.00512358 ,
116+ 0.00610363 ,
117+ 0.02831643 ,
118+ 0.00742467 ,
119+ 0.00293561 ,
120+ 0.01380033 ,
121+ 0.02112359 ,
122+ 0.01725711 ,
123+ - 0.00431877 ,
124+ - 0.01007722 ,
125+ - 0.00526983 ,
138126 ]
139127 )
140-
141- np .testing .assert_array_almost_equal (grads [0 , 0 , 10 , :32 ], expected_gradients1 , decimal = 1 )
128+ np .testing .assert_array_almost_equal (grads [0 , 0 , 208 , 192 :224 ], expected_gradients1 , decimal = 1 )
142129
143130 expected_gradients2 = np .asarray (
144131 [
145- - 0.00757495 ,
146- - 0.00101332 ,
147- 0.00368362 ,
148- 0.00283334 ,
149- - 0.00096027 ,
150- 0.00873749 ,
151- 0.00546095 ,
152- - 0.00823532 ,
153- - 0.00710872 ,
154- 0.00389713 ,
155- - 0.00966289 ,
156- 0.00448294 ,
157- 0.00754991 ,
158- - 0.00934104 ,
159- - 0.00350194 ,
160- - 0.00541577 ,
161- - 0.00395624 ,
162- 0.00147651 ,
163- 0.0105616 ,
164- 0.01231265 ,
165- - 0.00148831 ,
166- - 0.0043609 ,
167- 0.00093031 ,
168- 0.00884939 ,
169- - 0.00356749 ,
170- 0.00093475 ,
171- - 0.00353712 ,
172- - 0.0060132 ,
173- - 0.00067899 ,
174- - 0.00886974 ,
175- 0.00108483 ,
176- - 0.00052412 ,
132+ - 0.00549417 ,
133+ - 0.01592844 ,
134+ - 0.01073932 ,
135+ - 0.00443333 ,
136+ - 0.00780143 ,
137+ - 0.02033146 ,
138+ - 0.0191503 ,
139+ 0.01227987 ,
140+ 0.019971 ,
141+ 0.01034214 ,
142+ - 0.00918145 ,
143+ - 0.02458049 ,
144+ - 0.00708776 ,
145+ - 0.00826812 ,
146+ - 0.01284431 ,
147+ - 0.00195021 ,
148+ 0.00523211 ,
149+ 0.00661678 ,
150+ 0.00851441 ,
151+ 0.01157211 ,
152+ - 0.00324841 ,
153+ - 0.00395823 ,
154+ 0.00756641 ,
155+ 0.00405913 ,
156+ - 0.00055517 ,
157+ 0.00221484 ,
158+ - 0.02415526 ,
159+ - 0.02096599 ,
160+ 0.00980014 ,
161+ 0.00174731 ,
162+ - 0.01008899 ,
163+ 0.00305779 ,
177164 ]
178165 )
179- np .testing .assert_array_almost_equal (grads [1 , 0 , 10 , : 32 ], expected_gradients2 , decimal = 1 )
166+ np .testing .assert_array_almost_equal (grads [0 , 0 , 192 : 224 , 208 ], expected_gradients2 , decimal = 1 )
180167
181168 except ARTTestException as e :
182169 art_warning (e )
@@ -239,18 +226,13 @@ def test_preprocessing_defences(art_warning, get_pytorch_detr):
239226 "boxes" : result [0 ]["boxes" ],
240227 "labels" : result [0 ]["labels" ],
241228 "scores" : np .ones_like (result [0 ]["labels" ]),
242- },
243- {
244- "boxes" : result [1 ]["boxes" ],
245- "labels" : result [1 ]["labels" ],
246- "scores" : np .ones_like (result [1 ]["labels" ]),
247- },
229+ }
248230 ]
249231
250232 # Compute gradients
251233 grads = object_detector .loss_gradient (x = x_test , y = y )
252234
253- assert grads .shape == (2 , 3 , 800 , 800 )
235+ assert grads .shape == (1 , 3 , 416 , 416 )
254236
255237 except ARTTestException as e :
256238 art_warning (e )
@@ -275,7 +257,7 @@ def test_compute_loss(art_warning, get_pytorch_detr):
275257 # Compute loss
276258 loss = object_detector .compute_loss (x = x_test , y = y_test )
277259
278- assert pytest .approx (6.7767677 , abs = 0.1 ) == float (loss )
260+ assert pytest .approx (2.172381 , abs = 0.1 ) == float (loss )
279261
280262 except ARTTestException as e :
281263 art_warning (e )
0 commit comments