Skip to content

Commit 79e5c87

Browse files
committed
comments, names, assert pix_fmt
1 parent dd04495 commit 79e5c87

File tree

6 files changed

+36
-29
lines changed

6 files changed

+36
-29
lines changed

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ const Npp32f defaultLimitedRangeRgbToNv12[3][4] = {
378378
{0.439f, -0.368f, -0.071f, 128.0f}};
379379
} // namespace
380380

381-
std::optional<UniqueAVFrame> CudaDeviceInterface::convertTensorToAVFrame(
381+
UniqueAVFrame CudaDeviceInterface::convertCUDATensorToAVFrameForEncoding(
382382
const torch::Tensor& tensor,
383383
int frameIndex,
384384
AVCodecContext* codecContext) {
@@ -440,7 +440,7 @@ std::optional<UniqueAVFrame> CudaDeviceInterface::convertTensorToAVFrame(
440440
// Allocates and initializes AVHWFramesContext, and sets pixel format fields
441441
// to enable encoding with CUDA device. The hw_frames_ctx field is needed by
442442
// FFmpeg to allocate frames on GPU's memory.
443-
void CudaDeviceInterface::setupHardwareFrameContext(
443+
void CudaDeviceInterface::setupHardwareFrameContextForEncoding(
444444
AVCodecContext* codecContext) {
445445
TORCH_CHECK(codecContext != nullptr, "codecContext is null");
446446
TORCH_CHECK(

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,13 @@ class CudaDeviceInterface : public DeviceInterface {
4141

4242
std::string getDetails() override;
4343

44-
std::optional<UniqueAVFrame> convertTensorToAVFrame(
44+
UniqueAVFrame convertCUDATensorToAVFrameForEncoding(
4545
const torch::Tensor& tensor,
4646
int frameIndex,
4747
AVCodecContext* codecContext) override;
4848

49-
void setupHardwareFrameContext(AVCodecContext* codecContext) override;
49+
void setupHardwareFrameContextForEncoding(
50+
AVCodecContext* codecContext) override;
5051

5152
private:
5253
// Our CUDA decoding code assumes NV12 format. In order to handle other

src/torchcodec/_core/DeviceInterface.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,16 +139,22 @@ class DeviceInterface {
139139
}
140140

141141
// Function used for video encoding, only implemented in CudaDeviceInterface.
142-
virtual std::optional<UniqueAVFrame> convertTensorToAVFrame(
142+
// It is here to isolate CUDA dependencies from CPU builds
143+
// TODO Video-Encoder: Reconsider using video encoding functions in device
144+
// interface
145+
virtual UniqueAVFrame convertCUDATensorToAVFrameForEncoding(
143146
[[maybe_unused]] const torch::Tensor& tensor,
144147
[[maybe_unused]] int frameIndex,
145148
[[maybe_unused]] AVCodecContext* codecContext) {
146-
return std::nullopt;
149+
TORCH_CHECK(false);
147150
}
148151

149152
// Function used for video encoding, only implemented in CudaDeviceInterface.
150-
virtual void setupHardwareFrameContext(
151-
[[maybe_unused]] AVCodecContext* codecContext) {}
153+
// It is here to isolate CUDA dependencies from CPU builds
154+
virtual void setupHardwareFrameContextForEncoding(
155+
[[maybe_unused]] AVCodecContext* codecContext) {
156+
TORCH_CHECK(false);
157+
}
152158

153159
protected:
154160
torch::Device device_;

src/torchcodec/_core/Encoder.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,8 @@ void VideoEncoder::initializeEncoder(
830830
// When frames are on a CUDA device, deviceInterface_ will be defined.
831831
if (frames_.device().is_cuda() && deviceInterface_) {
832832
deviceInterface_->registerHardwareDeviceWithCodec(avCodecContext_.get());
833-
deviceInterface_->setupHardwareFrameContext(avCodecContext_.get());
833+
deviceInterface_->setupHardwareFrameContextForEncoding(
834+
avCodecContext_.get());
834835
}
835836

836837
int status = avcodec_open2(avCodecContext_.get(), avCodec, &avCodecOptions);
@@ -875,15 +876,15 @@ void VideoEncoder::encode() {
875876
torch::Tensor currFrame = frames_[i];
876877
UniqueAVFrame avFrame;
877878
if (frames_.device().is_cuda() && deviceInterface_) {
878-
auto cudaFrame = deviceInterface_->convertTensorToAVFrame(
879+
auto cudaFrame = deviceInterface_->convertCUDATensorToAVFrameForEncoding(
879880
currFrame, i, avCodecContext_.get());
880881
TORCH_CHECK(
881-
cudaFrame.has_value(),
882-
"convertTensorToAVFrame failed for frame ",
882+
cudaFrame != nullptr,
883+
"convertCUDATensorToAVFrameForEncoding failed for frame ",
883884
i,
884-
"on device: ",
885+
" on device: ",
885886
frames_.device());
886-
avFrame = std::move(*cudaFrame);
887+
avFrame = std::move(cudaFrame);
887888
} else {
888889
avFrame = convertTensorToAVFrame(currFrame, i);
889890
}

src/torchcodec/_core/StreamOptions.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ struct VideoStreamOptions {
4141
ColorConversionLibrary::FILTERGRAPH;
4242

4343
// By default we use CPU for decoding for both C++ and python users.
44-
// Note: For video encoding, device is determined by the location of the input
45-
// frame tensor.
44+
// Note: This is not used for video encoding, because device is determined by
45+
// the device of the input frame tensor.
4646
torch::Device device = torch::kCPU;
4747
// Device variant (e.g., "ffmpeg", "beta", etc.)
4848
std::string_view deviceVariant = "ffmpeg";

test/test_encoders.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
assert_tensor_close_on_at_least,
1818
get_ffmpeg_major_version,
1919
get_ffmpeg_minor_version,
20-
in_fbcode,
2120
IS_WINDOWS,
2221
NASA_AUDIO_MP3,
2322
needs_ffmpeg_cli,
@@ -1304,14 +1303,16 @@ def test_extra_options_utilized(self, tmp_path, profile, colorspace, color_range
13041303
assert metadata["color_space"] == colorspace
13051304
assert metadata["color_range"] == color_range
13061305

1306+
@needs_ffmpeg_cli
13071307
@pytest.mark.needs_cuda
1308-
@pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available")
1308+
# TODO-VideoEncoder: Auto-select codec for GPU encoding
13091309
@pytest.mark.parametrize(
13101310
"format_codec",
13111311
[
13121312
("mov", "h264_nvenc"),
13131313
("mp4", "hevc_nvenc"),
13141314
("avi", "h264_nvenc"),
1315+
# TODO-VideoEncoder: add in_CI mark, similar to in_fbcode
13151316
# ("mkv", "av1_nvenc"), # av1_nvenc is not supported on CI
13161317
],
13171318
)
@@ -1354,16 +1355,7 @@ def test_nvenc_against_ffmpeg_cli(self, tmp_path, format_codec, method):
13541355
ffmpeg_cmd.extend(["-rc", "constqp"]) # Set rate control mode for AV1
13551356
ffmpeg_cmd.extend(["-qp", str(qp)]) # Use lossless qp for other codecs
13561357
ffmpeg_cmd.extend([ffmpeg_encoded_path])
1357-
1358-
# TODO-VideoEncoder: Ensure CI does not skip this test, as we know NVENC is available.
1359-
try:
1360-
subprocess.run(ffmpeg_cmd, check=True, capture_output=True)
1361-
except subprocess.CalledProcessError as e:
1362-
if b"No NVENC capable devices found" in e.stderr:
1363-
pytest.skip("NVENC not available on this system")
1364-
else:
1365-
raise
1366-
1358+
subprocess.run(ffmpeg_cmd, check=True, capture_output=True)
13671359
encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate)
13681360

13691361
encoder_extra_options = {"qp": qp}
@@ -1404,4 +1396,11 @@ def test_nvenc_against_ffmpeg_cli(self, tmp_path, format_codec, method):
14041396
assert ffmpeg_frames.shape[0] == encoder_frames.shape[0]
14051397
for ff_frame, enc_frame in zip(ffmpeg_frames, encoder_frames):
14061398
assert psnr(ff_frame, enc_frame) > 25
1407-
assert_tensor_close_on_at_least(ff_frame, enc_frame, percentage=95, atol=2)
1399+
assert_tensor_close_on_at_least(ff_frame, enc_frame, percentage=96, atol=2)
1400+
1401+
if method == "to_file":
1402+
ffmpeg_metadata = self._get_video_metadata(ffmpeg_encoded_path, ["pix_fmt"])
1403+
encoder_metadata = self._get_video_metadata(encoder_output, ["pix_fmt"])
1404+
# pix_fmt nv12 is stored as yuv420p in metadata
1405+
assert encoder_metadata["pix_fmt"] == "yuv420p"
1406+
assert ffmpeg_metadata["pix_fmt"] == "yuv420p"

0 commit comments

Comments
 (0)