Skip to content

Commit 6980c62

Browse files
committed
Fix for AttributeError
1 parent 69c4361 commit 6980c62

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

art/estimators/object_detection/pytorch_object_detector.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,10 @@ def __init__(
140140

141141
self._model: torch.nn.Module
142142
self._model.to(self._device)
143-
self._model.model.eval()
143+
try:
144+
self._model.model.eval()
145+
except AttributeError:
146+
self._model.eval()
144147

145148
@property
146149
def native_label_is_pytorch_format(self) -> bool:
@@ -403,7 +406,10 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> list[dict[s
403406
from torch.utils.data import TensorDataset, DataLoader
404407

405408
# Set model to evaluation mode
406-
self._model.model.eval()
409+
try:
410+
self._model.model.eval()
411+
except AttributeError:
412+
self._model.eval()
407413

408414
# Apply preprocessing and convert to tensors
409415
x_preprocessed, _ = self._preprocess_and_convert_inputs(x=x, y=None, fit=False, no_grad=True)

0 commit comments

Comments
 (0)