Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/torchcodec/_core/BetaCudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,4 +699,9 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor);
}

std::string BetaCudaDeviceInterface::getDetails() {
return std::string("Beta CUDA Device Interface. Using ") +
(cpuFallback_ ? "CPU fallback." : "NVDEC.");
}

} // namespace facebook::torchcodec
2 changes: 2 additions & 0 deletions src/torchcodec/_core/BetaCudaDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class BetaCudaDeviceInterface : public DeviceInterface {
int frameReadyForDecoding(CUVIDPICPARAMS* picParams);
int frameReadyInDisplayOrder(CUVIDPARSERDISPINFO* dispInfo);

std::string getDetails() override;

private:
int sendCuvidPacket(CUVIDSOURCEDATAPACKET& cuvidPacket);

Expand Down
4 changes: 4 additions & 0 deletions src/torchcodec/_core/CpuDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,4 +346,8 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
return rgbAVFrameToTensor(filterGraph_->convert(avFrame));
}

std::string CpuDeviceInterface::getDetails() {
return std::string("CPU Device Interface.");
}

} // namespace facebook::torchcodec
2 changes: 2 additions & 0 deletions src/torchcodec/_core/CpuDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class CpuDeviceInterface : public DeviceInterface {
std::optional<torch::Tensor> preAllocatedOutputTensor =
std::nullopt) override;

std::string getDetails() override;

private:
int convertAVFrameToTensorUsingSwScale(
const UniqueAVFrame& avFrame,
Expand Down
11 changes: 11 additions & 0 deletions src/torchcodec/_core/CudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,12 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
frameOutput.data = cpuFrameOutput.data.to(device_);
}

usingCPUFallback_ = true;
return;
}

usingCPUFallback_ = false;

// Above we checked that the AVFrame was on GPU, but that's not enough, we
// also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits),
// because this is what the NPP color conversion routines expect. This SHOULD
Expand Down Expand Up @@ -351,4 +354,12 @@ std::optional<const AVCodec*> CudaDeviceInterface::findCodec(
return std::nullopt;
}

std::string CudaDeviceInterface::getDetails() {
// Note: for this interface specificaly the fallback is only known after a
// frame has been decoded, not before: that's when FFmpeg decides to fallback,
// so we can't know earlier.
return std::string("FFmpeg CUDA Device Interface. Using ") +
(usingCPUFallback_ ? "CPU fallback." : "NVDEC.");
}

} // namespace facebook::torchcodec
4 changes: 4 additions & 0 deletions src/torchcodec/_core/CudaDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class CudaDeviceInterface : public DeviceInterface {
std::optional<torch::Tensor> preAllocatedOutputTensor =
std::nullopt) override;

std::string getDetails() override;

private:
// Our CUDA decoding code assumes NV12 format. In order to handle other
// kinds of input, we need to convert them to NV12. Our current implementation
Expand All @@ -60,6 +62,8 @@ class CudaDeviceInterface : public DeviceInterface {
// maybeConvertAVFrameToNV12().
std::unique_ptr<FiltersContext> nv12ConversionContext_;
std::unique_ptr<FilterGraph> nv12Conversion_;

bool usingCPUFallback_ = false;
};

} // namespace facebook::torchcodec
4 changes: 4 additions & 0 deletions src/torchcodec/_core/DeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ class DeviceInterface {
avcodec_flush_buffers(codecContext_.get());
}

virtual std::string getDetails() {
return "";
}

protected:
torch::Device device_;
SharedAVCodecContext codecContext_;
Expand Down
5 changes: 5 additions & 0 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1702,4 +1702,9 @@ double SingleStreamDecoder::getPtsSecondsForFrame(int64_t frameIndex) {
streamInfo.allFrames[frameIndex].pts, streamInfo.timeBase);
}

