Skip to content

Commit 4cb09f3

Browse files
authored
Merge branch 'main' into threads1
2 parents 2d3fab9 + bc89ce1 commit 4cb09f3

File tree

8 files changed

+102
-52
lines changed

8 files changed

+102
-52
lines changed

.github/workflows/linux_cuda_wheel.yaml

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,20 +82,25 @@ jobs:
8282
name: pytorch_torchcodec__3.9_cu${{ env.cuda_version_without_periods }}_x86_64
8383
path: pytorch/torchcodec/dist/
8484
- name: Setup miniconda using test-infra
85-
uses: ahmadsharif1/test-infra/.github/actions/setup-miniconda@14bc3c29f88d13b0237ab4ddf00aa409e45ade40
85+
uses: pytorch/test-infra/.github/actions/setup-miniconda@main
8686
with:
8787
python-version: ${{ matrix.python-version }}
88-
default-packages: "conda-forge::ffmpeg=${{ matrix.ffmpeg-version-for-tests }}"
88+
#
89+
# For some reason nvidia::libnpp=12.4 doesn't install but nvidia/label/cuda-12.4.0::libnpp does.
90+
# So we use the latter convention for libnpp.
91+
# We install conda packages at the start because otherwise conda may have conflicts with dependencies.
92+
default-packages: "nvidia/label/cuda-${{ matrix.cuda-version }}.0::libnpp nvidia::cuda-nvrtc=${{ matrix.cuda-version }} nvidia::cuda-toolkit=${{ matrix.cuda-version }} nvidia::cuda-cudart=${{ matrix.cuda-version }} nvidia::cuda-driver-dev=${{ matrix.cuda-version }} conda-forge::ffmpeg=${{ matrix.ffmpeg-version-for-tests }}"
8993
- name: Check env
9094
run: |
9195
${CONDA_RUN} env
9296
${CONDA_RUN} conda info
9397
${CONDA_RUN} nvidia-smi
98+
${CONDA_RUN} conda list
9499
- name: Update pip
95100
run: ${CONDA_RUN} python -m pip install --upgrade pip
96101
- name: Install PyTorch
97102
run: |
98-
${CONDA_RUN} python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu${{ env.cuda_version_without_periods }}
103+
${CONDA_RUN} python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu${{ env.cuda_version_without_periods }}
99104
${CONDA_RUN} python -c 'import torch; print(f"{torch.__version__}"); print(f"{torch.__file__}"); print(f"{torch.cuda.is_available()=}")'
100105
- name: Install torchcodec from the wheel
101106
run: |
@@ -106,14 +111,8 @@ jobs:
106111
- name: Check out repo
107112
uses: actions/checkout@v3
108113

109-
- name: Install cuda runtime dependencies
110-
run: |
111-
# For some reason nvidia::libnpp=12.4 doesn't install but nvidia/label/cuda-12.4.0::libnpp does.
112-
# So we use the latter convention for libnpp.
113-
${CONDA_RUN} conda install --yes nvidia/label/cuda-${{ matrix.cuda-version }}.0::libnpp nvidia::cuda-nvrtc=${{ matrix.cuda-version }} nvidia::cuda-toolkit=${{ matrix.cuda-version }} nvidia::cuda-cudart=${{ matrix.cuda-version }} nvidia::cuda-driver-dev=${{ matrix.cuda-version }}
114114
- name: Install test dependencies
115115
run: |
116-
${CONDA_RUN} python -m pip install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
117116
# Ideally we would find a way to get those dependencies from pyproject.toml
118117
${CONDA_RUN} python -m pip install numpy pytest pillow
119118

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ decoder.get_frame_at(len(decoder) - 1)
6565
# pts_seconds: 9.960000038146973
6666
# duration_seconds: 0.03999999910593033
6767

