Skip to content

Commit e542bf7

Browse files
author
Beat Buesser
committed
Add support for images in range [0, 1] in VideoCompression
Signed-off-by: Beat Buesser <[email protected]>
1 parent 300e607 commit e542bf7

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

art/defences/preprocessor/video_compression.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import os
2828
from tempfile import TemporaryDirectory
2929
from typing import Optional, Tuple
30+
import warnings
3031

3132
import numpy as np
3233
from tqdm.auto import tqdm
@@ -78,7 +79,8 @@ def __call__(self, x: np.ndarray, y: Optional[np.ndarray] = None) -> Tuple[np.nd
7879
"""
7980
Apply video compression to sample `x`.
8081
81-
:param x: Sample to compress of shape NCFHW or NFHWC. `x` values are expected to be in the data range [0, 255].
82+
:param x: Sample to compress of shape NCFHW or NFHWC. `x` values are expected to be either in range [0, 1] or
83+
[0, 255].
8284
:param y: Labels of the sample `x`. This function does not affect them in any way.
8385
:return: Compressed sample.
8486
"""
@@ -92,6 +94,9 @@ def compress_video(x: np.ndarray, video_format: str, constant_rate_factor: int,
9294
video_path = os.path.join(dir_, f"tmp_video.{video_format}")
9395
_, height, width, _ = x.shape
9496

97+
if (height % 2) != 0 or (width % 2) != 0:
98+
warnings.warn("Codec might require even number of pixels in height and width.")
99+
95100
# numpy to local video file
96101
process = (
97102
ffmpeg.input("pipe:", format="rawvideo", pix_fmt="rgb24", s=f"{width}x{height}")
@@ -118,11 +123,19 @@ def compress_video(x: np.ndarray, video_format: str, constant_rate_factor: int,
118123
x = np.transpose(x, (0, 2, 3, 4, 1))
119124

120125
# apply video compression per video item
126+
scale = 1
127+
if x.min() >= 0 and x.max() <= 1.0:
128+
scale = 255
129+
121130
x_compressed = x.copy()
122131
with TemporaryDirectory(dir=config.ART_DATA_PATH) as tmp_dir:
123132
for i, x_i in enumerate(tqdm(x, desc="Video compression", disable=not self.verbose)):
133+
x_i *= scale
124134
x_compressed[i] = compress_video(x_i, self.video_format, self.constant_rate_factor, dir_=tmp_dir)
125135

136+
x_compressed = x_compressed / scale
137+
x_compressed = x_compressed.astype(x.dtype)
138+
126139
if self.channels_first:
127140
x_compressed = np.transpose(x_compressed, (0, 4, 1, 2, 3))
128141

art/defences/preprocessor/video_compression_pytorch.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,17 @@ def forward(
116116
"""
117117
Apply video compression to sample `x`.
118118
119-
:param x: Sample to compress of shape NCFHW or NFHWC. `x` values are expected to be in the data range [0, 255].
119+
:param x: Sample to compress of shape NCFHW or NFHWC. `x` values are expected to be either in range [0, 1] or
120+
[0, 255].
120121
:param y: Labels of the sample `x`. This function does not affect them in any way.
121122
:return: Compressed sample.
122123
"""
124+
scale = 1
125+
if x.min() >= 0 and x.max() <= 1.0:
126+
scale = 255
127+
x = x * scale
123128
x_compressed = self._compression_pytorch_numpy.apply(x)
129+
x_compressed = x_compressed / scale
124130
return x_compressed, y
125131

126132
def _check_params(self) -> None:

0 commit comments

Comments
 (0)