std::string SingleStreamDecoder::getInterfaceDetails() const {
TORCH_CHECK(deviceInterface_ != nullptr, "Device interface doesn't exist.");
return deviceInterface_->getDetails();
}

} // namespace facebook::torchcodec
2 changes: 2 additions & 0 deletions src/torchcodec/_core/SingleStreamDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ class SingleStreamDecoder {
DecodeStats getDecodeStats() const;
void resetDecodeStats();

std::string getInterfaceDetails() const;

private:
// --------------------------------------------------------------------------
// STREAMINFO AND ASSOCIATED STRUCTS
Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from .ops import (
_add_video_stream,
_get_backend_details,
_get_key_frame_indices,
_test_frame_pts_equality,
add_audio_stream,
Expand Down
8 changes: 8 additions & 0 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"get_stream_json_metadata(Tensor(a!) decoder, int stream_index) -> str");
m.def("_get_json_ffmpeg_library_versions() -> str");
m.def("_get_backend_details(Tensor(a!) decoder) -> str");
m.def(
"_test_frame_pts_equality(Tensor(a!) decoder, *, int frame_index, float pts_seconds_to_test) -> bool");
m.def("scan_all_streams_to_update_metadata(Tensor(a!) decoder) -> ()");
Expand Down Expand Up @@ -869,6 +870,11 @@ std::string _get_json_ffmpeg_library_versions() {
return ss.str();
}

std::string get_backend_details(at::Tensor& decoder) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
return videoDecoder->getInterfaceDetails();
}

// Scans video packets to get more accurate metadata like frame count, exact
// keyframe positions, etc. Exact keyframe positions are useful for efficient
// accurate seeking. Note that this function reads the entire video but it does
Expand Down Expand Up @@ -912,6 +918,8 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
m.impl(
"scan_all_streams_to_update_metadata",
&scan_all_streams_to_update_metadata);

m.impl("_get_backend_details", &get_backend_details);
}

} // namespace facebook::torchcodec
6 changes: 6 additions & 0 deletions src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def load_torchcodec_shared_libraries():
_get_json_ffmpeg_library_versions = (
torch.ops.torchcodec_ns._get_json_ffmpeg_library_versions.default
)
_get_backend_details = torch.ops.torchcodec_ns._get_backend_details.default


# =============================
Expand Down Expand Up @@ -509,3 +510,8 @@ def scan_all_streams_to_update_metadata_abstract(decoder: torch.Tensor) -> None:
def get_ffmpeg_library_versions():
versions_json = _get_json_ffmpeg_library_versions()
return json.loads(versions_json)


@register_fake("torchcodec_ns::_get_backend_details")
def _get_backend_details_abstract(decoder: torch.Tensor) -> str:
return ""
22 changes: 3 additions & 19 deletions test/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1738,26 +1738,10 @@ def test_set_cuda_backend(self):
with set_cuda_backend("BETA"):
assert _get_cuda_backend() == "beta"

def assert_decoder_uses(decoder, *, expected_backend):
# TODO: This doesn't work anymore after
# https://github.com/meta-pytorch/torchcodec/pull/977
# We need to define a better way to assert which backend a decoder
# is using.
return
# Assert that a decoder instance is using a given backend.
#
# We know H265_VIDEO fails on the BETA backend while it works on the
# ffmpeg one.
# if expected_backend == "ffmpeg":
# decoder.get_frame_at(0) # this would fail if this was BETA
# else:
# with pytest.raises(RuntimeError, match="Video is too small"):
# decoder.get_frame_at(0)

# Check that the default is the ffmpeg backend
assert _get_cuda_backend() == "ffmpeg"
dec = VideoDecoder(H265_VIDEO.path, device="cuda")
assert_decoder_uses(dec, expected_backend="ffmpeg")
assert _core._get_backend_details(dec._decoder).startswith("FFmpeg CUDA")

# Check the setting "beta" effectively uses the BETA backend.
# We also show that the affects decoder creation only. When the decoder
Expand All @@ -1766,9 +1750,9 @@ def assert_decoder_uses(decoder, *, expected_backend):
with set_cuda_backend("beta"):
dec = VideoDecoder(H265_VIDEO.path, device="cuda")
assert _get_cuda_backend() == "ffmpeg"
assert_decoder_uses(dec, expected_backend="beta")
assert _core._get_backend_details(dec._decoder).startswith("Beta CUDA")
with set_cuda_backend("ffmpeg"):
assert_decoder_uses(dec, expected_backend="beta")
assert _core._get_backend_details(dec._decoder).startswith("Beta CUDA")

# Hacky way to ensure passing "cuda:1" is supported by both backends. We
# just check that there's an error when passing cuda:N where N is too
Expand Down
Loading