3030@pytest .mark .only_with_platform ("pytorch" )
3131def test_predict (art_warning , get_pytorch_faster_rcnn ):
3232 try :
33- object_detector , x_test , _ = get_pytorch_faster_rcnn
33+ from art . utils import non_maximum_suppression
3434
35- result = object_detector .predict (x_test )
36- assert list (result [0 ].keys ()) == ["boxes" , "labels" , "scores" ]
35+ object_detector , x_test , _ = get_pytorch_faster_rcnn
3736
38- assert result [ 0 ][ "boxes" ]. shape == ( 7 , 4 )
39- expected_detection_boxes = np . asarray ([ 4.4017954 , 6.3090835 , 22.128296 , 27.570665 ] )
40- np . testing . assert_array_almost_equal (result [ 0 ][ "boxes" ][ 2 , :], expected_detection_boxes , decimal = 3 )
37+ preds = object_detector . predict ( x_test )
38+ result = non_maximum_suppression ( preds [ 0 ], iou_threshold = 0.4 , confidence_threshold = 0.3 )
39+ assert list (result . keys ()) == [ "boxes" , "labels" , "scores" ]
4140
42- assert result [0 ]["scores" ].shape == (7 ,)
43- expected_detection_scores = np .asarray (
44- [0.3314798 , 0.14125851 , 0.13928168 , 0.0996184 , 0.08550017 , 0.06690315 , 0.05359321 ]
41+ assert result ["boxes" ].shape == (2 , 4 )
42+ expected_detection_boxes = np .asarray (
43+ [
44+ [6.136914 , 22.481018 , 413.05814 , 346.08746 ],
45+ [0.000000 , 24.181173 , 406.47644 , 342.62213 ],
46+ ]
4547 )
46- np .testing .assert_array_almost_equal (result [0 ]["scores" ][:10 ], expected_detection_scores , decimal = 6 )
48+ np .testing .assert_array_almost_equal (result ["boxes" ], expected_detection_boxes , decimal = 3 )
49+
50+ assert result ["scores" ].shape == (2 ,)
51+ expected_detection_scores = np .asarray ([0.4237412 , 0.35696018 ])
52+ np .testing .assert_array_almost_equal (result ["scores" ], expected_detection_scores , decimal = 3 )
4753
48- assert result [0 ][ "labels" ].shape == (7 ,)
49- expected_detection_classes = np .asarray ([72 , 79 , 1 , 72 , 78 , 72 , 82 ])
50- np .testing .assert_array_almost_equal (result [0 ][ "labels" ][: 10 ] , expected_detection_classes , decimal = 6 )
54+ assert result ["labels" ].shape == (2 ,)
55+ expected_detection_classes = np .asarray ([21 , 18 ])
56+ np .testing .assert_array_equal (result ["labels" ], expected_detection_classes )
5157
5258 except ARTTestException as e :
5359 art_warning (e )
@@ -80,75 +86,59 @@ def test_loss_gradient(art_warning, get_pytorch_faster_rcnn):
8086
8187 # Compute gradients
8288 grads = object_detector .loss_gradient (x_test , y_test )
83- assert grads .shape == (2 , 28 , 28 , 3 )
89+ assert grads .shape == (1 , 3 , 416 , 416 )
8490
8591 expected_gradients1 = np .asarray (
8692 [
87- [4.6265591e-04 , 1.2323459e-03 , 1.3915040e-03 ],
88- [- 3.2658060e-04 , - 3.6941725e-03 , - 4.5638453e-04 ],
89- [7.8702159e-04 , - 3.3072452e-03 , 3.0583731e-04 ],
90- [1.0381485e-03 , - 2.0846087e-03 , 2.3015277e-04 ],
91- [2.1460971e-03 , - 1.3157589e-03 , 3.5176644e-04 ],
92- [3.3839934e-03 , 1.3083456e-03 , 1.6155940e-03 ],
93- [3.8621046e-03 , 1.6645766e-03 , 1.8313043e-03 ],
94- [3.0887076e-03 , 1.4632678e-03 , 1.1174511e-03 ],
95- [3.3404885e-03 , 2.0578136e-03 , 9.6874911e-04 ],
96- [3.2202434e-03 , 7.2660763e-04 , 8.9162006e-04 ],
97- [3.5761783e-03 , 2.3615893e-03 , 8.8510796e-04 ],
98- [3.4721815e-03 , 1.9500104e-03 , 9.2907902e-04 ],
99- [3.4767685e-03 , 2.1154548e-03 , 5.5654044e-04 ],
100- [3.9492580e-03 , 3.5505455e-03 , 6.5863604e-04 ],
101- [3.9963769e-03 , 4.0338552e-03 , 3.9539216e-04 ],
102- [2.2312226e-03 , 5.1399925e-06 , - 1.0743635e-03 ],
103- [2.3955442e-03 , 6.7116896e-04 , - 1.2389944e-03 ],
104- [1.9969011e-03 , - 4.5717746e-04 , - 1.5225793e-03 ],
105- [1.8131963e-03 , - 7.7948131e-04 , - 1.6078206e-03 ],
106- [1.4277012e-03 , - 7.7973347e-04 , - 1.3463887e-03 ],
107- [7.3705515e-04 , - 1.1704378e-03 , - 9.8979671e-04 ],
108- [1.0899740e-04 , - 1.2144407e-03 , - 1.1339665e-03 ],
109- [1.2254890e-04 , - 4.7438752e-04 , - 8.8673591e-04 ],
110- [7.0695346e-04 , 7.2568876e-04 , - 2.5591519e-04 ],
111- [5.0835893e-04 , 2.6866698e-04 , 2.2731400e-04 ],
112- [- 5.9932750e-04 , - 1.1667561e-03 , - 4.8044650e-04 ],
113- [4.0421321e-04 , 3.1692928e-04 , - 8.3296909e-05 ],
114- [4.0506107e-05 , - 3.1728629e-04 , - 4.4132984e-04 ],
93+ - 2.7225273e-05 ,
94+ - 2.7225284e-05 ,
95+ - 3.2535860e-05 ,
96+ - 9.3287526e-06 ,
97+ - 1.1088990e-05 ,
98+ - 3.4527478e-05 ,
99+ 5.7807661e-06 ,
100+ 1.1616970e-05 ,
101+ 2.9732121e-06 ,
102+ 1.1190044e-05 ,
103+ - 6.4673945e-06 ,
104+ - 1.6562306e-05 ,
105+ - 1.5946282e-05 ,
106+ - 1.8079168e-06 ,
107+ - 9.7664342e-06 ,
108+ 6.2075532e-07 ,
109+ - 8.9023115e-06 ,
110+ - 1.5546989e-06 ,
111+ - 7.2730008e-06 ,
112+ - 7.5181362e-07 ,
115113 ]
116114 )
117- np .testing .assert_array_almost_equal (grads [0 , 0 , : , :], expected_gradients1 , decimal = 2 )
115+ np .testing .assert_array_almost_equal (grads [0 , 0 , 0 , :20 ], expected_gradients1 , decimal = 2 )
118116
119117 expected_gradients2 = np .asarray (
120118 [
121- [4.7986404e-04 , 7.7701372e-04 , 1.1786318e-03 ],
122- [7.3503907e-04 , - 2.3474507e-03 , - 3.9008856e-04 ],
123- [4.1874062e-04 , - 2.5707064e-03 , - 1.1054531e-03 ],
124- [- 1.7942721e-03 , - 3.3968450e-03 , - 1.4989552e-03 ],
125- [- 2.9697213e-03 , - 4.6922294e-03 , - 1.3162185e-03 ],
126- [- 3.1759157e-03 , - 9.8660104e-03 , - 4.7163852e-03 ],
127- [1.8666144e-03 , - 2.8793041e-03 , - 3.1324378e-03 ],
128- [1.0555880e-02 , 7.6373261e-03 , 5.3013843e-03 ],
129- [8.9815725e-04 , - 1.0321697e-02 , 1.4192325e-03 ],
130- [8.5643278e-03 , 3.0152409e-03 , 2.0114987e-03 ],
131- [- 2.7870361e-03 , - 1.1686913e-02 , - 7.0649502e-03 ],
132- [- 7.7482774e-03 , - 1.3334424e-03 , - 9.1927368e-03 ],
133- [- 8.1487820e-03 , - 3.8133820e-03 , - 4.3300558e-03 ],
134- [- 7.7006700e-03 , - 1.2594147e-02 , - 3.9680018e-03 ],
135- [- 9.5743872e-03 , - 2.1007264e-02 , - 9.1963671e-03 ],
136- [- 8.6777220e-03 , - 1.7278835e-02 , - 1.3328674e-02 ],
137- [- 1.7368209e-02 , - 2.3461722e-02 , - 1.1538444e-02 ],
138- [- 4.6307812e-03 , - 5.7058665e-03 , 1.3555109e-03 ],
139- [4.8570461e-03 , - 5.8050654e-03 , 8.1082489e-03 ],
140- [6.4304657e-03 , 2.8407066e-03 , 8.7463465e-03 ],
141- [5.0593228e-03 , 1.4102085e-03 , 5.2116364e-03 ],
142- [2.5003455e-03 , - 6.0178695e-04 , 2.0183939e-03 ],
143- [2.1247163e-03 , 4.7659015e-04 , 7.5940741e-04 ],
144- [1.3499497e-03 , 6.2203623e-04 , 1.2288829e-04 ],
145- [2.8991612e-04 , - 4.0216290e-04 , - 7.2287643e-05 ],
146- [6.6898909e-05 , - 6.3778006e-04 , - 3.6294860e-04 ],
147- [5.3613615e-04 , 9.9137833e-05 , - 1.6657988e-05 ],
148- [- 3.9828232e-05 , - 3.8453130e-04 , - 2.3702848e-04 ],
119+ - 2.7307957e-05 ,
120+ - 1.9417710e-05 ,
121+ - 2.0928457e-05 ,
122+ - 2.1384752e-05 ,
123+ - 2.5035972e-05 ,
124+ - 3.6572790e-05 ,
125+ - 8.2444545e-05 ,
126+ - 7.3255811e-05 ,
127+ - 4.5060227e-05 ,
128+ - 1.9829258e-05 ,
129+ - 2.2043951e-05 ,
130+ - 3.6746951e-05 ,
131+ - 4.2588043e-05 ,
132+ - 3.1833035e-05 ,
133+ - 1.5923406e-05 ,
134+ - 3.5026955e-05 ,
135+ - 4.4511849e-05 ,
136+ - 3.3867167e-05 ,
137+ - 1.8569792e-05 ,
138+ - 3.5141209e-05 ,
149139 ]
150140 )
151- np .testing .assert_array_almost_equal (grads [1 , 0 , :, : ], expected_gradients2 , decimal = 2 )
141+ np .testing .assert_array_almost_equal (grads [0 , 0 , :20 , 0 ], expected_gradients2 , decimal = 2 )
152142
153143 except ARTTestException as e :
154144 art_warning (e )
@@ -198,7 +188,7 @@ def test_preprocessing_defences(art_warning, get_pytorch_faster_rcnn):
198188 # Compute gradients
199189 grads = object_detector .loss_gradient (x = x_test , y = y_test )
200190
201- assert grads .shape == (2 , 28 , 28 , 3 )
191+ assert grads .shape == (1 , 3 , 416 , 416 )
202192
203193 except ARTTestException as e :
204194 art_warning (e )
@@ -221,7 +211,7 @@ def test_compute_loss(art_warning, get_pytorch_faster_rcnn):
221211 # Compute loss
222212 loss = object_detector .compute_loss (x = x_test , y = y_test )
223213
224- assert pytest .approx (0.84883332 , abs = 0.01 ) == float (loss )
214+ assert pytest .approx (0.0995874 , abs = 0.05 ) == float (loss )
225215
226216 except ARTTestException as e :
227217 art_warning (e )
0 commit comments