Skip to content

Commit 6780e32

Browse files
authored
Merge branch 'main' into cuda11
2 parents 2730ca3 + bc89ce1 commit 6780e32

File tree

8 files changed

+102
-52
lines changed

8 files changed

+102
-52
lines changed

.github/workflows/linux_cuda_wheel.yaml

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,20 +82,25 @@ jobs:
8282
name: pytorch_torchcodec__3.9_cu${{ env.cuda_version_without_periods }}_x86_64
8383
path: pytorch/torchcodec/dist/
8484
- name: Setup miniconda using test-infra
85-
uses: ahmadsharif1/test-infra/.github/actions/setup-miniconda@14bc3c29f88d13b0237ab4ddf00aa409e45ade40
85+
uses: pytorch/test-infra/.github/actions/setup-miniconda@main
8686
with:
8787
python-version: ${{ matrix.python-version }}
88-
default-packages: "conda-forge::ffmpeg=${{ matrix.ffmpeg-version-for-tests }}"
88+
#
89+
# For some reason nvidia::libnpp=12.4 doesn't install but nvidia/label/cuda-12.4.0::libnpp does.
90+
# So we use the latter convention for libnpp.
91+
# We install conda packages at the start because otherwise conda may have conflicts with dependencies.
92+
default-packages: "nvidia/label/cuda-${{ matrix.cuda-version }}.0::libnpp nvidia::cuda-nvrtc=${{ matrix.cuda-version }} nvidia::cuda-toolkit=${{ matrix.cuda-version }} nvidia::cuda-cudart=${{ matrix.cuda-version }} nvidia::cuda-driver-dev=${{ matrix.cuda-version }} conda-forge::ffmpeg=${{ matrix.ffmpeg-version-for-tests }}"
8993
- name: Check env
9094
run: |
9195
${CONDA_RUN} env
9296
${CONDA_RUN} conda info
9397
${CONDA_RUN} nvidia-smi
98+
${CONDA_RUN} conda list
9499
- name: Update pip
95100
run: ${CONDA_RUN} python -m pip install --upgrade pip
96101
- name: Install PyTorch
97102
run: |
98-
${CONDA_RUN} python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu${{ env.cuda_version_without_periods }}
103+
${CONDA_RUN} python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu${{ env.cuda_version_without_periods }}
99104
${CONDA_RUN} python -c 'import torch; print(f"{torch.__version__}"); print(f"{torch.__file__}"); print(f"{torch.cuda.is_available()=}")'
100105
- name: Install torchcodec from the wheel
101106
run: |
@@ -106,14 +111,8 @@ jobs:
106111
- name: Check out repo
107112
uses: actions/checkout@v3
108113

109-
- name: Install cuda runtime dependencies
110-
run: |
111-
# For some reason nvidia::libnpp=12.4 doesn't install but nvidia/label/cuda-12.4.0::libnpp does.
112-
# So we use the latter convention for libnpp.
113-
${CONDA_RUN} conda install --yes nvidia/label/cuda-${{ matrix.cuda-version }}.0::libnpp nvidia::cuda-nvrtc=${{ matrix.cuda-version }} nvidia::cuda-toolkit=${{ matrix.cuda-version }} nvidia::cuda-cudart=${{ matrix.cuda-version }} nvidia::cuda-driver-dev=${{ matrix.cuda-version }}
114114
- name: Install test dependencies
115115
run: |
116-
${CONDA_RUN} python -m pip install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
117116
# Ideally we would find a way to get those dependencies from pyproject.toml
118117
${CONDA_RUN} python -m pip install numpy pytest pillow
119118

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ decoder.get_frame_at(len(decoder) - 1)
6565
# pts_seconds: 9.960000038146973
6666
# duration_seconds: 0.03999999910593033
6767

