diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 4274ca483c..741d289e56 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -29,7 +29,7 @@ jobs: - name: install dependencies run: | sudo apt-get update - sudo apt-get install clang qt6-base-dev libglvnd-dev zlib1g-dev libfftw3-dev ninja-build python3-numpy libpng-dev + sudo apt-get install clang qt6-base-dev libglvnd-dev zlib1g-dev libfftw3-dev ninja-build python3-numpy libpng-dev mesa-vulkan-drivers - name: Run sccache-cache uses: mozilla-actions/sccache-action@v0.0.9 @@ -96,7 +96,7 @@ jobs: - name: install dependencies run: | sudo apt-get update - sudo apt-get install g++-9 qt6-base-dev libglvnd-dev zlib1g-dev libfftw3-dev ninja-build python3-numpy libpng-dev + sudo apt-get install g++-9 qt6-base-dev libglvnd-dev zlib1g-dev libfftw3-dev ninja-build python3-numpy libpng-dev mesa-vulkan-drivers - name: Run sccache-cache uses: mozilla-actions/sccache-action@v0.0.9 @@ -233,6 +233,7 @@ jobs: ${{env.MINGW_PACKAGE_PREFIX}}-diffutils ${{env.MINGW_PACKAGE_PREFIX}}-fftw ${{env.MINGW_PACKAGE_PREFIX}}-gcc + ${{env.MINGW_PACKAGE_PREFIX}}-mesa ${{env.MINGW_PACKAGE_PREFIX}}-ninja ${{env.MINGW_PACKAGE_PREFIX}}-pkg-config ${{env.MINGW_PACKAGE_PREFIX}}-qt6-base diff --git a/CMakeLists.txt b/CMakeLists.txt index 3e7fc38db6..7ae18f04aa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,12 +25,16 @@ option(MRTRIX_USE_PCH "Use precompiled headers" ON) option(MRTRIX_PYTHON_SOFTLINK "Build directory softlink to Python source code rather than copying" ON) option(MRTRIX_BUILD_STATIC "Build MRtrix's library statically" OFF) option(MRTRIX_USE_LLD "Use lld as the linker" OFF) +option(MRTRIX_ENABLE_GPU "Enable GPU computing features" ON) option(MRTRIX_IGNORE_VERSION_MISMATCH "Ignore version mismatch between git tag and MRtrix base version" OFF) option(MRTRIX_USE_SYSTEM_EIGEN "Use system-installed Eigen3 library" OFF) option(MRTRIX_USE_SYSTEM_JSON "Use system-installed Json for Modern C++ library" OFF) option(MRTRIX_USE_SYSTEM_NIFTI "Use system-installed NIfTI C headers" OFF) option(MRTRIX_USE_SYSTEM_GTEST "Use system-installed Google Test library" OFF) +option(MRTRIX_USE_SYSTEM_DAWN "Use system-installed Dawn library" OFF) +option(MRTRIX_USE_SYSTEM_SLANG "Use system-installed Slang library" OFF) +option(MRTRIX_USE_SYSTEM_TCB_SPAN "Use system-installed TCB Span library" OFF) if(MRTRIX_BUILD_TESTS) list(APPEND CMAKE_CTEST_ARGUMENTS "--output-on-failure") @@ -68,6 +72,10 @@ file(RELATIVE_PATH relDir set(CMAKE_INSTALL_RPATH ${base} ${base}/${relDir}) +if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") + cmake_policy(SET CMP0135 NEW) +endif() + include(BuildType) include(BuildInfo) include(LinkerSetup) diff --git a/README.md b/README.md index 1eaf3655b1..37fb25350c 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ You can address all *MRtrix3*-related queries there, using your GitHub or Google ## Quick install 1. Install dependencies by whichever means your system uses. - These include: CMake (>= 3.16), Python3, a C++ compiler with full C++17 support, + These include: CMake (>= 3.22), Python3, a C++ compiler with full C++17 support, Eigen (>=3.2.8), zlib, OpenGL (>=3.3), and Qt (>=5.5). 2. Clone Git repository and compile: diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 1c0910a9d2..ba40cee6bc 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -96,3 +96,137 @@ if(MRTRIX_BUILD_TESTS) FetchContent_MakeAvailable(googletest) endif() endif() + + +if(MRTRIX_ENABLE_GPU) + # Dawn + + # Threads (required by Dawn exported targets) + find_package(Threads REQUIRED) + + if(NOT MRTRIX_USE_SYSTEM_DAWN) + message(STATUS "Downloading prebuilt binaries for Dawn...") + include(FetchContent) + set(FETCHCONTENT_QUIET OFF) + + if(CMAKE_SYSTEM_NAME STREQUAL "Linux") + set(DAWN_PLATFORM "linux") + elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin") + set(DAWN_PLATFORM "macos") + elseif(CMAKE_SYSTEM_NAME STREQUAL "Windows") + set(DAWN_PLATFORM "windows-msys2") + else() + message(FATAL_ERROR "Unsupported platform: ${CMAKE_SYSTEM_NAME}") + endif() + + set(DAWN_VERSION 7495) + + + set(DAWN_BINARIES_URL_PREFIX + "https://github.com/mrtrix3/webgpu-dawn-binaries/releases/download/chromium") + set(DAWN_BINARIES_URL + ${DAWN_BINARIES_URL_PREFIX}-${DAWN_VERSION}/webgpu-dawn-chromium-${DAWN_VERSION}-${DAWN_PLATFORM}.zip + ) + + FetchContent_Declare( + dawn + DOWNLOAD_NO_PROGRESS 1 + URL ${DAWN_BINARIES_URL} + ) + FetchContent_MakeAvailable(dawn) + + # On Linux, Dawn prebuilt packages use lib64; others use lib + if(CMAKE_SYSTEM_NAME STREQUAL "Linux") + set(DAWN_LIB_DIR_NAME "lib64") + else() + set(DAWN_LIB_DIR_NAME "lib") + endif() + set( + Dawn_DIR + "${dawn_SOURCE_DIR}/${DAWN_LIB_DIR_NAME}/cmake/Dawn" + CACHE PATH "Folder containing DawnConfig.cmake" + FORCE + ) + set(FETCHCONTENT_QUIET ON) + endif() + + + # Slang + + find_package(slang QUIET) + + if(NOT MRTRIX_USE_SYSTEM_SLANG) + message(STATUS "Downloading prebuilt binaries for Slang...") + set(SLANG_VERSION "2025.22.1" CACHE STRING "Slang version to download from GitHub releases") + + if(CMAKE_SYSTEM_NAME STREQUAL "Linux") + set(SLANG_OS "linux") + elseif(APPLE) + set(SLANG_OS "macos") + else() + set(SLANG_OS "windows") + endif() + + if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64" OR CMAKE_SYSTEM_PROCESSOR MATCHES "arm64") + set(SLANG_ARCH "aarch64") + else() + set(SLANG_ARCH "x86_64") + endif() + + set(SLANG_SUBSTRING "-${SLANG_OS}-${SLANG_ARCH}") + + set(SLANG_DOWNLOAD_LINK + "https://github.com/shader-slang/slang/releases/download/v${SLANG_VERSION}/slang-${SLANG_VERSION}${SLANG_SUBSTRING}.zip" + ) + + message(STATUS "Downloading Slang ${SLANG_VERSION} (${SLANG_OS}/${SLANG_ARCH})...") + + FetchContent_Declare( + slang + DOWNLOAD_NO_PROGRESS 1 + URL ${SLANG_DOWNLOAD_LINK} + ) + FetchContent_MakeAvailable(slang) + + if(WIN32) + set(slang_DIR_PATH "${slang_SOURCE_DIR}/cmake") + else() + set(slang_DIR_PATH "${slang_SOURCE_DIR}/lib/cmake/slang") + endif() + + set( + slang_DIR + "${slang_DIR_PATH}" + CACHE PATH "Folder containing SlangConfig.cmake" + FORCE + ) + endif() +endif() + +# tcb::span +if(MRTRIX_USE_SYSTEM_TCB_SPAN) + find_path(TCB_SPAN_INCLUDE_DIR + NAMES tcb/span.hpp + PATHS /usr/include /usr/local/include + ) + if(NOT TCB_SPAN_INCLUDE_DIR) + message(FATAL_ERROR "Could not find tcb::span headers. Please install tcb::span or disable MRTRIX_USE_SYSTEM_TCB_SPAN.") + endif() + + add_library(tcb_span INTERFACE) + target_include_directories(tcb_span INTERFACE ${TCB_SPAN_INCLUDE_DIR}) + add_library(tcb::span ALIAS tcb_span) +else() + message(STATUS "Downloading tcb::span...") + + FetchContent_Populate( + tcb_span + GIT_REPOSITORY https://github.com/tcbrindle/span.git + GIT_TAG 836dc6a0efd9849cb194e88e4aa2387436bb079b + ) + + add_library(tcb_span INTERFACE) + target_include_directories(tcb_span INTERFACE ${tcb_span_SOURCE_DIR}/include) + add_library(tcb::span ALIAS tcb_span) +endif() + diff --git a/cpp/core/CMakeLists.txt b/cpp/core/CMakeLists.txt index fba0ab2123..28a6951372 100644 --- a/cpp/core/CMakeLists.txt +++ b/cpp/core/CMakeLists.txt @@ -21,6 +21,14 @@ endif() file(GLOB_RECURSE HEADLESS_SOURCES *.h *.cpp) + +if(MRTRIX_ENABLE_GPU) + find_package(Dawn CONFIG REQUIRED) + find_package(slang CONFIG REQUIRED) +else() + list(FILTER HEADLESS_SOURCES EXCLUDE REGEX "${CMAKE_CURRENT_LIST_DIR}/gpu/.*") +endif() + find_package(Git QUIET) # Create version target and library @@ -110,13 +118,27 @@ target_link_libraries(mrtrix-core PUBLIC Threads::Threads nlohmann_json::nlohmann_json nifti::nifti + tcb::span ) +if(MRTRIX_ENABLE_GPU) + target_link_libraries(mrtrix-core PUBLIC dawn::webgpu_dawn slang::slang) +endif() + # On Windows, the libraries need to be in the same directory as the executables if(WIN32) set_target_properties(mrtrix-core PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/bin ) + + add_custom_command(TARGET mrtrix-core POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + $ + $ + $ + COMMENT "Copying imported dependencies to output directory" + ) + endif() install(TARGETS mrtrix-core diff --git a/cpp/core/gpu/gpu.cpp b/cpp/core/gpu/gpu.cpp new file mode 100644 index 0000000000..fa52ff17c5 --- /dev/null +++ b/cpp/core/gpu/gpu.cpp @@ -0,0 +1,714 @@ +/* Copyright (c) 2008-2026 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#include "gpu.h" +#include "exception.h" +#include "image_helpers.h" +#include "match_variant.h" +#include "platform.h" +#include "slangcodegen.h" + +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace MR::GPU { + +struct SlangSessionInfo { + Slang::ComPtr globalSession; + Slang::ComPtr session; +}; + +constexpr auto GPUBackendType +#ifdef __APPLE__ + = wgpu::BackendType::Metal; +#else + = wgpu::BackendType::Vulkan; +#endif + +namespace { + +uint32_t next_multiple_of(const uint32_t value, const uint32_t multiple) { + if (value > std::numeric_limits::max() - multiple) { + return std::numeric_limits::max(); + } + return (value + multiple - 1) / multiple * multiple; +} + +uint32_t pixel_size_in_bytes(const TextureFormat format) { + switch (format) { + case TextureFormat::R32Float: + return 4; + case TextureFormat::RGBA32Float: + return 16; + default: + throw MR::Exception("Only R32Float and RGBA32Float textures are supported!"); + } +} + +wgpu::ShaderModule make_wgsl_shader_module(std::string_view name, std::string_view code, const wgpu::Device &device) { + wgpu::ShaderSourceWGSL wgsl; + wgsl.code = code; + + const wgpu::ShaderModuleDescriptor shader_module_descriptor{ + .nextInChain = &wgsl, + .label = name, + }; + + return device.CreateShaderModule(&shader_module_descriptor); +} + +wgpu::ShaderModule +make_spirv_shader_module(std::string_view name, tcb::span spirvCode, const wgpu::Device &device) { + wgpu::ShaderSourceSPIRV spirv; + spirv.codeSize = spirvCode.size_bytes(); + spirv.code = spirvCode.data(); + + const wgpu::ShaderModuleDescriptor shader_module_descriptor{ + .nextInChain = &spirv, + .label = name, + }; + + return device.CreateShaderModule(&shader_module_descriptor); +} + +wgpu::TextureFormat to_wgpu_format(const MR::GPU::TextureFormat &format) { + switch (format) { + case MR::GPU::TextureFormat::R32Float: + return wgpu::TextureFormat::R32Float; + case MR::GPU::TextureFormat::RGBA32Float: + return wgpu::TextureFormat::RGBA32Float; + default: + return wgpu::TextureFormat::Undefined; + }; +} + +wgpu::TextureUsage to_wgpu_usage(const MR::GPU::TextureUsage &usage) { + wgpu::TextureUsage textureUsage = + wgpu::TextureUsage::CopySrc | wgpu::TextureUsage::CopyDst | wgpu::TextureUsage::TextureBinding; + + if (usage.storage_binding) { + textureUsage |= wgpu::TextureUsage::StorageBinding; + } + if (usage.render_target) { + textureUsage |= wgpu::TextureUsage::RenderAttachment; + } + return textureUsage; +} + +std::future> request_slang_global_session_async() { + return SlangCodegen::request_slang_global_session_async(); +} +} // namespace + +ComputeContext::ComputeContext() : m_slang_session_info(std::make_unique()) { + // We request the creation of the slang global session asynchronously + // as it can take some time to complete. This allows the WebGPU instance + // and adapter to be created in parallel with the Slang global session. + auto slang_global_session_request = request_slang_global_session_async(); + + { + std::vector dawn_toggles{"allow_unsafe_apis", "enable_immediate_error_handling", "disable_robustness"}; + + // NOLINTNEXTLINE(concurrency-mt-unsafe) + const char *dawn_gpu_debug_env = std::getenv("MRTRIX_GPU_DEBUG_TRACE"); // check_syntax off + if (dawn_gpu_debug_env != nullptr && std::string(dawn_gpu_debug_env) == "1") { + dawn_toggles.emplace_back("dump_shaders"); + dawn_toggles.emplace_back("disable_symbol_renaming"); + } + + wgpu::DawnTogglesDescriptor dawn_toggles_desc; + dawn_toggles_desc.enabledToggles = dawn_toggles.data(); + dawn_toggles_desc.enabledToggleCount = dawn_toggles.size(); + + constexpr std::array instance_features = {wgpu::InstanceFeatureName::TimedWaitAny}; + const wgpu::InstanceDescriptor instance_descriptor{ + .nextInChain = nullptr, + .requiredFeatureCount = instance_features.size(), + .requiredFeatures = instance_features.data(), + }; + + const wgpu::Instance wgpu_instance = wgpu::CreateInstance(&instance_descriptor); + wgpu::Adapter wgpu_adapter; + + const wgpu::RequestAdapterOptions adapter_options{.powerPreference = wgpu::PowerPreference::HighPerformance, + .backendType = GPUBackendType}; + + struct RequestAdapterResult { + wgpu::RequestAdapterStatus status = wgpu::RequestAdapterStatus::Error; + wgpu::Adapter adapter = nullptr; + std::string message; + } request_adapter_result; + + const auto adapter_callback = [&request_adapter_result](wgpu::RequestAdapterStatus status, + wgpu::Adapter foundAdapter, + wgpu::StringView message) { + request_adapter_result = {status, std::move(foundAdapter), std::string(message)}; + }; + + const wgpu::Future adapter_request = + wgpu_instance.RequestAdapter(&adapter_options, wgpu::CallbackMode::WaitAnyOnly, adapter_callback); + const wgpu::WaitStatus wait_status = wgpu_instance.WaitAny(adapter_request, -1); + + if (wait_status == wgpu::WaitStatus::Success) { + if (request_adapter_result.status != wgpu::RequestAdapterStatus::Success) { + throw MR::Exception("Failed to get adapter: " + request_adapter_result.message); + } + } else { + throw MR::Exception("Failed to get adapter: wgpu::Instance::WaitAny failed"); + } + + wgpu_adapter = request_adapter_result.adapter; + const std::vector required_device_features = { + wgpu::FeatureName::R8UnormStorage, wgpu::FeatureName::Float32Filterable, wgpu::FeatureName::Subgroups}; + + wgpu::Limits supported_limits; + wgpu_adapter.GetLimits(&supported_limits); + + constexpr uint64_t desired_max_storage_buffer_binding_size = 1'073'741'824ULL; // 1 GiB + constexpr uint64_t desired_max_buffer_size = 1'073'741'824ULL; // 1 GiB + + const wgpu::Limits required_device_limits{ + .maxStorageTexturesPerShaderStage = 8, + .maxStorageBufferBindingSize = + std::min(desired_max_storage_buffer_binding_size, supported_limits.maxStorageBufferBindingSize), + .maxBufferSize = std::min(desired_max_buffer_size, supported_limits.maxBufferSize), + .maxComputeWorkgroupStorageSize = 32768, + .maxComputeInvocationsPerWorkgroup = 1024, + .maxComputeWorkgroupSizeX = 1024, + }; + + wgpu::DeviceDescriptor device_descriptor{}; + device_descriptor.nextInChain = &dawn_toggles_desc; + device_descriptor.requiredFeatures = required_device_features.data(); + device_descriptor.requiredFeatureCount = required_device_features.size(); + device_descriptor.requiredLimits = &required_device_limits; + + device_descriptor.SetDeviceLostCallback( + wgpu::CallbackMode::AllowSpontaneous, + [](const wgpu::Device &, wgpu::DeviceLostReason reason, wgpu::StringView message) { + if (reason != wgpu::DeviceLostReason::Destroyed) { + throw MR::Exception(std::string("GPU device lost: ") + message.data); + } + }); + device_descriptor.SetUncapturedErrorCallback( + [](const wgpu::Device &, wgpu::ErrorType type, wgpu::StringView message) { + (void)type; + FAIL("Uncaptured gpu error: " + std::string(message)); + throw MR::Exception("Uncaptured gpu error: " + std::string(message)); + }); + + m_instance = wgpu_instance; + m_adapter = wgpu_adapter; + m_device = wgpu_adapter.CreateDevice(&device_descriptor); + wgpu::AdapterInfo adapter_info; + wgpu_adapter.GetInfo(&adapter_info); + + wgpu::Limits device_limits; + m_device.GetLimits(&device_limits); + + m_device_info = DeviceInfo{.subgroup_min_size = adapter_info.subgroupMinSize, .limits = device_limits}; + } + m_slang_session_info->globalSession = std::move(slang_global_session_request.get()); + + const slang::TargetDesc target_desc{.format = SLANG_WGSL}; + + const auto executable_path = MR::Platform::get_executable_path(); + const std::string executable_dir_string = (std::filesystem::path(executable_path).parent_path() / "shaders").string(); + const char *executable_dir_cstr = executable_dir_string.c_str(); // check_syntax off + + std::vector slang_compiler_options; + { + const slang::CompilerOptionEntry uniformity_analysis_option = { + .name = slang::CompilerOptionName::ValidateUniformity, + .value = {.intValue0 = 1}, + }; + slang_compiler_options.push_back(uniformity_analysis_option); + } + + const slang::SessionDesc slang_session_desc{ + .targets = &target_desc, + .targetCount = 1, + .defaultMatrixLayoutMode = SLANG_MATRIX_LAYOUT_COLUMN_MAJOR, + .searchPaths = &executable_dir_cstr, + .searchPathCount = 1, + .compilerOptionEntries = slang_compiler_options.data(), + }; + + { + const SlangResult slang_res = m_slang_session_info->globalSession->createSession( + slang_session_desc, m_slang_session_info->session.writeRef()); + if (SLANG_FAILED(slang_res)) { + throw MR::Exception("Failed to create Slang session!"); + } + } +} + +ComputeContext &ComputeContext::operator=(ComputeContext &&) noexcept = default; +ComputeContext::ComputeContext(ComputeContext &&) noexcept = default; +ComputeContext::~ComputeContext() = default; + +std::future ComputeContext::request_async() { + return std::async(std::launch::async, []() { return ComputeContext(); }); +} + +void ComputeContext::copy_buffer_to_buffer(const BufferVariant &srcBuffer, + const BufferVariant &dstBuffer, + const BufferCopyInfo &info) const { + const auto extract_handle = [](const BufferVariant &buffer) -> wgpu::Buffer { + return MR::match_v(buffer, [](auto &&arg) { return arg.wgpu_handle; }); + }; + + const wgpu::Buffer src_handle = extract_handle(srcBuffer); + const wgpu::Buffer dst_handle = extract_handle(dstBuffer); + + assert(dst_handle.GetUsage() & wgpu::BufferUsage::CopyDst && + "Destination buffer must have CopyDst usage for copyBufferToBuffer"); + + const uint64_t src_size = src_handle.GetSize(); + const uint64_t dst_size = dst_handle.GetSize(); + + uint64_t final_byte_size = info.byteSize; + // If byteSize == 0, copy as much as possible from src->dst given offsets + if (info.byteSize == 0) { + if (info.srcOffset >= src_size || info.dstOffset >= dst_size) { + // Nothing to copy + return; + } + final_byte_size = static_cast(std::min(src_size - info.srcOffset, dst_size - info.dstOffset)); + } + + if (info.srcOffset + final_byte_size > src_size) { + throw MR::Exception("copyBufferToBuffer: source range out of bounds"); + } + if (info.dstOffset + final_byte_size > dst_size) { + throw MR::Exception("copyBufferToBuffer: destination range out of bounds"); + } + + if (final_byte_size == 0) { + return; + } + + const wgpu::CommandEncoder command_encoder = m_device.CreateCommandEncoder(); + command_encoder.CopyBufferToBuffer( + src_handle, static_cast(info.srcOffset), dst_handle, info.dstOffset, final_byte_size); + const wgpu::CommandBuffer command_buffer = command_encoder.Finish(); + m_device.GetQueue().Submit(1, &command_buffer); +} + +wgpu::Buffer ComputeContext::inner_new_empty_buffer(size_t byteSize, BufferType bufferType) const { + wgpu::BufferUsage buffer_usage = wgpu::BufferUsage::None; + size_t buffer_byte_size = byteSize; + + switch (bufferType) { + case BufferType::StorageBuffer: + buffer_usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage; + break; + case BufferType::UniformBuffer: + buffer_usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Uniform; + // Align buffer size to 16 bytes for uniform buffers + buffer_byte_size = next_multiple_of(byteSize, 16); + break; + }; + + const wgpu::BufferDescriptor buffer_descriptor{ + .usage = buffer_usage, + .size = buffer_byte_size, + }; + + return m_device.CreateBuffer(&buffer_descriptor); +} + +wgpu::Buffer ComputeContext::inner_new_buffer_from_host_memory(const void *srcMemory, + size_t srcByteSize, + BufferType bufferType) const { + const auto buffer = inner_new_empty_buffer(srcByteSize, bufferType); + inner_write_to_buffer(buffer, srcMemory, srcByteSize, 0); + return buffer; +} + +void ComputeContext::inner_download_buffer(const wgpu::Buffer &buffer, void *dstMemory, size_t dstByteSize) const { + assert(buffer.GetSize() == dstByteSize); + assert(dstByteSize % 4 == 0 && "Destination buffer size must be a multiple of 4 bytes"); + const wgpu::BufferDescriptor staging_buffer_descriptor{ + .usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, + .size = dstByteSize, + }; + + const wgpu::Buffer staging_buffer = m_device.CreateBuffer(&staging_buffer_descriptor); + const wgpu::CommandEncoder command_encoder = m_device.CreateCommandEncoder(); + command_encoder.CopyBufferToBuffer(buffer, 0, staging_buffer, 0, dstByteSize); + const wgpu::CommandBuffer command_buffer = command_encoder.Finish(); + m_device.GetQueue().Submit(1, &command_buffer); + + auto mapping_callback = [](wgpu::MapAsyncStatus status, wgpu::StringView message) { + if (status != wgpu::MapAsyncStatus::Success) { + throw MR::Exception("Failed to map buffer: " + std::string(message)); + } + }; + const wgpu::Future mapping_future = staging_buffer.MapAsync( + wgpu::MapMode::Read, 0, staging_buffer.GetSize(), wgpu::CallbackMode::WaitAnyOnly, mapping_callback); + const wgpu::WaitStatus wait_status = m_instance.WaitAny(mapping_future, std::numeric_limits::max()); + if (wait_status != wgpu::WaitStatus::Success) { + throw MR::Exception("Failed to map buffer to host memory: wgpu::Instance::WaitAny failed"); + } + + const void *mapped_data = staging_buffer.GetConstMappedRange(); + if (dstMemory != nullptr) { + std::memcpy(dstMemory, mapped_data, dstByteSize); + staging_buffer.Unmap(); + } else { + throw MR::Exception("Failed to map buffer to host memory: wgpu::Buffer::GetMappedRange returned nullptr"); + } +} + +void ComputeContext::inner_write_to_buffer(const wgpu::Buffer &buffer, + const void *data, + size_t srcByteSize, + uint64_t offset) const { + + // WebGPU requirement is that srcByteSize is a multiple of 4 + // See https://www.w3.org/TR/webgpu/#dom-gpuqueue-writebuffer + if ((offset & 3U) != 0U || (srcByteSize & 3U) != 0U) { + throw MR::Exception("Buffer writes require 4-byte aligned offset and size"); + } + if (buffer.GetUsage() & wgpu::BufferUsage::Uniform) { + const uint64_t align = m_device_info.limits.minUniformBufferOffsetAlignment; + if (align != 0 && (offset % align) != 0) { + const std::string min_offset_alignment = std::to_string(align); + throw MR::Exception("Uniform buffer offset must be aligned to minUniformBufferOffsetAlignment: " + + min_offset_alignment); + } + } + m_device.GetQueue().WriteBuffer(buffer, offset, data, srcByteSize); +} + +void ComputeContext::inner_clear_buffer(const wgpu::Buffer &buffer) const { + const wgpu::CommandEncoder command_encoder = m_device.CreateCommandEncoder(); + command_encoder.ClearBuffer(buffer); + const wgpu::CommandBuffer command_buffer = command_encoder.Finish(); + m_device.GetQueue().Submit(1, &command_buffer); +} + +Texture ComputeContext::new_empty_texture(const TextureSpec &textureSpec) const { + const wgpu::TextureDescriptor wgpu_texture_desc{.usage = to_wgpu_usage(textureSpec.usage), + .dimension = textureSpec.depth > 1 ? wgpu::TextureDimension::e3D + : wgpu::TextureDimension::e2D, + .size = {textureSpec.width, textureSpec.height, textureSpec.depth}, + .format = to_wgpu_format(textureSpec.format)}; + return {textureSpec, m_device.CreateTexture(&wgpu_texture_desc)}; +} + +Texture ComputeContext::new_texture_from_host_memory(const TextureSpec &texture_desc, + tcb::span src_memory_region) const { + const Texture texture = new_empty_texture(texture_desc); + const wgpu::TexelCopyTextureInfo image_copy_texture{.texture = texture.wgpu_handle}; + const wgpu::TexelCopyBufferLayout texture_data_layout{ + .bytesPerRow = texture_desc.width * pixel_size_in_bytes(texture_desc.format), + .rowsPerImage = texture_desc.height, + }; + + const wgpu::Extent3D texture_size{texture_desc.width, texture_desc.height, texture_desc.depth}; + m_device.GetQueue().WriteTexture(&image_copy_texture, + src_memory_region.data(), + src_memory_region.size_bytes(), + &texture_data_layout, + &texture_size); + return texture; +} + +Texture ComputeContext::new_texture_from_host_image(const MR::Image &image, const TextureUsage &usage) const { + const TextureSpec textureSpec = { + .width = static_cast(image.size(0)), + .height = static_cast(image.size(1)), + .depth = static_cast(image.size(2)), + .usage = usage, + }; + const auto image_size = MR::voxel_count(image); + return new_texture_from_host_memory(textureSpec, tcb::span(image.address(), image_size)); +} + +void ComputeContext::download_texture(const Texture &texture, tcb::span dst_memory_region) const { + const uint32_t components_per_texel = pixel_size_in_bytes(texture.spec.format) / sizeof(float); + assert(dst_memory_region.size() >= static_cast(texture.wgpu_handle.GetWidth()) * + texture.wgpu_handle.GetHeight() * texture.wgpu_handle.GetDepthOrArrayLayers() * + components_per_texel && + "Memory region size is too small for the texture"); + + const uint32_t bytes_per_row = + next_multiple_of(texture.wgpu_handle.GetWidth() * pixel_size_in_bytes(texture.spec.format), 256); + const size_t padded_data_size = static_cast(bytes_per_row) * texture.wgpu_handle.GetHeight() * + texture.wgpu_handle.GetDepthOrArrayLayers(); + const wgpu::CommandEncoder command_encoder = m_device.CreateCommandEncoder(); + const wgpu::BufferDescriptor staging_buffer_desc{ + .usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, + .size = padded_data_size, + }; + const wgpu::Buffer staging_buffer = m_device.CreateBuffer(&staging_buffer_desc); + + const wgpu::TexelCopyTextureInfo image_copy_texture{.texture = texture.wgpu_handle}; + const wgpu::TexelCopyBufferInfo image_copy_buffer{.layout = + wgpu::TexelCopyBufferLayout{ + .bytesPerRow = bytes_per_row, + .rowsPerImage = texture.wgpu_handle.GetHeight(), + }, + .buffer = staging_buffer}; + + const wgpu::Extent3D image_copy_size{ + .width = texture.wgpu_handle.GetWidth(), + .height = texture.wgpu_handle.GetHeight(), + .depthOrArrayLayers = texture.wgpu_handle.GetDepthOrArrayLayers(), + }; + command_encoder.CopyTextureToBuffer(&image_copy_texture, &image_copy_buffer, &image_copy_size); + const wgpu::CommandBuffer command_buffer = command_encoder.Finish(); + m_device.GetQueue().Submit(1, &command_buffer); + + auto mapping_callback = [](wgpu::MapAsyncStatus status, wgpu::StringView message) { + if (status != wgpu::MapAsyncStatus::Success) { + throw MR::Exception("Failed to map buffer: " + std::string(message)); + } + }; + + const wgpu::Future mapping_future = staging_buffer.MapAsync( + wgpu::MapMode::Read, 0, staging_buffer.GetSize(), wgpu::CallbackMode::WaitAnyOnly, mapping_callback); + + const wgpu::WaitStatus wait_status = m_instance.WaitAny(mapping_future, -1); + if (wait_status != wgpu::WaitStatus::Success) { + throw MR::Exception("Failed to map buffer to host memory: wgpu::Instance::WaitAny failed"); + } + + const void *mapped_data = staging_buffer.GetConstMappedRange(); + + // Copy the unpadded data + if (mapped_data != nullptr) { + // This is amount of data we will copy per row + const size_t texture_width_in_floats = static_cast(texture.wgpu_handle.GetWidth()) * components_per_texel; + // The full distance in floats from the start of one row to the start of the next in the source GPU buffer + const size_t padded_row_width_in_floats = bytes_per_row / sizeof(float); + const size_t num_rows = + static_cast(texture.wgpu_handle.GetDepthOrArrayLayers()) * texture.wgpu_handle.GetHeight(); + + const tcb::span src_span(static_cast(mapped_data), + padded_row_width_in_floats * num_rows); + const tcb::span dst_span(dst_memory_region.data(), texture_width_in_floats * num_rows); + + for (size_t row = 0; row < num_rows; ++row) { + const auto row_src = src_span.subspan(row * padded_row_width_in_floats, texture_width_in_floats); + auto row_dst = dst_span.subspan(row * texture_width_in_floats, texture_width_in_floats); + // copy exactly 'textureWidthInFloats' texels + std::copy_n(row_src.begin(), texture_width_in_floats, row_dst.begin()); + } + } else { + throw MR::Exception("Failed to map buffer to host memory: wgpu::Buffer::GetMappedRange returned nullptr"); + } + + staging_buffer.Unmap(); +} + +Kernel ComputeContext::new_kernel(const KernelSpec &kernel_spec) const { + struct BindingEntries { + std::vector bind_group_entries; + std::vector bind_group_layout_entries; + + void add(const wgpu::BindGroupEntry &bindGroupEntry, const wgpu::BindGroupLayoutEntry &bindGroupLayoutEntry) { + bind_group_entries.push_back(bindGroupEntry); + bind_group_layout_entries.push_back(bindGroupLayoutEntry); + } + }; + BindingEntries binding_entries; + + const auto &slang_session = m_slang_session_info->session; + const auto compiled_kernel = + SlangCodegen::compile_kernel_code_to_wgsl(kernel_spec, slang_session.get(), m_shader_cache); + + const auto reflected_bindings_map = + SlangCodegen::reflect_bindings(compiled_kernel.linked_program->getLayout(), compiled_kernel.entry_point_name); + + const auto reflected_wg_size = + SlangCodegen::workgroup_size(compiled_kernel.linked_program->getLayout(), compiled_kernel.entry_point_name); + + for (const auto &[name, resource] : kernel_spec.bindings_map) { + auto it = reflected_bindings_map.find(name); + if (it == reflected_bindings_map.end()) { + throw MR::Exception("Slang reflection failed to find binding: " + name + " in " + + kernel_spec.compute_shader.name + " with entry point " + + kernel_spec.compute_shader.entryPoint); + } + + const auto &binding_info = it->second; + const uint32_t binding_index = binding_info.binding_index; + auto *type_layout = binding_info.layout->getTypeLayout(); + + const auto access = type_layout->getResourceAccess(); + MR::match_v( + resource, + // we can't capture structure bindings till C++20 + [&, name = name](const BufferVariant &buffer) { + DEBUG("Buffer binding: " + name); + auto binding_kind = type_layout->getKind(); + wgpu::BufferBindingType buffer_binding_type = wgpu::BufferBindingType::Undefined; + if (binding_kind == slang::TypeReflection::Kind::ConstantBuffer) { + buffer_binding_type = wgpu::BufferBindingType::Uniform; + } else if (binding_kind == slang::TypeReflection::Kind::Resource || + binding_kind == slang::TypeReflection::Kind::ShaderStorageBuffer) { + switch (access) { + case SLANG_RESOURCE_ACCESS_READ: + buffer_binding_type = wgpu::BufferBindingType::ReadOnlyStorage; + break; + case SLANG_RESOURCE_ACCESS_READ_WRITE: + buffer_binding_type = wgpu::BufferBindingType::Storage; + break; + default: + throw MR::Exception("Unsupported buffer access type for '" + name + "'"); + } + } else { + throw MR::Exception("Cannot determine WGPU buffer binding type for '" + name + + "'. Its Slang type kind is not a recognized buffer type."); + } + const wgpu::BindGroupLayoutEntry layout_entry{.binding = binding_index, + .visibility = wgpu::ShaderStage::Compute, + .buffer = {.type = buffer_binding_type}}; + const wgpu::BindGroupEntry bind_group_entry{ + .binding = layout_entry.binding, + .buffer = MR::match_v(buffer, [](auto &&arg) { return arg.wgpu_handle; })}; + binding_entries.add(bind_group_entry, layout_entry); + }, + [&, name = name](const Texture &texture) { + wgpu::BindGroupLayoutEntry layout_entry; + if (access == SLANG_RESOURCE_ACCESS_READ) { + layout_entry = {.binding = binding_index, + .visibility = wgpu::ShaderStage::Compute, + .texture = { + .sampleType = wgpu::TextureSampleType::Float, + .viewDimension = texture.wgpu_handle.GetDepthOrArrayLayers() > 1 + ? wgpu::TextureViewDimension::e3D + : wgpu::TextureViewDimension::e2D, + }}; + } else if (access == SLANG_RESOURCE_ACCESS_WRITE || access == SLANG_RESOURCE_ACCESS_READ_WRITE) { + layout_entry = {.binding = binding_index, + .visibility = wgpu::ShaderStage::Compute, + .storageTexture = {.access = access == SLANG_RESOURCE_ACCESS_WRITE + ? wgpu::StorageTextureAccess::WriteOnly + : wgpu::StorageTextureAccess::ReadWrite, + .format = texture.wgpu_handle.GetFormat(), + .viewDimension = texture.wgpu_handle.GetDepthOrArrayLayers() > 1 + ? wgpu::TextureViewDimension::e3D + : wgpu::TextureViewDimension::e2D}}; + } else { + throw MR::Exception("Unsupported texture access type for '" + name + "'"); + } + const wgpu::BindGroupEntry bind_group_entry{ + .binding = layout_entry.binding, + .textureView = texture.wgpu_handle.CreateView(), + }; + binding_entries.add(bind_group_entry, layout_entry); + }, + [&](const Sampler &sampler) { + const wgpu::BindGroupLayoutEntry layout_entry{ + .binding = binding_index, + .visibility = wgpu::ShaderStage::Compute, + .sampler = {.type = wgpu::SamplerBindingType::Filtering}, + }; + + const wgpu::BindGroupEntry bind_group_entry{ + .binding = layout_entry.binding, + .sampler = sampler.wgpu_handle, + }; + binding_entries.add(bind_group_entry, layout_entry); + }); + } + + const auto layout_desc_label = kernel_spec.compute_shader.name + " layout descriptor"; + + const wgpu::BindGroupLayoutDescriptor bind_group_layout_desc{ + .label = layout_desc_label.c_str(), + .entryCount = binding_entries.bind_group_layout_entries.size(), + .entries = binding_entries.bind_group_layout_entries.data(), + }; + + const wgpu::BindGroupLayout bind_group_layout = m_device.CreateBindGroupLayout(&bind_group_layout_desc); + + const wgpu::PipelineLayoutDescriptor pipeline_layout_desc{ + .bindGroupLayoutCount = 1, + .bindGroupLayouts = &bind_group_layout, + }; + const wgpu::PipelineLayout pipeline_layout = m_device.CreatePipelineLayout(&pipeline_layout_desc); + + const std::string compute_pipeline_label = kernel_spec.compute_shader.name + " compute pipeline"; + const wgpu::ComputePipelineDescriptor compute_pipeline_desc{ + .label = compute_pipeline_label.c_str(), + .layout = pipeline_layout, + .compute = {.module = + make_wgsl_shader_module(kernel_spec.compute_shader.name, compiled_kernel.wgsl_source, m_device), + .entryPoint = compiled_kernel.entry_point_name.c_str()}}; + + const wgpu::BindGroupDescriptor bind_group_desc{ + .layout = bind_group_layout, + .entryCount = binding_entries.bind_group_entries.size(), + .entries = binding_entries.bind_group_entries.data(), + }; + + const WorkgroupSize wg_size = {.x = reflected_wg_size[0], .y = reflected_wg_size[1], .z = reflected_wg_size[2]}; + + return Kernel{.name = kernel_spec.compute_shader.name, + .pipeline = m_device.CreateComputePipeline(&compute_pipeline_desc), + .bind_group = m_device.CreateBindGroup(&bind_group_desc), + .shader_source = compiled_kernel.wgsl_source, + .workgroup_size = wg_size}; +} +void ComputeContext::dispatch_kernel(const Kernel &kernel, const DispatchGrid &dispatch_grid) const { + const wgpu::ComputePassDescriptor pass_desc{ + .label = kernel.name.c_str(), + }; + const wgpu::CommandEncoder command_encoder = m_device.CreateCommandEncoder(); + const wgpu::ComputePassEncoder compute_pass = command_encoder.BeginComputePass(&pass_desc); + compute_pass.SetPipeline(kernel.pipeline); + compute_pass.SetBindGroup(0, kernel.bind_group); + compute_pass.DispatchWorkgroups(dispatch_grid.x, dispatch_grid.y, dispatch_grid.z); + compute_pass.End(); + + const wgpu::CommandBuffer command_buffer = command_encoder.Finish(); + m_device.GetQueue().Submit(1, &command_buffer); +} + +Sampler ComputeContext::new_linear_sampler() const { + const wgpu::SamplerDescriptor sampler_desc{.magFilter = wgpu::FilterMode::Linear, + .minFilter = wgpu::FilterMode::Linear, + .mipmapFilter = wgpu::MipmapFilterMode::Linear, + .maxAnisotropy = 1}; + + return {Sampler::FilterMode::Linear, m_device.CreateSampler(&sampler_desc)}; +} + +} // namespace MR::GPU diff --git a/cpp/core/gpu/gpu.h b/cpp/core/gpu/gpu.h new file mode 100644 index 0000000000..de3ca91252 --- /dev/null +++ b/cpp/core/gpu/gpu.h @@ -0,0 +1,328 @@ +/* Copyright (c) 2008-2026 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#pragma once + +#include "image.h" +#include "match_variant.h" +#include "shadercache.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace MR::GPU { + +enum class BufferType : uint8_t { StorageBuffer, UniformBuffer }; + +template struct Buffer { + BufferType buffer_type = BufferType::StorageBuffer; + wgpu::Buffer wgpu_handle{}; + + uint64_t elementsCount() const { + assert(wgpu_handle.GetSize() % sizeof(T) == 0); + return wgpu_handle.GetSize() / sizeof(T); + } + uint64_t bytesSize() const { return wgpu_handle.GetSize(); } + + using element_t = std::remove_cv_t>; + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v, + "Buffer supports float, int32_t, uint32_t, std::byte."); +}; + +using BufferVariant = std::variant, Buffer, Buffer, Buffer>; + +struct TextureUsage { + bool storage_binding = false; + bool render_target = false; +}; + +enum class TextureFormat : uint8_t { R32Float, RGBA32Float }; + +struct TextureSpec { + uint32_t width = 0; + uint32_t height = 0; + uint32_t depth = 1; + TextureFormat format = TextureFormat::R32Float; + TextureUsage usage; +}; + +struct Texture { + TextureSpec spec; + wgpu::Texture wgpu_handle; +}; + +struct Sampler { + enum class FilterMode : uint8_t { Nearest, Linear }; + FilterMode filter_mode = FilterMode::Linear; + wgpu::Sampler wgpu_handle; +}; + +// A workgroup is a collection of threads that execute the same kernel +// function in parallel. Each thread within a workgroup can cooperate with others +// through shared memory. +struct WorkgroupSize { + uint32_t x = 1; + uint32_t y = 1; + uint32_t z = 1; + + // As a rule of thumb, for optimal performance across different hardware, the + // total number of threads in a workgroup should be a multiple of 64. + uint32_t threadCount() const { return x * y * z; } +}; + +// The dispatch grid defines the number of workgroups to be dispatched for a +// kernel. The total number of threads dispatched is the product of the number of +// workgroups in each dimension and the number of threads per workgroup. +struct DispatchGrid { + // Number of workgroups for each dimension. + uint32_t x = 1; + uint32_t y = 1; + uint32_t z = 1; + + uint32_t workgroup_count() const { return x * y * z; } + + // Given `workgroup_size`, the returned grid contains the number + // of workgroups per dimension so that at most one thread is dispatched + // per logical element (i.e. an injective, element-wise dispatch so that each + // element is processed by a single thread). + static DispatchGrid element_wise(const std::array &data_dimensions, const WorkgroupSize &workgroup_size) { + assert(workgroup_size.x > 0 && workgroup_size.y > 0 && workgroup_size.z > 0); + return {static_cast((data_dimensions[0] + workgroup_size.x - 1) / workgroup_size.x), + static_cast((data_dimensions[1] + workgroup_size.y - 1) / workgroup_size.y), + static_cast((data_dimensions[2] + workgroup_size.z - 1) / workgroup_size.z)}; + } + + // Convenience function for 3D textures. + static DispatchGrid element_wise_texture(const Texture &texture, const WorkgroupSize &workgroup_size) { + return element_wise({texture.spec.width, texture.spec.height, texture.spec.depth}, workgroup_size); + } +}; + +// Absolute/relative (to working dir) path of a Slang file. +struct ShaderFile { + std::filesystem::path file_path; +}; + +struct InlineShaderText { + std::string text; +}; + +using ShaderSource = std::variant; + +using ShaderConstantValue = std::variant; + +struct ShaderEntry { + ShaderSource shader_source; + + std::string entryPoint = "main"; + + std::string name = MR::match_v( + shader_source, + [](const ShaderFile &file) { return file.file_path.stem().string(); }, + [](const InlineShaderText &) { return std::string("inline_shader"); }); + + // Convenience property to set the kWorkgroupSizeX, kWorkgroupSizeY, and + // kWorkgroupSizeZ constants in the shader. These constant must be declared + // as extern static const in the shader code. + std::optional workgroup_size; + + using ShaderConstantMap = std::unordered_map; + // Link time constants to specialise the shader module. + // To use a constant in the shader code, declare it as extern static const. + ShaderConstantMap constants; + // Generic specialisation arguments for the shader entry point. + std::vector entry_point_args; +}; + +using ShaderBindingResource = std::variant; + +using ShaderBindingsMap = std::unordered_map; + +struct KernelSpec { + ShaderEntry compute_shader; + ShaderBindingsMap bindings_map; +}; + +struct Kernel { + std::string name; + wgpu::ComputePipeline pipeline; + wgpu::BindGroup bind_group; + // For debugging purposes, the shader source code is stored here. + std::string shader_source; + WorkgroupSize workgroup_size; +}; + +struct SlangSessionInfo; + +struct ComputeContext { + explicit ComputeContext(); + ComputeContext(const ComputeContext &) = delete; + ComputeContext &operator=(const ComputeContext &) = delete; + ComputeContext(ComputeContext &&) noexcept; + ComputeContext &operator=(ComputeContext &&) noexcept; + ~ComputeContext(); + + [[nodiscard]] static std::future request_async(); + + // NOTE: For all buffer creation and write operations, it's safe to discard + // the original data on the host side after the operation is complete as + // the data is internally copied to a staging buffer by Dawn's implementation. + template + [[nodiscard]] Buffer new_empty_buffer(size_t size, BufferType buffer_type = BufferType::StorageBuffer) const { + return {buffer_type, inner_new_empty_buffer(size * sizeof(T), buffer_type)}; + } + + template + [[nodiscard]] Buffer new_buffer_from_host_memory(std::initializer_list srcMemory, + BufferType bufferType = BufferType::StorageBuffer) const { + return new_buffer_from_host_memory(tcb::span(srcMemory), bufferType); + } + + template + [[nodiscard]] Buffer new_buffer_from_host_memory(tcb::span src_memory, + BufferType buffer_type = BufferType::StorageBuffer) const { + return {buffer_type, inner_new_buffer_from_host_memory(src_memory.data(), src_memory.size_bytes(), buffer_type)}; + } + + // Creates a GPU buffer by copying the raw bytes of a host-side object into device memory. + // Intended for uploading small POD-like structs that live on the stack. + template + [[nodiscard]] Buffer + new_buffer_from_host_object(const Object &object, BufferType buffer_type = BufferType::StorageBuffer) const { + static_assert(std::is_trivially_copyable_v, "Object must be trivially copyable"); + static_assert(std::is_standard_layout_v, "Object must be standard layout"); + return {buffer_type, inner_new_buffer_from_host_memory(&object, sizeof(object), buffer_type)}; + } + + template + [[nodiscard]] Buffer new_buffer_from_host_memory(const std::vector> &src_memory_regions, + BufferType bufferType = BufferType::StorageBuffer) const { + size_t totalBytes = 0; + for (const auto ®ion : src_memory_regions) + totalBytes += region.size_bytes(); + + auto buffer = inner_new_empty_buffer(totalBytes, bufferType); + uint64_t offset = 0; + for (const auto ®ion : src_memory_regions) { + inner_write_to_buffer(buffer, region.data(), region.size_bytes(), offset); + offset += region.size_bytes(); + } + return Buffer{bufferType, std::move(buffer)}; + } + + // This function blocks until the download is complete. + template [[nodiscard]] std::vector download_buffer_as_vector(const Buffer &buffer) const { + std::vector result(buffer.wgpu_handle.GetSize() / sizeof(T)); + download_buffer(buffer, result.data(), result.size() * sizeof(T)); + return result; + } + + // This function blocks until the download is complete. + template void download_buffer(const Buffer &buffer, tcb::span dst_memory_region) const { + download_buffer(buffer, dst_memory_region.data(), dst_memory_region.size_bytes()); + } + + // This function blocks until the download is complete. + template void download_buffer(const Buffer &buffer, void *data, size_t dst_byte_size) const { + inner_download_buffer(buffer.wgpu_handle, data, dst_byte_size); + } + + // Writes to the buffer at the specified offset. + template + void write_to_buffer(const Buffer &buffer, tcb::span src_memory_region, uint64_t offset = 0) const { + inner_write_to_buffer( + buffer.wgpu_handle, src_memory_region.data(), src_memory_region.size_bytes(), offset * sizeof(T)); + } + + // Copy bytes from a source buffer to a destination buffer. + // if byteSize is 0, the whole source buffer is copied. + struct BufferCopyInfo { + uint64_t srcOffset = 0; + uint64_t dstOffset = 0; + uint64_t byteSize = 0; + }; + void copy_buffer_to_buffer(const BufferVariant &srcBuffer, + const BufferVariant &dstBuffer, + const BufferCopyInfo &info) const; + + template void clear_buffer(const Buffer &buffer) const { + inner_clear_buffer(buffer.wgpu_handle); + } + + [[nodiscard]] Texture new_empty_texture(const TextureSpec &textureSpec) const; + + [[nodiscard]] Texture new_texture_from_host_memory(const TextureSpec &texture_desc, + tcb::span src_memory_region) const; + + [[nodiscard]] Texture new_texture_from_host_image(const MR::Image &image, + const TextureUsage &usage = {}) const; + + [[nodiscard]] Buffer new_buffer_from_host_image(const MR::Image &image, + BufferType bufferType = BufferType::StorageBuffer) const; + + // This function blocks until the download is complete. + void download_texture(const Texture &texture, tcb::span dst_memory_region) const; + + [[nodiscard]] Kernel new_kernel(const KernelSpec &kernel_spec) const; + + void dispatch_kernel(const Kernel &kernel, const DispatchGrid &dispatch_grid) const; + + [[nodiscard]] Sampler new_linear_sampler() const; + +private: + wgpu::Buffer inner_new_empty_buffer(size_t byteSize, BufferType bufferType = BufferType::StorageBuffer) const; + wgpu::Buffer inner_new_buffer_from_host_memory(const void *srcMemory, + size_t srcByteSize, + BufferType bufferType = BufferType::StorageBuffer) const; + void inner_download_buffer(const wgpu::Buffer &buffer, void *dstMemory, size_t dstByteSize) const; + void inner_write_to_buffer(const wgpu::Buffer &buffer, const void *data, size_t srcByteSize, uint64_t offset) const; + void inner_clear_buffer(const wgpu::Buffer &buffer) const; + + wgpu::Instance m_instance; + wgpu::Adapter m_adapter; + wgpu::Device m_device; + + struct DeviceInfo { + uint32_t subgroup_min_size = 0; + wgpu::Limits limits; + }; + + DeviceInfo m_device_info; + + std::unique_ptr m_slang_session_info; + + // Cache of Slang-compiled WGSL shaders + mutable ShaderCache m_shader_cache; +}; +} // namespace MR::GPU diff --git a/cpp/core/gpu/shadercache.cpp b/cpp/core/gpu/shadercache.cpp new file mode 100644 index 0000000000..555f8d1dd7 --- /dev/null +++ b/cpp/core/gpu/shadercache.cpp @@ -0,0 +1,27 @@ +/* Copyright (c) 2008-2026 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#include "shadercache.h" + +namespace MR::GPU { +bool ShaderCache::contains(const CacheKey &key) const { return m_cache.find(key) != m_cache.end(); } + +void ShaderCache::insert(const CacheKey &key, const CacheValue &value) { m_cache.insert_or_assign(key, value); } + +const ShaderCache::CacheValue &ShaderCache::get(const CacheKey &key) const { return m_cache.at(key); } + +void ShaderCache::clear() { m_cache.clear(); } +} // namespace MR::GPU diff --git a/cpp/core/gpu/shadercache.h b/cpp/core/gpu/shadercache.h new file mode 100644 index 0000000000..e16830b8d5 --- /dev/null +++ b/cpp/core/gpu/shadercache.h @@ -0,0 +1,43 @@ +/* Copyright (c) 2008-2026 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#pragma once + +#include +#include + +namespace MR::GPU { +class ShaderCache { +public: + using CacheKey = std::string; + using CacheValue = std::string; + using CacheMap = std::unordered_map; + + explicit ShaderCache() = default; + + bool contains(const CacheKey &key) const; + + void insert(const CacheKey &key, const CacheValue &value); + + const CacheValue &get(const CacheKey &key) const; + + void clear(); + +private: + CacheMap m_cache; +}; + +} // namespace MR::GPU diff --git a/cpp/core/gpu/slangcodegen.cpp b/cpp/core/gpu/slangcodegen.cpp new file mode 100644 index 0000000000..5680f4c92e --- /dev/null +++ b/cpp/core/gpu/slangcodegen.cpp @@ -0,0 +1,419 @@ +/* Copyright (c) 2008-2026 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#include "slangcodegen.h" + +#include "exception.h" +#include "gpu/gpu.h" +#include "match_variant.h" +#include "shadercache.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace MR::GPU::SlangCodegen { + +namespace { + +enum ReadFileMode : uint8_t { Text, Binary }; +std::string read_file(const std::filesystem::path &filePath, ReadFileMode mode = ReadFileMode::Text) { + using namespace std::string_literals; + if (!std::filesystem::exists(filePath)) { + throw std::runtime_error("File not found: "s + filePath.string()); + } + + const auto openMode = (mode == ReadFileMode::Binary) ? std::ios::in | std::ios::binary : std::ios::in; + std::ifstream f(filePath, std::ios::in | openMode); + const auto fileSize64 = std::filesystem::file_size(filePath); + if (fileSize64 > static_cast(std::numeric_limits::max())) { + throw std::runtime_error("File too large to read into memory: "s + filePath.string()); + } + const std::streamsize fileSize = static_cast(fileSize64); + std::string result(static_cast(fileSize), '\0'); + f.read(result.data(), fileSize); + + return result; +} + +std::string hash_string(std::string_view input) { + const std::hash hasher; + const size_t hashValue = hasher(std::string(input)); + return std::to_string(hashValue); +} + +void check_slang_result(SlangResult res, + std::string_view errorMessage = "", + const Slang::ComPtr &diagnostics = nullptr) { + if (SLANG_FAILED(res)) { + std::string full_error = "Slang Error: " + errorMessage; + if (diagnostics != nullptr) { + const std::string diag_string = + std::string(static_cast(diagnostics->getBufferPointer()), diagnostics->getBufferSize()); + if (!diag_string.empty()) { + full_error += "\nDiagnostics:\n" + diag_string; + } + } + throw SlangCodeGenException(full_error); + } +} + +void find_bindings_in_type_layout(slang::TypeLayoutReflection *typeLayout, + std::unordered_map &bindings); + +struct EntryPointSelection { + SlangUInt index = 0; + slang::EntryPointLayout *layout = nullptr; + std::string name; +}; + +EntryPointSelection select_entry_point(slang::ProgramLayout *programLayout, std::string_view requested_entry_point) { + if (programLayout == nullptr) { + throw SlangCodeGenException("Slang program layout is null!"); + } + + const SlangUInt entry_point_count = programLayout->getEntryPointCount(); + if (entry_point_count == 0) { + throw SlangCodeGenException("Slang program layout has no entry points!"); + } + + for (SlangUInt i = 0; i < entry_point_count; ++i) { + auto *entry_point_layout = programLayout->getEntryPointByIndex(i); + if (entry_point_layout == nullptr) { + continue; + } + + const char *const name_override = entry_point_layout->getNameOverride(); // check_syntax off + const char *const name = entry_point_layout->getName(); // check_syntax off + + const bool override_matches = (name_override != nullptr) && (requested_entry_point == name_override); + const bool name_matches = (name != nullptr) && (requested_entry_point == name); + + if (!override_matches && !name_matches) { + continue; + } + + std::string resolved_name; + if (name_override != nullptr) { + resolved_name = name_override; + } else if (name != nullptr) { + resolved_name = name; + } + return EntryPointSelection{.index = i, .layout = entry_point_layout, .name = resolved_name}; + } + + // Produce a human-readable list of available entry points + std::string available; + for (SlangUInt i = 0; i < entry_point_count; ++i) { + auto *entry_point_layout = programLayout->getEntryPointByIndex(i); + if (entry_point_layout == nullptr) { + continue; + } + const char *const name_override = entry_point_layout->getNameOverride(); // check_syntax off + const char *const name = entry_point_layout->getName(); // check_syntax off + const char *const resolved = (name_override != nullptr) ? name_override : name; // check_syntax off + if (resolved == nullptr) { + continue; + } + + if (!available.empty()) { + available += ", "; + } + available += resolved; + } + + throw SlangCodeGenException("Failed to find entry point '" + std::string(requested_entry_point) + + "' in linked Slang program layout. Available entry points: [" + available + "]"); +} + +void find_bindings_in_variable_layout(slang::VariableLayoutReflection *varLayout, + std::unordered_map &bindings) { + if (varLayout == nullptr) { + return; + } + + const char *var_name = varLayout->getName(); // check_syntax off + + if (var_name != nullptr) { + for (uint32_t i = 0; i < varLayout->getCategoryCount(); ++i) { + auto category = varLayout->getCategoryByIndex(i); + if (category == slang::ParameterCategory::DescriptorTableSlot) { + // This is a Texture, Buffer, or Sampler. + const ReflectedBindingInfo binding_info{.binding_index = static_cast(varLayout->getOffset(category)), + .layout = varLayout, + .category = category}; + + bindings[var_name] = binding_info; + break; // A variable can only have one slot binding. + } + } + } else { + // This is an anonymous variable (e.g., the element inside a ConstantBuffer). + // It doesn't have a name or binding itself, but we must traverse its type. + // We pass the parent's path along without modification. + find_bindings_in_type_layout(varLayout->getTypeLayout(), bindings); + } +} + +// Traverses the members of a type layout (like a struct or container). +void find_bindings_in_type_layout(slang::TypeLayoutReflection *typeLayout, + std::unordered_map &bindings) { + if (typeLayout == nullptr) + return; + + switch (typeLayout->getKind()) { + case slang::TypeReflection::Kind::Struct: { + // For a struct, iterate over its fields and process each one. + for (SlangUInt i = 0; i < typeLayout->getFieldCount(); ++i) { + find_bindings_in_variable_layout(typeLayout->getFieldByIndex(i), bindings); + } + break; + } + + case slang::TypeReflection::Kind::ConstantBuffer: + case slang::TypeReflection::Kind::ParameterBlock: { + // For a container, we get the layout of its contents. + // getElementVarLayout() is the key idiomatic call here. + slang::VariableLayoutReflection *element_layout = typeLayout->getElementVarLayout(); + find_bindings_in_variable_layout(element_layout, bindings); + break; + } + + default: + // Other types (Scalar, Vector, Array, etc.) don't contain resource bindings themselves. + break; + } +} + +} // anonymous namespace + +std::future> request_slang_global_session_async() { + auto r = std::async(std::launch::async, []() { + Slang::ComPtr global_session; + const SlangGlobalSessionDesc global_session_desc; + check_slang_result(createGlobalSession(&global_session_desc, global_session.writeRef()), + "Failed to create Slang global session!"); + return global_session; + }); + + return r; +} + +CompiledKernelWGSL compile_kernel_code_to_wgsl(const MR::GPU::KernelSpec &kernel_spec, + slang::ISession *session, + ShaderCache &shader_cache) { + Slang::ComPtr diagnostics; + Slang::ComPtr shader_module; + + auto log_diagnostics = [&diagnostics]() { + if (diagnostics != nullptr) { + const std::string diag_string = + std::string(static_cast(diagnostics->getBufferPointer()), diagnostics->getBufferSize()); + if (!diag_string.empty()) { + DEBUG("Slang diagnostics:\n" + diag_string); + } + } + }; + + MR::match_v( + kernel_spec.compute_shader.shader_source, + [&](const ShaderFile &shaderFile) { + const auto shader_path_string = shaderFile.file_path.string(); + const std::string module_name = std::filesystem::path(shader_path_string).stem().string(); + const auto shader_source = read_file(shaderFile.file_path); + shader_module = session->loadModuleFromSourceString( + module_name.c_str(), shader_path_string.c_str(), shader_source.c_str(), diagnostics.writeRef()); + }, + [&](const InlineShaderText &inlineString) { + const std::string path_string = "inline_" + hash_string(inlineString.text); + // Use the unique path string as the module name to prevent collisions between different inline shaders + // that might otherwise share the same default name. + shader_module = session->loadModuleFromSourceString( + path_string.c_str(), path_string.c_str(), inlineString.text.c_str(), diagnostics.writeRef()); + }); + log_diagnostics(); + if (shader_module == nullptr) { + throw SlangCodeGenException("Failed to load shader module: " + kernel_spec.compute_shader.name); + } + + Slang::ComPtr entry_point; + check_slang_result( + shader_module->findEntryPointByName(kernel_spec.compute_shader.entryPoint.c_str(), entry_point.writeRef()), + "Slang failed to findEntryPointByName", + diagnostics); + + const auto &generic_type_args = kernel_spec.compute_shader.entry_point_args; + + // Specialisation arguments + Slang::ComPtr specialized_entry_point; + { + std::vector slang_generic_args(generic_type_args.size()); + if (!generic_type_args.empty()) { + auto *program_layout = shader_module->getLayout(); + std::transform(generic_type_args.begin(), + generic_type_args.end(), + slang_generic_args.begin(), + [program_layout](const std::string &arg) { // check_syntax off + auto *const spec_type = program_layout->findTypeByName(arg.c_str()); + if (spec_type == nullptr) { + throw SlangCodeGenException("Failed to find specialization type: " + arg); + } + return slang::SpecializationArg{slang::SpecializationArg::Kind::Type, spec_type}; + }); + check_slang_result(entry_point->specialize(slang_generic_args.data(), + static_cast(slang_generic_args.size()), + specialized_entry_point.writeRef(), + diagnostics.writeRef()), + "Slang failed to specialise entry point", + diagnostics); + } + } + + Slang::ComPtr slang_program; + std::vector shader_components; + shader_components.push_back(shader_module.get()); + shader_components.push_back(generic_type_args.empty() ? entry_point.get() : specialized_entry_point.get()); + + if (kernel_spec.compute_shader.workgroup_size.has_value()) { + const auto &wg_size = *kernel_spec.compute_shader.workgroup_size; + std::ostringstream oss; + oss << "export static const uint kWorkgroupSizeX = " << wg_size.x << ";\n" + << "export static const uint kWorkgroupSizeY = " << wg_size.y << ";\n" + << "export static const uint kWorkgroupSizeZ = " << wg_size.z << ";\n"; + + const std::string workgroup_size_constants = oss.str(); + const std::string workgroup_size_constants_name = + "workgroup_size_constants_" + hash_string(workgroup_size_constants); + Slang::ComPtr workgroup_size_constants_module; + workgroup_size_constants_module = session->loadModuleFromSourceString(workgroup_size_constants_name.c_str(), + workgroup_size_constants_name.c_str(), + workgroup_size_constants.data(), + diagnostics.writeRef()); + shader_components.push_back(workgroup_size_constants_module.get()); + } + + if (!kernel_spec.compute_shader.constants.empty()) { + std::ostringstream oss; + for (const auto &[name, value] : kernel_spec.compute_shader.constants) { + MR::match_v( + value, + [&oss, name = name](int32_t v) { oss << "export static const int32_t " << name << " = " << v << ";\n"; }, + [&oss, name = name](uint32_t v) { oss << "export static const uint32_t " << name << " = " << v << ";\n"; }, + [&oss, name = name](float v) { oss << "export static const float " << name << " = " << v << ";\n"; }); + } + + const std::string constant_definitions = oss.str(); + const std::string constant_definitions_name = "constant_definitions_" + hash_string(constant_definitions); + Slang::ComPtr constant_definitions_module; + constant_definitions_module = session->loadModuleFromSourceString(constant_definitions_name.c_str(), + constant_definitions_name.c_str(), + constant_definitions.data(), + diagnostics.writeRef()); + shader_components.push_back(constant_definitions_module.get()); + } + + check_slang_result(session->createCompositeComponentType( + shader_components.data(), static_cast(shader_components.size()), slang_program.writeRef())); + + Slang::ComPtr linked_slang_program; + check_slang_result(slang_program->link(linked_slang_program.writeRef(), diagnostics.writeRef()), + "Slang failed to link program", + diagnostics); + + const auto entry_point_selection = + select_entry_point(linked_slang_program->getLayout(), kernel_spec.compute_shader.entryPoint); + const auto entry_point_index = static_cast(entry_point_selection.index); + + Slang::ComPtr hash_blob; + linked_slang_program->getEntryPointHash(entry_point_index, 0, hash_blob.writeRef()); + const std::string hash_key = + std::string(static_cast(hash_blob->getBufferPointer()), hash_blob->getBufferSize()); + + std::string wgsl_code; + Slang::ComPtr slang_kernel_blob; + if (shader_cache.contains(hash_key)) { + wgsl_code = shader_cache.get(hash_key); + } else { + check_slang_result(linked_slang_program->getEntryPointCode( + entry_point_index, 0, slang_kernel_blob.writeRef(), diagnostics.writeRef()), + "Slang failed to get entry point code", + diagnostics); + wgsl_code = std::string(static_cast(slang_kernel_blob->getBufferPointer()), + slang_kernel_blob->getBufferSize()); + shader_cache.insert(hash_key, wgsl_code); + } + + DEBUG(kernel_spec.compute_shader.name + " WGSL code:\n" + wgsl_code); + return CompiledKernelWGSL{ + .wgsl_source = wgsl_code, .linked_program = linked_slang_program, .entry_point_name = entry_point_selection.name}; +} + +std::unordered_map reflect_bindings(slang::ProgramLayout *program_layout, + std::string_view entry_point_name) { + std::unordered_map bindings_map; + const auto entry_point_selection = select_entry_point(program_layout, entry_point_name); + + auto *global_var_layout = program_layout->getGlobalParamsVarLayout(); + if (global_var_layout != nullptr) { + // If the program has global variables, we can find bindings in them. + find_bindings_in_variable_layout(global_var_layout, bindings_map); + } + + auto *entry_point_layout = entry_point_selection.layout; + + auto *entry_point_root_variable_layout = entry_point_layout->getVarLayout(); + if (entry_point_root_variable_layout == nullptr) { + // This can happen if the entry point has no uniform parameters. + return bindings_map; + } + + auto *entry_point_root_type_layout = entry_point_root_variable_layout->getTypeLayout(); + if (entry_point_root_type_layout == nullptr) { + throw SlangCodeGenException("Slang entry point variable layout has no type layout!"); + } + + find_bindings_in_variable_layout(entry_point_root_variable_layout, bindings_map); + return bindings_map; +} + +std::array workgroup_size(slang::ProgramLayout *program_layout, std::string_view entry_point_name) { + const auto entry_point_selection = select_entry_point(program_layout, entry_point_name); + auto *entry_point_layout = entry_point_selection.layout; + + std::array wg_size{}; + entry_point_layout->getComputeThreadGroupSize(3, wg_size.data()); + + return {static_cast(wg_size[0]), static_cast(wg_size[1]), static_cast(wg_size[2])}; +} + +} // namespace MR::GPU::SlangCodegen diff --git a/cpp/core/gpu/slangcodegen.h b/cpp/core/gpu/slangcodegen.h new file mode 100644 index 0000000000..7cbb5d17a8 --- /dev/null +++ b/cpp/core/gpu/slangcodegen.h @@ -0,0 +1,72 @@ +/* Copyright (c) 2008-2026 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#pragma once + +#include "exception.h" +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace MR::GPU { +struct KernelSpec; +class ShaderCache; +} // namespace MR::GPU + +namespace MR::GPU::SlangCodegen { + +struct ReflectedBindingInfo { + uint32_t binding_index = 0; + slang::VariableLayoutReflection *layout = nullptr; + slang::ParameterCategory category = slang::ParameterCategory::None; +}; + +struct CompiledKernelWGSL { + std::string wgsl_source; + Slang::ComPtr linked_program; + std::string entry_point_name; +}; + +struct SlangCodeGenException : public Exception { + explicit SlangCodeGenException(std::string_view message) + : Exception(std::string("Slang codegen error: ") + message.data()) {} +}; + +// Request a Slang global session asynchronously. +std::future> request_slang_global_session_async(); + +// Compile a Slang kernel to WGSL. +// Returns the WGSL source string, the linked component type for subsequent +// reflection, and the resolved entry point name for pipeline creation. +CompiledKernelWGSL compile_kernel_code_to_wgsl(const MR::GPU::KernelSpec &kernel_spec, + slang::ISession *session, + ShaderCache &shader_cache); + +// Reflect resource bindings from a linked Slang program layout. +// Produces a map from binding names to their binding index and layout details. +std::unordered_map reflect_bindings(slang::ProgramLayout *program_layout, + std::string_view entry_point_name); + +// Returns the workgroup size specified in the compiled Slang program layout. +std::array workgroup_size(slang::ProgramLayout *program_layout, std::string_view entry_point_name); +} // namespace MR::GPU::SlangCodegen diff --git a/testing/unit_tests/CMakeLists.txt b/testing/unit_tests/CMakeLists.txt index 0a3a5a95a9..099e87b213 100644 --- a/testing/unit_tests/CMakeLists.txt +++ b/testing/unit_tests/CMakeLists.txt @@ -19,6 +19,10 @@ set(UNIT_TESTS_CPP_SRCS to_tests.cpp ) +if(MRTRIX_ENABLE_GPU) + list(APPEND UNIT_TESTS_CPP_SRCS gputests.cpp) +endif() + get_filename_component(SOURCE_PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) set(DATA_DIR ${SOURCE_PARENT_DIR}/data) diff --git a/testing/unit_tests/gputests.cpp b/testing/unit_tests/gputests.cpp new file mode 100644 index 0000000000..a609a24645 --- /dev/null +++ b/testing/unit_tests/gputests.cpp @@ -0,0 +1,488 @@ +/* Copyright (c) 2008-2026 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#include +#include + +#include "exception.h" +#include "gpu/gpu.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace MR; +using namespace MR::GPU; + +class GPUTest : public ::testing::Test { +protected: + // Static pointer to the single, shared context for all tests in this suite. + inline static std::unique_ptr shared_context; + + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + MR::GPU::ComputeContext &context; + + GPUTest() : context(*shared_context) {} + + static void SetUpTestSuite() { + try { + shared_context = std::make_unique(); + } catch (const std::exception &e) { + std::cerr << "Skipping GPU tests: " << e.what() << std::endl; + GTEST_SKIP() << "Skipping GPU tests: " << e.what(); + } + ASSERT_NE(shared_context, nullptr) << "Failed to create shared GPU context."; + } + + static void TearDownTestSuite() { shared_context.reset(); } +}; + +TEST_F(GPUTest, MakeEmptyBuffer) { + const size_t buffer_element_count = 1024; + const Buffer buffer = context.new_empty_buffer(buffer_element_count); + + std::vector downloaded_data(buffer_element_count, 1); // Initialize with non-zero + + context.download_buffer(buffer, downloaded_data); + + for (auto val : downloaded_data) { + EXPECT_EQ(val, 0); + } +} + +TEST_F(GPUTest, BufferFromHostMemory) { + std::vector host_data = {1, 2, 3, 4, 5}; + + const Buffer buffer = context.new_buffer_from_host_memory(tcb::span(host_data)); + + std::vector downloaded_data(host_data.size(), 0); + context.download_buffer(buffer, downloaded_data); + + EXPECT_EQ(downloaded_data, host_data); +} + +TEST_F(GPUTest, BufferFromHostMemoryObject) { + struct Data { + float a; + float b; + float c; + }; + + const Data host_data{1.0F, 2.5F, -3.0F}; + const Buffer buffer = context.new_buffer_from_host_object(host_data); + + std::vector downloaded_bytes(sizeof(Data)); + context.download_buffer(buffer, downloaded_bytes); + + Data downloaded_data{}; + std::memcpy(&downloaded_data, downloaded_bytes.data(), sizeof(Data)); + + EXPECT_EQ(downloaded_data.a, host_data.a); + EXPECT_EQ(downloaded_data.b, host_data.b); + EXPECT_EQ(downloaded_data.c, host_data.c); +} + +TEST_F(GPUTest, BufferFromHostMemoryMultipleRegions) { + std::vector region1 = {1, 2, 3}; + std::vector region2 = {4, 5}; + std::vector region3 = {6, 7, 8, 9}; + + const std::vector> regions = {region1, region2, region3}; + const Buffer buffer = context.new_buffer_from_host_memory(regions); + + const std::vector expected_data = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector downloaded_data(expected_data.size()); + context.download_buffer(buffer, downloaded_data); + + EXPECT_EQ(downloaded_data, expected_data); +} + +TEST_F(GPUTest, WriteToBuffer) { + std::vector new_data = {0.1F, 0.2F, 0.3F, 0.4F}; + + const Buffer buffer = context.new_empty_buffer(new_data.size()); + std::vector downloaded_data(new_data.size(), 0.0F); + + context.write_to_buffer(buffer, new_data); + context.download_buffer(buffer, downloaded_data); + + for (size_t i = 0; i < new_data.size(); i++) { + EXPECT_FLOAT_EQ(downloaded_data[i], new_data[i]); + } +} + +TEST_F(GPUTest, WriteToBufferWithOffset) { + const size_t buffer_size = 10; + std::vector initial_data(buffer_size); + std::iota(initial_data.begin(), initial_data.end(), 0.0F); // 0, 1, ..., 9 + + const Buffer buffer = context.new_buffer_from_host_memory(initial_data); + + std::vector new_data = {100.0F, 101.0F, 102.0F}; + const uint64_t offset_size = static_cast(new_data.size()); + + context.write_to_buffer(buffer, new_data, offset_size); + + std::vector downloaded_data(buffer_size); + context.download_buffer(buffer, downloaded_data); + + std::vector expected_data = {0.0F, 1.0F, 2.0F, 100.0F, 101.0F, 102.0F, 6.0F, 7.0F, 8.0F, 9.0F}; + for (size_t i = 0; i < buffer_size; ++i) { + EXPECT_FLOAT_EQ(downloaded_data[i], expected_data[i]); + } +} + +TEST_F(GPUTest, EmptyTexture) { + const MR::GPU::TextureSpec textureSpec = { + .width = 4, + .height = 4, + .depth = 1, + .format = TextureFormat::R32Float, + }; + + const auto texture = context.new_empty_texture(textureSpec); + + const uint32_t bytes_per_pixel = 4; // R32Float + const size_t downloaded_size_bytes = + static_cast(textureSpec.width) * textureSpec.height * textureSpec.depth * bytes_per_pixel; + std::vector downloaded_data(downloaded_size_bytes / sizeof(float), 1.0F); // Init with non-zero + + context.download_texture(texture, downloaded_data); + + for (uint32_t z = 0; z < textureSpec.depth; ++z) { + for (uint32_t y = 0; y < textureSpec.height; ++y) { + for (uint32_t x = 0; x < textureSpec.width; ++x) { + const size_t idx = ((z * textureSpec.height + y) * textureSpec.width) + x; + EXPECT_FLOAT_EQ(downloaded_data[idx], 0.0F); + } + } + } +} + +TEST_F(GPUTest, KernelWithUniformBuffer) { + const std::string shaderCode = R"slang( + struct Params { + float scale; + float bias; + }; + + [shader("compute")] + [numthreads(1, 1, 1)] + void main( + uint32_t3 id : SV_DispatchThreadID, + RWStructuredBuffer data, + ConstantBuffer params + ){ + let idx = id.x; + uint32_t element_count, stride; + data.GetDimensions(element_count, stride); + if (idx < element_count) { + data[idx] = data[idx] * params.scale + params.bias; + } + } + )slang"; + + struct Params { + float scale = 0.0F; + float bias = 0.0F; + }; + + const Params params{.scale = 3.0F, .bias = 1.0F}; + const Buffer params_buffer = context.new_buffer_from_host_object(params, BufferType::UniformBuffer); + + const std::vector host_data = {1.0F, 2.0F, 3.0F, 4.0F}; + const std::vector expected_data = {4.0F, 7.0F, 10.0F, 13.0F}; + Buffer buffer = context.new_buffer_from_host_memory(host_data); + + const KernelSpec kernel_spec{ + .compute_shader = {.shader_source = InlineShaderText{shaderCode}}, + .bindings_map = {{"data", buffer}, {"params", params_buffer}}, + }; + const Kernel kernel = context.new_kernel(kernel_spec); + const DispatchGrid dispatch_grid = {static_cast(host_data.size()), 1, 1}; + context.dispatch_kernel(kernel, dispatch_grid); + + std::vector result_data(host_data.size()); + context.download_buffer(buffer, result_data); + EXPECT_EQ(result_data, expected_data); +} + +TEST_F(GPUTest, KernelWithInlineShader) { + const std::string shaderCode = R"slang( + [shader("compute")] + [numthreads(1, 1, 1)] + void main( + uint32_t3 id : SV_DispatchThreadID, + RWStructuredBuffer data + ){ + let idx = id.x; + uint element_count, stride; + data.GetDimensions(element_count, stride); + if (idx < element_count) { + data[idx] = data[idx] * 3.0; + } + } + )slang"; + + const std::vector host_data = {1.0F, 2.0F, 3.0F, 4.0F}; + const std::vector expected_data = {3.0F, 6.0F, 9.0F, 12.0F}; + Buffer buffer = context.new_buffer_from_host_memory(host_data); + const KernelSpec kernel_spec{ + .compute_shader = + { + .shader_source = InlineShaderText{shaderCode}, + }, + .bindings_map = {{"data", buffer}}, + }; + const Kernel kernel = context.new_kernel(kernel_spec); + const DispatchGrid dispatch_grid = {static_cast((host_data.size() + 63)), 1, 1}; + context.dispatch_kernel(kernel, dispatch_grid); + + std::vector result_data(host_data.size()); + context.download_buffer(buffer, result_data); + EXPECT_EQ(result_data, expected_data); +} + +TEST_F(GPUTest, ShaderConstants) { + const std::string shaderCode = R"slang( + extern const static uint32_t uConstantValue; + extern const static int32_t iConstantValue; + extern const static float fConstantValue; + + [shader("compute")] + [numthreads(1, 1, 1)] + void main( + uint32_t id : SV_DispatchThreadID, + RWStructuredBuffer floatBuffer, + RWStructuredBuffer uintBuffer, + RWStructuredBuffer intBuffer + ){ + floatBuffer[0] = fConstantValue; + uintBuffer[0] = uConstantValue; + intBuffer[0] = iConstantValue; + })slang"; + + const float f_constant_value = 3.14F; + const uint32_t u_constant_value = 42; + const int32_t i_constant_value = -7; + + Buffer float_buffer = context.new_empty_buffer(1); + Buffer uint_buffer = context.new_empty_buffer(1); + Buffer int_buffer = context.new_empty_buffer(1); + + const KernelSpec kernel_spec{ + .compute_shader = {.shader_source = InlineShaderText{shaderCode}, + .constants = {{"fConstantValue", f_constant_value}, + {"uConstantValue", u_constant_value}, + {"iConstantValue", i_constant_value}}}, + .bindings_map = {{"floatBuffer", float_buffer}, {"uintBuffer", uint_buffer}, {"intBuffer", int_buffer}}, + }; + + const Kernel kernel = context.new_kernel(kernel_spec); + const DispatchGrid dispatch_grid = {1, 1, 1}; + context.dispatch_kernel(kernel, dispatch_grid); + float downloaded_float = 0.0F; + uint32_t downloaded_uint = 0; + int32_t downloaded_int = 0; + + context.download_buffer(float_buffer, &downloaded_float, sizeof(float)); + context.download_buffer(uint_buffer, &downloaded_uint, sizeof(uint32_t)); + context.download_buffer(int_buffer, &downloaded_int, sizeof(int32_t)); + + EXPECT_FLOAT_EQ(downloaded_float, f_constant_value); + EXPECT_EQ(downloaded_uint, u_constant_value); + EXPECT_EQ(downloaded_int, i_constant_value); +} + +TEST_F(GPUTest, ShaderEntryPointArgs) { + const std::string shaderCode = R"slang( + interface IOperation { + float execute(float a, float b); + } + + struct Add : IOperation { + float execute(float a, float b) { return a + b; } + } + + struct Multiply : IOperation { + float execute(float a, float b) { return a * b; } + } + + [shader("compute")] + [numthreads(1, 1, 1)] + void main( + uint32_t3 id : SV_DispatchThreadID, + RWStructuredBuffer data + ){ + let idx = id.x; + let op = Op(); + data[idx] = op.execute(data[idx], 2.0); + } + )slang"; + + // Test Add + { + const std::vector host_data = {1.0F, 2.0F, 3.0F}; + const std::vector expected_data = {3.0F, 4.0F, 5.0F}; // + 2.0 + Buffer buffer = context.new_buffer_from_host_memory(host_data); + + const KernelSpec kernel_spec{ + .compute_shader = {.shader_source = InlineShaderText{shaderCode}, .entry_point_args = {"Add"}}, + .bindings_map = {{"data", buffer}}, + }; + const Kernel kernel = context.new_kernel(kernel_spec); + const DispatchGrid dispatch_grid = {static_cast(host_data.size()), 1, 1}; + context.dispatch_kernel(kernel, dispatch_grid); + + std::vector result_data(host_data.size()); + context.download_buffer(buffer, result_data); + EXPECT_EQ(result_data, expected_data); + } + + // Test Multiply + { + const std::vector host_data = {1.0F, 2.0F, 3.0F}; + const std::vector expected_data = {2.0F, 4.0F, 6.0F}; // * 2.0 + Buffer buffer = context.new_buffer_from_host_memory(host_data); + + const KernelSpec kernel_spec{ + .compute_shader = {.shader_source = InlineShaderText{shaderCode}, .entry_point_args = {"Multiply"}}, + .bindings_map = {{"data", buffer}}, + }; + const Kernel kernel = context.new_kernel(kernel_spec); + const DispatchGrid dispatch_grid = {static_cast(host_data.size()), 1, 1}; + context.dispatch_kernel(kernel, dispatch_grid); + + std::vector result_data(host_data.size()); + context.download_buffer(buffer, result_data); + EXPECT_EQ(result_data, expected_data); + } +} + +TEST_F(GPUTest, CopyBufferToBuffer_Full) { + std::vector src_data = {1, 2, 3, 4, 5}; + + const Buffer src_buffer = context.new_buffer_from_host_memory(src_data); + const Buffer dst_buffer = context.new_empty_buffer(src_data.size()); + + // defaults to zeros -> byteSize == 0 means copy whole buffer + const ComputeContext::BufferCopyInfo info{.byteSize = 0}; + + context.copy_buffer_to_buffer(src_buffer, dst_buffer, info); + + std::vector downloaded_data(src_data.size()); + context.download_buffer(dst_buffer, downloaded_data); + + EXPECT_EQ(downloaded_data, src_data); +} + +TEST_F(GPUTest, CopyBufferToBuffer_Partial) { + std::vector src(10); + std::iota(src.begin(), src.end(), 0U); + + const Buffer src_buffer = context.new_buffer_from_host_memory(src); + Buffer dst_buffer = context.new_buffer_from_host_memory(src); + + const ComputeContext::BufferCopyInfo info{ + .srcOffset = 2 * sizeof(uint32_t), .dstOffset = 5 * sizeof(uint32_t), .byteSize = 3 * sizeof(uint32_t)}; + + context.copy_buffer_to_buffer(src_buffer, dst_buffer, info); + + std::vector downloaded_data(src.size()); + context.download_buffer(dst_buffer, downloaded_data); + + const size_t src_start = info.srcOffset / sizeof(uint32_t); + const size_t dst_start = info.dstOffset / sizeof(uint32_t); + const size_t count = info.byteSize / sizeof(uint32_t); + + for (size_t i = 0; i < count; ++i) { + EXPECT_EQ(downloaded_data[dst_start + i], src[src_start + i]); + } + + for (size_t i = 0; i < downloaded_data.size(); ++i) { + if (i < dst_start || i >= dst_start + count) { + EXPECT_EQ(downloaded_data[i], src[i]); + } + } +} + +TEST_F(GPUTest, CopyBufferToBuffer_SourceOutOfRangeThrows) { + std::vector src = {1, 2, 3}; + std::vector dst = {0, 0, 0}; + + const Buffer src_buffer = context.new_buffer_from_host_memory(src); + const Buffer dst_buffer = context.new_buffer_from_host_memory(dst); + + // set srcOffset near end and request more bytes than available + const ComputeContext::BufferCopyInfo info{ + .srcOffset = 2 * sizeof(uint32_t), // points to last element + .dstOffset = 0, + .byteSize = 2 * sizeof(uint32_t) // will read last + one beyond -> should throw + }; + EXPECT_THROW(context.copy_buffer_to_buffer(src_buffer, dst_buffer, info), Exception); +} + +TEST_F(GPUTest, CopyBufferToBuffer_DestinationOutOfRangeThrows) { + std::vector src = {10, 20, 30, 40}; + std::vector dst = {0, 0}; + + const Buffer src_buffer = context.new_buffer_from_host_memory(src); + const Buffer dst_buffer = context.new_buffer_from_host_memory(dst); + + const ComputeContext::BufferCopyInfo info{ + .srcOffset = 0, + .dstOffset = 1 * sizeof(uint32_t), // only room for one element, but we'll request more + .byteSize = 2 * sizeof(uint32_t) // exceeds dst size -> should throw + }; + EXPECT_THROW(context.copy_buffer_to_buffer(src_buffer, dst_buffer, info), Exception); +} + +TEST_F(GPUTest, ClearBuffer) { + std::vector data = {1.5F, -2.0F, 3.25F, 4.0F}; + + const Buffer buffer = context.new_buffer_from_host_memory(data); + + // Ensure buffer contains non-zero values initially + std::vector before(data.size(), 0.0F); + context.download_buffer(buffer, before); + for (auto v : before) + EXPECT_NE(v, 0.0F); + + context.clear_buffer(buffer); + + std::vector downloaded(data.size(), 1.0F); + context.download_buffer(buffer, downloaded); + for (auto val : downloaded) { + EXPECT_FLOAT_EQ(val, 0.0F); + } +} + +TEST_F(GPUTest, DownloadBufferAsVector) { + const std::vector host = {10, 20, 30, 40}; + + const Buffer buffer = context.new_buffer_from_host_memory(host); + + const std::vector downloaded = context.download_buffer_as_vector(buffer); + + EXPECT_EQ(downloaded, host); +}