diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index bf32f60bf5667..510ac45e06e9e 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -139,6 +139,13 @@ class ComputeContext final { return webgpu_context_.Run(*this, program); } + // + // Get the execution provider. + // + inline const WebGpuExecutionProvider& GetExecutionProvider() const { + return ep_; + } + private: WebGpuContext& webgpu_context_; OpKernelContext& kernel_context_; diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index a2777979ae983..e64f8f931a57e 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -8,6 +8,8 @@ #include "core/providers/webgpu/nn/grouped_conv.h" #include "core/providers/webgpu/webgpu_utils.h" #include "core/providers/webgpu/math/matmul.h" +#include "core/providers/webgpu/weight_layout_transform.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" namespace onnxruntime { namespace webgpu { @@ -25,6 +27,37 @@ Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const Tens return Transpose::DoTranspose(context, perm, reshaped_kernel, *transposed_kernel); } +template +Status Conv::GetTransformedWeight(ComputeContext& context, + const Tensor* original_weight, + const Tensor*& transformed_weight) const { + // Return cached weight if already transformed + if (transformed_weight_) { + transformed_weight = transformed_weight_; + return Status::OK(); + } + + // If transformation was attempted but failed, return error + if (weight_transform_attempted_) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Weight transformation previously failed"); + } + + weight_transform_attempted_ = true; + + // Use the weight input name extracted during construction + ORT_ENFORCE(!weight_name_.empty(), "Weight input name must be available for transformation caching"); + + // Get cache from execution provider + auto& cache = context.GetExecutionProvider().GetWeightLayoutTransformCache(); + + // Transform weight to HWIO layout + ORT_RETURN_IF_ERROR(TransformWeightLayout(context, original_weight, weight_name_, + "hwio", cache, transformed_weight_)); + + transformed_weight = transformed_weight_; + return Status::OK(); +} + template Status Conv::ComputeInternal(ComputeContext& context) const { bool has_bias = context.InputCount() > 2; @@ -104,11 +137,11 @@ Status Conv::ComputeInternal(ComputeContext& context auto pad1 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[1] : (pads[1] + pads[3] + auto_pad_adjust) / 2; std::vector updated_pads{pad0, pad1}; if (conv_attrs_.group > 1) { - Tensor transposed_kernel; if (is_channels_last) { - ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm)); - inputs[1] = &transposed_kernel; - modified_input_output_shapes[1] = transposed_kernel.Shape(); + const Tensor* kernel_to_use = nullptr; + ORT_RETURN_IF_ERROR(GetTransformedWeight(context, kernel, kernel_to_use)); + inputs[1] = kernel_to_use; + modified_input_output_shapes[1] = kernel_to_use->Shape(); } auto output_channels_per_group = output_channels / conv_attrs_.group; auto components = static_cast(is_channels_last && output_channels_per_group >= 4 ? GetMaxComponents(output_channels) : 1); @@ -138,17 +171,16 @@ Status Conv::ComputeInternal(ComputeContext& context const auto same_size = is_channels_last && input_height == kernel_height && input_width == kernel_width && pads[0] == 0 && pads[1] == 0; if (same_size || (kernel_height == 1 && kernel_width == 1 && pads[0] == 0 && pads[1] == 0 && strides[0] == 1 && strides[1] == 1)) { - Tensor transposed_kernel; TensorShape input_reshape; TensorShape kernel_reshape; TensorShape matmul_output_shape; std::vector matmul_inputs; std::vector matmul_input_reshapes; if (is_channels_last) { - // Transpose weights - - ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm)); - inputs[1] = &transposed_kernel; + // Transform weights to HWIO layout (cached on first inference) + const Tensor* kernel_to_use = nullptr; + ORT_RETURN_IF_ERROR(GetTransformedWeight(context, kernel, kernel_to_use)); + inputs[1] = kernel_to_use; if (same_size) { const auto shared_dim = input_height * input_width * input_channels; input_reshape = TensorShape({1, batch, shared_dim}); @@ -160,7 +192,7 @@ Status Conv::ComputeInternal(ComputeContext& context matmul_output_shape = TensorShape({batch, output_height * output_width, output_channels}); } matmul_inputs.push_back(input); - matmul_inputs.push_back(&transposed_kernel); + matmul_inputs.push_back(kernel_to_use); matmul_input_reshapes.push_back(input_reshape); matmul_input_reshapes.push_back(kernel_reshape); } else { @@ -204,15 +236,14 @@ Status Conv::ComputeInternal(ComputeContext& context return context.RunProgram(program); } } - // Transpose weights - Tensor transposed_kernel; - ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm)); + // Transpose weights - use cached transformation + const Tensor* kernel_to_use = nullptr; + ORT_RETURN_IF_ERROR(GetTransformedWeight(context, kernel, kernel_to_use)); + inputs[1] = kernel_to_use; + modified_input_output_shapes[1] = kernel_to_use->Shape(); auto dim_a_outer = static_cast(is_channels_last ? output_height * output_width : output_channels); auto dim_b_outer = static_cast(is_channels_last ? output_channels : output_height * output_width); auto dim_inner = static_cast(kernel_height * kernel_width * input_channels); - inputs[1] = &transposed_kernel; - TensorShape transposed_kernel_shape = transposed_kernel.Shape(); - modified_input_output_shapes[1] = transposed_kernel.Shape(); Conv2dMMProgram conv2d_mm_program = CreateConv2dMMProgram(activation_, inputs, pads, strides, dilations, output, dim_a_outer, dim_b_outer, dim_inner, is_channels_last, modified_input_output_shapes); return context.RunProgram(conv2d_mm_program); } diff --git a/onnxruntime/core/providers/webgpu/nn/conv.h b/onnxruntime/core/providers/webgpu/nn/conv.h index cafaa272c0613..433fe4f83b075 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.h +++ b/onnxruntime/core/providers/webgpu/nn/conv.h @@ -20,12 +20,26 @@ class Conv : public WebGpuKernel { if (is_fused) { ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK()); } + // Extract weight input name (input index 1) for caching + const auto& input_defs = info.node().InputDefs(); + if (input_defs.size() > 1 && input_defs[1]->Exists()) { + weight_name_ = input_defs[1]->Name(); + } } Status ComputeInternal(ComputeContext& context) const override; protected: ConvAttributes conv_attrs_; Activation activation_; + std::string weight_name_; // Name of weight input for cache key + + // Cached transformed weight pointer (set on first inference) + mutable const Tensor* transformed_weight_ = nullptr; + mutable bool weight_transform_attempted_ = false; + + // Get or create transformed weight (lazy transformation on first inference) + Status GetTransformedWeight(ComputeContext& context, const Tensor* original_weight, + const Tensor*& transformed_weight) const; }; Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector& perm); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 3df194217933e..402adaa0c49c5 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -800,7 +800,8 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, context_{context}, preferred_data_layout_{config.data_layout}, force_cpu_node_names_{std::move(config.force_cpu_node_names)}, - enable_graph_capture_{config.enable_graph_capture} { + enable_graph_capture_{config.enable_graph_capture}, + weight_layout_transform_cache_{std::make_unique()} { // If graph capture is enabled, create a dedicated buffer manager for graph mode if (enable_graph_capture_) { // Create buffer manager for graph capture mode with appropriate cache modes @@ -948,6 +949,12 @@ std::optional WebGpuExecutionProvider::ShouldConvertDataLayoutForOp(std::s } WebGpuExecutionProvider::~WebGpuExecutionProvider() { + // Clear weight transform cache before releasing WebGPU resources + // This ensures cached GPU tensors are freed while BufferManager is still valid + if (weight_layout_transform_cache_) { + weight_layout_transform_cache_->Clear(); + } + // Release all resources associated with the captured graph if (!captured_commands_.empty()) { context_.ReleaseGraphResources(captured_commands_); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index a9282a028c803..36ee904ee5292 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -9,6 +9,7 @@ #include "core/graph/constants.h" #include "core/providers/providers.h" #include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/weight_layout_transform_cache.h" struct pthreadpool; namespace onnxruntime { @@ -85,6 +86,11 @@ class WebGpuExecutionProvider : public IExecutionProvider { Status ReplayGraph(int graph_annotation_id) override; webgpu::BufferManager& BufferManager() const; + // Get weight layout transform cache + webgpu::WeightLayoutTransformCache& GetWeightLayoutTransformCache() const { + return *weight_layout_transform_cache_; + } + private: bool IsGraphCaptureAllowed() const; void IncrementRegularRunCountBeforeGraphCapture(); @@ -105,6 +111,9 @@ class WebGpuExecutionProvider : public IExecutionProvider { // Store captured commands directly in the EP instead of in WebGpuContext std::vector captured_commands_; + + // Cache for transformed weights (e.g., OIHW -> HWIO) + std::unique_ptr weight_layout_transform_cache_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/weight_layout_transform.cc b/onnxruntime/core/providers/webgpu/weight_layout_transform.cc new file mode 100644 index 0000000000000..358fee4100c49 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/weight_layout_transform.cc @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/weight_layout_transform.h" +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/weight_layout_transform_cache.h" +#include "core/providers/webgpu/nn/conv.h" // For TransposeKernel + +namespace onnxruntime { +namespace webgpu { + +Status TransformWeightLayout( + ComputeContext& context, + const Tensor* weight, + const std::string& weight_name, + const std::string& format_descriptor, + WeightLayoutTransformCache& cache, + /*out*/ const Tensor*& transformed_weight) { + // Check cache first + const auto* cached = cache.GetTransformedWeight(weight_name, format_descriptor); + if (cached != nullptr) { + transformed_weight = cached; + return Status::OK(); + } + + // Not in cache, need to transform + + const auto& original_shape = weight->Shape(); + auto num_dims = original_shape.NumDimensions(); + + // Dispatch transformation based on format + Tensor output_tensor; + if (format_descriptor == "hwio") { + // For 3D tensors, extend to 4D before transposing + TensorShape input_shape_for_transpose = original_shape; + if (num_dims == 3) { + // Extend OIW [O, I, W] to OIHW [O, I, 1, W] + TensorShapeVector extended_shape = original_shape.AsShapeVector(); + extended_shape.insert(extended_shape.begin() + 2, 1); // Insert H=1 at position 2 + input_shape_for_transpose = TensorShape(extended_shape); + } + + // Use existing TransposeKernel: OIHW [O,I,H,W] -> HWIO [H,W,I,O] + // Permutation: [2, 3, 1, 0] means output[i] = input[perm[i]] + // TransposeKernel creates the output tensor internally + const InlinedVector perm = {2, 3, 1, 0}; + ORT_RETURN_IF_ERROR(TransposeKernel(context, weight, input_shape_for_transpose, + &output_tensor, perm)); + } else { + // Add more format implementations here + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "Format not yet implemented: ", format_descriptor); + } + + // Add to cache + cache.AddTransformedWeight(weight_name, format_descriptor, std::move(output_tensor)); + + // Return cached tensor + const auto* cached_result = cache.GetTransformedWeight(weight_name, format_descriptor); + ORT_ENFORCE(cached_result != nullptr, "Failed to cache transformed weight"); + transformed_weight = cached_result; + + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/weight_layout_transform.h b/onnxruntime/core/providers/webgpu/weight_layout_transform.h new file mode 100644 index 0000000000000..e3f44fd3f0922 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/weight_layout_transform.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/tensor.h" +#include + +namespace onnxruntime { +namespace webgpu { + +class ComputeContext; +class WeightLayoutTransformCache; + +// Transform weight tensor to specified format +// Returns the transformed tensor (either from cache or newly created) +Status TransformWeightLayout( + ComputeContext& context, + const Tensor* weight, + const std::string& weight_name, + const std::string& format_descriptor, + WeightLayoutTransformCache& cache, + /*out*/ const Tensor*& transformed_weight); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/weight_layout_transform_cache.cc b/onnxruntime/core/providers/webgpu/weight_layout_transform_cache.cc new file mode 100644 index 0000000000000..dfdef9f2eb8a3 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/weight_layout_transform_cache.cc @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/weight_layout_transform_cache.h" + +namespace onnxruntime { +namespace webgpu { + +const Tensor* WeightLayoutTransformCache::GetTransformedWeight( + const std::string& weight_name, + const std::string& format_descriptor) const { + std::lock_guard lock(mutex_); + std::string cache_key = MakeCacheKey(weight_name, format_descriptor); + auto it = cache_.find(cache_key); + if (it != cache_.end()) { + return &it->second; + } + return nullptr; +} + +void WeightLayoutTransformCache::AddTransformedWeight( + const std::string& weight_name, + const std::string& format_descriptor, + Tensor&& tensor) { + std::lock_guard lock(mutex_); + std::string cache_key = MakeCacheKey(weight_name, format_descriptor); + cache_[cache_key] = std::move(tensor); +} + +void WeightLayoutTransformCache::Clear() { + std::lock_guard lock(mutex_); + cache_.clear(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/weight_layout_transform_cache.h b/onnxruntime/core/providers/webgpu/weight_layout_transform_cache.h new file mode 100644 index 0000000000000..d2edfbe1dafa2 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/weight_layout_transform_cache.h @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include "core/framework/tensor.h" +#include "core/common/common.h" + +namespace onnxruntime { +namespace webgpu { + +// Cache manager for transformed weights +// Owned by WebGpuExecutionProvider +class WeightLayoutTransformCache { + public: + WeightLayoutTransformCache() = default; + ~WeightLayoutTransformCache() = default; + + // Get transformed weight from cache (nullptr if not found) + const Tensor* GetTransformedWeight(const std::string& weight_name, + const std::string& format_descriptor) const; + + // Add transformed weight to cache + void AddTransformedWeight(const std::string& weight_name, + const std::string& format_descriptor, + Tensor&& tensor); + + // Clear cache (must be called before BufferManager is destroyed) + void Clear(); + + private: + std::string MakeCacheKey(const std::string& weight_name, + const std::string& format) const { + return weight_name + ":" + format; + } + + mutable std::mutex mutex_; + std::unordered_map cache_; +}; + +} // namespace webgpu +} // namespace onnxruntime