Skip to content

Commit 68c808e

Browse files
committed
update-ci
1 parent c0c15cc commit 68c808e

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

.github/unittest/linux/scripts/run_all.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ uv_pip_install \
129129
wandb \
130130
mlflow \
131131
av \
132+
torchcodec \
132133
coverage \
133134
transformers \
134135
ninja \

test/test_loggers.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
from tensordict import MemoryMappedTensor
1919

2020
from torchrl.envs import check_env_specs, GymEnv, ParallelEnv
21+
from torchrl.record.loggers.common import _has_torchcodec, _has_tv
2122
from torchrl.record.loggers.csv import CSVLogger
22-
from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger
23+
from torchrl.record.loggers.mlflow import _has_mlflow, MLFlowLogger
2324
from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger
2425
from torchrl.record.loggers.trackio import _has_trackio, TrackioLogger
2526
from torchrl.record.loggers.wandb import _has_wandb, WandbLogger
@@ -34,6 +35,8 @@
3435
else:
3536
TORCHVISION_VERSION = version.parse("0.0.1")
3637

38+
_has_mp4 = (_has_tv and TORCHVISION_VERSION < version.parse("0.22")) or _has_torchcodec
39+
3740
if _has_tb:
3841
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
3942

@@ -170,10 +173,7 @@ def test_log_scalar(self, steps, tmpdir):
170173

171174
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
172175
@pytest.mark.parametrize(
173-
"video_format", ["pt", "memmap"] + ["mp4"] if _has_tv else []
174-
)
175-
@pytest.mark.skipif(
176-
TORCHVISION_VERSION < version.parse("0.20.0"), reason="av compatibility bug"
176+
"video_format", ["pt", "memmap"] + (["mp4"] if _has_mp4 else [])
177177
)
178178
def test_log_video(self, steps, video_format, tmpdir):
179179
torch.manual_seed(0)
@@ -217,11 +217,18 @@ def test_log_video(self, steps, video_format, tmpdir):
217217
)
218218
assert torch.equal(video, logged_video), logged_video
219219
elif video_format == "mp4":
220-
import torchvision
220+
if _has_torchcodec:
221+
from torchcodec.decoders import VideoDecoder
221222

222-
logged_video = torchvision.io.read_video(path, output_format="TCHW")[0][
223-
:, :1
224-
]
223+
logged_video = (
224+
VideoDecoder(path)
225+
.get_frames_in_range(start=0, stop=128)
226+
.data[:, :1]
227+
)
228+
else:
229+
logged_video = torchvision.io.read_video(path, output_format="TCHW")[0][
230+
:, :1
231+
]
225232
logged_video = logged_video.unsqueeze(0)
226233
torch.testing.assert_close(video, logged_video)
227234

@@ -376,7 +383,7 @@ def test_log_scalar(self, steps, mlflow_fixture):
376383
assert metric.value == values[i].item()
377384

378385
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
379-
@pytest.mark.skipif(not _has_tv, reason="torchvision not installed")
386+
@pytest.mark.skipif(not _has_mp4, reason="no mp4 video backend available")
380387
def test_log_video(self, steps, mlflow_fixture):
381388

382389
logger, client = mlflow_fixture

0 commit comments

Comments
 (0)