-
Notifications
You must be signed in to change notification settings - Fork 74
Added doc for nvdec #335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added doc for nvdec #335
Changes from 72 commits
3ff6292
1fd5a10
fa3e3b9
36a5420
f49baca
f087a91
5092418
243e2ca
7c6c033
e40ec7a
bb4bff9
e8a5b07
c9d54a4
fb633e4
9e334cd
c107e02
885c43f
dd937c6
bab07db
60b06e1
904bfa3
75e76ee
16218ac
e8f0128
9c36f4e
2406435
7b78be3
20c6fba
7630fdd
37bfa5c
24f2843
4cb95a2
4055346
63bbb9e
51e2308
a926934
400001a
ccf95da
209e746
8d66147
0a8ae5f
8864b30
936cbd1
49197b5
8291aa6
4e10d0b
b90bc7f
2ae49ac
f0444d4
3d95977
5cbccd0
bf81cbe
0ca9469
64a9ebd
30d9be7
c91e73c
0f50210
f8d5e69
5a4291a
af3f684
891125b
9809feb
92e2aef
8d206f4
893c490
2a106ca
39f4606
a51dfbd
f29b05c
dfa9fcc
3003be2
015f355
bde6324
3304341
3d74528
f79dfa4
cc1c3a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,45 +5,113 @@ on: | |
| branches: [ main ] | ||
| pull_request: | ||
|
|
||
| permissions: | ||
| id-token: write | ||
| contents: write | ||
|
|
||
| defaults: | ||
| run: | ||
| shell: bash -l -eo pipefail {0} | ||
|
|
||
| jobs: | ||
| generate-matrix: | ||
| uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main | ||
| with: | ||
| package-type: wheel | ||
| os: linux | ||
| test-infra-repository: pytorch/test-infra | ||
| test-infra-ref: main | ||
| with-cpu: disable | ||
| with-xpu: disable | ||
| with-rocm: disable | ||
| with-cuda: enable | ||
| build-python-only: "disable" | ||
| build: | ||
| runs-on: ubuntu-latest | ||
| needs: generate-matrix | ||
| strategy: | ||
| fail-fast: false | ||
| name: Build and Upload wheel | ||
| uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main | ||
| with: | ||
| repository: pytorch/torchcodec | ||
| ref: "" | ||
| test-infra-repository: pytorch/test-infra | ||
| test-infra-ref: main | ||
| build-matrix: ${{ needs.generate-matrix.outputs.matrix }} | ||
| post-script: packaging/post_build_script.sh | ||
| smoke-test-script: packaging/fake_smoke_test.py | ||
| package-name: torchcodec | ||
| trigger-event: ${{ github.event_name }} | ||
| build-platform: "python-build-package" | ||
| build-command: "BUILD_AGAINST_ALL_FFMPEG_FROM_S3=1 ENABLE_CUDA=1 python -m build --wheel -vvv --no-isolation" | ||
|
|
||
| build-docs: | ||
| runs-on: linux.4xlarge.nvidia.gpu | ||
| strategy: | ||
| fail-fast: false | ||
| matrix: | ||
| # 3.9 corresponds to the minimum python version for which we build | ||
| # the wheel unless the label cliflow/binaries/all is present in the | ||
| # PR. | ||
| # For the actual release we should add that label and change this to | ||
| # include more python versions. | ||
| python-version: ['3.9'] | ||
| cuda-version: ['12.4'] | ||
| ffmpeg-version-for-tests: ['7'] | ||
| container: | ||
| image: "pytorch/manylinux-builder:cuda${{ matrix.cuda-version }}" | ||
| options: "--gpus all -e NVIDIA_DRIVER_CAPABILITIES=video,compute,utility" | ||
| needs: build | ||
| steps: | ||
| - name: Check out repo | ||
| uses: actions/checkout@v3 | ||
| - name: Setup conda env | ||
| uses: conda-incubator/setup-miniconda@v2 | ||
| - name: Setup env vars | ||
| run: | | ||
| cuda_version_without_periods=$(echo "${{ matrix.cuda-version }}" | sed 's/\.//g') | ||
| echo cuda_version_without_periods=${cuda_version_without_periods} >> $GITHUB_ENV | ||
| - uses: actions/download-artifact@v3 | ||
| with: | ||
| auto-update-conda: true | ||
| miniconda-version: "latest" | ||
| activate-environment: test | ||
| python-version: '3.12' | ||
| name: pytorch_torchcodec__3.9_cu${{ env.cuda_version_without_periods }}_x86_64 | ||
| path: pytorch/torchcodec/dist/ | ||
| - name: Setup miniconda using test-infra | ||
| uses: pytorch/test-infra/.github/actions/setup-miniconda@main | ||
| with: | ||
| python-version: ${{ matrix.python-version }} | ||
| # | ||
| # For some reason nvidia::libnpp=12.4 doesn't install but nvidia/label/cuda-12.4.0::libnpp does. | ||
| # So we use the latter convention for libnpp. | ||
| # We install conda packages at the start because otherwise conda may have conflicts with dependencies. | ||
| default-packages: "nvidia/label/cuda-${{ matrix.cuda-version }}.0::libnpp nvidia::cuda-nvrtc=${{ matrix.cuda-version }} nvidia::cuda-toolkit=${{ matrix.cuda-version }} nvidia::cuda-cudart=${{ matrix.cuda-version }} nvidia::cuda-driver-dev=${{ matrix.cuda-version }} conda-forge::ffmpeg=${{ matrix.ffmpeg-version-for-tests }}" | ||
| - name: Check env | ||
| run: | | ||
| ${CONDA_RUN} env | ||
| ${CONDA_RUN} conda info | ||
| ${CONDA_RUN} nvidia-smi | ||
| ${CONDA_RUN} conda list | ||
| - name: Assert ffmpeg exists | ||
| run: | | ||
| ${CONDA_RUN} ffmpeg -buildconf | ||
| - name: Update pip | ||
| run: python -m pip install --upgrade pip | ||
| - name: Install dependencies and FFmpeg | ||
| run: ${CONDA_RUN} python -m pip install --upgrade pip | ||
| - name: Install PyTorch | ||
| run: | | ||
| # TODO: torchvision and torchaudio shouldn't be needed. They were only added | ||
| # to silence an error as seen in https://github.com/pytorch/torchcodec/issues/203 | ||
| python -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu | ||
| conda install "ffmpeg=7.0.1" pkg-config -c conda-forge | ||
| ffmpeg -version | ||
| - name: Build and install torchcodec | ||
| ${CONDA_RUN} python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu${{ env.cuda_version_without_periods }} | ||
| ${CONDA_RUN} python -c 'import torch; print(f"{torch.__version__}"); print(f"{torch.__file__}"); print(f"{torch.cuda.is_available()=}")' | ||
| - name: Install torchcodec from the wheel | ||
| run: | | ||
| python -m pip install -e ".[dev]" --no-build-isolation -vvv | ||
| wheel_path=`find pytorch/torchcodec/dist -type f -name "*.whl"` | ||
| echo Installing $wheel_path | ||
| ${CONDA_RUN} python -m pip install $wheel_path -vvv | ||
|
|
||
| - name: Check out repo | ||
| uses: actions/checkout@v3 | ||
|
|
||
| - name: Install doc dependencies | ||
| run: | | ||
| cd docs | ||
| python -m pip install -r requirements.txt | ||
| ${CONDA_RUN} python -m pip install -r requirements.txt | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we understand why we need CONDA_RUN here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Everything is installed in a conda env. Without CONDA_RUN the |
||
| - name: Build docs | ||
| run: | | ||
| cd docs | ||
| make html | ||
| ${CONDA_RUN} make html | ||
| - uses: actions/upload-artifact@v3 | ||
| with: | ||
| name: Built-Docs | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,174 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| """ | ||
| Accelerated video decoding on GPUs with CUDA and NVDEC | ||
| ================================================================ | ||
| .. _ndecoderec_tutorial: | ||
ahmadsharif1 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| TorchCodec can use supported Nvidia hardware (see support matrix | ||
| `here <https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new>`_) to speed-up | ||
| video decoding. This is called "CUDA Decoding" and it uses Nvidia's | ||
| `NVDEC hardware decoder <https://developer.nvidia.com/video-codec-sdk>`_ | ||
| and CUDA kernels to respectively decompress and convert to RGB. | ||
| CUDA Decoding can be faster than CPU Decoding for the actual decoding step and also for | ||
| subsequent transform steps like scaling, cropping or rotating. This is because the decode step leaves | ||
| the decoded tensor in GPU memory so the GPU doesn't have to fetch from main memory before | ||
| running the transform steps. Encoded packets are often much smaller than decoded frames so | ||
| CUDA decoding also uses less PCI-e bandwidth. | ||
| CUDA Decoding can offer speed-up over CPU Decoding in a few scenarios: | ||
| #. You are decoding a large resolution video | ||
| #. You are decoding a large batch of videos that's saturating the CPU | ||
| #. You want to do whole-image transforms like scaling or convolutions on the decoded tensors | ||
| after decoding | ||
| #. Your CPU is saturated and you want to free it up for other work | ||
| Here are situations where CUDA Decoding may not make sense: | ||
| #. You want bit-exact results compared to CPU Decoding | ||
ahmadsharif1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| #. You have small resolution videos and the PCI-e transfer latency is large | ||
| #. Your GPU is already busy and CPU is not | ||
| It's best to experiment with CUDA Decoding to see if it improves your use-case. With | ||
| TorchCodec you can simply pass in a device parameter to the | ||
| :class:`~torchcodec.decoders.VideoDecoder` class to use CUDA Decoding. | ||
| In order to use CUDA Decoding will need the following installed in your environment: | ||
| #. An Nvidia GPU that supports decoding the video format you want to decode. See | ||
| the support matrix `here <https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new>`_ | ||
| #. `CUDA-enabled pytorch <https://pytorch.org/get-started/locally/>`_ | ||
| #. FFmpeg binaries that support NdecoderEC-enabled codecs | ||
ahmadsharif1 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| #. libnpp and nvrtc (these are usually installed when you install the full cuda-toolkit) | ||
| FFmpeg versions 5, 6 and 7 from conda-forge are built with NdecoderEC support and you can | ||
| install them with conda. For example, to install FFmpeg version 7: | ||
ahmadsharif1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| .. code-block:: bash | ||
| conda install ffmpeg=7 -c conda-forge | ||
| conda install libnpp cuda-nvrtc -c nvidia | ||
| """ | ||
|
|
||
| # %% | ||
| # Checking if Pytorch has CUDA enabled | ||
| # ------------------------------------- | ||
| # | ||
| # .. note:: | ||
| # | ||
| # This tutorial requires FFmpeg libraries compiled with CUDA support. | ||
| # | ||
| # | ||
| import torch | ||
|
|
||
| print(f"{torch.__version__=}") | ||
| print(f"{torch.cuda.is_available()=}") | ||
| print(f"{torch.cuda.get_device_properties(0)=}") | ||
|
|
||
|
|
||
| # %% | ||
| # Downloading the video | ||
| # ------------------------------------- | ||
| # | ||
| # We will use the following video which has the following properties; | ||
| # | ||
| # - Codec: H.264 | ||
| # - Resolution: 960x540 | ||
| # - FPS: 29.97 | ||
| # - Pixel format: YUV420P | ||
| # | ||
| # .. raw:: html | ||
| # | ||
| # <video style="max-width: 100%" controls> | ||
| # <source src="https://download.pytorch.org/torchaudio/tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4_small.mp4" type="video/mp4"> | ||
| # </video> | ||
| import urllib.request | ||
|
|
||
| video_file = "video.mp4" | ||
| urllib.request.urlretrieve( | ||
| "https://download.pytorch.org/torchaudio/tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4_small.mp4", | ||
| video_file, | ||
| ) | ||
|
|
||
|
|
||
| # %% | ||
| # CUDA Decoding using VideoDecoder | ||
| # ------------------------------------- | ||
| # | ||
| # To use CUDA decoder, you need to pass in a cuda device to the decoder. | ||
| # | ||
| from torchcodec.decoders import VideoDecoder | ||
|
|
||
| decoder = VideoDecoder(video_file, device="cuda") | ||
| frame = decoder[0] | ||
|
|
||
| # %% | ||
| # | ||
| # The video frames are decoded and returned as tensor of NCHW format. | ||
|
|
||
| print(frame.data.shape, frame.data.dtype) | ||
ahmadsharif1 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # %% | ||
| # | ||
| # The video frames are left on the GPU memory. | ||
|
|
||
| print(frame.data.device) | ||
|
|
||
|
|
||
| # %% | ||
| # Visualizing Frames | ||
| # ------------------------------------- | ||
| # | ||
| # Let's look at the frames decoded by CUDA decoder and compare them | ||
| # against equivalent results from the CPU decoders. | ||
| import matplotlib.pyplot as plt | ||
| from torchvision.transforms.v2.functional import to_pil_image | ||
ahmadsharif1 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def get_frames(timestamps: list[float], device: str): | ||
| decoder = VideoDecoder(video_file, device=device) | ||
| return [decoder.get_frame_played_at(ts).data for ts in timestamps] | ||
ahmadsharif1 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| timestamps = [12, 19, 45, 131, 180] | ||
| cpu_frames = get_frames(timestamps, device="cpu") | ||
| cuda_frames = get_frames(timestamps, device="cuda:0") | ||
|
|
||
|
|
||
| def plot_cpu_and_cuda_frames( | ||
| cpu_frames: list[torch.Tensor], cuda_frames: list[torch.Tensor] | ||
ahmadsharif1 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ): | ||
| n_rows = len(timestamps) | ||
| fig, axes = plt.subplots(n_rows, 2, figsize=[12.8, 16.0]) | ||
| for i in range(n_rows): | ||
| axes[i][0].imshow(to_pil_image(cpu_frames[i].to("cpu"))) | ||
| axes[i][1].imshow(to_pil_image(cuda_frames[i].to("cpu"))) | ||
|
|
||
| axes[0][0].set_title("CPU decoder", fontsize=24) | ||
| axes[0][1].set_title("CUDA decoder", fontsize=24) | ||
| plt.setp(axes, xticks=[], yticks=[]) | ||
| plt.tight_layout() | ||
|
|
||
|
|
||
| plot_cpu_and_cuda_frames(cpu_frames, cuda_frames) | ||
|
|
||
| # %% | ||
| # | ||
| # They look visually similar to the human eye but there may be subtle | ||
| # differences because CUDA math is not bit-exact with respect to CPU math. | ||
| # | ||
| first_cpu_frame = cpu_frames[0].data.to("cpu") | ||
| first_cuda_frame = cuda_frames[0].data.to("cpu") | ||
| frames_equal = torch.equal(first_cpu_frame, first_cuda_frame) | ||
| print(f"{frames_equal=}") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Instead of this binary indicator, we may want to print max abs diff and mean abs diff? We could even do it across all frames.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||

Uh oh!
There was an error while loading. Please reload this page.