Skip to content

Commit ef88ed2

Browse files
committed
Fixed pylint, mypy issues
Signed-off-by: Kieran Fraser <[email protected]>
1 parent cacc829 commit ef88ed2

File tree

1 file changed

+22
-24
lines changed

1 file changed

+22
-24
lines changed

art/estimators/object_detection/pytorch_detection_transformer.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
| Paper link: https://arxiv.org/abs/2005.12872
2222
"""
2323
import logging
24-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
24+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, Any
2525

2626
import numpy as np
2727

@@ -581,10 +581,10 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> List[Dict[s
581581
import torch
582582

583583
self._model.eval()
584-
x, _ = self._apply_resizing(x, None)
584+
x_resized, _ = self._apply_resizing(x)
585585

586586
# Apply preprocessing
587-
x_preprocessed, _ = self._apply_preprocessing(x, y=None, fit=False)
587+
x_preprocessed, _ = self._apply_preprocessing(x_resized, y=None, fit=False)
588588

589589
if self.clip_values is not None:
590590
norm_factor = self.clip_values[1]
@@ -644,6 +644,7 @@ def _get_losses(
644644

645645
# Apply preprocessing
646646
if self.all_framework_preprocessing:
647+
print(y)
647648
if y is not None and isinstance(y, list) and isinstance(y[0]["boxes"], np.ndarray):
648649
y_tensor = []
649650
for y_i in y:
@@ -733,15 +734,15 @@ def loss_gradient(
733734
- labels (Tensor[N]): the predicted labels for each image
734735
:return: Loss gradients of the same shape as `x`.
735736
"""
736-
x, y = self._apply_resizing(x, y)
737-
output, inputs_t, image_tensor_list_grad = self._get_losses(x=x, y=y)
737+
x_resized, y_resized = self._apply_resizing(x, y)
738+
output, inputs_t, image_tensor_list_grad = self._get_losses(x=x_resized, y=y_resized)
738739
loss = sum(output[k] * self.weight_dict[k] for k in output.keys() if k in self.weight_dict)
739740

740741
self._model.zero_grad()
741742

742743
loss.backward(retain_graph=True) # type: ignore
743744

744-
if isinstance(x, np.ndarray):
745+
if isinstance(x_resized, np.ndarray):
745746
if image_tensor_list_grad.grad is not None:
746747
grads = image_tensor_list_grad.grad.cpu().numpy().copy()
747748
else:
@@ -756,9 +757,7 @@ def loss_gradient(
756757
grads = grads / self.clip_values[1]
757758

758759
if not self.all_framework_preprocessing:
759-
grads = self._apply_preprocessing_gradient(x, grads)
760-
761-
assert grads.shape == x.shape
760+
grads = self._apply_preprocessing_gradient(x_resized, grads)
762761

763762
return grads
764763

@@ -787,8 +786,8 @@ def compute_losses(
787786
- scores (Tensor[N]): the scores or each prediction.
788787
:return: Dictionary of loss components.
789788
"""
790-
x, y = self._apply_resizing(x, y)
791-
output_tensor, _, _ = self._get_losses(x=x, y=y)
789+
x_resized, y = self._apply_resizing(x, y)
790+
output_tensor, _, _ = self._get_losses(x=x_resized, y=y)
792791
output = {}
793792
for key, value in output_tensor.items():
794793
if key in self.attack_losses:
@@ -824,7 +823,6 @@ def compute_loss( # type: ignore
824823
loss = output[loss_name]
825824
else:
826825
loss = loss + output[loss_name]
827-
828826
assert loss is not None
829827

830828
if isinstance(x, torch.Tensor):
@@ -835,10 +833,10 @@ def compute_loss( # type: ignore
835833
def _apply_resizing(
836834
self,
837835
x: Union[np.ndarray, "torch.Tensor"],
838-
y: List[Dict[str, Union[np.ndarray, "torch.Tensor"]]],
836+
y: Any = None,
839837
height: int = 800,
840838
width: int = 800,
841-
):
839+
) -> Tuple[Union[np.ndarray, "torch.Tensor"], List[Any]]:
842840
"""
843841
Resize the input and targets to dimensions expected by DETR.
844842
@@ -861,9 +859,9 @@ def _apply_resizing(
861859
if isinstance(x, torch.Tensor):
862860
x = T.Resize(size=(height, width))(x)
863861
else:
864-
for i, _ in enumerate(x):
862+
for i in x:
865863
resized = cv2.resize(
866-
(x)[i].transpose(1, 2, 0),
864+
i.transpose(1, 2, 0),
867865
dsize=(height, width),
868866
interpolation=cv2.INTER_CUBIC,
869867
)
@@ -877,20 +875,23 @@ def _apply_resizing(
877875
if isinstance(x, torch.Tensor):
878876
x = T.Resize(size=(rescale_dim, rescale_dim))(x)
879877
else:
880-
for i, _ in enumerate(x):
878+
for i in x:
881879
resized = cv2.resize(
882-
(x)[i].transpose(1, 2, 0),
880+
i.transpose(1, 2, 0),
883881
dsize=(rescale_dim, rescale_dim),
884882
interpolation=cv2.INTER_CUBIC,
885883
)
886884
resized = resized.transpose(2, 0, 1)
887885
resized_imgs.append(resized)
888886
x = np.array(resized_imgs)
889887

890-
targets = []
888+
targets: List[Any] = []
891889
if y is not None:
892890
if isinstance(y[0]["boxes"], torch.Tensor):
893891
for target in y:
892+
assert isinstance(target["boxes"], torch.Tensor)
893+
assert isinstance(target["labels"], torch.Tensor)
894+
assert isinstance(target["scores"], torch.Tensor)
894895
cxcy_norm = revert_rescale_bboxes(target["boxes"], (self.input_shape[2], self.input_shape[1]))
895896
targets.append(
896897
{
@@ -901,9 +902,8 @@ def _apply_resizing(
901902
)
902903
else:
903904
for target in y:
904-
cxcy_norm = revert_rescale_bboxes(
905-
torch.from_numpy(target["boxes"]), (self.input_shape[2], self.input_shape[1])
906-
)
905+
tensor_box = torch.from_numpy(target["boxes"])
906+
cxcy_norm = revert_rescale_bboxes(tensor_box, (self.input_shape[2], self.input_shape[1]))
907907
targets.append(
908908
{
909909
"labels": torch.from_numpy(target["labels"]).type(torch.int64).to(self.device),
@@ -988,11 +988,9 @@ def grad_enabled_forward(self, samples: NestedTensor):
988988
if isinstance(samples, (list, torch.Tensor)):
989989
samples = nested_tensor_from_tensor_list(samples)
990990
features, pos = self.backbone(samples)
991-
992991
src, mask = features[-1].decompose()
993992
assert mask is not None
994993
h_s = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
995-
996994
outputs_class = self.class_embed(h_s)
997995
outputs_coord = self.bbox_embed(h_s).sigmoid()
998996
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}

0 commit comments

Comments
 (0)