|
18 | 18 | from tensordict import MemoryMappedTensor |
19 | 19 |
|
20 | 20 | from torchrl.envs import check_env_specs, GymEnv, ParallelEnv |
| 21 | +from torchrl.record.loggers.common import _has_torchcodec, _has_tv |
21 | 22 | from torchrl.record.loggers.csv import CSVLogger |
22 | | -from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger |
| 23 | +from torchrl.record.loggers.mlflow import _has_mlflow, MLFlowLogger |
23 | 24 | from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger |
24 | 25 | from torchrl.record.loggers.trackio import _has_trackio, TrackioLogger |
25 | 26 | from torchrl.record.loggers.wandb import _has_wandb, WandbLogger |
|
34 | 35 | else: |
35 | 36 | TORCHVISION_VERSION = version.parse("0.0.1") |
36 | 37 |
|
| 38 | +_has_mp4 = (_has_tv and TORCHVISION_VERSION < version.parse("0.22")) or _has_torchcodec |
| 39 | + |
37 | 40 | if _has_tb: |
38 | 41 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator |
39 | 42 |
|
@@ -170,10 +173,7 @@ def test_log_scalar(self, steps, tmpdir): |
170 | 173 |
|
171 | 174 | @pytest.mark.parametrize("steps", [None, [1, 10, 11]]) |
172 | 175 | @pytest.mark.parametrize( |
173 | | - "video_format", ["pt", "memmap"] + ["mp4"] if _has_tv else [] |
174 | | - ) |
175 | | - @pytest.mark.skipif( |
176 | | - TORCHVISION_VERSION < version.parse("0.20.0"), reason="av compatibility bug" |
| 176 | + "video_format", ["pt", "memmap"] + (["mp4"] if _has_mp4 else []) |
177 | 177 | ) |
178 | 178 | def test_log_video(self, steps, video_format, tmpdir): |
179 | 179 | torch.manual_seed(0) |
@@ -217,11 +217,18 @@ def test_log_video(self, steps, video_format, tmpdir): |
217 | 217 | ) |
218 | 218 | assert torch.equal(video, logged_video), logged_video |
219 | 219 | elif video_format == "mp4": |
220 | | - import torchvision |
| 220 | + if _has_torchcodec: |
| 221 | + from torchcodec.decoders import VideoDecoder |
221 | 222 |
|
222 | | - logged_video = torchvision.io.read_video(path, output_format="TCHW")[0][ |
223 | | - :, :1 |
224 | | - ] |
| 223 | + logged_video = ( |
| 224 | + VideoDecoder(path) |
| 225 | + .get_frames_in_range(start=0, stop=128) |
| 226 | + .data[:, :1] |
| 227 | + ) |
| 228 | + else: |
| 229 | + logged_video = torchvision.io.read_video(path, output_format="TCHW")[0][ |
| 230 | + :, :1 |
| 231 | + ] |
225 | 232 | logged_video = logged_video.unsqueeze(0) |
226 | 233 | torch.testing.assert_close(video, logged_video) |
227 | 234 |
|
@@ -376,7 +383,7 @@ def test_log_scalar(self, steps, mlflow_fixture): |
376 | 383 | assert metric.value == values[i].item() |
377 | 384 |
|
378 | 385 | @pytest.mark.parametrize("steps", [None, [1, 10, 11]]) |
379 | | - @pytest.mark.skipif(not _has_tv, reason="torchvision not installed") |
| 386 | + @pytest.mark.skipif(not _has_mp4, reason="no mp4 video backend available") |
380 | 387 | def test_log_video(self, steps, mlflow_fixture): |
381 | 388 |
|
382 | 389 | logger, client = mlflow_fixture |
|
0 commit comments