From f06c28684edd0570cb44e9d7cbe440b60e37ff0e Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Wed, 26 Feb 2025 01:56:45 +0000 Subject: [PATCH] Enable Intel GPU support in torchcodec on Linux (xpu device) This commit enables support for Intel GPUs in torchcodec. It adds: * ffmpeg-vaapi for decoding * VAAPI based color space conversion (decoding output to RGBA) * RGBA surface import as torch tensor (on torch xpu device) * RGBA to RGB24 tensor slicing To build torchcodec with Intel GPU support: * Install pytorch with XPU backend support. For example, with: ``` pip3 install torch --index-url https://download.pytorch.org/whl/xpu ``` * Install oneAPI development environment following https://github.com/pytorch/pytorch?tab=readme-ov-file#intel-gpu-support * Build and install FFmpeg with `--enable-vaapi` * Install torcheval (for tests): `pip3 install torcheval` * Build torchcodec with: `ENABLE_XPU=1 python3 setup.py devel` Notes: * RGB24 is not supported color format on current Intel GPUs (as it is considered to be suboptimal due to odd alignments) * Intel media and compute APIs can't seamlessly work with the memory from each other. For example, Intel computes's Unified Shared Memory pointers are not recognized by media APIs. Thus, lower level sharing via dma fds is needed. This alos makes this part of the solution OS dependent. * Color space conversion algoriths might be quite different as it happens for Intel. This requires to check PSNR values instead of per-pixel atol/rtol differences. * Installing oneAPI environment is neded due to https://github.com/pytorch/pytorch/issues/149075 This commit was primary verfied on Intel Battlemage G21 (0xe20b) and Intel Data Center GPU Flex (0x56c0). Co-authored-by: Edgar Romo Montiel Signed-off-by: Edgar Romo Montiel Signed-off-by: Dmitry Rogozhkin --- setup.py | 2 + src/torchcodec/_core/CMakeLists.txt | 20 +- src/torchcodec/_core/XpuDeviceInterface.cpp | 404 ++++++++++++++++++++ src/torchcodec/_core/XpuDeviceInterface.h | 35 ++ test/conftest.py | 19 +- test/test_ops.py | 14 + test/utils.py | 20 +- 7 files changed, 510 insertions(+), 4 deletions(-) create mode 100644 src/torchcodec/_core/XpuDeviceInterface.cpp create mode 100644 src/torchcodec/_core/XpuDeviceInterface.h diff --git a/setup.py b/setup.py index 59b5ef53c..4b3698aaf 100644 --- a/setup.py +++ b/setup.py @@ -112,6 +112,7 @@ def _build_all_extensions_with_cmake(self): torch_dir = Path(torch.utils.cmake_prefix_path) / "Torch" cmake_build_type = os.environ.get("CMAKE_BUILD_TYPE", "Release") enable_cuda = os.environ.get("ENABLE_CUDA", "") + enable_xpu = os.environ.get("ENABLE_XPU", "") torchcodec_disable_compile_warning_as_error = os.environ.get( "TORCHCODEC_DISABLE_COMPILE_WARNING_AS_ERROR", "OFF" ) @@ -123,6 +124,7 @@ def _build_all_extensions_with_cmake(self): f"-DCMAKE_BUILD_TYPE={cmake_build_type}", f"-DPYTHON_VERSION={python_version.major}.{python_version.minor}", f"-DENABLE_CUDA={enable_cuda}", + f"-DENABLE_XPU={enable_xpu}", f"-DTORCHCODEC_DISABLE_COMPILE_WARNING_AS_ERROR={torchcodec_disable_compile_warning_as_error}", ] diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 1e6d2ec80..1fe5e729b 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -34,6 +34,15 @@ else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic ${TORCHCODEC_WERROR_OPTION} ${TORCH_CXX_FLAGS}") endif() +if(ENABLE_CUDA) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_CUDA") +endif() +if(ENABLE_XPU) + find_package(PkgConfig REQUIRED) + pkg_check_modules(L0 REQUIRED IMPORTED_TARGET level-zero) + pkg_check_modules(LIBVA REQUIRED IMPORTED_TARGET libva) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_XPU") +endif() function(make_torchcodec_sublibrary library_name @@ -109,7 +118,11 @@ function(make_torchcodec_libraries ) if(ENABLE_CUDA) - list(APPEND core_sources CudaDeviceInterface.cpp) + list(APPEND core_sources CudaDeviceInterface.cpp) + endif() + + if(ENABLE_XPU) + list(APPEND core_sources XpuDeviceInterface.cpp) endif() set(core_library_dependencies @@ -124,6 +137,11 @@ function(make_torchcodec_libraries ) endif() + if(ENABLE_XPU) + list(APPEND core_library_dependencies + PkgConfig::L0 PkgConfig::LIBVA) + endif() + make_torchcodec_sublibrary( "${core_library_name}" SHARED diff --git a/src/torchcodec/_core/XpuDeviceInterface.cpp b/src/torchcodec/_core/XpuDeviceInterface.cpp new file mode 100644 index 000000000..99ac26480 --- /dev/null +++ b/src/torchcodec/_core/XpuDeviceInterface.cpp @@ -0,0 +1,404 @@ +#include + +#include +#include + +#include +#include + +#include "src/torchcodec/_core/Cache.h" +#include "src/torchcodec/_core/FFMPEGCommon.h" +#include "src/torchcodec/_core/XpuDeviceInterface.h" + +extern "C" { +#include +#include +} + +namespace facebook::torchcodec { +namespace { + +static bool g_xpu = registerDeviceInterface( + torch::kXPU, + [](const torch::Device& device) { return new XpuDeviceInterface(device); }); + +const int MAX_XPU_GPUS = 128; +// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching. +// Set to a positive number to have a cache of that size. +const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1; +PerGpuCache> + g_cached_hw_device_ctxs(MAX_XPU_GPUS, MAX_CONTEXTS_PER_GPU_IN_CACHE); + +UniqueAVBufferRef getVaapiContext(const torch::Device& device) { + enum AVHWDeviceType type = av_hwdevice_find_type_by_name("vaapi"); + TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find vaapi device"); + torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device); + + UniqueAVBufferRef hw_device_ctx = g_cached_hw_device_ctxs.get(device); + if (hw_device_ctx) { + return hw_device_ctx; + } + + std::string renderD = "/dev/dri/renderD128"; + + sycl::device syclDevice = c10::xpu::get_raw_device(nonNegativeDeviceIndex); + if (syclDevice.has(sycl::aspect::ext_intel_pci_address)) { + auto BDF = + syclDevice.get_info(); + renderD = "/dev/dri/by-path/pci-" + BDF + "-render"; + } + + AVBufferRef* ctx = nullptr; + int err = av_hwdevice_ctx_create(&ctx, type, renderD.c_str(), nullptr, 0); + if (err < 0) { + TORCH_CHECK( + false, + "Failed to create specified HW device: ", + getFFMPEGErrorStringFromErrorCode(err)); + } + return UniqueAVBufferRef(ctx); +} + +} // namespace + +XpuDeviceInterface::XpuDeviceInterface(const torch::Device& device) + : DeviceInterface(device) { + TORCH_CHECK(g_xpu, "XpuDeviceInterface was not registered!"); + TORCH_CHECK( + device_.type() == torch::kXPU, "Unsupported device: ", device_.str()); +} + +XpuDeviceInterface::~XpuDeviceInterface() { + if (ctx_) { + g_cached_hw_device_ctxs.addIfCacheHasCapacity(device_, std::move(ctx_)); + } +} + +void XpuDeviceInterface::initializeContext(AVCodecContext* codecContext) { + TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized"); + + // It is important for pytorch itself to create the xpu context. If ffmpeg + // creates the context it may not be compatible with pytorch. + // This is a dummy tensor to initialize the xpu context. + torch::Tensor dummyTensorForXpuInitialization = torch::empty( + {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); + ctx_ = getVaapiContext(device_); + codecContext->hw_device_ctx = av_buffer_ref(ctx_.get()); + return; +} + +struct vaapiSurface { + vaapiSurface(VADisplay dpy, uint32_t width, uint32_t height); + + ~vaapiSurface() { + vaDestroySurfaces(dpy_, &id_, 1); + } + + inline VASurfaceID id() const { + return id_; + } + + torch::Tensor toTensor(const torch::Device& device); + + private: + VADisplay dpy_; + VASurfaceID id_; +}; + +vaapiSurface::vaapiSurface(VADisplay dpy, uint32_t width, uint32_t height) + : dpy_(dpy) { + VASurfaceAttrib attrib{}; + + attrib.type = VASurfaceAttribPixelFormat; + attrib.flags = VA_SURFACE_ATTRIB_SETTABLE; + attrib.value.type = VAGenericValueTypeInteger; + attrib.value.value.i = VA_FOURCC_RGBX; + + VAStatus res = vaCreateSurfaces( + dpy_, VA_RT_FORMAT_RGB32, width, height, &id_, 1, &attrib, 1); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, + "Failed to create VAAPI surface: ", + vaErrorStr(res)); +} + +void deleter(DLManagedTensor* self) { + std::unique_ptr tensor(self); + std::unique_ptr context( + (ze_context_handle_t*)self->manager_ctx); + zeMemFree(*context, self->dl_tensor.data); +} + +torch::Tensor vaapiSurface::toTensor(const torch::Device& device) { + VADRMPRIMESurfaceDescriptor desc{}; + + VAStatus sts = vaExportSurfaceHandle( + dpy_, + id_, + VA_SURFACE_ATTRIB_MEM_TYPE_DRM_PRIME_2, + VA_EXPORT_SURFACE_READ_ONLY, + &desc); + TORCH_CHECK( + sts == VA_STATUS_SUCCESS, + "vaExportSurfaceHandle failed: ", + vaErrorStr(sts)); + + TORCH_CHECK(desc.num_objects == 1, "Expected 1 fd, got ", desc.num_objects); + TORCH_CHECK(desc.num_layers == 1, "Expected 1 layer, got ", desc.num_layers); + TORCH_CHECK( + desc.layers[0].num_planes == 1, + "Expected 1 plane, got ", + desc.num_layers); + + std::unique_ptr ze_context = + std::make_unique(); + ze_device_handle_t ze_device{}; + sycl::queue queue = c10::xpu::getCurrentXPUStream(device.index()); + + queue + .submit([&](sycl::handler& cgh) { + cgh.host_task([&](const sycl::interop_handle& ih) { + *ze_context = + ih.get_native_context(); + ze_device = + ih.get_native_device(); + }); + }) + .wait(); + + ze_external_memory_import_fd_t import_fd_desc{}; + import_fd_desc.stype = ZE_STRUCTURE_TYPE_EXTERNAL_MEMORY_IMPORT_FD; + import_fd_desc.flags = ZE_EXTERNAL_MEMORY_TYPE_FLAG_DMA_BUF; + import_fd_desc.fd = desc.objects[0].fd; + + ze_device_mem_alloc_desc_t alloc_desc{}; + alloc_desc.pNext = &import_fd_desc; + void* usm_ptr = nullptr; + + ze_result_t res = zeMemAllocDevice( + *ze_context, &alloc_desc, desc.objects[0].size, 0, ze_device, &usm_ptr); + TORCH_CHECK( + res == ZE_RESULT_SUCCESS, "Failed to import fd=", desc.objects[0].fd); + + close(desc.objects[0].fd); + + std::unique_ptr dl_dst = std::make_unique(); + int64_t shape[3] = {desc.height, desc.width, 4}; + + dl_dst->manager_ctx = ze_context.release(); + dl_dst->deleter = deleter; + dl_dst->dl_tensor.data = usm_ptr; + dl_dst->dl_tensor.device.device_type = kDLOneAPI; + dl_dst->dl_tensor.device.device_id = device.index(); + dl_dst->dl_tensor.ndim = 3; + dl_dst->dl_tensor.dtype.code = kDLUInt; + dl_dst->dl_tensor.dtype.bits = 8; + dl_dst->dl_tensor.dtype.lanes = 1; + dl_dst->dl_tensor.shape = shape; + dl_dst->dl_tensor.strides = nullptr; + dl_dst->dl_tensor.byte_offset = desc.layers[0].offset[0]; + + auto dst = at::fromDLPack(dl_dst.release()); + + return dst; +} + +VADisplay getVaDisplayFromAV(UniqueAVFrame& avFrame) { + AVHWFramesContext* hwfc = (AVHWFramesContext*)avFrame->hw_frames_ctx->data; + AVHWDeviceContext* hwdc = hwfc->device_ctx; + AVVAAPIDeviceContext* vactx = (AVVAAPIDeviceContext*)hwdc->hwctx; + return vactx->display; +} + +struct vaapiVpContext { + VADisplay dpy_; + VAConfigID config_id_ = VA_INVALID_ID; + VAContextID context_id_ = VA_INVALID_ID; + VABufferID pipeline_buf_id_ = VA_INVALID_ID; + + // These structures must be available thru all life + // circle of the struct since they are reused by the media + // driver internally during vaRenderPicture(). + VAProcPipelineParameterBuffer pipeline_{}; + VARectangle surface_region_{}; + + vaapiVpContext() = delete; + vaapiVpContext( + VADisplay dpy, + UniqueAVFrame& avFrame, + uint16_t width, + uint16_t height); + + ~vaapiVpContext() { + if (pipeline_buf_id_ != VA_INVALID_ID) + vaDestroyBuffer(dpy_, pipeline_buf_id_); + if (context_id_ != VA_INVALID_ID) + vaDestroyContext(dpy_, context_id_); + if (config_id_ != VA_INVALID_ID) + vaDestroyConfig(dpy_, config_id_); + } + + void convertTo(VASurfaceID id); +}; + +vaapiVpContext::vaapiVpContext( + VADisplay dpy, + UniqueAVFrame& avFrame, + uint16_t width, + uint16_t height) + : dpy_(dpy) { + VAStatus res = vaCreateConfig( + dpy_, VAProfileNone, VAEntrypointVideoProc, nullptr, 0, &config_id_); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, + "Failed to create VAAPI config: ", + vaErrorStr(res)); + + res = vaCreateContext( + dpy_, + config_id_, + width, + height, + VA_PROGRESSIVE, + nullptr, + 0, + &context_id_); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, + "Failed to create VAAPI VP context: ", + vaErrorStr(res)); + + surface_region_.width = width; + surface_region_.height = height; + + pipeline_.surface = (VASurfaceID)(uintptr_t)avFrame->data[3]; + pipeline_.surface_region = &surface_region_; + pipeline_.output_region = &surface_region_; + if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) + pipeline_.surface_color_standard = VAProcColorStandardBT709; + + res = vaCreateBuffer( + dpy_, + context_id_, + VAProcPipelineParameterBufferType, + sizeof(pipeline_), + 1, + &pipeline_, + &pipeline_buf_id_); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaCreateBuffer failed: ", vaErrorStr(res)); +} + +void vaapiVpContext::convertTo(VASurfaceID id) { + VAStatus res = vaBeginPicture(dpy_, context_id_, id); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaBeginPicture failed: ", vaErrorStr(res)); + + res = vaRenderPicture(dpy_, context_id_, &pipeline_buf_id_, 1); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaRenderPicture failed: ", vaErrorStr(res)); + + res = vaEndPicture(dpy_, context_id_); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaEndPicture failed: ", vaErrorStr(res)); + + res = vaSyncSurface(dpy_, id); + TORCH_CHECK( + res == VA_STATUS_SUCCESS, "vaSyncSurface failed: ", vaErrorStr(res)); +} + +torch::Tensor convertAVFrameToTensor( + const torch::Device& device, + UniqueAVFrame& avFrame, + int width, + int height) { + TORCH_CHECK(height > 0, "height must be > 0, got: ", height); + TORCH_CHECK(width > 0, "width must be > 0, got: ", width); + + // Allocating intermediate tensor we can convert input to with VAAPI. + // This tensor should be WxHx4 since VAAPI does not support RGB24 + // and works only with RGB32. + VADisplay va_dpy = getVaDisplayFromAV(avFrame); + // Importing tensor to VAAPI. + vaapiSurface va_surface(va_dpy, width, height); + + vaapiVpContext va_vp(va_dpy, avFrame, width, height); + va_vp.convertTo(va_surface.id()); + + return va_surface.toTensor(device); +} + +void XpuDeviceInterface::convertAVFrameToFrameOutput( + const VideoStreamOptions& videoStreamOptions, + [[maybe_unused]] const AVRational& timeBase, + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor) { + // TODO: consider to copy handling of CPU frame from CUDA + // TODO: consider to copy NV12 format check from CUDA + TORCH_CHECK( + avFrame->format == AV_PIX_FMT_VAAPI, + "Expected format to be AV_PIX_FMT_VAAPI, got " + + std::string(av_get_pix_fmt_name((AVPixelFormat)avFrame->format))); + auto frameDims = + getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); + int height = frameDims.height; + int width = frameDims.width; + torch::Tensor& dst = frameOutput.data; + if (preAllocatedOutputTensor.has_value()) { + dst = preAllocatedOutputTensor.value(); + auto shape = dst.sizes(); + TORCH_CHECK( + (shape.size() == 3) && (shape[0] == height) && (shape[1] == width) && + (shape[2] == 3), + "Expected tensor of shape ", + height, + "x", + width, + "x3, got ", + shape); + } else { + dst = allocateEmptyHWCTensor(height, width, device_); + } + + auto start = std::chrono::high_resolution_clock::now(); + + // We convert input to the RGBX color format with VAAPI getting WxHx4 + // tensor on the output. + torch::Tensor dst_rgb4 = + convertAVFrameToTensor(device_, avFrame, width, height); + dst.copy_(dst_rgb4.narrow(2, 0, 3)); + + auto end = std::chrono::high_resolution_clock::now(); + + std::chrono::duration duration = end - start; + VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width + << " took: " << duration.count() << "us" << std::endl; +} + +// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9 +// we have to do this because of an FFmpeg bug where hardware decoding is not +// appropriately set, so we just go off and find the matching codec for the CUDA +// device +std::optional XpuDeviceInterface::findCodec( + const AVCodecID& codecId) { + void* i = nullptr; + const AVCodec* codec = nullptr; + while ((codec = av_codec_iterate(&i)) != nullptr) { + if (codec->id != codecId || !av_codec_is_decoder(codec)) { + continue; + } + + const AVCodecHWConfig* config = nullptr; + for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr; + ++j) { + if (config->device_type == AV_HWDEVICE_TYPE_VAAPI) { + return codec; + } + } + } + + return std::nullopt; +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/XpuDeviceInterface.h b/src/torchcodec/_core/XpuDeviceInterface.h new file mode 100644 index 000000000..b294725c6 --- /dev/null +++ b/src/torchcodec/_core/XpuDeviceInterface.h @@ -0,0 +1,35 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include "src/torchcodec/_core/DeviceInterface.h" + +namespace facebook::torchcodec { + +class XpuDeviceInterface : public DeviceInterface { + public: + XpuDeviceInterface(const torch::Device& device); + + virtual ~XpuDeviceInterface(); + + std::optional findCodec(const AVCodecID& codecId) override; + + void initializeContext(AVCodecContext* codecContext) override; + + void convertAVFrameToFrameOutput( + const VideoStreamOptions& videoStreamOptions, + const AVRational& timeBase, + UniqueAVFrame& avFrame, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor = + std::nullopt) override; + + private: + UniqueAVBufferRef ctx_; +}; + +} // namespace facebook::torchcodec diff --git a/test/conftest.py b/test/conftest.py index bef5291e5..067f730e7 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -10,6 +10,9 @@ def pytest_configure(config): config.addinivalue_line( "markers", "needs_cuda: mark for tests that rely on a CUDA device" ) + config.addinivalue_line( + "markers", "needs_xpu: mark for tests that rely on a XPU device" + ) def pytest_collection_modifyitems(items): @@ -19,8 +22,8 @@ def pytest_collection_modifyitems(items): out_items = [] for item in items: - # The needs_cuda mark will exist if the test was explicitly decorated - # with the @needs_cuda decorator. It will also exist if it was + # The needs_[cuda|xpu] mark will exist if the test was explicitly decorated + # with the respective @needs_* decorator. It will also exist if it was # parametrized with a parameter that has the mark: for example if a test # is parametrized with # @pytest.mark.parametrize('device', all_supported_devices()) @@ -28,6 +31,7 @@ def pytest_collection_modifyitems(items): # 'needs_cuda' mark, and the ones with device == 'cpu' won't have the # mark. needs_cuda = item.get_closest_marker("needs_cuda") is not None + needs_xpu = item.get_closest_marker("needs_xpu") is not None if ( needs_cuda @@ -42,6 +46,13 @@ def pytest_collection_modifyitems(items): # those for whatever reason, we need to know. item.add_marker(pytest.mark.skip(reason="CUDA not available.")) + if ( + needs_xpu + and not torch.xpu.is_available() + and os.environ.get("FAIL_WITHOUT_XPU") is None + ): + item.add_marker(pytest.mark.skip(reason="XPU not available.")) + out_items.append(item) items[:] = out_items @@ -56,6 +67,8 @@ def prevent_leaking_rng(): builtin_rng_state = random.getstate() if torch.cuda.is_available(): cuda_rng_state = torch.cuda.get_rng_state() + if torch.xpu.is_available(): + xpu_rng_state = torch.xpu.get_rng_state() yield @@ -63,3 +76,5 @@ def prevent_leaking_rng(): random.setstate(builtin_rng_state) if torch.cuda.is_available(): torch.cuda.set_rng_state(cuda_rng_state) + if torch.xpu.is_available(): + torch.xpu.set_rng_state(xpu_rng_state) diff --git a/test/test_ops.py b/test/test_ops.py index d2f2fd3b1..f5ad492b8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -49,6 +49,7 @@ NASA_AUDIO_MP3, NASA_VIDEO, needs_cuda, + needs_xpu, SINE_MONO_S32, SINE_MONO_S32_44100, SINE_MONO_S32_8000, @@ -740,6 +741,19 @@ def test_cuda_decoder(self): duration, torch.tensor(0.0334).double(), atol=0, rtol=1e-3 ) + @needs_xpu + def test_xpu_decoder(self): + decoder = create_from_file(str(NASA_VIDEO.path)) + add_video_stream(decoder, device="xpu") + frame0, pts, duration = get_next_frame(decoder) + assert frame0.device.type == "xpu" + reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) + assert_frames_equal(frame0, reference_frame0.to("xpu")) + assert pts == torch.tensor([0]) + torch.testing.assert_close( + duration, torch.tensor(0.0334).double(), atol=0, rtol=1e-3 + ) + class TestAudioDecoderOps: @pytest.mark.parametrize( diff --git a/test/utils.py b/test/utils.py index ed611cfda..5d3a7a5f1 100644 --- a/test/utils.py +++ b/test/utils.py @@ -23,8 +23,19 @@ def needs_cuda(test_item): return pytest.mark.needs_cuda(test_item) +# Decorator for skipping XPU tests when XPU isn't available. The tests are +# effectively marked to be skipped in pytest_collection_modifyitems() of +# conftest.py +def needs_xpu(test_item): + return pytest.mark.needs_xpu(test_item) + + def all_supported_devices(): - return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) + return ( + "cpu", + pytest.param("cuda", marks=pytest.mark.needs_cuda), + pytest.param("xpu", marks=pytest.mark.needs_xpu), + ) def get_ffmpeg_major_version(): @@ -77,6 +88,13 @@ def assert_frames_equal(*args, **kwargs): ) else: torch.testing.assert_close(*args, **kwargs, atol=atol, rtol=0) + elif args[0].device.type == "xpu": + if not torch.allclose(*args, atol=0, rtol=0): + from torcheval.metrics import PeakSignalNoiseRatio + + metric = PeakSignalNoiseRatio() + metric.update(args[0], args[1]) + assert metric.compute() >= 40 else: torch.testing.assert_close(*args, **kwargs, atol=0, rtol=0) else: