Skip to content

Commit 67e89b3

Browse files
committed
linting and style checks
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent 5fe8e36 commit 67e89b3

File tree

3 files changed

+19
-9
lines changed

3 files changed

+19
-9
lines changed

art/estimators/object_detection/pytorch_object_detector.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,11 @@ def fit( # pylint: disable=W0221
415415
# Apply preprocessing and convert to tensors
416416
x_preprocessed, y_preprocessed = self._preprocess_and_convert_inputs(x=x, y=y, fit=True, no_grad=True)
417417

418-
class ObjectDetectorDataset(Dataset):
418+
class ObjectDetectionDataset(Dataset):
419+
"""
420+
Object detection dataset in PyTorch.
421+
"""
422+
419423
def __init__(self, x, y):
420424
self.x = x
421425
self.y = y
@@ -427,7 +431,7 @@ def __getitem__(self, idx):
427431
return self.x[idx], self.y[idx]
428432

429433
# Create dataloader
430-
dataset = ObjectDetectorDataset(x_preprocessed, y_preprocessed)
434+
dataset = ObjectDetectionDataset(x_preprocessed, y_preprocessed)
431435
dataloader = DataLoader(
432436
dataset=dataset,
433437
batch_size=batch_size,
@@ -442,7 +446,7 @@ def __getitem__(self, idx):
442446
for x_batch, y_batch in dataloader:
443447
# Move inputs to device
444448
x_batch = torch.stack(x_batch).to(self.device)
445-
y_batch = y_batch = [{k: v.to(self.device) for k, v in y_i.items()} for y_i in y_batch]
449+
y_batch = [{k: v.to(self.device) for k, v in y_i.items()} for y_i in y_batch]
446450

447451
# Zero the parameter gradients
448452
self._optimizer.zero_grad()

art/estimators/object_detection/pytorch_yolo.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,9 @@ def _get_losses(
347347
width = self.input_shape[1]
348348

349349
# Convert labels to YOLO format
350-
y_preprocessed_yolo = translate_labels_x1y1x2y2_to_xcycwh(labels_x1y1x2y2=y_preprocessed, height=height, width=width)
350+
y_preprocessed_yolo = translate_labels_x1y1x2y2_to_xcycwh(
351+
labels_x1y1x2y2=y_preprocessed, height=height, width=width
352+
)
351353

352354
# Move inputs to device
353355
x_preprocessed = x_preprocessed.to(self.device)
@@ -514,7 +516,11 @@ def fit( # pylint: disable=W0221
514516
# Apply preprocessing and convert to tensors
515517
x_preprocessed, y_preprocessed = self._preprocess_and_convert_inputs(x=x, y=y, fit=True, no_grad=True)
516518

517-
class ObjectDetectorDataset(Dataset):
519+
class ObjectDetectionDataset(Dataset):
520+
"""
521+
Object detection dataset in PyTorch.
522+
"""
523+
518524
def __init__(self, x, y):
519525
self.x = x
520526
self.y = y
@@ -526,7 +532,7 @@ def __getitem__(self, idx):
526532
return self.x[idx], self.y[idx]
527533

528534
# Create dataloader
529-
dataset = ObjectDetectorDataset(x_preprocessed, y_preprocessed)
535+
dataset = ObjectDetectionDataset(x_preprocessed, y_preprocessed)
530536
dataloader = DataLoader(
531537
dataset=dataset,
532538
batch_size=batch_size,

art/estimators/object_detection/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""
1919
This module contains utility functions for object detection.
2020
"""
21-
from typing import Dict, List, Any, Tuple, Union, Optional, TYPE_CHECKING
21+
from typing import Dict, List, Union, Tuple, Optional, TYPE_CHECKING
2222

2323
import numpy as np
2424

@@ -95,8 +95,8 @@ def convert_pt_to_tf(y: List[Dict[str, np.ndarray]], height: int, width: int) ->
9595

9696

9797
def cast_inputs_to_pt(
98-
x: np.ndarray,
99-
y: Optional[List[Dict[str, np.ndarray]]] = None,
98+
x: Union[np.ndarray, "torch.Tensor"],
99+
y: Optional[List[Dict[str, Union[np.ndarray, "torch.Tensor"]]]] = None,
100100
) -> Tuple["torch.Tensor", List[Dict[str, "torch.Tensor"]]]:
101101
"""
102102
Cast object detection inputs `(x, y)` to PyTorch tensors.

0 commit comments

Comments
 (0)