From 9a75b571cad91db4be249fd790354f0173cba319 Mon Sep 17 00:00:00 2001 From: Daljit Singh Date: Wed, 25 Feb 2026 02:14:32 +0000 Subject: [PATCH 1/8] Add new function to download textures as an MR::Image --- cpp/core/gpu/gpu.cpp | 95 ++++++++++++++++++++++++++++++++++++++++++++ cpp/core/gpu/gpu.h | 14 +++++++ 2 files changed, 109 insertions(+) diff --git a/cpp/core/gpu/gpu.cpp b/cpp/core/gpu/gpu.cpp index fa52ff17c5..6b7b4dc685 100644 --- a/cpp/core/gpu/gpu.cpp +++ b/cpp/core/gpu/gpu.cpp @@ -15,6 +15,8 @@ */ #include "gpu.h" +#include "adapter/extract.h" +#include "algo/threaded_copy.h" #include "exception.h" #include "image_helpers.h" #include "match_variant.h" @@ -77,6 +79,17 @@ uint32_t pixel_size_in_bytes(const TextureFormat format) { } } +uint32_t texture_channels_count(const TextureFormat format) { + switch (format) { + case TextureFormat::R32Float: + return 1; + case TextureFormat::RGBA32Float: + return 4; + default: + throw MR::Exception("Only R32Float and RGBA32Float textures are supported!"); + } +} + wgpu::ShaderModule make_wgsl_shader_module(std::string_view name, std::string_view code, const wgpu::Device &device) { wgpu::ShaderSourceWGSL wgsl; wgsl.code = code; @@ -536,6 +549,88 @@ void ComputeContext::download_texture(const Texture &texture, tcb::span d staging_buffer.Unmap(); } +Image ComputeContext::download_texture_as_image(const Texture &texture, + const Header &header, + std::string_view label, + DownloadTextureAlphaMode alpha_mode) const { + Image image = Image::scratch(header, label); + const uint32_t texture_channels = texture_channels_count(texture.spec.format); + const uint32_t texture_dims = [&texture]() { + switch (texture.wgpu_handle.GetDimension()) { + case wgpu::TextureDimension::e2D: + return 2U; + case wgpu::TextureDimension::e3D: + return 3U; + default: + throw MR::Exception("Unsupported texture dimension"); + } + }(); + + const uint32_t header_dims = static_cast(header.ndim()); + const bool has_channel_axis = header_dims == (texture_dims + 1U); + if (texture_dims != header_dims && !has_channel_axis) { + throw MR::Exception("Texture dimension (" + std::to_string(texture_dims) + ") does not match header dimension (" + + std::to_string(header_dims) + ")"); + } + + const uint32_t texture_width = texture.spec.width; + const uint32_t texture_height = texture.spec.height; + const uint32_t texture_depth = texture.spec.depth; + + if (header.size(0) != static_cast(texture_width) || header.size(1) != static_cast(texture_height) || + (texture_dims == 3U && header.size(2) != static_cast(texture_depth))) { + throw MR::Exception("Header dimensions do not match texture size"); + } + + const uint32_t expected_channels = [&]() -> uint32_t { + switch (texture_channels) { + case 1U: + return 1U; + case 4U: + return alpha_mode == DownloadTextureAlphaMode::KeepAlpha ? 4U : 3U; + default: + throw MR::Exception("Unsupported texture channel count"); + } + }(); + + if (has_channel_axis) { + if (header.size(texture_dims) != static_cast(expected_channels)) { + throw MR::Exception("Header channel axis does not match expected channel count"); + } + } else if (expected_channels != 1U) { + throw MR::Exception("Header must include a channel axis for multi-channel textures"); + } + + const size_t voxel_count = static_cast(texture_width) * texture_height * texture_depth; + Header source_header(header); + if (has_channel_axis) { + source_header.size(texture_dims) = static_cast(texture_channels); + } + + // Force the strides to match the memory layout written by download_texture: + // Channel (if present) is fastest, then X, then Y, then Z. + for (size_t i = 0; i < texture_dims; ++i) { + source_header.stride(i) = has_channel_axis ? i + 2 : i + 1; + } + if (has_channel_axis) { + source_header.stride(texture_dims) = 1; + } + + Image source_image = Image::scratch(source_header); + + download_texture(texture, tcb::span(source_image.address(), voxel_count * texture_channels)); + + if (texture_channels == 4U && alpha_mode == DownloadTextureAlphaMode::IgnoreAlpha) { + const std::vector rgb_channels = {0U, 1U, 2U}; + Adapter::Extract1D source_rgb(source_image, texture_dims, rgb_channels); + threaded_copy(source_rgb, image); + } else { + threaded_copy(source_image, image); + } + + return image; +} + Kernel ComputeContext::new_kernel(const KernelSpec &kernel_spec) const { struct BindingEntries { std::vector bind_group_entries; diff --git a/cpp/core/gpu/gpu.h b/cpp/core/gpu/gpu.h index de3ca91252..dcbc6681a3 100644 --- a/cpp/core/gpu/gpu.h +++ b/cpp/core/gpu/gpu.h @@ -294,6 +294,20 @@ struct ComputeContext { // This function blocks until the download is complete. void download_texture(const Texture &texture, tcb::span dst_memory_region) const; + enum class DownloadTextureAlphaMode : uint8_t { + IgnoreAlpha, + KeepAlpha, + }; + + // This function blocks until the download is complete. + // The returned image will have the same strides as the provided header. + // The texture data is downloaded in a strict row-major format (Channels -> X -> Y -> Z) + // and then reshuffled to match the requested strides. + [[nodiscard]] Image download_texture_as_image(const Texture &texture, + const Header &header, + std::string_view label, + DownloadTextureAlphaMode alpha_mode) const; + [[nodiscard]] Kernel new_kernel(const KernelSpec &kernel_spec) const; void dispatch_kernel(const Kernel &kernel, const DispatchGrid &dispatch_grid) const; From 30663c91930bf9c2676c0e5acb842bb6327b3cdf Mon Sep 17 00:00:00 2001 From: Daljit Singh Date: Thu, 26 Feb 2026 13:44:48 +0000 Subject: [PATCH 2/8] Add comment for instructions for MacOS profiling --- cpp/core/gpu/gpu.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/core/gpu/gpu.cpp b/cpp/core/gpu/gpu.cpp index 6b7b4dc685..ca19f4e5f1 100644 --- a/cpp/core/gpu/gpu.cpp +++ b/cpp/core/gpu/gpu.cpp @@ -154,6 +154,8 @@ ComputeContext::ComputeContext() : m_slang_session_info(std::make_unique Date: Sat, 28 Feb 2026 15:19:05 +0000 Subject: [PATCH 3/8] Pack MR::Image into contiguous buffer before texture upload --- cpp/core/gpu/gpu.cpp | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/cpp/core/gpu/gpu.cpp b/cpp/core/gpu/gpu.cpp index ca19f4e5f1..6d63972699 100644 --- a/cpp/core/gpu/gpu.cpp +++ b/cpp/core/gpu/gpu.cpp @@ -471,7 +471,23 @@ Texture ComputeContext::new_texture_from_host_image(const MR::Image &imag .usage = usage, }; const auto image_size = MR::voxel_count(image); - return new_texture_from_host_memory(textureSpec, tcb::span(image.address(), image_size)); + // We need to pack the image data into a contiguous buffer in the layout expected by the GPU texture. + // TODO: we cannot rely on Image::with_direct_io() to do this packing for us + // See discussion at https://github.com/MRtrix3/mrtrix3/pull/3108 + std::vector contiguous_host_data(image_size, 0.0F); + auto source = image; + const size_t width = static_cast(source.size(0)); + const size_t height = static_cast(source.size(1)); + const auto pack_voxel = [&contiguous_host_data, width, height](auto &vox) { + const size_t x = static_cast(vox.index(0)); + const size_t y = static_cast(vox.index(1)); + const size_t z = static_cast(vox.index(2)); + const size_t linear_offset = x + width * (y + height * z); + contiguous_host_data[linear_offset] = vox.value(); + }; + ThreadedLoop(source, 0, 3).run(pack_voxel, source); + + return new_texture_from_host_memory(textureSpec, contiguous_host_data); } void ComputeContext::download_texture(const Texture &texture, tcb::span dst_memory_region) const { From 73b9b7e64f4667d6125b8e895a1f3aa27ab4aa47 Mon Sep 17 00:00:00 2001 From: Daljit Singh Date: Mon, 2 Mar 2026 22:33:20 +0000 Subject: [PATCH 4/8] Add support for read-write for RGBAF32 textures in shaders --- cpp/core/gpu/gpu.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cpp/core/gpu/gpu.cpp b/cpp/core/gpu/gpu.cpp index 6d63972699..d2689dc92b 100644 --- a/cpp/core/gpu/gpu.cpp +++ b/cpp/core/gpu/gpu.cpp @@ -206,7 +206,13 @@ ComputeContext::ComputeContext() : m_slang_session_info(std::make_unique required_device_features = { - wgpu::FeatureName::R8UnormStorage, wgpu::FeatureName::Float32Filterable, wgpu::FeatureName::Subgroups}; + wgpu::FeatureName::R8UnormStorage, + wgpu::FeatureName::Float32Filterable, + wgpu::FeatureName::Subgroups, + // Require for read-write support for RGBA32Float textures in shaders. + // Should be ubiquitous on modern desktop GPUs. + wgpu::FeatureName::TextureFormatsTier2, + }; wgpu::Limits supported_limits; wgpu_adapter.GetLimits(&supported_limits); From 036be6330da8c3f1baec5fb945cfa22f8739fcf8 Mon Sep 17 00:00:00 2001 From: Daljit Singh Date: Tue, 3 Mar 2026 17:55:01 +0000 Subject: [PATCH 5/8] Add function for GPU texture->texture copy --- cpp/core/gpu/gpu.cpp | 84 ++++++++++++++++++++++++++++++++++++++++++-- cpp/core/gpu/gpu.h | 15 ++++++++ 2 files changed, 96 insertions(+), 3 deletions(-) diff --git a/cpp/core/gpu/gpu.cpp b/cpp/core/gpu/gpu.cpp index d2689dc92b..e184861ea0 100644 --- a/cpp/core/gpu/gpu.cpp +++ b/cpp/core/gpu/gpu.cpp @@ -267,7 +267,12 @@ ComputeContext::ComputeContext() : m_slang_session_info(std::make_unique search_paths = {executable_dir_cstr, registration_dir_cstr}; std::vector slang_compiler_options; { @@ -282,8 +287,8 @@ ComputeContext::ComputeContext() : m_slang_session_info(std::make_unique= srcTexture.spec.width || info.srcY >= srcTexture.spec.height || + info.srcZ >= srcTexture.spec.depth || info.dstX >= dstTexture.spec.width || + info.dstY >= dstTexture.spec.height || info.dstZ >= dstTexture.spec.depth) { + return; + } + + final_width = std::min(srcTexture.spec.width - info.srcX, dstTexture.spec.width - info.dstX); + final_height = std::min(srcTexture.spec.height - info.srcY, dstTexture.spec.height - info.dstY); + final_depth = std::min(srcTexture.spec.depth - info.srcZ, dstTexture.spec.depth - info.dstZ); + } else if (info.width == 0 || info.height == 0 || info.depth == 0) { + throw MR::Exception("copyTextureToTexture: width, height and depth must all be non-zero unless all are zero"); + } + + if (final_width == 0 || final_height == 0 || final_depth == 0) { + return; + } + + const auto range_exceeds = [](const uint32_t offset, const uint32_t size, const uint32_t limit) { + return static_cast(offset) + size > limit; + }; + + if (range_exceeds(info.srcX, final_width, srcTexture.spec.width) || + range_exceeds(info.srcY, final_height, srcTexture.spec.height) || + range_exceeds(info.srcZ, final_depth, srcTexture.spec.depth)) { + throw MR::Exception("copyTextureToTexture: source range out of bounds"); + } + + if (range_exceeds(info.dstX, final_width, dstTexture.spec.width) || + range_exceeds(info.dstY, final_height, dstTexture.spec.height) || + range_exceeds(info.dstZ, final_depth, dstTexture.spec.depth)) { + throw MR::Exception("copyTextureToTexture: destination range out of bounds"); + } + + const wgpu::TexelCopyTextureInfo src_copy{ + .texture = srcTexture.wgpu_handle, + .mipLevel = 0, + .origin = {.x = info.srcX, .y = info.srcY, .z = info.srcZ}, + .aspect = wgpu::TextureAspect::All, + }; + const wgpu::TexelCopyTextureInfo dst_copy{ + .texture = dstTexture.wgpu_handle, + .mipLevel = 0, + .origin = {.x = info.dstX, .y = info.dstY, .z = info.dstZ}, + .aspect = wgpu::TextureAspect::All, + }; + const wgpu::Extent3D copy_size{ + .width = final_width, + .height = final_height, + .depthOrArrayLayers = final_depth, + }; + + const wgpu::CommandEncoder command_encoder = m_device.CreateCommandEncoder(); + command_encoder.CopyTextureToTexture(&src_copy, &dst_copy, ©_size); + const wgpu::CommandBuffer command_buffer = command_encoder.Finish(); + m_device.GetQueue().Submit(1, &command_buffer); +} + wgpu::Buffer ComputeContext::inner_new_empty_buffer(size_t byteSize, BufferType bufferType) const { wgpu::BufferUsage buffer_usage = wgpu::BufferUsage::None; size_t buffer_byte_size = byteSize; diff --git a/cpp/core/gpu/gpu.h b/cpp/core/gpu/gpu.h index dcbc6681a3..d7763687e3 100644 --- a/cpp/core/gpu/gpu.h +++ b/cpp/core/gpu/gpu.h @@ -276,6 +276,21 @@ struct ComputeContext { const BufferVariant &dstBuffer, const BufferCopyInfo &info) const; + // Copy a region from a source texture to a destination texture. + // if width, height and depth are all 0, as much as possible is copied. + struct TextureCopyInfo { + uint32_t srcX = 0; + uint32_t srcY = 0; + uint32_t srcZ = 0; + uint32_t dstX = 0; + uint32_t dstY = 0; + uint32_t dstZ = 0; + uint32_t width = 0; + uint32_t height = 0; + uint32_t depth = 0; + }; + void copy_texture_to_texture(const Texture &srcTexture, const Texture &dstTexture, const TextureCopyInfo &info) const; + template void clear_buffer(const Buffer &buffer) const { inner_clear_buffer(buffer.wgpu_handle); } From 89b7c073079ec7d8f7fcdc31e3f01c55e1c46166 Mon Sep 17 00:00:00 2001 From: Daljit Singh Date: Wed, 4 Mar 2026 01:51:47 +0000 Subject: [PATCH 6/8] Log GPU adapter information --- cpp/core/gpu/gpu.cpp | 59 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/cpp/core/gpu/gpu.cpp b/cpp/core/gpu/gpu.cpp index e184861ea0..b4d46ad407 100644 --- a/cpp/core/gpu/gpu.cpp +++ b/cpp/core/gpu/gpu.cpp @@ -140,6 +140,54 @@ wgpu::TextureUsage to_wgpu_usage(const MR::GPU::TextureUsage &usage) { return textureUsage; } +std::string to_string(const wgpu::StringView value) { + if (value.data == nullptr) { + return "unknown"; + } + if (value.length == wgpu::kStrlen) { + return std::string(value.data); + } + return std::string(value.data, value.length); +} + +std::string to_string(const wgpu::BackendType backend_type) { + switch (backend_type) { + case wgpu::BackendType::Null: + return "Null"; + case wgpu::BackendType::WebGPU: + return "WebGPU"; + case wgpu::BackendType::D3D11: + return "D3D11"; + case wgpu::BackendType::D3D12: + return "D3D12"; + case wgpu::BackendType::Metal: + return "Metal"; + case wgpu::BackendType::Vulkan: + return "Vulkan"; + case wgpu::BackendType::OpenGL: + return "OpenGL"; + case wgpu::BackendType::OpenGLES: + return "OpenGLES"; + case wgpu::BackendType::Undefined: + default: + return "Undefined"; + } +} + +std::string to_string(const wgpu::AdapterType adapter_type) { + switch (adapter_type) { + case wgpu::AdapterType::DiscreteGPU: + return "DiscreteGPU"; + case wgpu::AdapterType::IntegratedGPU: + return "IntegratedGPU"; + case wgpu::AdapterType::CPU: + return "CPU"; + case wgpu::AdapterType::Unknown: + default: + return "Unknown"; + } +} + std::future> request_slang_global_session_async() { return SlangCodegen::request_slang_global_session_async(); } @@ -259,6 +307,17 @@ ComputeContext::ComputeContext() : m_slang_session_info(std::make_uniqueglobalSession = std::move(slang_global_session_request.get()); From 3813447673ed282989f7ac4302d8625136be88c7 Mon Sep 17 00:00:00 2001 From: Daljit Singh Date: Fri, 20 Mar 2026 14:51:43 +0000 Subject: [PATCH 7/8] Add tests for copy_texture_to_texture --- testing/unit_tests/gputests.cpp | 144 ++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) diff --git a/testing/unit_tests/gputests.cpp b/testing/unit_tests/gputests.cpp index a609a24645..54dc43aa7c 100644 --- a/testing/unit_tests/gputests.cpp +++ b/testing/unit_tests/gputests.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -34,6 +35,30 @@ using namespace MR; using namespace MR::GPU; +namespace { +std::vector make_test_texture_data(uint32_t width, uint32_t height, uint32_t depth, uint32_t channels) { + const size_t voxel_count = static_cast(width) * height * depth; + std::vector data(voxel_count * channels, 0.0F); + + // Generate random data for each voxel and channel + std::mt19937 rng(42); // Fixed seed for reproducibility + std::uniform_real_distribution dist(0.0F, 1.0F); + + for (uint32_t z = 0U; z < depth; ++z) { + for (uint32_t y = 0U; y < height; ++y) { + for (uint32_t x = 0U; x < width; ++x) { + const size_t voxel_idx = (static_cast(z) * height + y) * width + x; + const size_t base = voxel_idx * channels; + for (uint32_t c = 0U; c < channels; ++c) { + data[base + c] = dist(rng); + } + } + } + } + return data; +} +} // namespace + class GPUTest : public ::testing::Test { protected: // Static pointer to the single, shared context for all tests in this suite. @@ -486,3 +511,122 @@ TEST_F(GPUTest, DownloadBufferAsVector) { EXPECT_EQ(downloaded, host); } + +TEST_F(GPUTest, CopyTextureToTextureFull) { + struct TestCase { + uint32_t channels = 0U; + TextureFormat format = TextureFormat::R32Float; + }; + + const std::array test_cases = {{ + {.channels = 1U, .format = TextureFormat::R32Float}, + {.channels = 4U, .format = TextureFormat::RGBA32Float}, + }}; + + for (const TestCase &test_case : test_cases) { + const uint32_t width = 4U; + const uint32_t height = 5U; + const uint32_t depth = 3U; + const std::vector src_data = make_test_texture_data(width, height, depth, test_case.channels); + const TextureSpec texture_spec{.width = width, .height = height, .depth = depth, .format = test_case.format}; + + const auto src_texture = context.new_texture_from_host_memory(texture_spec, src_data); + + // Copy to empty texture of same size and format + const auto dst_texture = context.new_empty_texture(texture_spec); + + const ComputeContext::TextureCopyInfo copy_info{.width = width, .height = height, .depth = depth}; + context.copy_texture_to_texture(src_texture, dst_texture, copy_info); + + std::vector downloaded_data(src_data.size(), 0.0F); + context.download_texture(dst_texture, downloaded_data); + + EXPECT_EQ(downloaded_data, src_data); + } +} + +TEST_F(GPUTest, CopyTextureToTextureWithOffsets) { + const uint32_t width = 5U; + const uint32_t height = 4U; + const uint32_t depth = 3U; + + const auto linear_index = [width, height](const uint32_t x, const uint32_t y, const uint32_t z) { + return static_cast(x) + + static_cast(width) * (static_cast(y) + static_cast(height) * z); + }; + + const std::vector src_data = make_test_texture_data(width, height, depth, 1U); + + const std::vector dst_initial(static_cast(width) * height * depth, -1.0F); + + const TextureSpec texture_spec{ + .width = width, + .height = height, + .depth = depth, + .format = TextureFormat::R32Float, + }; + + const Texture src_texture = context.new_texture_from_host_memory(texture_spec, src_data); + const Texture dst_texture = context.new_texture_from_host_memory(texture_spec, dst_initial); + + const ComputeContext::TextureCopyInfo copy_info{ + .srcX = 1U, + .srcY = 1U, + .srcZ = 1U, + .dstX = 2U, + .dstY = 0U, + .dstZ = 0U, + .width = 2U, + .height = 2U, + .depth = 2U, + }; + context.copy_texture_to_texture(src_texture, dst_texture, copy_info); + + std::vector downloaded_data(dst_initial.size(), 0.0F); + context.download_texture(dst_texture, downloaded_data); + + for (uint32_t z = 0U; z < depth; ++z) { + for (uint32_t y = 0U; y < height; ++y) { + for (uint32_t x = 0U; x < width; ++x) { + const bool is_within_dst_copy_region = x >= copy_info.dstX && x < copy_info.dstX + copy_info.width && + y >= copy_info.dstY && y < copy_info.dstY + copy_info.height && + z >= copy_info.dstZ && z < copy_info.dstZ + copy_info.depth; + + if (is_within_dst_copy_region) { + const uint32_t src_x = copy_info.srcX + (x - copy_info.dstX); + const uint32_t src_y = copy_info.srcY + (y - copy_info.dstY); + const uint32_t src_z = copy_info.srcZ + (z - copy_info.dstZ); + EXPECT_FLOAT_EQ(downloaded_data[linear_index(x, y, z)], src_data[linear_index(src_x, src_y, src_z)]); + } else { + EXPECT_FLOAT_EQ(downloaded_data[linear_index(x, y, z)], dst_initial[linear_index(x, y, z)]); + } + } + } + } +} + +TEST_F(GPUTest, CopyTextureToTextureWithOffsetsOutOfRangeThrows) { + const TextureSpec texture_spec{ + .width = 4U, + .height = 4U, + .depth = 2U, + .format = TextureFormat::R32Float, + }; + + const Texture src_texture = context.new_empty_texture(texture_spec); + const Texture dst_texture = context.new_empty_texture(texture_spec); + + const ComputeContext::TextureCopyInfo copy_info{ + .srcX = 3U, + .srcY = 1U, + .srcZ = 0U, + .dstX = 0U, + .dstY = 0U, + .dstZ = 0U, + .width = 2U, + .height = 1U, + .depth = 1U, + }; + + EXPECT_THROW(context.copy_texture_to_texture(src_texture, dst_texture, copy_info), Exception); +} From 9dfc383f0f345c002e4ad30c37910eced5266298 Mon Sep 17 00:00:00 2001 From: Daljit Singh Date: Fri, 20 Mar 2026 16:00:45 +0000 Subject: [PATCH 8/8] Add option to select GPU by using MRTRIX_GPU_ID --- cpp/core/gpu/gpu.cpp | 138 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 112 insertions(+), 26 deletions(-) diff --git a/cpp/core/gpu/gpu.cpp b/cpp/core/gpu/gpu.cpp index b4d46ad407..df35fea364 100644 --- a/cpp/core/gpu/gpu.cpp +++ b/cpp/core/gpu/gpu.cpp @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -34,12 +35,14 @@ #include #include #include +#include #include #include #include #include #include #include +#include #include #include #include @@ -191,6 +194,97 @@ std::string to_string(const wgpu::AdapterType adapter_type) { std::future> request_slang_global_session_async() { return SlangCodegen::request_slang_global_session_async(); } + +std::optional parse_gpu_adapter_index_env() { + // NOLINTNEXTLINE(concurrency-mt-unsafe) + const char *gpu_id_env = std::getenv("MRTRIX_GPU_ID"); // check_syntax off + if (gpu_id_env == nullptr) { + return std::nullopt; + } + + const std::string_view gpu_id_string(gpu_id_env); + uint32_t gpu_id = 0; + const auto [parsed_to, parse_error] = + std::from_chars(gpu_id_string.data(), gpu_id_string.data() + gpu_id_string.size(), gpu_id); + if (parse_error != std::errc() || parsed_to != gpu_id_string.data() + gpu_id_string.size()) { + throw MR::Exception("Invalid MRTRIX_GPU_ID value: '" + std::string(gpu_id_string) + + "'. Expected a non-negative integer adapter index."); + } + + return gpu_id; +} + +wgpu::Adapter request_default_adapter(const wgpu::Instance &instance, + const wgpu::RequestAdapterOptions &adapter_options) { + struct RequestAdapterResult { + wgpu::RequestAdapterStatus status = wgpu::RequestAdapterStatus::Error; + wgpu::Adapter adapter = nullptr; + std::string message; + } request_adapter_result; + + const auto adapter_callback = [&request_adapter_result](wgpu::RequestAdapterStatus status, + wgpu::Adapter found_adapter, + wgpu::StringView message) { + request_adapter_result = {status, std::move(found_adapter), std::string(message)}; + }; + + const wgpu::Future adapter_request = + instance.RequestAdapter(&adapter_options, wgpu::CallbackMode::WaitAnyOnly, adapter_callback); + const wgpu::WaitStatus wait_status = instance.WaitAny(adapter_request, -1); + + if (wait_status == wgpu::WaitStatus::Success) { + if (request_adapter_result.status != wgpu::RequestAdapterStatus::Success) { + throw MR::Exception("Failed to get adapter: " + request_adapter_result.message); + } + } else { + throw MR::Exception("Failed to get adapter: wgpu::Instance::WaitAny failed"); + } + + return request_adapter_result.adapter; +} + +struct SelectedAdapter { + wgpu::Instance instance = nullptr; + wgpu::Adapter adapter = nullptr; +}; + +void log_available_adapters(const std::vector &adapters) { + INFO("Available GPU adapters:"); + for (size_t adapter_index = 0; adapter_index < adapters.size(); ++adapter_index) { + const wgpu::Adapter adapter(adapters[adapter_index].Get()); + wgpu::AdapterInfo adapter_info; + adapter.GetInfo(&adapter_info); + + INFO(" [" + std::to_string(adapter_index) + "] " + to_string(adapter_info.description)); + INFO(" details: backend=" + to_string(adapter_info.backendType) + + ", type=" + to_string(adapter_info.adapterType) + ", vendor=" + to_string(adapter_info.vendor) + + ", architecture=" + to_string(adapter_info.architecture) + ", device=" + to_string(adapter_info.device)); + INFO(" identifiers: vendor_id=" + std::to_string(adapter_info.vendorID) + + ", device_id=" + std::to_string(adapter_info.deviceID)); + } +} + +SelectedAdapter request_adapter_by_index(const wgpu::InstanceDescriptor &instance_descriptor, + const wgpu::RequestAdapterOptions &adapter_options, + const uint32_t adapter_index) { + dawn::native::Instance dawn_instance(&instance_descriptor); + const std::vector adapters = dawn_instance.EnumerateAdapters(&adapter_options); + + if (adapters.empty()) { + throw MR::Exception("Failed to get adapter: no adapters available for the requested backend."); + } + + log_available_adapters(adapters); + if (adapter_index >= adapters.size()) { + throw MR::Exception("Invalid MRTRIX_GPU_ID value: " + std::to_string(adapter_index) + ". Found " + + std::to_string(adapters.size()) + " adapter(s)."); + } + + return { + .instance = wgpu::Instance(dawn_instance.Get()), + .adapter = wgpu::Adapter(adapters[adapter_index].Get()), + }; +} } // namespace ComputeContext::ComputeContext() : m_slang_session_info(std::make_unique()) { @@ -222,37 +316,29 @@ ComputeContext::ComputeContext() : m_slang_session_info(std::make_unique adapter_index_env = parse_gpu_adapter_index_env(); + if (adapter_index_env.has_value()) { + INFO("Selecting GPU adapter from MRTRIX_GPU_ID=" + std::to_string(adapter_index_env.value())); + const wgpu::RequestAdapterOptions adapter_options{ + .powerPreference = wgpu::PowerPreference::Undefined, + .backendType = GPUBackendType, + }; + const SelectedAdapter selected_adapter = + request_adapter_by_index(instance_descriptor, adapter_options, adapter_index_env.value()); + wgpu_instance = selected_adapter.instance; + wgpu_adapter = selected_adapter.adapter; } else { - throw MR::Exception("Failed to get adapter: wgpu::Instance::WaitAny failed"); + wgpu_instance = wgpu::CreateInstance(&instance_descriptor); + const wgpu::RequestAdapterOptions adapter_options{ + .powerPreference = wgpu::PowerPreference::HighPerformance, + .backendType = GPUBackendType, + }; + wgpu_adapter = request_default_adapter(wgpu_instance, adapter_options); } - wgpu_adapter = request_adapter_result.adapter; const std::vector required_device_features = { wgpu::FeatureName::R8UnormStorage, wgpu::FeatureName::Float32Filterable,