68-
decoder.get_frames_at(start=10, stop=30, step=5)
68+
decoder.get_frames_in_range(start=10, stop=30, step=5)
6969
# FrameBatch:
7070
# data (shape): torch.Size([4, 3, 400, 640])
7171
# pts_seconds: tensor([0.4000, 0.6000, 0.8000, 1.0000])

benchmarks/samplers/benchmark_samplers.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,22 @@ def bench(f, *args, num_exp=100, warmup=0, **kwargs):
1616
for _ in range(warmup):
1717
f(*args, **kwargs)
1818

19+
num_frames = None
1920
times = []
2021
for _ in range(num_exp):
2122
start = perf_counter_ns()
22-
f(*args, **kwargs)
23+
clips = f(*args, **kwargs)
2324
end = perf_counter_ns()
2425
times.append(end - start)
25-
return torch.tensor(times).float()
26+
num_frames = (
27+
clips.data.shape[0] * clips.data.shape[1]
28+
) # should be constant across calls
29+
return torch.tensor(times).float(), num_frames
30+
2631

32+
def report_stats(times, num_frames, unit="ms"):
33+
fps = num_frames * 1e9 / torch.median(times)
2734

28-
def report_stats(times, unit="ms"):
2935
mul = {
3036
"ns": 1,
3137
"µs": 1e-3,
@@ -35,13 +41,13 @@ def report_stats(times, unit="ms"):
3541
times = times * mul
3642
std = times.std().item()
3743
med = times.median().item()
38-
print(f"{med = :.2f}{unit} +- {std:.2f}")
39-
return med
44+
print(f"{med = :.2f}{unit} +- {std:.2f} med fps = {fps:.1f}")
45+
return med, fps
4046

4147

4248
def sample(sampler, **kwargs):
4349
decoder = VideoDecoder(VIDEO_PATH)
44-
sampler(
50+
return sampler(
4551
decoder,
4652
num_frames_per_clip=10,
4753
**kwargs,
@@ -56,34 +62,34 @@ def sample(sampler, **kwargs):
5662
print(f"{num_clips = }")
5763

5864
print("clips_at_random_indices ", end="")
59-
times = bench(
65+
times, num_frames = bench(
6066
sample, clips_at_random_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2
6167
)
62-
report_stats(times, unit="ms")
68+
report_stats(times, num_frames, unit="ms")
6369

6470
print("clips_at_regular_indices ", end="")
65-
times = bench(
71+
times, num_frames = bench(
6672
sample, clips_at_regular_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2
6773
)
68-
report_stats(times, unit="ms")
74+
report_stats(times, num_frames, unit="ms")
6975

7076
print("clips_at_random_timestamps ", end="")
71-
times = bench(
77+
times, num_frames = bench(
7278
sample,
7379
clips_at_random_timestamps,
7480
num_clips=num_clips,
7581
num_exp=NUM_EXP,
7682
warmup=2,
7783
)
78-
report_stats(times, unit="ms")
84+
report_stats(times, num_frames, unit="ms")
7985

8086
print("clips_at_regular_timestamps ", end="")
8187
seconds_between_clip_starts = 13 / num_clips # approximate. video is 13s long
82-
times = bench(
88+
times, num_frames = bench(
8389
sample,
8490
clips_at_regular_timestamps,
8591
seconds_between_clip_starts=seconds_between_clip_starts,
8692
num_exp=NUM_EXP,
8793
warmup=2,
8894
)
89-
report_stats(times, unit="ms")
95+
report_stats(times, num_frames, unit="ms")

examples/basic_example.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
120120
# their :term:`pts` (Presentation Time Stamp), and their duration.
121121
# This can be achieved using the
122122
# :meth:`~torchcodec.decoders.VideoDecoder.get_frame_at` and
123-
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_at` methods, which
123+
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_in_range` methods, which
124124
# will return a :class:`~torchcodec.Frame` and
125125
# :class:`~torchcodec.FrameBatch` objects respectively.
126126

@@ -129,7 +129,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
129129
print(last_frame)
130130

131131
# %%
132-
middle_frames = decoder.get_frames_at(start=10, stop=20, step=2)
132+
middle_frames = decoder.get_frames_in_range(start=10, stop=20, step=2)
133133
print(f"{type(middle_frames) = }")
134134
print(middle_frames)
135135

@@ -152,7 +152,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
152152
# So far, we have retrieved frames based on their index. We can also retrieve
153153
# frames based on *when* they are displayed with
154154
# :meth:`~torchcodec.decoders.VideoDecoder.get_frame_displayed_at` and
155-
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_displayed_at`, which
155+
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_displayed_in_range`, which
156156
# also returns :class:`~torchcodec.Frame` and :class:`~torchcodec.FrameBatch`
157157
# respectively.
158158

@@ -161,7 +161,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
161161
print(frame_at_2_seconds)
162162

163163
# %%
164-
first_two_seconds = decoder.get_frames_displayed_at(
164+
first_two_seconds = decoder.get_frames_displayed_in_range(
165165
start_seconds=0,
166166
stop_seconds=2,
167167
)

src/torchcodec/decoders/_video_decoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def get_frame_at(self, index: int) -> Frame:
190190
duration_seconds=duration_seconds.item(),
191191
)
192192

193-
def get_frames_at(self, start: int, stop: int, step: int = 1) -> FrameBatch:
193+
def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatch:
194194
"""Return multiple frames at the given index range.
195195
196196
Frames are in [start, stop).
@@ -247,7 +247,7 @@ def get_frame_displayed_at(self, seconds: float) -> Frame:
247247
duration_seconds=duration_seconds.item(),
248248
)
249249

250-
def get_frames_displayed_at(
250+
def get_frames_displayed_in_range(
251251
self, start_seconds: float, stop_seconds: float
252252
) -> FrameBatch:
253253
"""Returns multiple frames in the given range.

src/torchcodec/samplers/_time_based.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _validate_sampling_range_time_based(
7171
if sampling_range_start is None:
7272
sampling_range_start = begin_stream_seconds
7373
else:
74-
if sampling_range_start <= begin_stream_seconds:
74+
if sampling_range_start < begin_stream_seconds:
7575
raise ValueError(
7676
f"sampling_range_start ({sampling_range_start}) must be at least {begin_stream_seconds}"
7777
)

test/decoders/test_video_decoder.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -367,14 +367,14 @@ def test_get_frame_displayed_at_fails(self):
367367
frame = decoder.get_frame_displayed_at(100.0) # noqa
368368

369369
@pytest.mark.parametrize("stream_index", [0, 3, None])
370-
def test_get_frames_at(self, stream_index):
370+
def test_get_frames_in_range(self, stream_index):
371371
decoder = VideoDecoder(NASA_VIDEO.path, stream_index=stream_index)
372372

373373
# test degenerate case where we only actually get 1 frame
374374
ref_frames9 = NASA_VIDEO.get_frame_data_by_range(
375375
start=9, stop=10, stream_index=stream_index
376376
)
377-
frames9 = decoder.get_frames_at(start=9, stop=10)
377+
frames9 = decoder.get_frames_in_range(start=9, stop=10)
378378

379379
assert_tensor_equal(ref_frames9, frames9.data)
380380
assert frames9.pts_seconds[0].item() == pytest.approx(
@@ -390,7 +390,7 @@ def test_get_frames_at(self, stream_index):
390390
ref_frames0_9 = NASA_VIDEO.get_frame_data_by_range(
391391
start=0, stop=10, stream_index=stream_index
392392
)
393-
frames0_9 = decoder.get_frames_at(start=0, stop=10)
393+
frames0_9 = decoder.get_frames_in_range(start=0, stop=10)
394394
assert frames0_9.data.shape == torch.Size(
395395
[
396396
10,
@@ -413,7 +413,7 @@ def test_get_frames_at(self, stream_index):
413413
ref_frames0_8_2 = NASA_VIDEO.get_frame_data_by_range(
414414
start=0, stop=10, step=2, stream_index=stream_index
415415
)
416-
frames0_8_2 = decoder.get_frames_at(start=0, stop=10, step=2)
416+
frames0_8_2 = decoder.get_frames_in_range(start=0, stop=10, step=2)
417417
assert frames0_8_2.data.shape == torch.Size(
418418
[
419419
5,
@@ -435,13 +435,13 @@ def test_get_frames_at(self, stream_index):
435435
)
436436

437437
# test numpy.int64 for indices
438-
frames0_8_2 = decoder.get_frames_at(
438+
frames0_8_2 = decoder.get_frames_in_range(
439439
start=numpy.int64(0), stop=numpy.int64(10), step=numpy.int64(2)
440440
)
441441
assert_tensor_equal(ref_frames0_8_2, frames0_8_2.data)
442442

443443
# an empty range is valid!
444-
empty_frames = decoder.get_frames_at(5, 5)
444+
empty_frames = decoder.get_frames_in_range(5, 5)
445445
assert_tensor_equal(
446446
empty_frames.data,
447447
NASA_VIDEO.get_empty_chw_tensor(stream_index=stream_index),
@@ -457,10 +457,10 @@ def test_get_frames_at(self, stream_index):
457457
(
458458
lambda decoder: decoder[0],
459459
lambda decoder: decoder.get_frame_at(0).data,
460-
lambda decoder: decoder.get_frames_at(0, 4).data,
460+
lambda decoder: decoder.get_frames_in_range(0, 4).data,
461461
lambda decoder: decoder.get_frame_displayed_at(0).data,
462462
# TODO: uncomment once D60001893 lands
463-
# lambda decoder: decoder.get_frames_displayed_at(0, 1).data,
463+
# lambda decoder: decoder.get_frames_displayed_in_range(0, 1).data,
464464
),
465465
)
466466
def test_dimension_order(self, dimension_order, frame_getter):
@@ -488,7 +488,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
488488
decoder = VideoDecoder(NASA_VIDEO.path, stream_index=stream_index)
489489

490490
# Note that we are comparing the results of VideoDecoder's method:
491-
# get_frames_displayed_at()
491+
# get_frames_displayed_in_range()
492492
# With the testing framework's method:
493493
# get_frame_data_by_range()
494494
# That is, we are testing the correctness of a pts-based range against an index-
@@ -505,7 +505,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
505505
# value for frame 5 that we have access to on the Python side is slightly less than the pts
506506
# value on the C++ side. This test still produces the correct result because a slightly
507507
# less value still falls into the correct window.
508-
frames0_4 = decoder.get_frames_displayed_at(
508+
frames0_4 = decoder.get_frames_displayed_in_range(
509509
decoder.get_frame_at(0).pts_seconds, decoder.get_frame_at(5).pts_seconds
510510
)
511511
assert_tensor_equal(
@@ -514,15 +514,15 @@ def test_get_frames_by_pts_in_range(self, stream_index):
514514
)
515515

516516
# Range where the stop seconds is about halfway between pts values for two frames.
517-
also_frames0_4 = decoder.get_frames_displayed_at(
517+
also_frames0_4 = decoder.get_frames_displayed_in_range(
518518
decoder.get_frame_at(0).pts_seconds,
519519
decoder.get_frame_at(4).pts_seconds + HALF_DURATION,
520520
)
521521
assert_tensor_equal(also_frames0_4.data, frames0_4.data)
522522

523523
# Again, the intention here is to provide the exact values we care about. In practice, our
524524
# pts values are slightly smaller, so we nudge the start upwards.
525-
frames5_9 = decoder.get_frames_displayed_at(
525+
frames5_9 = decoder.get_frames_displayed_in_range(
526526
decoder.get_frame_at(5).pts_seconds,
527527
decoder.get_frame_at(10).pts_seconds,
528528
)
@@ -534,7 +534,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
534534
# Range where we provide start_seconds and stop_seconds that are different, but
535535
# also should land in the same window of time between two frame's pts values. As
536536
# a result, we should only get back one frame.
537-
frame6 = decoder.get_frames_displayed_at(
537+
frame6 = decoder.get_frames_displayed_in_range(
538538
decoder.get_frame_at(6).pts_seconds,
539539
decoder.get_frame_at(6).pts_seconds + HALF_DURATION,
540540
)
@@ -544,7 +544,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
544544
)
545545

546546
# Very small range that falls in the same frame.
547-
frame35 = decoder.get_frames_displayed_at(
547+
frame35 = decoder.get_frames_displayed_in_range(
548548
decoder.get_frame_at(35).pts_seconds,
549549
decoder.get_frame_at(35).pts_seconds + 1e-10,
550550
)
@@ -556,7 +556,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
556556
# Single frame where the start seconds is before frame i's pts, and the stop is
557557
# after frame i's pts, but before frame i+1's pts. In that scenario, we expect
558558
# to see frames i-1 and i.
559-
frames7_8 = decoder.get_frames_displayed_at(
559+
frames7_8 = decoder.get_frames_displayed_in_range(
560560
NASA_VIDEO.get_frame_info(8, stream_index=stream_index).pts_seconds
561561
- HALF_DURATION,
562562
NASA_VIDEO.get_frame_info(8, stream_index=stream_index).pts_seconds
@@ -568,7 +568,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
568568
)
569569

570570
# Start and stop seconds are the same value, which should not return a frame.
571-
empty_frame = decoder.get_frames_displayed_at(
571+
empty_frame = decoder.get_frames_displayed_in_range(
572572
NASA_VIDEO.get_frame_info(4, stream_index=stream_index).pts_seconds,
573573
NASA_VIDEO.get_frame_info(4, stream_index=stream_index).pts_seconds,
574574
)
@@ -584,7 +584,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
584584
)
585585

586586
# Start and stop seconds land within the first frame.
587-
frame0 = decoder.get_frames_displayed_at(
587+
frame0 = decoder.get_frames_displayed_in_range(
588588
NASA_VIDEO.get_frame_info(0, stream_index=stream_index).pts_seconds,
589589
NASA_VIDEO.get_frame_info(0, stream_index=stream_index).pts_seconds
590590
+ HALF_DURATION,
@@ -596,7 +596,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
596596

597597
# We should be able to get all frames by giving the beginning and ending time
598598
# for the stream.
599-
all_frames = decoder.get_frames_displayed_at(
599+
all_frames = decoder.get_frames_displayed_in_range(
600600
decoder.metadata.begin_stream_seconds, decoder.metadata.end_stream_seconds
601601
)
602602
assert_tensor_equal(all_frames.data, decoder[:])
@@ -605,13 +605,13 @@ def test_get_frames_by_pts_in_range_fails(self):
605605
decoder = VideoDecoder(NASA_VIDEO.path)
606606

607607
with pytest.raises(ValueError, match="Invalid start seconds"):
608-
frame = decoder.get_frames_displayed_at(100.0, 1.0) # noqa
608+
frame = decoder.get_frames_displayed_in_range(100.0, 1.0) # noqa
609609

610610
with pytest.raises(ValueError, match="Invalid start seconds"):
611-
frame = decoder.get_frames_displayed_at(20, 23) # noqa
611+
frame = decoder.get_frames_displayed_in_range(20, 23) # noqa
612612

613613
with pytest.raises(ValueError, match="Invalid stop seconds"):
614-
frame = decoder.get_frames_displayed_at(0, 23) # noqa
614+
frame = decoder.get_frames_displayed_in_range(0, 23) # noqa
615615

616616

617617
if __name__ == "__main__":

0 commit comments

Comments
 (0)