68-
decoder.get_frames_at(start=10, stop=30, step=5)
68+
decoder.get_frames_in_range(start=10, stop=30, step=5)
6969
# FrameBatch:
7070
# data (shape): torch.Size([4, 3, 400, 640])
7171
# pts_seconds: tensor([0.4000, 0.6000, 0.8000, 1.0000])

benchmarks/samplers/benchmark_samplers.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,22 @@ def bench(f, *args, num_exp=100, warmup=0, **kwargs):
1616
for _ in range(warmup):
1717
f(*args, **kwargs)
1818

19+
num_frames = None
1920
times = []
2021
for _ in range(num_exp):
2122
start = perf_counter_ns()
22-
f(*args, **kwargs)
23+
clips = f(*args, **kwargs)
2324
end = perf_counter_ns()
2425
times.append(end - start)
25-
return torch.tensor(times).float()
26+
num_frames = (
27+
clips.data.shape[0] * clips.data.shape[1]
28+
) # should be constant across calls
29+
return torch.tensor(times).float(), num_frames
30+
2631

32+
def report_stats(times, num_frames, unit="ms"):
33+
fps = num_frames * 1e9 / torch.median(times)
2734

28-
def report_stats(times, unit="ms"):
2935
mul = {
3036
"ns": 1,
3137
"µs": 1e-3,
@@ -35,13 +41,13 @@ def report_stats(times, unit="ms"):
3541
times = times * mul
3642
std = times.std().item()
3743
med = times.median().item()
38-
print(f"{med = :.2f}{unit} +- {std:.2f}")
39-
return med
44+
print(f"{med = :.2f}{unit} +- {std:.2f} med fps = {fps:.1f}")
45+
return med, fps
4046

4147

4248
def sample(sampler, **kwargs):
4349
decoder = VideoDecoder(VIDEO_PATH)
44-
sampler(
50+
return sampler(
4551
decoder,
4652
num_frames_per_clip=10,
4753
**kwargs,
@@ -56,34 +62,34 @@ def sample(sampler, **kwargs):
5662
print(f"{num_clips = }")
5763

5864
print("clips_at_random_indices ", end="")
59-
times = bench(
65+
times, num_frames = bench(
6066
sample, clips_at_random_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2
6167
)
62-
report_stats(times, unit="ms")
68+
report_stats(times, num_frames, unit="ms")
6369

6470
print("clips_at_regular_indices ", end="")
65-
times = bench(
71+
times, num_frames = bench(
6672
sample, clips_at_regular_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2
6773
)
68-
report_stats(times, unit="ms")
74+
report_stats(times, num_frames, unit="ms")
6975

7076
print("clips_at_random_timestamps ", end="")
71-
times = bench(
77+
times, num_frames = bench(
7278
sample,
7379
clips_at_random_timestamps,
7480
num_clips=num_clips,
7581
num_exp=NUM_EXP,
7682
warmup=2,
7783
)
78-
report_stats(times, unit="ms")
84+
report_stats(times, num_frames, unit="ms")
7985

8086
print("clips_at_regular_timestamps ", end="")
8187
seconds_between_clip_starts = 13 / num_clips # approximate. video is 13s long
82-
times = bench(
88+
times, num_frames = bench(
8389
sample,
8490
clips_at_regular_timestamps,
8591
seconds_between_clip_starts=seconds_between_clip_starts,
8692
num_exp=NUM_EXP,
8793
warmup=2,
8894
)
89-
report_stats(times, unit="ms")
95+
report_stats(times, num_frames, unit="ms")

examples/basic_example.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
120120
# their :term:`pts` (Presentation Time Stamp), and their duration.
121121
# This can be achieved using the
122122
# :meth:`~torchcodec.decoders.VideoDecoder.get_frame_at` and
123-
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_at` methods, which
123+
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_in_range` methods, which
124124
# will return a :class:`~torchcodec.Frame` and
125125
# :class:`~torchcodec.FrameBatch` objects respectively.
126126

@@ -129,7 +129,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
129129
print(last_frame)
130130

131131
# %%
132-
middle_frames = decoder.get_frames_at(start=10, stop=20, step=2)
132+
middle_frames = decoder.get_frames_in_range(start=10, stop=20, step=2)
133133
print(f"{type(middle_frames) = }")
134134
print(middle_frames)
135135

@@ -152,7 +152,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
152152
# So far, we have retrieved frames based on their index. We can also retrieve
153153
# frames based on *when* they are displayed with
154154
# :meth:`~torchcodec.decoders.VideoDecoder.get_frame_displayed_at` and
155-
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_displayed_at`, which
155+
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_displayed_in_range`, which
156156
# also returns :class:`~torchcodec.Frame` and :class:`~torchcodec.FrameBatch`
157157
# respectively.
158158

@@ -161,7 +161,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
161161
print(frame_at_2_seconds)
162162

163163
# %%
164-
first_two_seconds = decoder.get_frames_displayed_at(
164+
first_two_seconds = decoder.get_frames_displayed_in_range(
165165
start_seconds=0,
166166
stop_seconds=2,
167167
)

src/torchcodec/decoders/_video_decoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def get_frame_at(self, index: int) -> Frame:
181181
duration_seconds=duration_seconds.item(),
182182
)
183183

184-
def get_frames_at(self, start: int, stop: int, step: int = 1) -> FrameBatch:
184+
def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatch:
185185
"""Return multiple frames at the given index range.
186186
187187
Frames are in [start, stop).
@@ -238,7 +238,7 @@ def get_frame_displayed_at(self, seconds: float) -> Frame:
238238
duration_seconds=duration_seconds.item(),
239239
)
240240

241-
def get_frames_displayed_at(
241+
def get_frames_displayed_in_range(
242242
self, start_seconds: float, stop_seconds: float
243243
) -> FrameBatch:
244244
"""Returns multiple frames in the given range.

src/torchcodec/samplers/_time_based.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _validate_sampling_range_time_based(
7171
if sampling_range_start is None:
7272
sampling_range_start = begin_stream_seconds
7373
else:
74-
if sampling_range_start <= begin_stream_seconds:
74+
if sampling_range_start < begin_stream_seconds:
7575
raise ValueError(
7676
f"sampling_range_start ({sampling_range_start}) must be at least {begin_stream_seconds}"
7777
)

test/decoders/test_video_decoder.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -366,14 +366,14 @@ def test_get_frame_displayed_at_fails(self):
366366
frame = decoder.get_frame_displayed_at(100.0) # noqa
367367

368368
@pytest.mark.parametrize("stream_index", [0, 3, None])
369-
def test_get_frames_at(self, stream_index):
369+
def test_get_frames_in_range(self, stream_index):
370370
decoder = VideoDecoder(NASA_VIDEO.path, stream_index=stream_index)
371371

372372
# test degenerate case where we only actually get 1 frame
373373
ref_frames9 = NASA_VIDEO.get_frame_data_by_range(
374374
start=9, stop=10, stream_index=stream_index
375375
)
376-
frames9 = decoder.get_frames_at(start=9, stop=10)
376+
frames9 = decoder.get_frames_in_range(start=9, stop=10)
377377

378378
assert_tensor_equal(ref_frames9, frames9.data)
379379
assert frames9.pts_seconds[0].item() == pytest.approx(
@@ -389,7 +389,7 @@ def test_get_frames_at(self, stream_index):
389389
ref_frames0_9 = NASA_VIDEO.get_frame_data_by_range(
390390
start=0, stop=10, stream_index=stream_index
391391
)
392-
frames0_9 = decoder.get_frames_at(start=0, stop=10)
392+
frames0_9 = decoder.get_frames_in_range(start=0, stop=10)
393393
assert frames0_9.data.shape == torch.Size(
394394
[
395395
10,
@@ -412,7 +412,7 @@ def test_get_frames_at(self, stream_index):
412412
ref_frames0_8_2 = NASA_VIDEO.get_frame_data_by_range(
413413
start=0, stop=10, step=2, stream_index=stream_index
414414
)
415-
frames0_8_2 = decoder.get_frames_at(start=0, stop=10, step=2)
415+
frames0_8_2 = decoder.get_frames_in_range(start=0, stop=10, step=2)
416416
assert frames0_8_2.data.shape == torch.Size(
417417
[
418418
5,
@@ -434,13 +434,13 @@ def test_get_frames_at(self, stream_index):
434434
)
435435

436436
# test numpy.int64 for indices
437-
frames0_8_2 = decoder.get_frames_at(
437+
frames0_8_2 = decoder.get_frames_in_range(
438438
start=numpy.int64(0), stop=numpy.int64(10), step=numpy.int64(2)
439439
)
440440
assert_tensor_equal(ref_frames0_8_2, frames0_8_2.data)
441441

442442
# an empty range is valid!
443-
empty_frames = decoder.get_frames_at(5, 5)
443+
empty_frames = decoder.get_frames_in_range(5, 5)
444444
assert_tensor_equal(
445445
empty_frames.data,
446446
NASA_VIDEO.get_empty_chw_tensor(stream_index=stream_index),
@@ -456,10 +456,10 @@ def test_get_frames_at(self, stream_index):
456456
(
457457
lambda decoder: decoder[0],
458458
lambda decoder: decoder.get_frame_at(0).data,
459-
lambda decoder: decoder.get_frames_at(0, 4).data,
459+
lambda decoder: decoder.get_frames_in_range(0, 4).data,
460460
lambda decoder: decoder.get_frame_displayed_at(0).data,
461461
# TODO: uncomment once D60001893 lands
462-
# lambda decoder: decoder.get_frames_displayed_at(0, 1).data,
462+
# lambda decoder: decoder.get_frames_displayed_in_range(0, 1).data,
463463
),
464464
)
465465
def test_dimension_order(self, dimension_order, frame_getter):
@@ -487,7 +487,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
487487
decoder = VideoDecoder(NASA_VIDEO.path, stream_index=stream_index)
488488

489489
# Note that we are comparing the results of VideoDecoder's method:
490-
# get_frames_displayed_at()
490+
# get_frames_displayed_in_range()
491491
# With the testing framework's method:
492492
# get_frame_data_by_range()
493493
# That is, we are testing the correctness of a pts-based range against an index-
@@ -504,7 +504,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
504504
# value for frame 5 that we have access to on the Python side is slightly less than the pts
505505
# value on the C++ side. This test still produces the correct result because a slightly
506506
# less value still falls into the correct window.
507-
frames0_4 = decoder.get_frames_displayed_at(
507+
frames0_4 = decoder.get_frames_displayed_in_range(
508508
decoder.get_frame_at(0).pts_seconds, decoder.get_frame_at(5).pts_seconds
509509
)
510510
assert_tensor_equal(
@@ -513,15 +513,15 @@ def test_get_frames_by_pts_in_range(self, stream_index):
513513
)
514514

515515
# Range where the stop seconds is about halfway between pts values for two frames.
516-
also_frames0_4 = decoder.get_frames_displayed_at(
516+
also_frames0_4 = decoder.get_frames_displayed_in_range(
517517
decoder.get_frame_at(0).pts_seconds,
518518
decoder.get_frame_at(4).pts_seconds + HALF_DURATION,
519519
)
520520
assert_tensor_equal(also_frames0_4.data, frames0_4.data)
521521

522522
# Again, the intention here is to provide the exact values we care about. In practice, our
523523
# pts values are slightly smaller, so we nudge the start upwards.
524-
frames5_9 = decoder.get_frames_displayed_at(
524+
frames5_9 = decoder.get_frames_displayed_in_range(
525525
decoder.get_frame_at(5).pts_seconds,
526526
decoder.get_frame_at(10).pts_seconds,
527527
)
@@ -533,7 +533,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
533533
# Range where we provide start_seconds and stop_seconds that are different, but
534534
# also should land in the same window of time between two frame's pts values. As
535535
# a result, we should only get back one frame.
536-
frame6 = decoder.get_frames_displayed_at(
536+
frame6 = decoder.get_frames_displayed_in_range(
537537
decoder.get_frame_at(6).pts_seconds,
538538
decoder.get_frame_at(6).pts_seconds + HALF_DURATION,
539539
)
@@ -543,7 +543,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
543543
)
544544

545545
# Very small range that falls in the same frame.
546-
frame35 = decoder.get_frames_displayed_at(
546+
frame35 = decoder.get_frames_displayed_in_range(
547547
decoder.get_frame_at(35).pts_seconds,
548548
decoder.get_frame_at(35).pts_seconds + 1e-10,
549549
)
@@ -555,7 +555,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
555555
# Single frame where the start seconds is before frame i's pts, and the stop is
556556
# after frame i's pts, but before frame i+1's pts. In that scenario, we expect
557557
# to see frames i-1 and i.
558-
frames7_8 = decoder.get_frames_displayed_at(
558+
frames7_8 = decoder.get_frames_displayed_in_range(
559559
NASA_VIDEO.get_frame_info(8, stream_index=stream_index).pts_seconds
560560
- HALF_DURATION,
561561
NASA_VIDEO.get_frame_info(8, stream_index=stream_index).pts_seconds
@@ -567,7 +567,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
567567
)
568568

569569
# Start and stop seconds are the same value, which should not return a frame.
570-
empty_frame = decoder.get_frames_displayed_at(
570+
empty_frame = decoder.get_frames_displayed_in_range(
571571
NASA_VIDEO.get_frame_info(4, stream_index=stream_index).pts_seconds,
572572
NASA_VIDEO.get_frame_info(4, stream_index=stream_index).pts_seconds,
573573
)
@@ -583,7 +583,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
583583
)
584584

585585
# Start and stop seconds land within the first frame.
586-
frame0 = decoder.get_frames_displayed_at(
586+
frame0 = decoder.get_frames_displayed_in_range(
587587
NASA_VIDEO.get_frame_info(0, stream_index=stream_index).pts_seconds,
588588
NASA_VIDEO.get_frame_info(0, stream_index=stream_index).pts_seconds
589589
+ HALF_DURATION,
@@ -595,7 +595,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
595595

596596
# We should be able to get all frames by giving the beginning and ending time
597597
# for the stream.
598-
all_frames = decoder.get_frames_displayed_at(
598+
all_frames = decoder.get_frames_displayed_in_range(
599599
decoder.metadata.begin_stream_seconds, decoder.metadata.end_stream_seconds
600600
)
601601
assert_tensor_equal(all_frames.data, decoder[:])
@@ -604,13 +604,13 @@ def test_get_frames_by_pts_in_range_fails(self):
604604
decoder = VideoDecoder(NASA_VIDEO.path)
605605

606606
with pytest.raises(ValueError, match="Invalid start seconds"):
607-
frame = decoder.get_frames_displayed_at(100.0, 1.0) # noqa
607+
frame = decoder.get_frames_displayed_in_range(100.0, 1.0) # noqa
608608

609609
with pytest.raises(ValueError, match="Invalid start seconds"):
610-
frame = decoder.get_frames_displayed_at(20, 23) # noqa
610+
frame = decoder.get_frames_displayed_in_range(20, 23) # noqa
611611

612612
with pytest.raises(ValueError, match="Invalid stop seconds"):
613-
frame = decoder.get_frames_displayed_at(0, 23) # noqa
613+
frame = decoder.get_frames_displayed_in_range(0, 23) # noqa
614614

615615

616616
if __name__ == "__main__":

0 commit comments

Comments
 (0)