-
Notifications
You must be signed in to change notification settings - Fork 64
fix: Solve CUDA AV1 decoding #448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
d92149c
b4825a7
7e6bc92
a29287c
5e82e48
70c1985
a9fa4bb
d49fde5
25437b0
c64b833
9a9b50f
e453891
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,11 +42,12 @@ git clone [email protected]:pytorch/torchcodec.git | |
cd torchcodec | ||
|
||
pip install -e ".[dev]" --no-build-isolation -vv | ||
# Or, for cuda support: ENABLE_CUDA=1 pip install -e ".[dev]" --no-build-isolation -vv | ||
``` | ||
|
||
### Running unit tests | ||
|
||
To run python tests run: | ||
To run python tests run (please make sure `torchvision` is installed): | ||
|
||
```bash | ||
pytest test -vvv | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -256,4 +256,36 @@ void convertAVFrameToDecodedOutputOnCuda( | |
<< " took: " << duration.count() << "us" << std::endl; | ||
} | ||
|
||
// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9 | ||
void forceCudaCodec( | ||
|
||
const torch::Device& device, | ||
AVCodecPtr* codec, | ||
const AVCodecID& codecId) { | ||
if (device.type() != torch::kCUDA) { | ||
|
||
return; | ||
} | ||
|
||
const AVCodec* c; | ||
|
||
void* i = NULL; | ||
bool found = false; | ||
|
||
while (!found && (c = av_codec_iterate(&i))) { | ||
const AVCodecHWConfig* config; | ||
|
||
if (c->id != codecId || !av_codec_is_decoder(c)) { | ||
continue; | ||
} | ||
|
||
for (int j = 0; config = avcodec_get_hw_config(c, j); j++) { | ||
if (config->device_type == AV_HWDEVICE_TYPE_CUDA) { | ||
found = true; | ||
} | ||
} | ||
} | ||
|
||
if (found) { | ||
*codec = c; | ||
} | ||
} | ||
|
||
} // namespace facebook::torchcodec |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -456,6 +456,12 @@ void VideoDecoder::addVideoStreamDecoder( | |
"Stream with index " + std::to_string(streamNumber) + | ||
" is not a video stream."); | ||
} | ||
|
||
if (options.device.type() == torch::kCUDA) { | ||
forceCudaCodec( | ||
options.device, &codec, streamInfo.stream->codecpar->codec_id); | ||
|
||
} | ||
|
||
AVCodecContext* codecContext = avcodec_alloc_context3(codec); | ||
codecContext->thread_count = options.ffmpegThreadCount.value_or(0); | ||
TORCH_CHECK(codecContext != nullptr); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,14 @@ | |
|
||
from torchcodec.decoders import _core, VideoDecoder | ||
|
||
from ..utils import assert_frames_equal, cpu_and_cuda, H265_VIDEO, in_fbcode, NASA_VIDEO | ||
from ..utils import ( | ||
assert_frames_equal, | ||
AV1_VIDEO, | ||
cpu_and_cuda, | ||
H265_VIDEO, | ||
in_fbcode, | ||
NASA_VIDEO, | ||
) | ||
|
||
|
||
class TestVideoDecoder: | ||
|
@@ -409,6 +416,17 @@ def test_get_frames_at_fails(self, device): | |
with pytest.raises(RuntimeError, match="Expected a value of type"): | ||
decoder.get_frames_at([0.3]) | ||
|
||
def test_get_frame_at_av1(self): | ||
# We don't parametrize with CUDA because the current GPUs on CI do not | ||
# support AV1: | ||
decoder = VideoDecoder(AV1_VIDEO.path, device="cpu") | ||
ref_frame11 = AV1_VIDEO.get_frame_data_by_index(10) | ||
|
||
ref_frame_info11 = AV1_VIDEO.get_frame_info(10) | ||
decoded_frame11 = decoder.get_frame_at(10) | ||
assert decoded_frame11.duration_seconds == ref_frame_info11.duration_seconds | ||
assert decoded_frame11.pts_seconds == ref_frame_info11.pts_seconds | ||
assert_frames_equal(decoded_frame11.data, ref_frame11.to(device="cpu")) | ||
|
||
@pytest.mark.parametrize("device", cpu_and_cuda()) | ||
def test_get_frame_played_at(self, device): | ||
decoder = VideoDecoder(NASA_VIDEO.path, device=device) | ||
|
Uh oh!
There was an error while loading. Please reload this page.