Skip to content

Commit 78a8238

Browse files
committed
Removing is_yolov8 flag from pytorch object detector.
Signed-off-by: Kieran Fraser <[email protected]>
1 parent 6aba944 commit 78a8238

File tree

1 file changed

+2
-11
lines changed

1 file changed

+2
-11
lines changed

art/estimators/object_detection/pytorch_object_detector.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def __init__(
6666
"loss_rpn_box_reg",
6767
),
6868
device_type: str = "gpu",
69-
is_yolov8: bool = False,
7069
):
7170
"""
7271
Initialization.
@@ -94,7 +93,6 @@ def __init__(
9493
'loss_objectness', and 'loss_rpn_box_reg'.
9594
:param device_type: Type of device to be used for model and tensors, if `cpu` run on CPU, if `gpu` run on GPU
9695
if available otherwise run on CPU.
97-
:param is_yolov8: The flag to be used for marking the YOLOv8 model.
9896
"""
9997
import torch
10098
import torchvision
@@ -139,11 +137,7 @@ def __init__(
139137

140138
self._model: torch.nn.Module
141139
self._model.to(self._device)
142-
self.is_yolov8 = is_yolov8
143-
if self.is_yolov8:
144-
self._model.model.eval() # type: ignore
145-
else:
146-
self._model.eval()
140+
self._model.eval()
147141

148142
@property
149143
def native_label_is_pytorch_format(self) -> bool:
@@ -412,10 +406,7 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> list[dict[s
412406
from torch.utils.data import TensorDataset, DataLoader
413407

414408
# Set model to evaluation mode
415-
if self.is_yolov8:
416-
self._model.model.eval() # type: ignore
417-
else:
418-
self._model.eval()
409+
self._model.eval()
419410

420411
# Apply preprocessing and convert to tensors
421412
x_preprocessed, _ = self._preprocess_and_convert_inputs(x=x, y=None, fit=False, no_grad=True)

0 commit comments

Comments
 (0)