Skip to content

Commit 7818154

Browse files
author
Daniel Flores
committed
Add tolerances for various cases
1 parent 4f20ff1 commit 7818154

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

test/test_ops.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import os
1010
from functools import partial
1111

12-
from .utils import in_fbcode
12+
from .utils import get_ffmpeg_major_version, in_fbcode
1313

1414
os.environ["TORCH_LOGS"] = "output_code"
1515
import json
@@ -55,7 +55,6 @@
5555
SINE_MONO_S32,
5656
SINE_MONO_S32_44100,
5757
SINE_MONO_S32_8000,
58-
TESTSRC2_VIDEO,
5958
)
6059

6160
torch._dynamo.config.capture_dynamic_output_shape_ops = True
@@ -1310,9 +1309,18 @@ def decode(self, file_path) -> torch.Tensor:
13101309
return frames
13111310

13121311
@pytest.mark.parametrize("format", ("mov", "mp4", "avi", "mkv", "webm", "flv"))
1313-
# TODO-VideoEncoder: enable additional formats ("mkv", "webm", "flv")
13141312
def test_video_encoder_test_round_trip(self, tmp_path, format):
1315-
asset = TESTSRC2_VIDEO
1313+
1314+
ffmpeg_version = get_ffmpeg_major_version()
1315+
if ffmpeg_version == 4 and format == "webm":
1316+
pytest.skip("Codec for webm is not available in the FFmpeg4 installation.")
1317+
# The output pixel format depends on the codecs available, and FFmpeg version.
1318+
# In the cases where YUV420P is chosen and chroma subsampling happens, we need higher tolerance.
1319+
if ffmpeg_version == 6 or format in ("avi", "flv"):
1320+
atol = 55
1321+
else:
1322+
atol = 2
1323+
asset = NASA_VIDEO
13161324

13171325
# Test that decode(encode(decode(asset))) == decode(asset)
13181326
source_frames = self.decode(str(asset.path)).data
@@ -1326,7 +1334,7 @@ def test_video_encoder_test_round_trip(self, tmp_path, format):
13261334
for s_frame, rt_frame in zip(source_frames, round_trip_frames):
13271335
res = psnr(s_frame, rt_frame)
13281336
assert res > 30
1329-
torch.testing.assert_close(s_frame, rt_frame, atol=0, rtol=0)
1337+
torch.testing.assert_close(s_frame, rt_frame, atol=atol, rtol=0)
13301338

13311339

13321340
if __name__ == "__main__":

0 commit comments

Comments
 (0)