Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/webgpu/compute_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ class ComputeContext final {
return webgpu_context_.Run(*this, program);
}

//
// Get the execution provider.
//
inline const WebGpuExecutionProvider& GetExecutionProvider() const {
return ep_;
}

private:
WebGpuContext& webgpu_context_;
OpKernelContext& kernel_context_;
Expand Down
63 changes: 47 additions & 16 deletions onnxruntime/core/providers/webgpu/nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "core/providers/webgpu/nn/grouped_conv.h"
#include "core/providers/webgpu/webgpu_utils.h"
#include "core/providers/webgpu/math/matmul.h"
#include "core/providers/webgpu/weight_layout_transform.h"
#include "core/providers/webgpu/webgpu_execution_provider.h"

namespace onnxruntime {
namespace webgpu {
Expand All @@ -25,6 +27,37 @@ Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const Tens
return Transpose::DoTranspose(context, perm, reshaped_kernel, *transposed_kernel);
}

template <bool is_channels_last, bool is_fused>
Status Conv<is_channels_last, is_fused>::GetTransformedWeight(ComputeContext& context,
const Tensor* original_weight,
const Tensor*& transformed_weight) const {
// Return cached weight if already transformed
if (transformed_weight_) {
transformed_weight = transformed_weight_;
return Status::OK();
}

// If transformation was attempted but failed, return error
if (weight_transform_attempted_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Weight transformation previously failed");
}

weight_transform_attempted_ = true;

// Use the weight input name extracted during construction
ORT_ENFORCE(!weight_name_.empty(), "Weight input name must be available for transformation caching");

// Get cache from execution provider
auto& cache = context.GetExecutionProvider().GetWeightLayoutTransformCache();

// Transform weight to HWIO layout
ORT_RETURN_IF_ERROR(TransformWeightLayout(context, original_weight, weight_name_,
"hwio", cache, transformed_weight_));

transformed_weight = transformed_weight_;
return Status::OK();
}

template <bool is_channels_last, bool is_fused>
Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context) const {
bool has_bias = context.InputCount() > 2;
Expand Down Expand Up @@ -104,11 +137,11 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
auto pad1 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[1] : (pads[1] + pads[3] + auto_pad_adjust) / 2;
std::vector<uint32_t> updated_pads{pad0, pad1};
if (conv_attrs_.group > 1) {
Tensor transposed_kernel;
if (is_channels_last) {
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
inputs[1] = &transposed_kernel;
modified_input_output_shapes[1] = transposed_kernel.Shape();
const Tensor* kernel_to_use = nullptr;
ORT_RETURN_IF_ERROR(GetTransformedWeight(context, kernel, kernel_to_use));
inputs[1] = kernel_to_use;
modified_input_output_shapes[1] = kernel_to_use->Shape();
}
auto output_channels_per_group = output_channels / conv_attrs_.group;
auto components = static_cast<int>(is_channels_last && output_channels_per_group >= 4 ? GetMaxComponents(output_channels) : 1);
Expand Down Expand Up @@ -138,17 +171,16 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context

const auto same_size = is_channels_last && input_height == kernel_height && input_width == kernel_width && pads[0] == 0 && pads[1] == 0;
if (same_size || (kernel_height == 1 && kernel_width == 1 && pads[0] == 0 && pads[1] == 0 && strides[0] == 1 && strides[1] == 1)) {
Tensor transposed_kernel;
TensorShape input_reshape;
TensorShape kernel_reshape;
TensorShape matmul_output_shape;
std::vector<const Tensor*> matmul_inputs;
std::vector<TensorShape> matmul_input_reshapes;
if (is_channels_last) {
// Transpose weights

ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
inputs[1] = &transposed_kernel;
// Transform weights to HWIO layout (cached on first inference)
const Tensor* kernel_to_use = nullptr;
ORT_RETURN_IF_ERROR(GetTransformedWeight(context, kernel, kernel_to_use));
inputs[1] = kernel_to_use;
if (same_size) {
const auto shared_dim = input_height * input_width * input_channels;
input_reshape = TensorShape({1, batch, shared_dim});
Expand All @@ -160,7 +192,7 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
matmul_output_shape = TensorShape({batch, output_height * output_width, output_channels});
}
matmul_inputs.push_back(input);
matmul_inputs.push_back(&transposed_kernel);
matmul_inputs.push_back(kernel_to_use);
matmul_input_reshapes.push_back(input_reshape);
matmul_input_reshapes.push_back(kernel_reshape);
} else {
Expand Down Expand Up @@ -204,15 +236,14 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
return context.RunProgram(program);
}
}
// Transpose weights
Tensor transposed_kernel;
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
// Transpose weights - use cached transformation
const Tensor* kernel_to_use = nullptr;
ORT_RETURN_IF_ERROR(GetTransformedWeight(context, kernel, kernel_to_use));
inputs[1] = kernel_to_use;
modified_input_output_shapes[1] = kernel_to_use->Shape();
auto dim_a_outer = static_cast<uint32_t>(is_channels_last ? output_height * output_width : output_channels);
auto dim_b_outer = static_cast<uint32_t>(is_channels_last ? output_channels : output_height * output_width);
auto dim_inner = static_cast<uint32_t>(kernel_height * kernel_width * input_channels);
inputs[1] = &transposed_kernel;
TensorShape transposed_kernel_shape = transposed_kernel.Shape();
modified_input_output_shapes[1] = transposed_kernel.Shape();
Conv2dMMProgram conv2d_mm_program = CreateConv2dMMProgram(activation_, inputs, pads, strides, dilations, output, dim_a_outer, dim_b_outer, dim_inner, is_channels_last, modified_input_output_shapes);
return context.RunProgram(conv2d_mm_program);
}
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/providers/webgpu/nn/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,26 @@ class Conv : public WebGpuKernel {
if (is_fused) {
ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
}
// Extract weight input name (input index 1) for caching
const auto& input_defs = info.node().InputDefs();
if (input_defs.size() > 1 && input_defs[1]->Exists()) {
weight_name_ = input_defs[1]->Name();
}
}
Status ComputeInternal(ComputeContext& context) const override;

protected:
ConvAttributes conv_attrs_;
Activation activation_;
std::string weight_name_; // Name of weight input for cache key

// Cached transformed weight pointer (set on first inference)
mutable const Tensor* transformed_weight_ = nullptr;
mutable bool weight_transform_attempted_ = false;

// Get or create transformed weight (lazy transformation on first inference)
Status GetTransformedWeight(ComputeContext& context, const Tensor* original_weight,
const Tensor*& transformed_weight) const;
};

Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector<size_t>& perm);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,8 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
context_{context},
preferred_data_layout_{config.data_layout},
force_cpu_node_names_{std::move(config.force_cpu_node_names)},
enable_graph_capture_{config.enable_graph_capture} {
enable_graph_capture_{config.enable_graph_capture},
weight_layout_transform_cache_{std::make_unique<WeightLayoutTransformCache>()} {
// If graph capture is enabled, create a dedicated buffer manager for graph mode
if (enable_graph_capture_) {
// Create buffer manager for graph capture mode with appropriate cache modes
Expand Down Expand Up @@ -948,6 +949,12 @@ std::optional<bool> WebGpuExecutionProvider::ShouldConvertDataLayoutForOp(std::s
}

WebGpuExecutionProvider::~WebGpuExecutionProvider() {
// Clear weight transform cache before releasing WebGPU resources
// This ensures cached GPU tensors are freed while BufferManager is still valid
if (weight_layout_transform_cache_) {
weight_layout_transform_cache_->Clear();
}

// Release all resources associated with the captured graph
if (!captured_commands_.empty()) {
context_.ReleaseGraphResources(captured_commands_);
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "core/graph/constants.h"
#include "core/providers/providers.h"
#include "core/providers/webgpu/buffer_manager.h"
#include "core/providers/webgpu/weight_layout_transform_cache.h"

struct pthreadpool;
namespace onnxruntime {
Expand Down Expand Up @@ -85,6 +86,11 @@ class WebGpuExecutionProvider : public IExecutionProvider {
Status ReplayGraph(int graph_annotation_id) override;
webgpu::BufferManager& BufferManager() const;

// Get weight layout transform cache
webgpu::WeightLayoutTransformCache& GetWeightLayoutTransformCache() const {
return *weight_layout_transform_cache_;
}

private:
bool IsGraphCaptureAllowed() const;
void IncrementRegularRunCountBeforeGraphCapture();
Expand All @@ -105,6 +111,9 @@ class WebGpuExecutionProvider : public IExecutionProvider {

// Store captured commands directly in the EP instead of in WebGpuContext
std::vector<webgpu::CapturedCommandInfo> captured_commands_;

// Cache for transformed weights (e.g., OIHW -> HWIO)
std::unique_ptr<webgpu::WeightLayoutTransformCache> weight_layout_transform_cache_;
};

} // namespace onnxruntime
67 changes: 67 additions & 0 deletions onnxruntime/core/providers/webgpu/weight_layout_transform.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/weight_layout_transform.h"
#include "core/providers/webgpu/compute_context.h"
#include "core/providers/webgpu/weight_layout_transform_cache.h"
#include "core/providers/webgpu/nn/conv.h" // For TransposeKernel

namespace onnxruntime {
namespace webgpu {

Status TransformWeightLayout(
ComputeContext& context,
const Tensor* weight,
const std::string& weight_name,
const std::string& format_descriptor,
WeightLayoutTransformCache& cache,
/*out*/ const Tensor*& transformed_weight) {
// Check cache first
const auto* cached = cache.GetTransformedWeight(weight_name, format_descriptor);
if (cached != nullptr) {
transformed_weight = cached;
return Status::OK();
}

// Not in cache, need to transform

const auto& original_shape = weight->Shape();
auto num_dims = original_shape.NumDimensions();

// Dispatch transformation based on format
Tensor output_tensor;
if (format_descriptor == "hwio") {
// For 3D tensors, extend to 4D before transposing
TensorShape input_shape_for_transpose = original_shape;
if (num_dims == 3) {
// Extend OIW [O, I, W] to OIHW [O, I, 1, W]
TensorShapeVector extended_shape = original_shape.AsShapeVector();
extended_shape.insert(extended_shape.begin() + 2, 1); // Insert H=1 at position 2
input_shape_for_transpose = TensorShape(extended_shape);
}

// Use existing TransposeKernel: OIHW [O,I,H,W] -> HWIO [H,W,I,O]
// Permutation: [2, 3, 1, 0] means output[i] = input[perm[i]]
// TransposeKernel creates the output tensor internally
const InlinedVector<size_t> perm = {2, 3, 1, 0};
ORT_RETURN_IF_ERROR(TransposeKernel(context, weight, input_shape_for_transpose,
&output_tensor, perm));
} else {
// Add more format implementations here
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"Format not yet implemented: ", format_descriptor);
}

// Add to cache
cache.AddTransformedWeight(weight_name, format_descriptor, std::move(output_tensor));

// Return cached tensor
const auto* cached_result = cache.GetTransformedWeight(weight_name, format_descriptor);
ORT_ENFORCE(cached_result != nullptr, "Failed to cache transformed weight");
transformed_weight = cached_result;

return Status::OK();
}

} // namespace webgpu
} // namespace onnxruntime
27 changes: 27 additions & 0 deletions onnxruntime/core/providers/webgpu/weight_layout_transform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/common.h"
#include "core/framework/tensor.h"
#include <string>

namespace onnxruntime {
namespace webgpu {

class ComputeContext;
class WeightLayoutTransformCache;

// Transform weight tensor to specified format
// Returns the transformed tensor (either from cache or newly created)
Status TransformWeightLayout(
ComputeContext& context,
const Tensor* weight,
const std::string& weight_name,
const std::string& format_descriptor,
WeightLayoutTransformCache& cache,
/*out*/ const Tensor*& transformed_weight);

} // namespace webgpu
} // namespace onnxruntime
36 changes: 36 additions & 0 deletions onnxruntime/core/providers/webgpu/weight_layout_transform_cache.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/weight_layout_transform_cache.h"

namespace onnxruntime {
namespace webgpu {

const Tensor* WeightLayoutTransformCache::GetTransformedWeight(
const std::string& weight_name,
const std::string& format_descriptor) const {
std::lock_guard<std::mutex> lock(mutex_);
std::string cache_key = MakeCacheKey(weight_name, format_descriptor);
auto it = cache_.find(cache_key);
if (it != cache_.end()) {
return &it->second;
}
return nullptr;
}

void WeightLayoutTransformCache::AddTransformedWeight(
const std::string& weight_name,
const std::string& format_descriptor,
Tensor&& tensor) {
std::lock_guard<std::mutex> lock(mutex_);
std::string cache_key = MakeCacheKey(weight_name, format_descriptor);
cache_[cache_key] = std::move(tensor);
}

void WeightLayoutTransformCache::Clear() {
std::lock_guard<std::mutex> lock(mutex_);
cache_.clear();
}

} // namespace webgpu
} // namespace onnxruntime
45 changes: 45 additions & 0 deletions onnxruntime/core/providers/webgpu/weight_layout_transform_cache.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <unordered_map>
#include <mutex>
#include <string>
#include "core/framework/tensor.h"
#include "core/common/common.h"

namespace onnxruntime {
namespace webgpu {

// Cache manager for transformed weights
// Owned by WebGpuExecutionProvider
class WeightLayoutTransformCache {
public:
WeightLayoutTransformCache() = default;
~WeightLayoutTransformCache() = default;

// Get transformed weight from cache (nullptr if not found)
const Tensor* GetTransformedWeight(const std::string& weight_name,
const std::string& format_descriptor) const;

// Add transformed weight to cache
void AddTransformedWeight(const std::string& weight_name,
const std::string& format_descriptor,
Tensor&& tensor);

// Clear cache (must be called before BufferManager is destroyed)
void Clear();

private:
std::string MakeCacheKey(const std::string& weight_name,
const std::string& format) const {
return weight_name + ":" + format;
}

mutable std::mutex mutex_;
std::unordered_map<std::string, Tensor> cache_;
};

} // namespace webgpu
} // namespace onnxruntime