diff --git a/setup.py b/setup.py index 59b5ef53c..4b3698aaf 100644 --- a/setup.py +++ b/setup.py @@ -112,6 +112,7 @@ def _build_all_extensions_with_cmake(self): torch_dir = Path(torch.utils.cmake_prefix_path) / "Torch" cmake_build_type = os.environ.get("CMAKE_BUILD_TYPE", "Release") enable_cuda = os.environ.get("ENABLE_CUDA", "") + enable_xpu = os.environ.get("ENABLE_XPU", "") torchcodec_disable_compile_warning_as_error = os.environ.get( "TORCHCODEC_DISABLE_COMPILE_WARNING_AS_ERROR", "OFF" ) @@ -123,6 +124,7 @@ def _build_all_extensions_with_cmake(self): f"-DCMAKE_BUILD_TYPE={cmake_build_type}", f"-DPYTHON_VERSION={python_version.major}.{python_version.minor}", f"-DENABLE_CUDA={enable_cuda}", + f"-DENABLE_XPU={enable_xpu}", f"-DTORCHCODEC_DISABLE_COMPILE_WARNING_AS_ERROR={torchcodec_disable_compile_warning_as_error}", ] diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 1e6d2ec80..a01ae0659 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -34,6 +34,15 @@ else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic ${TORCHCODEC_WERROR_OPTION} ${TORCH_CXX_FLAGS}") endif() +if(ENABLE_CUDA) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_CUDA") +endif() +if(ENABLE_XPU) + find_package(PkgConfig REQUIRED) + pkg_check_modules(L0 REQUIRED IMPORTED_TARGET level-zero) + pkg_check_modules(LIBVA REQUIRED IMPORTED_TARGET libva) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_XPU") +endif() function(make_torchcodec_sublibrary library_name @@ -101,6 +110,7 @@ function(make_torchcodec_libraries AVIOContextHolder.cpp AVIOTensorContext.cpp FFMPEGCommon.cpp + FilterGraph.cpp Frame.cpp DeviceInterface.cpp CpuDeviceInterface.cpp @@ -109,7 +119,11 @@ function(make_torchcodec_libraries ) if(ENABLE_CUDA) - list(APPEND core_sources CudaDeviceInterface.cpp) + list(APPEND core_sources CudaDeviceInterface.cpp) + endif() + + if(ENABLE_XPU) + list(APPEND core_sources XpuDeviceInterface.cpp) endif() set(core_library_dependencies @@ -124,6 +138,11 @@ function(make_torchcodec_libraries ) endif() + if(ENABLE_XPU) + list(APPEND core_library_dependencies + PkgConfig::L0 PkgConfig::LIBVA) + endif() + make_torchcodec_sublibrary( "${core_library_name}" SHARED diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 4d0cbddf9..f685eb8f6 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -6,11 +6,6 @@ #include "src/torchcodec/_core/CpuDeviceInterface.h" -extern "C" { -#include -#include -} - namespace facebook::torchcodec { namespace { @@ -20,20 +15,6 @@ static bool g_cpu = registerDeviceInterface( } // namespace -bool CpuDeviceInterface::DecodedFrameContext::operator==( - const CpuDeviceInterface::DecodedFrameContext& other) { - return decodedWidth == other.decodedWidth && - decodedHeight == other.decodedHeight && - decodedFormat == other.decodedFormat && - expectedWidth == other.expectedWidth && - expectedHeight == other.expectedHeight; -} - -bool CpuDeviceInterface::DecodedFrameContext::operator!=( - const CpuDeviceInterface::DecodedFrameContext& other) { - return !(*this == other); -} - CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device) : DeviceInterface(device) { TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!"); @@ -84,13 +65,22 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( // conversion objects as much as possible for performance reasons. enum AVPixelFormat frameFormat = static_cast(avFrame->format); - auto frameContext = DecodedFrameContext{ - avFrame->width, - avFrame->height, - frameFormat, - avFrame->sample_aspect_ratio, - expectedOutputWidth, - expectedOutputHeight}; + FiltersContext filtersContext; + + filtersContext.inputWidth = avFrame->width; + filtersContext.inputHeight = avFrame->height; + filtersContext.inputFormat = frameFormat; + filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio; + filtersContext.outputWidth = expectedOutputWidth; + filtersContext.outputHeight = expectedOutputHeight; + filtersContext.outputFormat = AV_PIX_FMT_RGB24; + filtersContext.timeBase = timeBase; + + std::stringstream filters; + filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight; + filters << ":sws_flags=bilinear"; + + filtersContext.filters = filters.str(); // By default, we want to use swscale for color conversion because it is // faster. However, it has width requirements, so we may need to fall back @@ -114,9 +104,9 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor( expectedOutputHeight, expectedOutputWidth, torch::kCPU)); - if (!swsContext_ || prevFrameContext_ != frameContext) { - createSwsContext(frameContext, avFrame->colorspace); - prevFrameContext_ = frameContext; + if (!swsContext_ || prevFiltersContext_ != filtersContext) { + createSwsContext(filtersContext, avFrame->colorspace); + prevFiltersContext_ = std::move(filtersContext); } int resultHeight = convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor); @@ -132,9 +122,10 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( frameOutput.data = outputTensor; } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) { - createFilterGraph(frameContext, videoStreamOptions, timeBase); - prevFrameContext_ = frameContext; + if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) { + filterGraphContext_ = + std::make_unique(filtersContext, videoStreamOptions); + prevFiltersContext_ = std::move(filtersContext); } outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame); @@ -187,14 +178,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale( torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( const UniqueAVFrame& avFrame) { - int status = av_buffersrc_write_frame( - filterGraphContext_.sourceContext, avFrame.get()); - TORCH_CHECK( - status >= AVSUCCESS, "Failed to add frame to buffer source context"); + UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame); - UniqueAVFrame filteredAVFrame(av_frame_alloc()); - status = av_buffersink_get_frame( - filterGraphContext_.sinkContext, filteredAVFrame.get()); TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24); auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get()); @@ -210,118 +195,16 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8}); } -void CpuDeviceInterface::createFilterGraph( - const DecodedFrameContext& frameContext, - const VideoStreamOptions& videoStreamOptions, - const AVRational& timeBase) { - filterGraphContext_.filterGraph.reset(avfilter_graph_alloc()); - TORCH_CHECK(filterGraphContext_.filterGraph.get() != nullptr); - - if (videoStreamOptions.ffmpegThreadCount.has_value()) { - filterGraphContext_.filterGraph->nb_threads = - videoStreamOptions.ffmpegThreadCount.value(); - } - - const AVFilter* buffersrc = avfilter_get_by_name("buffer"); - const AVFilter* buffersink = avfilter_get_by_name("buffersink"); - - std::stringstream filterArgs; - filterArgs << "video_size=" << frameContext.decodedWidth << "x" - << frameContext.decodedHeight; - filterArgs << ":pix_fmt=" << frameContext.decodedFormat; - filterArgs << ":time_base=" << timeBase.num << "/" << timeBase.den; - filterArgs << ":pixel_aspect=" << frameContext.decodedAspectRatio.num << "/" - << frameContext.decodedAspectRatio.den; - - int status = avfilter_graph_create_filter( - &filterGraphContext_.sourceContext, - buffersrc, - "in", - filterArgs.str().c_str(), - nullptr, - filterGraphContext_.filterGraph.get()); - TORCH_CHECK( - status >= 0, - "Failed to create filter graph: ", - filterArgs.str(), - ": ", - getFFMPEGErrorStringFromErrorCode(status)); - - status = avfilter_graph_create_filter( - &filterGraphContext_.sinkContext, - buffersink, - "out", - nullptr, - nullptr, - filterGraphContext_.filterGraph.get()); - TORCH_CHECK( - status >= 0, - "Failed to create filter graph: ", - getFFMPEGErrorStringFromErrorCode(status)); - - enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE}; - - status = av_opt_set_int_list( - filterGraphContext_.sinkContext, - "pix_fmts", - pix_fmts, - AV_PIX_FMT_NONE, - AV_OPT_SEARCH_CHILDREN); - TORCH_CHECK( - status >= 0, - "Failed to set output pixel formats: ", - getFFMPEGErrorStringFromErrorCode(status)); - - UniqueAVFilterInOut outputs(avfilter_inout_alloc()); - UniqueAVFilterInOut inputs(avfilter_inout_alloc()); - - outputs->name = av_strdup("in"); - outputs->filter_ctx = filterGraphContext_.sourceContext; - outputs->pad_idx = 0; - outputs->next = nullptr; - inputs->name = av_strdup("out"); - inputs->filter_ctx = filterGraphContext_.sinkContext; - inputs->pad_idx = 0; - inputs->next = nullptr; - - std::stringstream description; - description << "scale=" << frameContext.expectedWidth << ":" - << frameContext.expectedHeight; - description << ":sws_flags=bilinear"; - - AVFilterInOut* outputsTmp = outputs.release(); - AVFilterInOut* inputsTmp = inputs.release(); - status = avfilter_graph_parse_ptr( - filterGraphContext_.filterGraph.get(), - description.str().c_str(), - &inputsTmp, - &outputsTmp, - nullptr); - outputs.reset(outputsTmp); - inputs.reset(inputsTmp); - TORCH_CHECK( - status >= 0, - "Failed to parse filter description: ", - getFFMPEGErrorStringFromErrorCode(status)); - - status = - avfilter_graph_config(filterGraphContext_.filterGraph.get(), nullptr); - TORCH_CHECK( - status >= 0, - "Failed to configure filter graph: ", - getFFMPEGErrorStringFromErrorCode(status)); -} - void CpuDeviceInterface::createSwsContext( - const DecodedFrameContext& frameContext, + const FiltersContext& filtersContext, const enum AVColorSpace colorspace) { SwsContext* swsContext = sws_getContext( - frameContext.decodedWidth, - frameContext.decodedHeight, - frameContext.decodedFormat, - frameContext.expectedWidth, - frameContext.expectedHeight, - AV_PIX_FMT_RGB24, + filtersContext.inputWidth, + filtersContext.inputHeight, + filtersContext.inputFormat, + filtersContext.outputWidth, + filtersContext.outputHeight, + filtersContext.outputFormat, SWS_BILINEAR, nullptr, nullptr, diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 404289bd6..54411efd6 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -8,6 +8,7 @@ #include "src/torchcodec/_core/DeviceInterface.h" #include "src/torchcodec/_core/FFMPEGCommon.h" +#include "src/torchcodec/_core/FilterGraph.h" namespace facebook::torchcodec { @@ -41,40 +42,18 @@ class CpuDeviceInterface : public DeviceInterface { torch::Tensor convertAVFrameToTensorUsingFilterGraph( const UniqueAVFrame& avFrame); - struct FilterGraphContext { - UniqueAVFilterGraph filterGraph; - AVFilterContext* sourceContext = nullptr; - AVFilterContext* sinkContext = nullptr; - }; - - struct DecodedFrameContext { - int decodedWidth; - int decodedHeight; - AVPixelFormat decodedFormat; - AVRational decodedAspectRatio; - int expectedWidth; - int expectedHeight; - bool operator==(const DecodedFrameContext&); - bool operator!=(const DecodedFrameContext&); - }; - void createSwsContext( - const DecodedFrameContext& frameContext, + const FiltersContext& filtersContext, const enum AVColorSpace colorspace); - void createFilterGraph( - const DecodedFrameContext& frameContext, - const VideoStreamOptions& videoStreamOptions, - const AVRational& timeBase); - // color-conversion fields. Only one of FilterGraphContext and // UniqueSwsContext should be non-null. - FilterGraphContext filterGraphContext_; + std::unique_ptr filterGraphContext_; UniqueSwsContext swsContext_; // Used to know whether a new FilterGraphContext or UniqueSwsContext should // be created before decoding a new frame. - DecodedFrameContext prevFrameContext_; + FiltersContext prevFiltersContext_; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/FilterGraph.cpp b/src/torchcodec/_core/FilterGraph.cpp new file mode 100644 index 000000000..5cefe806e --- /dev/null +++ b/src/torchcodec/_core/FilterGraph.cpp @@ -0,0 +1,150 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "src/torchcodec/_core/FilterGraph.h" + +extern "C" { +#include +#include +} + +namespace facebook::torchcodec { + +bool operator==(const AVRational& lhs, const AVRational& rhs) { + return lhs.num == rhs.num && lhs.den == rhs.den; +} + +bool FiltersContext::operator==(const FiltersContext& other) { + return inputWidth == other.inputWidth && inputHeight == other.inputHeight && + inputFormat == other.inputFormat && outputWidth == other.outputWidth && + outputHeight == other.outputHeight && + outputFormat == other.outputFormat && filters == other.filters && + timeBase == other.timeBase && + hwFramesCtx.get() == other.hwFramesCtx.get(); +} + +bool FiltersContext::operator!=(const FiltersContext& other) { + return !(*this == other); +} + +FilterGraph::FilterGraph( + const FiltersContext& filtersContext, + const VideoStreamOptions& videoStreamOptions) { + filterGraph_.reset(avfilter_graph_alloc()); + TORCH_CHECK(filterGraph_.get() != nullptr); + + if (videoStreamOptions.ffmpegThreadCount.has_value()) { + filterGraph_->nb_threads = videoStreamOptions.ffmpegThreadCount.value(); + } + + const AVFilter* buffersrc = avfilter_get_by_name("buffer"); + const AVFilter* buffersink = avfilter_get_by_name("buffersink"); + + auto deleter = [](AVBufferSrcParameters* p) { + if (p) { + av_freep(&p); + } + }; + std::unique_ptr srcParams( + nullptr, deleter); + + srcParams.reset(av_buffersrc_parameters_alloc()); + TORCH_CHECK(srcParams, "Failed to allocate buffersrc params"); + + srcParams->format = filtersContext.inputFormat; + srcParams->width = filtersContext.inputWidth; + srcParams->height = filtersContext.inputHeight; + srcParams->sample_aspect_ratio = filtersContext.inputAspectRatio; + srcParams->time_base = filtersContext.timeBase; + if (filtersContext.hwFramesCtx) { + srcParams->hw_frames_ctx = av_buffer_ref(filtersContext.hwFramesCtx.get()); + } + + sourceContext_ = + avfilter_graph_alloc_filter(filterGraph_.get(), buffersrc, "in"); + TORCH_CHECK(sourceContext_, "Failed to allocate filter graph"); + + int status = av_buffersrc_parameters_set(sourceContext_, srcParams.get()); + TORCH_CHECK( + status >= 0, + "Failed to create filter graph: ", + getFFMPEGErrorStringFromErrorCode(status)); + + status = avfilter_init_str(sourceContext_, nullptr); + TORCH_CHECK( + status >= 0, + "Failed to create filter graph : ", + getFFMPEGErrorStringFromErrorCode(status)); + + status = avfilter_graph_create_filter( + &sinkContext_, buffersink, "out", nullptr, nullptr, filterGraph_.get()); + TORCH_CHECK( + status >= 0, + "Failed to create filter graph: ", + getFFMPEGErrorStringFromErrorCode(status)); + + enum AVPixelFormat pix_fmts[] = { + filtersContext.outputFormat, AV_PIX_FMT_NONE}; + + status = av_opt_set_int_list( + sinkContext_, + "pix_fmts", + pix_fmts, + AV_PIX_FMT_NONE, + AV_OPT_SEARCH_CHILDREN); + TORCH_CHECK( + status >= 0, + "Failed to set output pixel formats: ", + getFFMPEGErrorStringFromErrorCode(status)); + + UniqueAVFilterInOut outputs(avfilter_inout_alloc()); + UniqueAVFilterInOut inputs(avfilter_inout_alloc()); + + outputs->name = av_strdup("in"); + outputs->filter_ctx = sourceContext_; + outputs->pad_idx = 0; + outputs->next = nullptr; + inputs->name = av_strdup("out"); + inputs->filter_ctx = sinkContext_; + inputs->pad_idx = 0; + inputs->next = nullptr; + + AVFilterInOut* outputsTmp = outputs.release(); + AVFilterInOut* inputsTmp = inputs.release(); + status = avfilter_graph_parse_ptr( + filterGraph_.get(), + filtersContext.filters.c_str(), + &inputsTmp, + &outputsTmp, + nullptr); + outputs.reset(outputsTmp); + inputs.reset(inputsTmp); + TORCH_CHECK( + status >= 0, + "Failed to parse filter description: ", + getFFMPEGErrorStringFromErrorCode(status)); + + status = avfilter_graph_config(filterGraph_.get(), nullptr); + TORCH_CHECK( + status >= 0, + "Failed to configure filter graph: ", + getFFMPEGErrorStringFromErrorCode(status)); +} + +UniqueAVFrame FilterGraph::convert(const UniqueAVFrame& avFrame) { + int status = av_buffersrc_write_frame(sourceContext_, avFrame.get()); + TORCH_CHECK( + status >= AVSUCCESS, "Failed to add frame to buffer source context"); + + UniqueAVFrame filteredAVFrame(av_frame_alloc()); + status = av_buffersink_get_frame(sinkContext_, filteredAVFrame.get()); + TORCH_CHECK( + status >= AVSUCCESS, "Failed to get frame from buffer sink context"); + + return filteredAVFrame; +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/FilterGraph.h b/src/torchcodec/_core/FilterGraph.h new file mode 100644 index 000000000..dd3eca610 --- /dev/null +++ b/src/torchcodec/_core/FilterGraph.h @@ -0,0 +1,45 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include "src/torchcodec/_core/FFMPEGCommon.h" +#include "src/torchcodec/_core/StreamOptions.h" + +namespace facebook::torchcodec { + +struct FiltersContext { + int inputWidth = 0; + int inputHeight = 0; + AVPixelFormat inputFormat = AV_PIX_FMT_NONE; + AVRational inputAspectRatio = {0, 0}; + int outputWidth = 0; + int outputHeight = 0; + AVPixelFormat outputFormat = AV_PIX_FMT_NONE; + + std::string filters; + AVRational timeBase = {0, 0}; + UniqueAVBufferRef hwFramesCtx; + + bool operator==(const FiltersContext&); + bool operator!=(const FiltersContext&); +}; + +class FilterGraph { + public: + FilterGraph( + const FiltersContext& filtersContext, + const VideoStreamOptions& videoStreamOptions); + + UniqueAVFrame convert(const UniqueAVFrame& avFrame); + + private: + UniqueAVFilterGraph filterGraph_; + AVFilterContext* sourceContext_ = nullptr; + AVFilterContext* sinkContext_ = nullptr; +}; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/XpuDeviceInterface.cpp b/src/torchcodec/_core/XpuDeviceInterface.cpp new file mode 100644 index 000000000..11507e474 --- /dev/null +++ b/src/torchcodec/_core/XpuDeviceInterface.cpp @@ -0,0 +1,317 @@ +#include + +#include +#include + +#include +#include + +#include "src/torchcodec/_core/Cache.h" +#include "src/torchcodec/_core/FFMPEGCommon.h" +#include "src/torchcodec/_core/XpuDeviceInterface.h" + +extern "C" { +#include +#include +#include +#include +} + +namespace facebook::torchcodec { +namespace { + +static bool g_xpu = registerDeviceInterface( + torch::kXPU, + [](const torch::Device& device) { return new XpuDeviceInterface(device); }); + +const int MAX_XPU_GPUS = 128; +// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching. +// Set to a positive number to have a cache of that size. +const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1; +PerGpuCache> + g_cached_hw_device_ctxs(MAX_XPU_GPUS, MAX_CONTEXTS_PER_GPU_IN_CACHE); + +UniqueAVBufferRef getVaapiContext(const torch::Device& device) { + enum AVHWDeviceType type = av_hwdevice_find_type_by_name("vaapi"); + TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find vaapi device"); + torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device); + + UniqueAVBufferRef hw_device_ctx = g_cached_hw_device_ctxs.get(device); + if (hw_device_ctx) { + return hw_device_ctx; + } + + std::string renderD = "/dev/dri/renderD128"; + + sycl::device syclDevice = c10::xpu::get_raw_device(nonNegativeDeviceIndex); + if (syclDevice.has(sycl::aspect::ext_intel_pci_address)) { + auto BDF = + syclDevice.get_info(); + renderD = "/dev/dri/by-path/pci-" + BDF + "-render"; + } + + AVBufferRef* ctx = nullptr; + int err = av_hwdevice_ctx_create(&ctx, type, renderD.c_str(), nullptr, 0); + if (err < 0) { + TORCH_CHECK( + false, + "Failed to create specified HW device: ", + getFFMPEGErrorStringFromErrorCode(err)); + } + return UniqueAVBufferRef(ctx); +} + +} // namespace + +XpuDeviceInterface::XpuDeviceInterface(const torch::Device& device) + : DeviceInterface(device) { + TORCH_CHECK(g_xpu, "XpuDeviceInterface was not registered!"); + TORCH_CHECK( + device_.type() == torch::kXPU, "Unsupported device: ", device_.str()); +} + +XpuDeviceInterface::~XpuDeviceInterface() { + if (ctx_) { + g_cached_hw_device_ctxs.addIfCacheHasCapacity(device_, std::move(ctx_)); + } +} + +VADisplay getVaDisplayFromAV(AVFrame* avFrame) { + AVHWFramesContext* hwfc = (AVHWFramesContext*)avFrame->hw_frames_ctx->data; + AVHWDeviceContext* hwdc = hwfc->device_ctx; + AVVAAPIDeviceContext* vactx = (AVVAAPIDeviceContext*)hwdc->hwctx; + return vactx->display; +} + +void XpuDeviceInterface::initializeContext(AVCodecContext* codecContext) { + TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized"); + + // It is important for pytorch itself to create the xpu context. If ffmpeg + // creates the context it may not be compatible with pytorch. + // This is a dummy tensor to initialize the xpu context. + torch::Tensor dummyTensorForXpuInitialization = torch::empty( + {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); + ctx_ = getVaapiContext(device_); + codecContext->hw_device_ctx = av_buffer_ref(ctx_.get()); + return; +} + +struct xpuManagerCtx { + UniqueAVFrame avFrame; + ze_context_handle_t zeCtx = nullptr; +}; + +void deleter(DLManagedTensor* self) { + std::unique_ptr tensor(self); + std::unique_ptr context((xpuManagerCtx*)self->manager_ctx); + zeMemFree(context->zeCtx, self->dl_tensor.data); +} + +torch::Tensor AVFrameToTensor( + const torch::Device& device, + const UniqueAVFrame& frame) { + TORCH_CHECK_EQ(frame->format, AV_PIX_FMT_VAAPI); + + VADRMPRIMESurfaceDescriptor desc{}; + + VAStatus sts = vaExportSurfaceHandle( + getVaDisplayFromAV(frame.get()), + (VASurfaceID)(uintptr_t)frame->data[3], + VA_SURFACE_ATTRIB_MEM_TYPE_DRM_PRIME_2, + VA_EXPORT_SURFACE_READ_ONLY, + &desc); + TORCH_CHECK( + sts == VA_STATUS_SUCCESS, + "vaExportSurfaceHandle failed: ", + vaErrorStr(sts)); + + TORCH_CHECK(desc.num_objects == 1, "Expected 1 fd, got ", desc.num_objects); + TORCH_CHECK(desc.num_layers == 1, "Expected 1 layer, got ", desc.num_layers); + TORCH_CHECK( + desc.layers[0].num_planes == 1, + "Expected 1 plane, got ", + desc.num_layers); + + std::unique_ptr context = std::make_unique(); + ze_device_handle_t ze_device{}; + sycl::queue queue = c10::xpu::getCurrentXPUStream(device.index()); + + queue + .submit([&](sycl::handler& cgh) { + cgh.host_task([&](const sycl::interop_handle& ih) { + context->zeCtx = + ih.get_native_context(); + ze_device = + ih.get_native_device(); + }); + }) + .wait(); + + ze_external_memory_import_fd_t import_fd_desc{}; + import_fd_desc.stype = ZE_STRUCTURE_TYPE_EXTERNAL_MEMORY_IMPORT_FD; + import_fd_desc.flags = ZE_EXTERNAL_MEMORY_TYPE_FLAG_DMA_BUF; + import_fd_desc.fd = desc.objects[0].fd; + + ze_device_mem_alloc_desc_t alloc_desc{}; + alloc_desc.pNext = &import_fd_desc; + void* usm_ptr = nullptr; + + ze_result_t res = zeMemAllocDevice( + context->zeCtx, + &alloc_desc, + desc.objects[0].size, + 0, + ze_device, + &usm_ptr); + TORCH_CHECK( + res == ZE_RESULT_SUCCESS, "Failed to import fd=", desc.objects[0].fd); + + close(desc.objects[0].fd); + + std::unique_ptr dl_dst = std::make_unique(); + int64_t shape[3] = {desc.height, desc.width, 4}; + + context->avFrame.reset(av_frame_alloc()); + TORCH_CHECK(context->avFrame.get(), "Failed to allocate AVFrame"); + + int status = av_frame_ref(context->avFrame.get(), frame.get()); + TORCH_CHECK( + status >= 0, + "Failed to reference AVFrame: ", + getFFMPEGErrorStringFromErrorCode(status)); + + dl_dst->manager_ctx = context.release(); + dl_dst->deleter = deleter; + dl_dst->dl_tensor.data = usm_ptr; + dl_dst->dl_tensor.device.device_type = kDLOneAPI; + dl_dst->dl_tensor.device.device_id = device.index(); + dl_dst->dl_tensor.ndim = 3; + dl_dst->dl_tensor.dtype.code = kDLUInt; + dl_dst->dl_tensor.dtype.bits = 8; + dl_dst->dl_tensor.dtype.lanes = 1; + dl_dst->dl_tensor.shape = shape; + dl_dst->dl_tensor.strides = nullptr; + dl_dst->dl_tensor.byte_offset = desc.layers[0].offset[0]; + + auto dst = at::fromDLPack(dl_dst.release()); + + return dst; +} + +VADisplay getVaDisplayFromAV(UniqueAVFrame& avFrame) { + AVHWFramesContext* hwfc = (AVHWFramesContext*)avFrame->hw_frames_ctx->data; + AVHWDeviceContext* hwdc = hwfc->device_ctx; + AVVAAPIDeviceContext* vactx = (AVVAAPIDeviceContext*)hwdc->hwctx; + return vactx->display; +} + +void XpuDeviceInterface::convertAVFrameToFrameOutput( + const VideoStreamOptions& videoStreamOptions, + [[maybe_unused]] const AVRational& timeBase, + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor) { + // TODO: consider to copy handling of CPU frame from CUDA + // TODO: consider to copy NV12 format check from CUDA + TORCH_CHECK( + avFrame->format == AV_PIX_FMT_VAAPI, + "Expected format to be AV_PIX_FMT_VAAPI, got " + + std::string(av_get_pix_fmt_name((AVPixelFormat)avFrame->format))); + auto frameDims = + getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); + int height = frameDims.height; + int width = frameDims.width; + torch::Tensor& dst = frameOutput.data; + if (preAllocatedOutputTensor.has_value()) { + dst = preAllocatedOutputTensor.value(); + auto shape = dst.sizes(); + TORCH_CHECK( + (shape.size() == 3) && (shape[0] == height) && (shape[1] == width) && + (shape[2] == 3), + "Expected tensor of shape ", + height, + "x", + width, + "x3, got ", + shape); + } else { + dst = allocateEmptyHWCTensor(height, width, device_); + } + + auto start = std::chrono::high_resolution_clock::now(); + // We need to compare the current frame context with our previous frame + // context. If they are different, then we need to re-create our colorspace + // conversion objects. We create our colorspace conversion objects late so + // that we don't have to depend on the unreliable metadata in the header. + // And we sometimes re-create them because it's possible for frame + // resolution to change mid-stream. Finally, we want to reuse the colorspace + // conversion objects as much as possible for performance reasons. + enum AVPixelFormat frameFormat = + static_cast(avFrame->format); + FiltersContext filtersContext; + + filtersContext.inputWidth = avFrame->width; + filtersContext.inputHeight = avFrame->height; + filtersContext.inputFormat = frameFormat; + filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio; + // Actual output color format will be set via filter options + filtersContext.outputFormat = AV_PIX_FMT_VAAPI; + filtersContext.timeBase = timeBase; + filtersContext.hwFramesCtx.reset(av_buffer_ref(avFrame->hw_frames_ctx)); + + std::stringstream filters; + filters << "scale_vaapi=" << width << ":" << height; + // CPU device interface outputs RGB in full (pc) color range. + // We are doing the same to match. + filters << ":format=rgba:out_range=pc"; + + filtersContext.filters = filters.str(); + + if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) { + filterGraphContext_ = + std::make_unique(filtersContext, videoStreamOptions); + prevFiltersContext_ = std::move(filtersContext); + } + + // We convert input to the RGBX color format with VAAPI getting WxHx4 + // tensor on the output. + UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame); + + TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_VAAPI); + + torch::Tensor dst_rgb4 = AVFrameToTensor(device_, filteredAVFrame); + dst.copy_(dst_rgb4.narrow(2, 0, 3)); + + auto end = std::chrono::high_resolution_clock::now(); + + std::chrono::duration duration = end - start; + VLOG(9) << "Conversion of frame height=" << height << " width=" << width + << " took: " << duration.count() << "us" << std::endl; +} + +// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9 +// we have to do this because of an FFmpeg bug where hardware decoding is not +// appropriately set, so we just go off and find the matching codec for the CUDA +// device +std::optional XpuDeviceInterface::findCodec( + const AVCodecID& codecId) { + void* i = nullptr; + const AVCodec* codec = nullptr; + while ((codec = av_codec_iterate(&i)) != nullptr) { + if (codec->id != codecId || !av_codec_is_decoder(codec)) { + continue; + } + + const AVCodecHWConfig* config = nullptr; + for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr; + ++j) { + if (config->device_type == AV_HWDEVICE_TYPE_VAAPI) { + return codec; + } + } + } + + return std::nullopt; +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/XpuDeviceInterface.h b/src/torchcodec/_core/XpuDeviceInterface.h new file mode 100644 index 000000000..6acce1c37 --- /dev/null +++ b/src/torchcodec/_core/XpuDeviceInterface.h @@ -0,0 +1,42 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include "src/torchcodec/_core/DeviceInterface.h" +#include "src/torchcodec/_core/FilterGraph.h" + +namespace facebook::torchcodec { + +class XpuDeviceInterface : public DeviceInterface { + public: + XpuDeviceInterface(const torch::Device& device); + + virtual ~XpuDeviceInterface(); + + std::optional findCodec(const AVCodecID& codecId) override; + + void initializeContext(AVCodecContext* codecContext) override; + + void convertAVFrameToFrameOutput( + const VideoStreamOptions& videoStreamOptions, + const AVRational& timeBase, + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor = + std::nullopt) override; + + private: + UniqueAVBufferRef ctx_; + + std::unique_ptr filterGraphContext_; + + // Used to know whether a new FilterGraphContext should + // be created before decoding a new frame. + FiltersContext prevFiltersContext_; +}; + +} // namespace facebook::torchcodec diff --git a/test/conftest.py b/test/conftest.py index bef5291e5..067f730e7 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -10,6 +10,9 @@ def pytest_configure(config): config.addinivalue_line( "markers", "needs_cuda: mark for tests that rely on a CUDA device" ) + config.addinivalue_line( + "markers", "needs_xpu: mark for tests that rely on a XPU device" + ) def pytest_collection_modifyitems(items): @@ -19,8 +22,8 @@ def pytest_collection_modifyitems(items): out_items = [] for item in items: - # The needs_cuda mark will exist if the test was explicitly decorated - # with the @needs_cuda decorator. It will also exist if it was + # The needs_[cuda|xpu] mark will exist if the test was explicitly decorated + # with the respective @needs_* decorator. It will also exist if it was # parametrized with a parameter that has the mark: for example if a test # is parametrized with # @pytest.mark.parametrize('device', all_supported_devices()) @@ -28,6 +31,7 @@ def pytest_collection_modifyitems(items): # 'needs_cuda' mark, and the ones with device == 'cpu' won't have the # mark. needs_cuda = item.get_closest_marker("needs_cuda") is not None + needs_xpu = item.get_closest_marker("needs_xpu") is not None if ( needs_cuda @@ -42,6 +46,13 @@ def pytest_collection_modifyitems(items): # those for whatever reason, we need to know. item.add_marker(pytest.mark.skip(reason="CUDA not available.")) + if ( + needs_xpu + and not torch.xpu.is_available() + and os.environ.get("FAIL_WITHOUT_XPU") is None + ): + item.add_marker(pytest.mark.skip(reason="XPU not available.")) + out_items.append(item) items[:] = out_items @@ -56,6 +67,8 @@ def prevent_leaking_rng(): builtin_rng_state = random.getstate() if torch.cuda.is_available(): cuda_rng_state = torch.cuda.get_rng_state() + if torch.xpu.is_available(): + xpu_rng_state = torch.xpu.get_rng_state() yield @@ -63,3 +76,5 @@ def prevent_leaking_rng(): random.setstate(builtin_rng_state) if torch.cuda.is_available(): torch.cuda.set_rng_state(cuda_rng_state) + if torch.xpu.is_available(): + torch.xpu.set_rng_state(xpu_rng_state) diff --git a/test/test_ops.py b/test/test_ops.py index d2f2fd3b1..f5ad492b8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -49,6 +49,7 @@ NASA_AUDIO_MP3, NASA_VIDEO, needs_cuda, + needs_xpu, SINE_MONO_S32, SINE_MONO_S32_44100, SINE_MONO_S32_8000, @@ -740,6 +741,19 @@ def test_cuda_decoder(self): duration, torch.tensor(0.0334).double(), atol=0, rtol=1e-3 ) + @needs_xpu + def test_xpu_decoder(self): + decoder = create_from_file(str(NASA_VIDEO.path)) + add_video_stream(decoder, device="xpu") + frame0, pts, duration = get_next_frame(decoder) + assert frame0.device.type == "xpu" + reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) + assert_frames_equal(frame0, reference_frame0.to("xpu")) + assert pts == torch.tensor([0]) + torch.testing.assert_close( + duration, torch.tensor(0.0334).double(), atol=0, rtol=1e-3 + ) + class TestAudioDecoderOps: @pytest.mark.parametrize( diff --git a/test/utils.py b/test/utils.py index ed611cfda..5d3a7a5f1 100644 --- a/test/utils.py +++ b/test/utils.py @@ -23,8 +23,19 @@ def needs_cuda(test_item): return pytest.mark.needs_cuda(test_item) +# Decorator for skipping XPU tests when XPU isn't available. The tests are +# effectively marked to be skipped in pytest_collection_modifyitems() of +# conftest.py +def needs_xpu(test_item): + return pytest.mark.needs_xpu(test_item) + + def all_supported_devices(): - return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) + return ( + "cpu", + pytest.param("cuda", marks=pytest.mark.needs_cuda), + pytest.param("xpu", marks=pytest.mark.needs_xpu), + ) def get_ffmpeg_major_version(): @@ -77,6 +88,13 @@ def assert_frames_equal(*args, **kwargs): ) else: torch.testing.assert_close(*args, **kwargs, atol=atol, rtol=0) + elif args[0].device.type == "xpu": + if not torch.allclose(*args, atol=0, rtol=0): + from torcheval.metrics import PeakSignalNoiseRatio + + metric = PeakSignalNoiseRatio() + metric.update(args[0], args[1]) + assert metric.compute() >= 40 else: torch.testing.assert_close(*args, **kwargs, atol=0, rtol=0) else: