Skip to content

Commit a95b42b

Browse files
committed
optimize pytorch yolo loops
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent b2c8c2a commit a95b42b

File tree

3 files changed

+171
-65
lines changed

3 files changed

+171
-65
lines changed

art/estimators/object_detection/pytorch_object_detector.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import numpy as np
2525

2626
from art.estimators.object_detection.object_detector import ObjectDetectorMixin
27+
from art.estimators.object_detection.utils import cast_inputs_to_pt
2728
from art.estimators.pytorch import PyTorchEstimator
2829

2930
if TYPE_CHECKING:
@@ -180,6 +181,69 @@ def device(self) -> "torch.device":
180181
"""
181182
return self._device
182183

184+
def _preprocess_and_convert_inputs(
185+
self,
186+
x: Union[np.ndarray, "torch.Tensor"],
187+
y: Optional[List[Dict[str, Union[np.ndarray, "torch.Tensor"]]]] = None,
188+
fit: bool = False,
189+
no_grad: bool = True,
190+
) -> Tuple["torch.Tensor", List[Dict[str, "torch.Tensor"]]]:
191+
"""
192+
Apply preprocessing on inputs `(x, y)` and convert to tensors, if needed.
193+
194+
:param x: Samples of shape NCHW or NHWC.
195+
:param y: Target values of format `List[Dict[str, Union[np.ndarray, torch.Tensor]]]`, one for each input image.
196+
The fields of the Dict are as follows:
197+
198+
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
199+
- labels [N]: the labels for each image.
200+
:param fit: `True` if the function is call before fit/training and `False` if the function is called before a
201+
predict operation.
202+
:param no_grad: `True` if no gradients required.
203+
:return: Preprocessed inputs `(x, y)` as tensors.
204+
"""
205+
import torch
206+
207+
if self.clip_values is not None:
208+
norm_factor = self.clip_values[1]
209+
else:
210+
norm_factor = 1.0
211+
212+
if self.all_framework_preprocessing:
213+
# Convert samples into tensor
214+
x_tensor, y_tensor = cast_inputs_to_pt(x, y)
215+
216+
if not self.channels_first:
217+
x_tensor = torch.permute(x_tensor, (0, 3, 1, 2))
218+
x_tensor /= norm_factor
219+
220+
# Set gradients
221+
if not no_grad:
222+
x_tensor.requires_grad = True
223+
224+
# Apply framework-specific preprocessing
225+
x_preprocessed, y_preprocessed = self._apply_preprocessing(x=x_tensor, y=y_tensor, fit=fit, no_grad=no_grad)
226+
227+
elif isinstance(x, np.ndarray):
228+
# Apply preprocessing
229+
x_preprocessed, y_preprocessed = self._apply_preprocessing(x=x, y=y, fit=fit, no_grad=no_grad)
230+
231+
# Convert inputs into tensor
232+
x_preprocessed, y_preprocessed = cast_inputs_to_pt(x_preprocessed, y_preprocessed)
233+
234+
if not self.channels_first:
235+
x_preprocessed = torch.permute(x_preprocessed, (0, 3, 1, 2))
236+
x_preprocessed /= norm_factor
237+
238+
# Set gradients
239+
if not no_grad:
240+
x_preprocessed.requires_grad = True
241+
242+
else:
243+
raise NotImplementedError("Combination of inputs and preprocessing not supported.")
244+
245+
return x_preprocessed, y_preprocessed
246+
183247
def _get_losses(
184248
self, x: np.ndarray, y: List[Dict[str, Union[np.ndarray, "torch.Tensor"]]]
185249
) -> Tuple[Dict[str, "torch.Tensor"], List["torch.Tensor"], List["torch.Tensor"]]:

art/estimators/object_detection/pytorch_yolo.py

Lines changed: 55 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import numpy as np
2727

2828
from art.estimators.object_detection.object_detector import ObjectDetectorMixin
29+
from art.estimators.object_detection.utils import cast_inputs_to_pt
2930
from art.estimators.pytorch import PyTorchEstimator
3031

3132
if TYPE_CHECKING:
@@ -296,28 +297,12 @@ def _preprocess_and_convert_inputs(
296297
norm_factor = 1.0
297298

298299
if self.all_framework_preprocessing:
299-
if isinstance(x, np.ndarray):
300-
# Convert samples into tensor
301-
x_tensor = torch.from_numpy(x / norm_factor).to(self.device)
302-
else:
303-
x_tensor = (x / norm_factor).to(self.device)
300+
# Convert samples into tensor
301+
x_tensor, y_tensor = cast_inputs_to_pt(x, y)
304302

305303
if not self.channels_first:
306304
x_tensor = torch.permute(x_tensor, (0, 3, 1, 2))
307-
308-
# Convert targets into tensor
309-
if y is not None and isinstance(y[0]["boxes"], np.ndarray):
310-
y_tensor = []
311-
for y_i in y:
312-
y_t = {
313-
"boxes": torch.from_numpy(y_i["boxes"]).to(device=self.device, dtype=torch.float32),
314-
"labels": torch.from_numpy(y_i["labels"]).to(device=self.device, dtype=torch.int64),
315-
}
316-
if "masks" in y_i:
317-
y_t["masks"] = torch.from_numpy(y_i["masks"]).to(device=self.device, dtype=torch.uint8)
318-
y_tensor.append(y_t)
319-
else:
320-
y_tensor = y # type: ignore
305+
x_tensor /= norm_factor
321306

322307
# Set gradients
323308
if not no_grad:
@@ -328,33 +313,19 @@ def _preprocess_and_convert_inputs(
328313

329314
elif isinstance(x, np.ndarray):
330315
# Apply preprocessing
331-
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y=y, fit=fit, no_grad=no_grad)
316+
x_preprocessed, y_preprocessed = self._apply_preprocessing(x=x, y=y, fit=fit, no_grad=no_grad)
332317

333-
# Convert samples into tensor
334-
x_preprocessed = torch.from_numpy(x_preprocessed / norm_factor).to(self.device)
318+
# Convert inputs into tensor
319+
x_preprocessed, y_preprocessed = cast_inputs_to_pt(x_preprocessed, y_preprocessed)
335320

336321
if not self.channels_first:
337322
x_preprocessed = torch.permute(x_preprocessed, (0, 3, 1, 2))
323+
x_preprocessed /= norm_factor
338324

339325
# Set gradients
340326
if not no_grad:
341327
x_preprocessed.requires_grad = True
342328

343-
# Convert targets into tensor
344-
if y_preprocessed is not None and isinstance(y_preprocessed[0]["boxes"], np.ndarray):
345-
y_preprocessed_tensor = []
346-
for y_i in y_preprocessed:
347-
y_preprocessed_t = {
348-
"boxes": torch.from_numpy(y_i["boxes"]).to(device=self.device, dtype=torch.float32),
349-
"labels": torch.from_numpy(y_i["labels"]).to(device=self.device, dtype=torch.int64),
350-
}
351-
if "masks" in y_i:
352-
y_preprocessed_t["masks"] = torch.from_numpy(y_i["masks"]).to(
353-
device=self.device, dtype=torch.uint8
354-
)
355-
y_preprocessed_tensor.append(y_preprocessed_t)
356-
y_preprocessed = y_preprocessed_tensor
357-
358329
else:
359330
raise NotImplementedError("Combination of inputs and preprocessing not supported.")
360331

@@ -380,6 +351,7 @@ def _get_losses(
380351
x_preprocessed, y_preprocessed = self._preprocess_and_convert_inputs(x=x, y=y, fit=False, no_grad=False)
381352
x_grad = x_preprocessed
382353

354+
# Extract height and width
383355
if self.channels_first:
384356
height = self.input_shape[1]
385357
width = self.input_shape[2]
@@ -389,7 +361,7 @@ def _get_losses(
389361

390362
labels_t = translate_labels_x1y1x2y2_to_xcycwh(labels_x1y1x2y2=y_preprocessed, height=height, width=width)
391363

392-
loss_components = self._model(x_grad, labels_t)
364+
loss_components = self._model(x_grad.to(self.device), labels_t.to(self.device))
393365

394366
return loss_components, x_grad
395367

@@ -463,30 +435,34 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> List[Dict[s
463435
- scores [N]: the scores of each prediction.
464436
"""
465437
import torch
438+
from torch.utils.data import TensorDataset, DataLoader
466439

467440
# Set model to evaluation mode
468441
self._model.eval()
469442

470443
# Apply preprocessing and convert to tensors
471444
x_preprocessed, _ = self._preprocess_and_convert_inputs(x=x, y=None, fit=False, no_grad=True)
472445

473-
predictions: List[Dict[str, np.ndarray]] = []
446+
# Create dataloader
447+
dataset = TensorDataset(x_preprocessed)
448+
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)
474449

450+
# Extract height and width
475451
if self.channels_first:
476452
height = self.input_shape[1]
477453
width = self.input_shape[2]
478454
else:
479455
height = self.input_shape[0]
480456
width = self.input_shape[1]
481457

482-
# Run prediction
483-
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
484-
for m in range(num_batch):
485-
# Batch using indices
486-
i_batch = x_preprocessed[m * batch_size : (m + 1) * batch_size]
458+
predictions: List[Dict[str, np.ndarray]] = []
459+
for (x_batch,) in dataloader:
460+
# Move inputs to device
461+
x_batch = x_batch.to(self._device)
487462

463+
# Run prediction
488464
with torch.no_grad():
489-
predictions_xcycwh = self._model(i_batch)
465+
predictions_xcycwh = self._model(x_batch.to(self.device))
490466

491467
predictions_x1y1x2y2 = translate_predictions_xcycwh_to_x1y1x2y2(
492468
y_pred_xcycwh=predictions_xcycwh, height=height, width=width
@@ -533,6 +509,8 @@ def fit( # pylint: disable=W0221
533509
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
534510
and providing it takes no effect.
535511
"""
512+
import torch
513+
from torch.utils.data import Dataset, DataLoader
536514

537515
# Set model to train mode
538516
self._model.train()
@@ -541,41 +519,54 @@ def fit( # pylint: disable=W0221
541519
raise ValueError("An optimizer is needed to train the model, but none for provided.")
542520

543521
# Apply preprocessing and convert to tensors
544-
x_preprocessed, y_preprocessed_list = self._preprocess_and_convert_inputs(x=x, y=y, fit=True, no_grad=True)
545-
546-
# Cast to np.ndarray to use list indexing
547-
y_preprocessed = np.asarray(y_preprocessed_list)
522+
x_preprocessed, y_preprocessed = self._preprocess_and_convert_inputs(x=x, y=y, fit=True, no_grad=True)
523+
524+
class ObjectDetectorDataset(Dataset):
525+
def __init__(self, x, y):
526+
self.x = x
527+
self.y = y
528+
529+
def __len__(self):
530+
return len(self.x)
531+
532+
def __getitem__(self, idx):
533+
return self.x[idx], self.y[idx]
534+
535+
# Create dataloader
536+
dataset = ObjectDetectorDataset(x_preprocessed, y_preprocessed)
537+
dataloader = DataLoader(
538+
dataset=dataset,
539+
batch_size=batch_size,
540+
shuffle=True,
541+
drop_last=drop_last,
542+
collate_fn=lambda batch: list(zip(*batch)),
543+
)
548544

545+
# Extract height and width
549546
if self.channels_first:
550547
height = self.input_shape[1]
551548
width = self.input_shape[2]
552549
else:
553550
height = self.input_shape[0]
554551
width = self.input_shape[1]
555552

556-
num_batch = len(x_preprocessed) / float(batch_size)
557-
if drop_last:
558-
num_batch = int(np.floor(num_batch))
559-
else:
560-
num_batch = int(np.ceil(num_batch))
561-
ind = np.arange(len(x_preprocessed))
562-
563553
# Start training
564554
for _ in range(nb_epochs):
565-
# Shuffle the examples
566-
np.random.shuffle(ind)
567-
568555
# Train for one epoch
569-
for m in range(num_batch):
570-
i_batch = x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
571-
o_batch = y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
556+
for x_batch, y_batch in dataloader:
557+
# Convert labels to YOLO
558+
x_batch = torch.stack(x_batch)
559+
y_batch = translate_labels_x1y1x2y2_to_xcycwh(labels_x1y1x2y2=y_batch, height=height, width=width)
560+
561+
# Move inputs to device
562+
x_batch = x_batch.to(self.device)
563+
y_batch = y_batch.to(self.device)
572564

573565
# Zero the parameter gradients
574566
self._optimizer.zero_grad()
575567

576568
# Form the loss function
577-
labels_t = translate_labels_x1y1x2y2_to_xcycwh(labels_x1y1x2y2=o_batch, height=height, width=width)
578-
loss_components = self._model(i_batch, labels_t)
569+
loss_components = self._model(x_batch, y_batch)
579570
if isinstance(loss_components, dict):
580571
loss = sum(loss_components.values())
581572
else:

art/estimators/object_detection/utils.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@
1818
"""
1919
This module contains utility functions for object detection.
2020
"""
21-
from typing import Dict, List
21+
from typing import Dict, List, Any, Tuple, Union, Optional, TYPE_CHECKING
2222

2323
import numpy as np
2424

25+
if TYPE_CHECKING:
26+
# pylint: disable=C0412
27+
import torch
28+
2529

2630
def convert_tf_to_pt(y: List[Dict[str, np.ndarray]], height: int, width: int) -> List[Dict[str, np.ndarray]]:
2731
"""
@@ -88,3 +92,50 @@ def convert_pt_to_tf(y: List[Dict[str, np.ndarray]], height: int, width: int) ->
8892
y[i]["labels"] = y[i]["labels"] - 1
8993

9094
return y
95+
96+
97+
def cast_inputs_to_pt(
98+
x: np.ndarray,
99+
y: Optional[List[Dict[str, np.ndarray]]] = None,
100+
) -> Tuple["torch.Tensor", List[Dict[str, "torch.Tensor"]]]:
101+
"""
102+
Cast object detection inputs `(x, y)` to PyTorch tensors.
103+
104+
:param x: Samples of shape NCHW or NHWC.
105+
:param y: Target values of format `List[Dict[str, Union[np.ndarray, torch.Tensor]]]`, one for each input image.
106+
The fields of the Dict are as follows:
107+
108+
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
109+
- labels [N]: the labels for each image.
110+
:return: Object detection inputs `(x, y)` as tensors.
111+
"""
112+
import torch
113+
114+
# Convert images into tensor
115+
if isinstance(x, np.ndarray):
116+
x_tensor = torch.from_numpy(x)
117+
else:
118+
x_tensor = x
119+
120+
# Convert labels into tensor
121+
if y is not None and isinstance(y, list) and isinstance(y[0]["boxes"], np.ndarray):
122+
y_tensor = []
123+
for y_i in y:
124+
y_t = {
125+
"boxes": torch.from_numpy(y_i["boxes"]).to(dtype=torch.float32),
126+
"labels": torch.from_numpy(y_i["labels"]).to(dtype=torch.int64),
127+
}
128+
if "masks" in y_i:
129+
y_t["masks"] = torch.from_numpy(y_i["masks"]).to(dtype=torch.uint8)
130+
y_tensor.append(y_t)
131+
elif y is not None and isinstance(y, dict):
132+
y_tensor = []
133+
for i in range(y["boxes"].shape[0]):
134+
y_t = {}
135+
y_t["boxes"] = y["boxes"][i]
136+
y_t["labels"] = y["labels"][i]
137+
y_tensor.append(y_t)
138+
else:
139+
y_tensor = y # type: ignore
140+
141+
return x_tensor, y_tensor

0 commit comments

Comments
 (0)