Skip to content

Commit 8e4c89d

Browse files
committed
Fixing formatting
Signed-off-by: Kieran Fraser <[email protected]>
1 parent 64db977 commit 8e4c89d

File tree

2 files changed

+107
-46
lines changed

2 files changed

+107
-46
lines changed

art/estimators/object_detection/pytorch_detection_transformer.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -832,10 +832,13 @@ def compute_loss( # type: ignore
832832

833833
return loss.detach().cpu().numpy()
834834

835-
def _apply_resizing(self, x: Union[np.ndarray, "torch.Tensor"],
836-
y: List[Dict[str, Union[np.ndarray, "torch.Tensor"]]],
837-
height: int = 800,
838-
width: int = 800):
835+
def _apply_resizing(
836+
self,
837+
x: Union[np.ndarray, "torch.Tensor"],
838+
y: List[Dict[str, Union[np.ndarray, "torch.Tensor"]]],
839+
height: int = 800,
840+
width: int = 800,
841+
):
839842
"""
840843
Resize the input and targets to dimensions expected by DETR.
841844
@@ -856,45 +859,39 @@ def _apply_resizing(self, x: Union[np.ndarray, "torch.Tensor"],
856859
):
857860
resized_imgs = []
858861
if isinstance(x, torch.Tensor):
859-
x = T.Resize(size = (height, width))(x)
862+
x = T.Resize(size=(height, width))(x)
860863
else:
861864
for i, _ in enumerate(x):
862865
resized = cv2.resize(
863-
(x)[i].transpose(1, 2, 0),
864-
dsize=(height, width),
865-
interpolation=cv2.INTER_CUBIC,
866-
)
867-
resized = resized.transpose(2, 0, 1)
868-
resized_imgs.append(
869-
resized
866+
(x)[i].transpose(1, 2, 0),
867+
dsize=(height, width),
868+
interpolation=cv2.INTER_CUBIC,
870869
)
870+
resized = resized.transpose(2, 0, 1)
871+
resized_imgs.append(resized)
871872
x = np.array(resized_imgs)
872873

873874
elif self._input_shape[1] != self._input_shape[2]:
874875
rescale_dim = max(self._input_shape[1], self._input_shape[2])
875876
resized_imgs = []
876877
if isinstance(x, torch.Tensor):
877-
x = T.Resize(size = (rescale_dim,rescale_dim))(x)
878+
x = T.Resize(size=(rescale_dim, rescale_dim))(x)
878879
else:
879880
for i, _ in enumerate(x):
880881
resized = cv2.resize(
881-
(x)[i].transpose(1, 2, 0),
882-
dsize=(rescale_dim, rescale_dim),
883-
interpolation=cv2.INTER_CUBIC,
884-
)
885-
resized = resized.transpose(2, 0, 1)
886-
resized_imgs.append(
887-
resized
882+
(x)[i].transpose(1, 2, 0),
883+
dsize=(rescale_dim, rescale_dim),
884+
interpolation=cv2.INTER_CUBIC,
888885
)
886+
resized = resized.transpose(2, 0, 1)
887+
resized_imgs.append(resized)
889888
x = np.array(resized_imgs)
890-
889+
891890
targets = []
892891
if y is not None:
893-
if isinstance(y[0]['boxes'], torch.Tensor):
892+
if isinstance(y[0]["boxes"], torch.Tensor):
894893
for target in y:
895-
cxcy_norm = revert_rescale_bboxes(
896-
target["boxes"], (self.input_shape[2], self.input_shape[1])
897-
)
894+
cxcy_norm = revert_rescale_bboxes(target["boxes"], (self.input_shape[2], self.input_shape[1]))
898895
targets.append(
899896
{
900897
"labels": target["labels"].type(torch.int64).to(self.device),
@@ -917,6 +914,7 @@ def _apply_resizing(self, x: Union[np.ndarray, "torch.Tensor"],
917914

918915
return x, targets
919916

917+
920918
class NestedTensor:
921919
"""
922920
From DETR source: https://github.com/facebookresearch/detr

tests/estimators/object_detection/test_pytorch_detection_transformer.py

Lines changed: 84 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -74,19 +74,29 @@ def test_predict(get_pytorch_detr):
7474
assert list(result[0].keys()) == ["boxes", "labels", "scores"]
7575

7676
assert result[0]["boxes"].shape == (100, 4)
77-
expected_detection_boxes = np.asarray([-5.9490204e-03, 1.1947733e+01, 3.1993944e+01, 3.1925127e+01])
77+
expected_detection_boxes = np.asarray([-5.9490204e-03, 1.1947733e01, 3.1993944e01, 3.1925127e01])
7878
np.testing.assert_array_almost_equal(result[0]["boxes"][2, :], expected_detection_boxes, decimal=3)
7979

8080
assert result[0]["scores"].shape == (100,)
8181
expected_detection_scores = np.asarray(
82-
[0.00679839, 0.0250559 , 0.07205943, 0.01115368, 0.03321039,
83-
0.10407761, 0.00113309, 0.01442852, 0.00527624, 0.01240906]
82+
[
83+
0.00679839,
84+
0.0250559,
85+
0.07205943,
86+
0.01115368,
87+
0.03321039,
88+
0.10407761,
89+
0.00113309,
90+
0.01442852,
91+
0.00527624,
92+
0.01240906,
93+
]
8494
)
85-
np.testing.assert_array_almost_equal(result[0]["scores"][:10], expected_detection_scores, decimal=6)
95+
np.testing.assert_array_almost_equal(result[0]["scores"][:10], expected_detection_scores, decimal=5)
8696

8797
assert result[0]["labels"].shape == (100,)
8898
expected_detection_classes = np.asarray([17, 17, 33, 17, 17, 17, 74, 17, 17, 17])
89-
np.testing.assert_array_almost_equal(result[0]["labels"][:10], expected_detection_classes, decimal=6)
99+
np.testing.assert_array_almost_equal(result[0]["labels"][:10], expected_detection_classes, decimal=5)
90100

91101

92102
@pytest.mark.only_with_platform("pytorch")
@@ -99,26 +109,79 @@ def test_loss_gradient(get_pytorch_detr):
99109
assert grads.shape == (2, 3, 800, 800)
100110

101111
expected_gradients1 = np.asarray(
102-
[-0.00061366, 0.00322502, -0.00039866, -0.00807413, -0.00476555,
103-
0.00181204, 0.01007765, 0.00415828, -0.00073114, 0.00018387,
104-
-0.00146992, -0.00119636, -0.00098966, -0.00295517, -0.0024271 ,
105-
-0.00131314, -0.00149217, -0.00104926, -0.00154239, -0.00110989,
106-
0.00092887, 0.00049146, -0.00292508, -0.00124526, 0.00140347,
107-
0.00019833, 0.00191074, -0.00117537, -0.00080604, 0.00057427,
108-
-0.00061728, -0.00206535]
112+
[
113+
-0.00061366,
114+
0.00322502,
115+
-0.00039866,
116+
-0.00807413,
117+
-0.00476555,
118+
0.00181204,
119+
0.01007765,
120+
0.00415828,
121+
-0.00073114,
122+
0.00018387,
123+
-0.00146992,
124+
-0.00119636,
125+
-0.00098966,
126+
-0.00295517,
127+
-0.0024271,
128+
-0.00131314,
129+
-0.00149217,
130+
-0.00104926,
131+
-0.00154239,
132+
-0.00110989,
133+
0.00092887,
134+
0.00049146,
135+
-0.00292508,
136+
-0.00124526,
137+
0.00140347,
138+
0.00019833,
139+
0.00191074,
140+
-0.00117537,
141+
-0.00080604,
142+
0.00057427,
143+
-0.00061728,
144+
-0.00206535,
145+
]
109146
)
110147

111148
np.testing.assert_array_almost_equal(grads[0, 0, 10, :32], expected_gradients1, decimal=2)
112149

113150
expected_gradients2 = np.asarray(
114-
[-1.1787530e-03, -2.8500680e-03, 5.0884970e-03, 6.4504531e-04,
115-
-6.8841036e-05, 2.8184296e-03, 3.0257765e-03, 2.8565727e-04,
116-
-1.0701057e-04, 1.2945699e-03, 7.3593057e-04, 1.0177144e-03,
117-
-2.4692707e-03, -1.3801848e-03, 6.3182280e-04, -4.2305476e-04,
118-
4.4307750e-04, 8.5821096e-04, -7.1204413e-04, -3.1404425e-03,
119-
-1.5964351e-03, -1.9222996e-03, -5.3157361e-04, -9.9202688e-04,
120-
-1.5815455e-03, 2.0060266e-04, -2.0584739e-03, 6.6960667e-04,
121-
9.7393827e-04, -1.6040013e-03, -6.9741381e-04, 1.4657658e-04]
151+
[
152+
-1.1787530e-03,
153+
-2.8500680e-03,
154+
5.0884970e-03,
155+
6.4504531e-04,
156+
-6.8841036e-05,
157+
2.8184296e-03,
158+
3.0257765e-03,
159+
2.8565727e-04,
160+
-1.0701057e-04,
161+
1.2945699e-03,
162+
7.3593057e-04,
163+
1.0177144e-03,
164+
-2.4692707e-03,
165+
-1.3801848e-03,
166+
6.3182280e-04,
167+
-4.2305476e-04,
168+
4.4307750e-04,
169+
8.5821096e-04,
170+
-7.1204413e-04,
171+
-3.1404425e-03,
172+
-1.5964351e-03,
173+
-1.9222996e-03,
174+
-5.3157361e-04,
175+
-9.9202688e-04,
176+
-1.5815455e-03,
177+
2.0060266e-04,
178+
-2.0584739e-03,
179+
6.6960667e-04,
180+
9.7393827e-04,
181+
-1.6040013e-03,
182+
-6.9741381e-04,
183+
1.4657658e-04,
184+
]
122185
)
123186
np.testing.assert_array_almost_equal(grads[1, 0, 10, :32], expected_gradients2, decimal=2)
124187

@@ -236,7 +299,7 @@ def test_pgd(get_pytorch_detr):
236299

237300
imgs = []
238301
for i in x_test:
239-
img = Image.fromarray((i*255).astype(np.uint8).transpose(1,2,0))
302+
img = Image.fromarray((i * 255).astype(np.uint8).transpose(1, 2, 0))
240303
img = img.resize(size=(800, 800))
241304
imgs.append(np.array(img))
242305
x_test = np.array(imgs).transpose(0, 3, 1, 2)

0 commit comments

Comments
 (0)