Skip to content

Commit acbf67d

Browse files
authored
Merge pull request #1470 from Trusted-AI/development_issue_1469
Preprocessing defences in PyTorchGoturn and VideoCompression for [0, 1] range
2 parents 676b996 + a537934 commit acbf67d

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

art/defences/preprocessor/video_compression.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import os
2828
from tempfile import TemporaryDirectory
2929
from typing import Optional, Tuple
30+
import warnings
3031

3132
import numpy as np
3233
from tqdm.auto import tqdm
@@ -78,7 +79,8 @@ def __call__(self, x: np.ndarray, y: Optional[np.ndarray] = None) -> Tuple[np.nd
7879
"""
7980
Apply video compression to sample `x`.
8081
81-
:param x: Sample to compress of shape NCFHW or NFHWC. `x` values are expected to be in the data range [0, 255].
82+
:param x: Sample to compress of shape NCFHW or NFHWC. `x` values are expected to be either in range [0, 1] or
83+
[0, 255].
8284
:param y: Labels of the sample `x`. This function does not affect them in any way.
8385
:return: Compressed sample.
8486
"""
@@ -92,6 +94,9 @@ def compress_video(x: np.ndarray, video_format: str, constant_rate_factor: int,
9294
video_path = os.path.join(dir_, f"tmp_video.{video_format}")
9395
_, height, width, _ = x.shape
9496

97+
if (height % 2) != 0 or (width % 2) != 0:
98+
warnings.warn("Codec might require even number of pixels in height and width.")
99+
95100
# numpy to local video file
96101
process = (
97102
ffmpeg.input("pipe:", format="rawvideo", pix_fmt="rgb24", s=f"{width}x{height}")
@@ -118,11 +123,19 @@ def compress_video(x: np.ndarray, video_format: str, constant_rate_factor: int,
118123
x = np.transpose(x, (0, 2, 3, 4, 1))
119124

120125
# apply video compression per video item
126+
scale = 1
127+
if x.min() >= 0 and x.max() <= 1.0:
128+
scale = 255
129+
121130
x_compressed = x.copy()
122131
with TemporaryDirectory(dir=config.ART_DATA_PATH) as tmp_dir:
123132
for i, x_i in enumerate(tqdm(x, desc="Video compression", disable=not self.verbose)):
133+
x_i *= scale
124134
x_compressed[i] = compress_video(x_i, self.video_format, self.constant_rate_factor, dir_=tmp_dir)
125135

136+
x_compressed = x_compressed / scale
137+
x_compressed = x_compressed.astype(x.dtype)
138+
126139
if self.channels_first:
127140
x_compressed = np.transpose(x_compressed, (0, 4, 1, 2, 3))
128141

art/defences/preprocessor/video_compression_pytorch.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,17 @@ def forward(
116116
"""
117117
Apply video compression to sample `x`.
118118
119-
:param x: Sample to compress of shape NCFHW or NFHWC. `x` values are expected to be in the data range [0, 255].
119+
:param x: Sample to compress of shape NCFHW or NFHWC. `x` values are expected to be either in range [0, 1] or
120+
[0, 255].
120121
:param y: Labels of the sample `x`. This function does not affect them in any way.
121122
:return: Compressed sample.
122123
"""
124+
scale = 1
125+
if x.min() >= 0 and x.max() <= 1.0:
126+
scale = 255
127+
x = x * scale
123128
x_compressed = self._compression_pytorch_numpy.apply(x)
129+
x_compressed = x_compressed / scale
124130
return x_compressed, y
125131

126132
def _check_params(self) -> None:

art/estimators/object_tracking/pytorch_goturn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,10 +670,11 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> List[Dict[s
670670
x_i = x[i].to(self.device)
671671

672672
# Apply preprocessing
673+
x_i = torch.unsqueeze(x_i, dim=0)
673674
x_i, _ = self._apply_preprocessing(x_i, y=None, fit=False, no_grad=False)
675+
x_i = torch.squeeze(x_i)
674676

675677
y_pred = self._track(x=x_i, y_init=y_init[i])
676-
677678
prediction_dict = dict()
678679
if isinstance(x, np.ndarray):
679680
prediction_dict["boxes"] = y_pred.detach().cpu().numpy()

0 commit comments

Comments
 (0)