diff --git a/setup.py b/setup.py index 59b5ef53c..4b3698aaf 100644 --- a/setup.py +++ b/setup.py @@ -112,6 +112,7 @@ def _build_all_extensions_with_cmake(self): torch_dir = Path(torch.utils.cmake_prefix_path) / "Torch" cmake_build_type = os.environ.get("CMAKE_BUILD_TYPE", "Release") enable_cuda = os.environ.get("ENABLE_CUDA", "") + enable_xpu = os.environ.get("ENABLE_XPU", "") torchcodec_disable_compile_warning_as_error = os.environ.get( "TORCHCODEC_DISABLE_COMPILE_WARNING_AS_ERROR", "OFF" ) @@ -123,6 +124,7 @@ def _build_all_extensions_with_cmake(self): f"-DCMAKE_BUILD_TYPE={cmake_build_type}", f"-DPYTHON_VERSION={python_version.major}.{python_version.minor}", f"-DENABLE_CUDA={enable_cuda}", + f"-DENABLE_XPU={enable_xpu}", f"-DTORCHCODEC_DISABLE_COMPILE_WARNING_AS_ERROR={torchcodec_disable_compile_warning_as_error}", ] diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 1e6d2ec80..1fe5e729b 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -34,6 +34,15 @@ else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic ${TORCHCODEC_WERROR_OPTION} ${TORCH_CXX_FLAGS}") endif() +if(ENABLE_CUDA) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_CUDA") +endif() +if(ENABLE_XPU) + find_package(PkgConfig REQUIRED) + pkg_check_modules(L0 REQUIRED IMPORTED_TARGET level-zero) + pkg_check_modules(LIBVA REQUIRED IMPORTED_TARGET libva) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_XPU") +endif() function(make_torchcodec_sublibrary library_name @@ -109,7 +118,11 @@ function(make_torchcodec_libraries ) if(ENABLE_CUDA) - list(APPEND core_sources CudaDeviceInterface.cpp) + list(APPEND core_sources CudaDeviceInterface.cpp) + endif() + + if(ENABLE_XPU) + list(APPEND core_sources XpuDeviceInterface.cpp) endif() set(core_library_dependencies @@ -124,6 +137,11 @@ function(make_torchcodec_libraries ) endif() + if(ENABLE_XPU) + list(APPEND core_library_dependencies + PkgConfig::L0 PkgConfig::LIBVA) + endif() + make_torchcodec_sublibrary( "${core_library_name}" SHARED diff --git a/src/torchcodec/_core/XpuDeviceInterface.cpp b/src/torchcodec/_core/XpuDeviceInterface.cpp new file mode 100644 index 000000000..99ac26480 --- /dev/null +++ b/src/torchcodec/_core/XpuDeviceInterface.cpp @@ -0,0 +1,404 @@ +#include + +#include +#include + +#include +#include + +#include "src/torchcodec/_core/Cache.h" +#include "src/torchcodec/_core/FFMPEGCommon.h" +#include "src/torchcodec/_core/XpuDeviceInterface.h" + +extern "C" { +#include +#include +} + +namespace facebook::torchcodec { +namespace { + +static bool g_xpu = registerDeviceInterface( + torch::kXPU, + [](const torch::Device& device) { return new XpuDeviceInterface(device); }); + +const int MAX_XPU_GPUS = 128; +// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching. +// Set to a positive number to have a cache of that size. +const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1; +PerGpuCache> + g_cached_hw_device_ctxs(MAX_XPU_GPUS, MAX_CONTEXTS_PER_GPU_IN_CACHE); + +UniqueAVBufferRef getVaapiContext(const torch::Device& device) { + enum AVHWDeviceType type = av_hwdevice_find_type_by_name("vaapi"); + TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find vaapi device"); + torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device); + + UniqueAVBufferRef hw_device_ctx = g_cached_hw_device_ctxs.get(device); + if (hw_device_ctx) { + return hw_device_ctx; + } + + std::string renderD = "/dev/dri/renderD128"; + + sycl::device syclDevice = c10::xpu::get_raw_device(nonNegativeDeviceIndex); + if (syclDevice.has(sycl::aspect::ext_intel_pci_address)) { + auto BDF = + syclDevice.get_info(); + renderD = "/dev/dri/by-path/pci-" + BDF + "-render"; + } + + AVBufferRef* ctx = nullptr; + int err = av_hwdevice_ctx_create(&ctx, type, renderD.c_str(), nullptr, 0); + if (err < 0) { + TORCH_CHECK( + false, + "Failed to create specified HW device: ", + getFFMPEGErrorStringFromErrorCode(err)); + } + return UniqueAVBufferRef(ctx); +} + +} // namespace + +XpuDeviceInterface::XpuDeviceInterface(const torch::Device& device) + : DeviceInterface(device) { + TORCH_CHECK(g_xpu, "XpuDeviceInterface was not registered!"); + TORCH_CHECK( + device_.type() == torch::kXPU, "Unsupported device: ", device_.str()); +} + +XpuDeviceInterface::~XpuDeviceInterface() { + if (ctx_) { + g_cached_hw_device_ctxs.addIfCacheHasCapacity(device_, std::move(ctx_)); + } +} + +void XpuDeviceInterface::initializeContext(AVCodecContext* codecContext) { + TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized"); + + // It is important for pytorch itself to create the xpu context. If ffmpeg + // creates the context it may not be compatible with pytorch. + // This is a dummy tensor to initialize the xpu context. + torch::Tensor dummyTensorForXpuInitialization = torch::empty( + {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); + ctx_ = getVaapiContext(device_); + codecContext->hw_device_ctx = av_buffer_ref(ctx_.get()); + return; +} + +struct vaapiSurface { + vaapiSurface(VADisplay dpy, uint32_t width, uint32_t height); + + ~vaapiSurface() { + vaDestroySurfaces(dpy_, &id_, 1); + } + + inline VASurfaceID id() const { + return id_; + } + + torch::Tensor toTensor(const torch::Device& device); + + private: + VADisplay dpy_; + VASurfaceID id_; +}; + +vaapiSurface::vaapiSurface(VADisplay dpy, uint32_t width, uint32_t height) + : dpy_(dpy) { + VASurfaceAttrib attrib{}; + + attrib.type = VASurfaceAttribPixelFormat; + attrib.flags = VA_SURFACE_ATTRIB_SETTABLE; + attrib.value.type = VAGenericValueTypeInteger; + attrib.value.value.i = VA_FOURCC_RGBX; + + VAStatus res = vaCreateSurfaces( + dpy_, VA_RT_FORMAT_RGB32, width, height, &id_, 1, &attrib, 1); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, + "Failed to create VAAPI surface: ", + vaErrorStr(res)); +} + +void deleter(DLManagedTensor* self) { + std::unique_ptr tensor(self); + std::unique_ptr context( + (ze_context_handle_t*)self->manager_ctx); + zeMemFree(*context, self->dl_tensor.data); +} + +torch::Tensor vaapiSurface::toTensor(const torch::Device& device) { + VADRMPRIMESurfaceDescriptor desc{}; + + VAStatus sts = vaExportSurfaceHandle( + dpy_, + id_, + VA_SURFACE_ATTRIB_MEM_TYPE_DRM_PRIME_2, + VA_EXPORT_SURFACE_READ_ONLY, + &desc); + TORCH_CHECK( + sts == VA_STATUS_SUCCESS, + "vaExportSurfaceHandle failed: ", + vaErrorStr(sts)); + + TORCH_CHECK(desc.num_objects == 1, "Expected 1 fd, got ", desc.num_objects); + TORCH_CHECK(desc.num_layers == 1, "Expected 1 layer, got ", desc.num_layers); + TORCH_CHECK( + desc.layers[0].num_planes == 1, + "Expected 1 plane, got ", + desc.num_layers); + + std::unique_ptr ze_context = + std::make_unique(); + ze_device_handle_t ze_device{}; + sycl::queue queue = c10::xpu::getCurrentXPUStream(device.index()); + + queue + .submit([&](sycl::handler& cgh) { + cgh.host_task([&](const sycl::interop_handle& ih) { + *ze_context = + ih.get_native_context(); + ze_device = + ih.get_native_device(); + }); + }) + .wait(); + + ze_external_memory_import_fd_t import_fd_desc{}; + import_fd_desc.stype = ZE_STRUCTURE_TYPE_EXTERNAL_MEMORY_IMPORT_FD; + import_fd_desc.flags = ZE_EXTERNAL_MEMORY_TYPE_FLAG_DMA_BUF; + import_fd_desc.fd = desc.objects[0].fd; + + ze_device_mem_alloc_desc_t alloc_desc{}; + alloc_desc.pNext = &import_fd_desc; + void* usm_ptr = nullptr; + + ze_result_t res = zeMemAllocDevice( + *ze_context, &alloc_desc, desc.objects[0].size, 0, ze_device, &usm_ptr); + TORCH_CHECK( + res == ZE_RESULT_SUCCESS, "Failed to import fd=", desc.objects[0].fd); + + close(desc.objects[0].fd); + + std::unique_ptr dl_dst = std::make_unique(); + int64_t shape[3] = {desc.height, desc.width, 4}; + + dl_dst->manager_ctx = ze_context.release(); + dl_dst->deleter = deleter; + dl_dst->dl_tensor.data = usm_ptr; + dl_dst->dl_tensor.device.device_type = kDLOneAPI; + dl_dst->dl_tensor.device.device_id = device.index(); + dl_dst->dl_tensor.ndim = 3; + dl_dst->dl_tensor.dtype.code = kDLUInt; + dl_dst->dl_tensor.dtype.bits = 8; + dl_dst->dl_tensor.dtype.lanes = 1; + dl_dst->dl_tensor.shape = shape; + dl_dst->dl_tensor.strides = nullptr; + dl_dst->dl_tensor.byte_offset = desc.layers[0].offset[0]; + + auto dst = at::fromDLPack(dl_dst.release()); + + return dst; +} + +VADisplay getVaDisplayFromAV(UniqueAVFrame& avFrame) { + AVHWFramesContext* hwfc = (AVHWFramesContext*)avFrame->hw_frames_ctx->data; + AVHWDeviceContext* hwdc = hwfc->device_ctx; + AVVAAPIDeviceContext* vactx = (AVVAAPIDeviceContext*)hwdc->hwctx; + return vactx->display; +} + +struct vaapiVpContext { + VADisplay dpy_; + VAConfigID config_id_ = VA_INVALID_ID; + VAContextID context_id_ = VA_INVALID_ID; + VABufferID pipeline_buf_id_ = VA_INVALID_ID; + + // These structures must be available thru all life + // circle of the struct since they are reused by the media + // driver internally during vaRenderPicture(). + VAProcPipelineParameterBuffer pipeline_{}; + VARectangle surface_region_{}; + + vaapiVpContext() = delete; + vaapiVpContext( + VADisplay dpy, + UniqueAVFrame& avFrame, + uint16_t width, + uint16_t height); + + ~vaapiVpContext() { + if (pipeline_buf_id_ != VA_INVALID_ID) + vaDestroyBuffer(dpy_, pipeline_buf_id_); + if (context_id_ != VA_INVALID_ID) + vaDestroyContext(dpy_, context_id_); + if (config_id_ != VA_INVALID_ID) + vaDestroyConfig(dpy_, config_id_); + } + + void convertTo(VASurfaceID id); +}; + +vaapiVpContext::vaapiVpContext( + VADisplay dpy, + UniqueAVFrame& avFrame, + uint16_t width, + uint16_t height) + : dpy_(dpy) { + VAStatus res = vaCreateConfig( + dpy_, VAProfileNone, VAEntrypointVideoProc, nullptr, 0, &config_id_); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, + "Failed to create VAAPI config: ", + vaErrorStr(res)); + + res = vaCreateContext( + dpy_, + config_id_, + width, + height, + VA_PROGRESSIVE, + nullptr, + 0, + &context_id_); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, + "Failed to create VAAPI VP context: ", + vaErrorStr(res)); + + surface_region_.width = width; + surface_region_.height = height; + + pipeline_.surface = (VASurfaceID)(uintptr_t)avFrame->data[3]; + pipeline_.surface_region = &surface_region_; + pipeline_.output_region = &surface_region_; + if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) + pipeline_.surface_color_standard = VAProcColorStandardBT709; + + res = vaCreateBuffer( + dpy_, + context_id_, + VAProcPipelineParameterBufferType, + sizeof(pipeline_), + 1, + &pipeline_, + &pipeline_buf_id_); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaCreateBuffer failed: ", vaErrorStr(res)); +} + +void vaapiVpContext::convertTo(VASurfaceID id) { + VAStatus res = vaBeginPicture(dpy_, context_id_, id); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaBeginPicture failed: ", vaErrorStr(res)); + + res = vaRenderPicture(dpy_, context_id_, &pipeline_buf_id_, 1); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaRenderPicture failed: ", vaErrorStr(res)); + + res = vaEndPicture(dpy_, context_id_); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaEndPicture failed: ", vaErrorStr(res)); + + res = vaSyncSurface(dpy_, id); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaSyncSurface failed: ", vaErrorStr(res)); +} + +torch::Tensor convertAVFrameToTensor( + const torch::Device& device, + UniqueAVFrame& avFrame, + int width, + int height) { + TORCH_CHECK(height > 0, "height must be > 0, got: ", height); + TORCH_CHECK(width > 0, "width must be > 0, got: ", width); + + // Allocating intermediate tensor we can convert input to with VAAPI. + // This tensor should be WxHx4 since VAAPI does not support RGB24 + // and works only with RGB32. + VADisplay va_dpy = getVaDisplayFromAV(avFrame); + // Importing tensor to VAAPI. + vaapiSurface va_surface(va_dpy, width, height); + + vaapiVpContext va_vp(va_dpy, avFrame, width, height); + va_vp.convertTo(va_surface.id()); + + return va_surface.toTensor(device); +} + +void XpuDeviceInterface::convertAVFrameToFrameOutput( + const VideoStreamOptions& videoStreamOptions, + [[maybe_unused]] const AVRational& timeBase, + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor) { + // TODO: consider to copy handling of CPU frame from CUDA + // TODO: consider to copy NV12 format check from CUDA + TORCH_CHECK( + avFrame->format == AV_PIX_FMT_VAAPI, + "Expected format to be AV_PIX_FMT_VAAPI, got " + + std::string(av_get_pix_fmt_name((AVPixelFormat)avFrame->format))); + auto frameDims = + getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); + int height = frameDims.height; + int width = frameDims.width; + torch::Tensor& dst = frameOutput.data; + if (preAllocatedOutputTensor.has_value()) { + dst = preAllocatedOutputTensor.value(); + auto shape = dst.sizes(); + TORCH_CHECK( + (shape.size() == 3) && (shape[0] == height) && (shape[1] == width) && + (shape[2] == 3), + "Expected tensor of shape ", + height, + "x", + width, + "x3, got ", + shape); + } else { + dst = allocateEmptyHWCTensor(height, width, device_); + } + + auto start = std::chrono::high_resolution_clock::now(); + + // We convert input to the RGBX color format with VAAPI getting WxHx4 + // tensor on the output. + torch::Tensor dst_rgb4 = + convertAVFrameToTensor(device_, avFrame, width, height); + dst.copy_(dst_rgb4.narrow(2, 0, 3)); + + auto end = std::chrono::high_resolution_clock::now(); + + std::chrono::duration duration = end - start; + VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width + << " took: " << duration.count() << "us" << std::endl; +} + +// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9 +// we have to do this because of an FFmpeg bug where hardware decoding is not +// appropriately set, so we just go off and find the matching codec for the CUDA +// device +std::optional XpuDeviceInterface::findCodec( + const AVCodecID& codecId) { + void* i = nullptr; + const AVCodec* codec = nullptr; + while ((codec = av_codec_iterate(&i)) != nullptr) { + if (codec->id != codecId || !av_codec_is_decoder(codec)) { + continue; + } + + const AVCodecHWConfig* config = nullptr; + for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr; + ++j) { + if (config->device_type == AV_HWDEVICE_TYPE_VAAPI) { + return codec; + } + } + } + + return std::nullopt; +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/XpuDeviceInterface.h b/src/torchcodec/_core/XpuDeviceInterface.h new file mode 100644 index 000000000..b294725c6 --- /dev/null +++ b/src/torchcodec/_core/XpuDeviceInterface.h @@ -0,0 +1,35 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include "src/torchcodec/_core/DeviceInterface.h" + +namespace facebook::torchcodec { + +class XpuDeviceInterface : public DeviceInterface { + public: + XpuDeviceInterface(const torch::Device& device); + + virtual ~XpuDeviceInterface(); + + std::optional findCodec(const AVCodecID& codecId) override; + + void initializeContext(AVCodecContext* codecContext) override; + + void convertAVFrameToFrameOutput( + const VideoStreamOptions& videoStreamOptions, + const AVRational& timeBase, + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor = + std::nullopt) override; + + private: + UniqueAVBufferRef ctx_; +}; + +} // namespace facebook::torchcodec diff --git a/test/conftest.py b/test/conftest.py index bef5291e5..067f730e7 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -10,6 +10,9 @@ def pytest_configure(config): config.addinivalue_line( "markers", "needs_cuda: mark for tests that rely on a CUDA device" ) + config.addinivalue_line( + "markers", "needs_xpu: mark for tests that rely on a XPU device" + ) def pytest_collection_modifyitems(items): @@ -19,8 +22,8 @@ def pytest_collection_modifyitems(items): out_items = [] for item in items: - # The needs_cuda mark will exist if the test was explicitly decorated - # with the @needs_cuda decorator. It will also exist if it was + # The needs_[cuda|xpu] mark will exist if the test was explicitly decorated + # with the respective @needs_* decorator. It will also exist if it was # parametrized with a parameter that has the mark: for example if a test # is parametrized with # @pytest.mark.parametrize('device', all_supported_devices()) @@ -28,6 +31,7 @@ def pytest_collection_modifyitems(items): # 'needs_cuda' mark, and the ones with device == 'cpu' won't have the # mark. needs_cuda = item.get_closest_marker("needs_cuda") is not None + needs_xpu = item.get_closest_marker("needs_xpu") is not None if ( needs_cuda @@ -42,6 +46,13 @@ def pytest_collection_modifyitems(items): # those for whatever reason, we need to know. item.add_marker(pytest.mark.skip(reason="CUDA not available.")) + if ( + needs_xpu + and not torch.xpu.is_available() + and os.environ.get("FAIL_WITHOUT_XPU") is None + ): + item.add_marker(pytest.mark.skip(reason="XPU not available.")) + out_items.append(item) items[:] = out_items @@ -56,6 +67,8 @@ def prevent_leaking_rng(): builtin_rng_state = random.getstate() if torch.cuda.is_available(): cuda_rng_state = torch.cuda.get_rng_state() + if torch.xpu.is_available(): + xpu_rng_state = torch.xpu.get_rng_state() yield @@ -63,3 +76,5 @@ def prevent_leaking_rng(): random.setstate(builtin_rng_state) if torch.cuda.is_available(): torch.cuda.set_rng_state(cuda_rng_state) + if torch.xpu.is_available(): + torch.xpu.set_rng_state(xpu_rng_state) diff --git a/test/test_ops.py b/test/test_ops.py index d2f2fd3b1..f5ad492b8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -49,6 +49,7 @@ NASA_AUDIO_MP3, NASA_VIDEO, needs_cuda, + needs_xpu, SINE_MONO_S32, SINE_MONO_S32_44100, SINE_MONO_S32_8000, @@ -740,6 +741,19 @@ def test_cuda_decoder(self): duration, torch.tensor(0.0334).double(), atol=0, rtol=1e-3 ) + @needs_xpu + def test_xpu_decoder(self): + decoder = create_from_file(str(NASA_VIDEO.path)) + add_video_stream(decoder, device="xpu") + frame0, pts, duration = get_next_frame(decoder) + assert frame0.device.type == "xpu" + reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) + assert_frames_equal(frame0, reference_frame0.to("xpu")) + assert pts == torch.tensor([0]) + torch.testing.assert_close( + duration, torch.tensor(0.0334).double(), atol=0, rtol=1e-3 + ) + class TestAudioDecoderOps: @pytest.mark.parametrize( diff --git a/test/utils.py b/test/utils.py index ed611cfda..5d3a7a5f1 100644 --- a/test/utils.py +++ b/test/utils.py @@ -23,8 +23,19 @@ def needs_cuda(test_item): return pytest.mark.needs_cuda(test_item) +# Decorator for skipping XPU tests when XPU isn't available. The tests are +# effectively marked to be skipped in pytest_collection_modifyitems() of +# conftest.py +def needs_xpu(test_item): + return pytest.mark.needs_xpu(test_item) + + def all_supported_devices(): - return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) + return ( + "cpu", + pytest.param("cuda", marks=pytest.mark.needs_cuda), + pytest.param("xpu", marks=pytest.mark.needs_xpu), + ) def get_ffmpeg_major_version(): @@ -77,6 +88,13 @@ def assert_frames_equal(*args, **kwargs): ) else: torch.testing.assert_close(*args, **kwargs, atol=atol, rtol=0) + elif args[0].device.type == "xpu": + if not torch.allclose(*args, atol=0, rtol=0): + from torcheval.metrics import PeakSignalNoiseRatio + + metric = PeakSignalNoiseRatio() + metric.update(args[0], args[1]) + assert metric.compute() >= 40 else: torch.testing.assert_close(*args, **kwargs, atol=0, rtol=0) else: