30
30
@pytest .mark .only_with_platform ("pytorch" )
31
31
def test_predict (art_warning , get_pytorch_faster_rcnn ):
32
32
try :
33
- object_detector , x_test , _ = get_pytorch_faster_rcnn
33
+ from art . utils import non_maximum_suppression
34
34
35
- result = object_detector .predict (x_test )
36
- assert list (result [0 ].keys ()) == ["boxes" , "labels" , "scores" ]
35
+ object_detector , x_test , _ = get_pytorch_faster_rcnn
37
36
38
- assert result [ 0 ][ "boxes" ]. shape == ( 7 , 4 )
39
- expected_detection_boxes = np . asarray ([ 4.4017954 , 6.3090835 , 22.128296 , 27.570665 ] )
40
- np . testing . assert_array_almost_equal (result [ 0 ][ "boxes" ][ 2 , :], expected_detection_boxes , decimal = 3 )
37
+ preds = object_detector . predict ( x_test )
38
+ result = non_maximum_suppression ( preds [ 0 ], iou_threshold = 0.4 , confidence_threshold = 0.3 )
39
+ assert list (result . keys ()) == [ "boxes" , "labels" , "scores" ]
41
40
42
- assert result [0 ]["scores" ].shape == (7 ,)
43
- expected_detection_scores = np .asarray (
44
- [0.3314798 , 0.14125851 , 0.13928168 , 0.0996184 , 0.08550017 , 0.06690315 , 0.05359321 ]
41
+ assert result ["boxes" ].shape == (2 , 4 )
42
+ expected_detection_boxes = np .asarray (
43
+ [
44
+ [6.136914 , 22.481018 , 413.05814 , 346.08746 ],
45
+ [0.000000 , 24.181173 , 406.47644 , 342.62213 ],
46
+ ]
45
47
)
46
- np .testing .assert_array_almost_equal (result [0 ]["scores" ][:10 ], expected_detection_scores , decimal = 6 )
48
+ np .testing .assert_array_almost_equal (result ["boxes" ], expected_detection_boxes , decimal = 3 )
49
+
50
+ assert result ["scores" ].shape == (2 ,)
51
+ expected_detection_scores = np .asarray ([0.4237412 , 0.35696018 ])
52
+ np .testing .assert_array_almost_equal (result ["scores" ], expected_detection_scores , decimal = 3 )
47
53
48
- assert result [0 ][ "labels" ].shape == (7 ,)
49
- expected_detection_classes = np .asarray ([72 , 79 , 1 , 72 , 78 , 72 , 82 ])
50
- np .testing .assert_array_almost_equal (result [0 ][ "labels" ][: 10 ] , expected_detection_classes , decimal = 6 )
54
+ assert result ["labels" ].shape == (2 ,)
55
+ expected_detection_classes = np .asarray ([21 , 18 ])
56
+ np .testing .assert_array_equal (result ["labels" ], expected_detection_classes )
51
57
52
58
except ARTTestException as e :
53
59
art_warning (e )
@@ -80,75 +86,59 @@ def test_loss_gradient(art_warning, get_pytorch_faster_rcnn):
80
86
81
87
# Compute gradients
82
88
grads = object_detector .loss_gradient (x_test , y_test )
83
- assert grads .shape == (2 , 28 , 28 , 3 )
89
+ assert grads .shape == (1 , 3 , 416 , 416 )
84
90
85
91
expected_gradients1 = np .asarray (
86
92
[
87
- [4.6265591e-04 , 1.2323459e-03 , 1.3915040e-03 ],
88
- [- 3.2658060e-04 , - 3.6941725e-03 , - 4.5638453e-04 ],
89
- [7.8702159e-04 , - 3.3072452e-03 , 3.0583731e-04 ],
90
- [1.0381485e-03 , - 2.0846087e-03 , 2.3015277e-04 ],
91
- [2.1460971e-03 , - 1.3157589e-03 , 3.5176644e-04 ],
92
- [3.3839934e-03 , 1.3083456e-03 , 1.6155940e-03 ],
93
- [3.8621046e-03 , 1.6645766e-03 , 1.8313043e-03 ],
94
- [3.0887076e-03 , 1.4632678e-03 , 1.1174511e-03 ],
95
- [3.3404885e-03 , 2.0578136e-03 , 9.6874911e-04 ],
96
- [3.2202434e-03 , 7.2660763e-04 , 8.9162006e-04 ],
97
- [3.5761783e-03 , 2.3615893e-03 , 8.8510796e-04 ],
98
- [3.4721815e-03 , 1.9500104e-03 , 9.2907902e-04 ],
99
- [3.4767685e-03 , 2.1154548e-03 , 5.5654044e-04 ],
100
- [3.9492580e-03 , 3.5505455e-03 , 6.5863604e-04 ],
101
- [3.9963769e-03 , 4.0338552e-03 , 3.9539216e-04 ],
102
- [2.2312226e-03 , 5.1399925e-06 , - 1.0743635e-03 ],
103
- [2.3955442e-03 , 6.7116896e-04 , - 1.2389944e-03 ],
104
- [1.9969011e-03 , - 4.5717746e-04 , - 1.5225793e-03 ],
105
- [1.8131963e-03 , - 7.7948131e-04 , - 1.6078206e-03 ],
106
- [1.4277012e-03 , - 7.7973347e-04 , - 1.3463887e-03 ],
107
- [7.3705515e-04 , - 1.1704378e-03 , - 9.8979671e-04 ],
108
- [1.0899740e-04 , - 1.2144407e-03 , - 1.1339665e-03 ],
109
- [1.2254890e-04 , - 4.7438752e-04 , - 8.8673591e-04 ],
110
- [7.0695346e-04 , 7.2568876e-04 , - 2.5591519e-04 ],
111
- [5.0835893e-04 , 2.6866698e-04 , 2.2731400e-04 ],
112
- [- 5.9932750e-04 , - 1.1667561e-03 , - 4.8044650e-04 ],
113
- [4.0421321e-04 , 3.1692928e-04 , - 8.3296909e-05 ],
114
- [4.0506107e-05 , - 3.1728629e-04 , - 4.4132984e-04 ],
93
+ - 2.7225273e-05 ,
94
+ - 2.7225284e-05 ,
95
+ - 3.2535860e-05 ,
96
+ - 9.3287526e-06 ,
97
+ - 1.1088990e-05 ,
98
+ - 3.4527478e-05 ,
99
+ 5.7807661e-06 ,
100
+ 1.1616970e-05 ,
101
+ 2.9732121e-06 ,
102
+ 1.1190044e-05 ,
103
+ - 6.4673945e-06 ,
104
+ - 1.6562306e-05 ,
105
+ - 1.5946282e-05 ,
106
+ - 1.8079168e-06 ,
107
+ - 9.7664342e-06 ,
108
+ 6.2075532e-07 ,
109
+ - 8.9023115e-06 ,
110
+ - 1.5546989e-06 ,
111
+ - 7.2730008e-06 ,
112
+ - 7.5181362e-07 ,
115
113
]
116
114
)
117
- np .testing .assert_array_almost_equal (grads [0 , 0 , : , :], expected_gradients1 , decimal = 2 )
115
+ np .testing .assert_array_almost_equal (grads [0 , 0 , 0 , :20 ], expected_gradients1 , decimal = 2 )
118
116
119
117
expected_gradients2 = np .asarray (
120
118
[
121
- [4.7986404e-04 , 7.7701372e-04 , 1.1786318e-03 ],
122
- [7.3503907e-04 , - 2.3474507e-03 , - 3.9008856e-04 ],
123
- [4.1874062e-04 , - 2.5707064e-03 , - 1.1054531e-03 ],
124
- [- 1.7942721e-03 , - 3.3968450e-03 , - 1.4989552e-03 ],
125
- [- 2.9697213e-03 , - 4.6922294e-03 , - 1.3162185e-03 ],
126
- [- 3.1759157e-03 , - 9.8660104e-03 , - 4.7163852e-03 ],
127
- [1.8666144e-03 , - 2.8793041e-03 , - 3.1324378e-03 ],
128
- [1.0555880e-02 , 7.6373261e-03 , 5.3013843e-03 ],
129
- [8.9815725e-04 , - 1.0321697e-02 , 1.4192325e-03 ],
130
- [8.5643278e-03 , 3.0152409e-03 , 2.0114987e-03 ],
131
- [- 2.7870361e-03 , - 1.1686913e-02 , - 7.0649502e-03 ],
132
- [- 7.7482774e-03 , - 1.3334424e-03 , - 9.1927368e-03 ],
133
- [- 8.1487820e-03 , - 3.8133820e-03 , - 4.3300558e-03 ],
134
- [- 7.7006700e-03 , - 1.2594147e-02 , - 3.9680018e-03 ],
135
- [- 9.5743872e-03 , - 2.1007264e-02 , - 9.1963671e-03 ],
136
- [- 8.6777220e-03 , - 1.7278835e-02 , - 1.3328674e-02 ],
137
- [- 1.7368209e-02 , - 2.3461722e-02 , - 1.1538444e-02 ],
138
- [- 4.6307812e-03 , - 5.7058665e-03 , 1.3555109e-03 ],
139
- [4.8570461e-03 , - 5.8050654e-03 , 8.1082489e-03 ],
140
- [6.4304657e-03 , 2.8407066e-03 , 8.7463465e-03 ],
141
- [5.0593228e-03 , 1.4102085e-03 , 5.2116364e-03 ],
142
- [2.5003455e-03 , - 6.0178695e-04 , 2.0183939e-03 ],
143
- [2.1247163e-03 , 4.7659015e-04 , 7.5940741e-04 ],
144
- [1.3499497e-03 , 6.2203623e-04 , 1.2288829e-04 ],
145
- [2.8991612e-04 , - 4.0216290e-04 , - 7.2287643e-05 ],
146
- [6.6898909e-05 , - 6.3778006e-04 , - 3.6294860e-04 ],
147
- [5.3613615e-04 , 9.9137833e-05 , - 1.6657988e-05 ],
148
- [- 3.9828232e-05 , - 3.8453130e-04 , - 2.3702848e-04 ],
119
+ - 2.7307957e-05 ,
120
+ - 1.9417710e-05 ,
121
+ - 2.0928457e-05 ,
122
+ - 2.1384752e-05 ,
123
+ - 2.5035972e-05 ,
124
+ - 3.6572790e-05 ,
125
+ - 8.2444545e-05 ,
126
+ - 7.3255811e-05 ,
127
+ - 4.5060227e-05 ,
128
+ - 1.9829258e-05 ,
129
+ - 2.2043951e-05 ,
130
+ - 3.6746951e-05 ,
131
+ - 4.2588043e-05 ,
132
+ - 3.1833035e-05 ,
133
+ - 1.5923406e-05 ,
134
+ - 3.5026955e-05 ,
135
+ - 4.4511849e-05 ,
136
+ - 3.3867167e-05 ,
137
+ - 1.8569792e-05 ,
138
+ - 3.5141209e-05 ,
149
139
]
150
140
)
151
- np .testing .assert_array_almost_equal (grads [1 , 0 , :, : ], expected_gradients2 , decimal = 2 )
141
+ np .testing .assert_array_almost_equal (grads [0 , 0 , :20 , 0 ], expected_gradients2 , decimal = 2 )
152
142
153
143
except ARTTestException as e :
154
144
art_warning (e )
@@ -198,7 +188,7 @@ def test_preprocessing_defences(art_warning, get_pytorch_faster_rcnn):
198
188
# Compute gradients
199
189
grads = object_detector .loss_gradient (x = x_test , y = y_test )
200
190
201
- assert grads .shape == (2 , 28 , 28 , 3 )
191
+ assert grads .shape == (1 , 3 , 416 , 416 )
202
192
203
193
except ARTTestException as e :
204
194
art_warning (e )
@@ -221,7 +211,7 @@ def test_compute_loss(art_warning, get_pytorch_faster_rcnn):
221
211
# Compute loss
222
212
loss = object_detector .compute_loss (x = x_test , y = y_test )
223
213
224
- assert pytest .approx (0.84883332 , abs = 0.01 ) == float (loss )
214
+ assert pytest .approx (0.0995874 , abs = 0.05 ) == float (loss )
225
215
226
216
except ARTTestException as e :
227
217
art_warning (e )
0 commit comments