Skip to content

Commit 2f0eb77

Browse files
fs-eireCopilot
andauthored
[webgpu] refactor initialization of WebGPU Context params (#26855)
### Description refactor initialization of WebGPU Context params The refactor: - makes all WebGPU options into ~~3~~ 2 classes: - `WebGpuExecutionProviderConfig`: configuration that passed to and stored inside the EP class. - ~~`WebGpuContextCreationParams`: configuration that passed to constructor of class `WebGpuContext`.~~ - ~~`WebGpuContextInitializationParams`: configuration that passed to `WebGpuContext::Initialize()`.~~ - `WebGpuContextConfig`: configuration that passed to construct and initialize `WebGpuContext`. - ensure all instance of the classes are created with default value initialized. - ensure all of the following happens only at one place: - setting default value - parse option - add `WebGpuContextFactory::DefaultContext` to allow "get or create" the default context. ### Motivation and Context - Make code more clean and consistent. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 38355ba commit 2f0eb77

File tree

4 files changed

+208
-246
lines changed

4 files changed

+208
-246
lines changed

onnxruntime/core/providers/webgpu/webgpu_context.cc

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@
3737
namespace onnxruntime {
3838
namespace webgpu {
3939

40-
void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture) {
41-
std::call_once(init_flag_, [this, &buffer_cache_config, backend_type, enable_pix_capture]() {
40+
void WebGpuContext::Initialize(const WebGpuContextConfig& config) {
41+
std::call_once(init_flag_, [this, &config]() {
4242
if (device_ == nullptr) {
4343
// Create wgpu::Adapter
4444
wgpu::RequestAdapterOptions req_adapter_options = {};
45-
req_adapter_options.backendType = static_cast<wgpu::BackendType>(backend_type);
46-
req_adapter_options.powerPreference = static_cast<wgpu::PowerPreference>(power_preference_);
45+
req_adapter_options.backendType = static_cast<wgpu::BackendType>(config.backend_type);
46+
req_adapter_options.powerPreference = static_cast<wgpu::PowerPreference>(config.power_preference);
4747

4848
#if !defined(__wasm__)
4949
auto enabled_adapter_toggles = GetEnabledAdapterToggles();
@@ -134,9 +134,9 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
134134

135135
// create buffer manager
136136
buffer_mgr_ = BufferManagerFactory::Create(*this,
137-
buffer_cache_config.storage.mode,
138-
buffer_cache_config.uniform.mode,
139-
buffer_cache_config.query_resolve.mode);
137+
config.buffer_cache_config.storage.mode,
138+
config.buffer_cache_config.uniform.mode,
139+
config.buffer_cache_config.query_resolve.mode);
140140

141141
// create initializer buffer manager. cache is always disabled for initializer buffer manager
142142
initializer_buffer_mgr_ = BufferManagerFactory::Create(*this,
@@ -161,7 +161,7 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
161161
} else {
162162
query_type_ = TimestampQueryType::None;
163163
}
164-
if (enable_pix_capture) {
164+
if (config.enable_pix_capture) {
165165
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
166166
// set pix frame generator
167167
pix_frame_generator_ = std::make_unique<WebGpuPIXFrameGenerator>(instance_,
@@ -979,15 +979,18 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co
979979
device,
980980
config.validation_mode,
981981
config.preserve_device,
982-
config.max_storage_buffer_binding_size,
983-
config.power_preference));
982+
config.max_storage_buffer_binding_size));
984983
it = contexts_.emplace(context_id, WebGpuContextFactory::WebGpuContextInfo{std::move(context), 0}).first;
985984
} else if (context_id != 0) {
986985
ORT_ENFORCE(it->second.context->instance_.Get() == instance &&
987986
it->second.context->device_.Get() == device,
988987
"WebGPU EP context ID ", context_id, " is already created with different WebGPU instance or device.");
989988
}
990989
it->second.ref_count++;
990+
991+
// perform initialization
992+
it->second.context->Initialize(config);
993+
991994
return *it->second.context;
992995
}
993996

@@ -1017,6 +1020,11 @@ void WebGpuContextFactory::Cleanup() {
10171020
default_instance_ = nullptr;
10181021
}
10191022

1023+
WebGpuContext& WebGpuContextFactory::DefaultContext() {
1024+
WebGpuContextConfig config{};
1025+
return WebGpuContextFactory::CreateContext(config);
1026+
}
1027+
10201028
void CleanupWebGpuContexts() {
10211029
WebGpuContextFactory::Cleanup();
10221030
}

onnxruntime/core/providers/webgpu/webgpu_context.h

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,26 +34,51 @@ struct CapturedCommandInfo {
3434
WGPUBuffer indirect_buffer; // WGPUBuffer for indirect dispatch, nullptr if not using indirect dispatch
3535
};
3636

37-
struct WebGpuContextConfig {
38-
int context_id;
39-
WGPUInstance instance;
40-
WGPUDevice device;
41-
const void* dawn_proc_table;
42-
ValidationMode validation_mode;
43-
bool preserve_device;
44-
uint64_t max_storage_buffer_binding_size;
45-
int power_preference;
46-
};
47-
4837
struct WebGpuBufferCacheConfig {
4938
struct ConfigEntry {
5039
BufferCacheMode mode;
51-
std::string config_string;
40+
std::string config_string; // preserved for customized configuration, eg. bucket sizes
41+
};
42+
ConfigEntry storage{BufferCacheMode::Bucket, {}};
43+
ConfigEntry uniform{BufferCacheMode::Simple, {}};
44+
ConfigEntry query_resolve{BufferCacheMode::Disabled, {}};
45+
ConfigEntry default_entry{BufferCacheMode::Disabled, {}};
46+
};
47+
48+
/// <summary>
49+
/// Represents the configuration options for creating a WebGpuContext.
50+
/// </summary>
51+
struct WebGpuContextConfig {
52+
int context_id{0};
53+
WGPUInstance instance{nullptr};
54+
WGPUDevice device{nullptr};
55+
const void* dawn_proc_table{nullptr};
56+
ValidationMode validation_mode{
57+
#ifndef NDEBUG
58+
webgpu::ValidationMode::Full // for debug build, enable full validation by default
59+
#else
60+
webgpu::ValidationMode::Basic // for release build, enable basic validation by default
61+
#endif // !NDEBUG
5262
};
53-
ConfigEntry storage;
54-
ConfigEntry uniform;
55-
ConfigEntry query_resolve;
56-
ConfigEntry default_entry;
63+
bool preserve_device{false};
64+
uint64_t max_storage_buffer_binding_size{0};
65+
WebGpuBufferCacheConfig buffer_cache_config{};
66+
int power_preference{static_cast<int>(WGPUPowerPreference_HighPerformance)};
67+
int backend_type{
68+
#ifdef _WIN32
69+
// Setup Windows default backend type based on the build configuration
70+
#if defined(DAWN_ENABLE_D3D12)
71+
static_cast<int>(WGPUBackendType_D3D12)
72+
#elif defined(DAWN_ENABLE_VULKAN)
73+
static_cast<int>(WGPUBackendType_Vulkan)
74+
#else
75+
0
76+
#endif
77+
#else
78+
0
79+
#endif
80+
};
81+
bool enable_pix_capture{false};
5782
};
5883

5984
class WebGpuContextFactory {
@@ -63,13 +88,28 @@ class WebGpuContextFactory {
6388
int ref_count;
6489
};
6590

91+
/// <summary>
92+
/// Create a new WebGPU context for the specified context ID if not present, or return the existing one. (ref-count based)
93+
/// </summary>
6694
static WebGpuContext& CreateContext(const WebGpuContextConfig& config);
95+
96+
/// <summary>
97+
/// Get the WebGPU context for the specified context ID. Throw if not present.
98+
/// </summary>
6799
static WebGpuContext& GetContext(int context_id);
68100

101+
/// <summary>
102+
/// Release the WebGPU context. (ref-count based)
103+
/// </summary>
69104
static void ReleaseContext(int context_id);
70105

71106
static void Cleanup();
72107

108+
/// <summary>
109+
/// Return the default context. Create if not present.
110+
/// </summary>
111+
static WebGpuContext& DefaultContext();
112+
73113
private:
74114
WebGpuContextFactory() {}
75115

@@ -82,8 +122,6 @@ class WebGpuContextFactory {
82122
// Class WebGpuContext includes all necessary resources for the context.
83123
class WebGpuContext final {
84124
public:
85-
void Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture);
86-
87125
Status Wait(wgpu::Future f);
88126

89127
const wgpu::Device& Device() const { return device_; }
@@ -190,20 +228,20 @@ class WebGpuContext final {
190228
WGPUDevice device,
191229
webgpu::ValidationMode validation_mode,
192230
bool preserve_device,
193-
uint64_t max_storage_buffer_binding_size,
194-
int power_preference = static_cast<int>(wgpu::PowerPreference::HighPerformance))
231+
uint64_t max_storage_buffer_binding_size)
195232
: instance_{instance},
196233
device_{device},
197234
validation_mode_{validation_mode},
198235
query_type_{TimestampQueryType::None},
199236
preserve_device_{preserve_device},
200-
max_storage_buffer_binding_size_{max_storage_buffer_binding_size},
201-
power_preference_{power_preference} {
237+
max_storage_buffer_binding_size_{max_storage_buffer_binding_size} {
202238
ORT_ENFORCE(max_storage_buffer_binding_size_ == 0 || max_storage_buffer_binding_size_ >= 134217728,
203239
"max_storage_buffer_binding_size must be 0 or at least 128MB");
204240
}
205241
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext);
206242

243+
void Initialize(const WebGpuContextConfig& config);
244+
207245
void LaunchComputePipeline(const wgpu::ComputePassEncoder& compute_pass_encoder,
208246
const std::vector<WGPUBuffer>& bind_buffers,
209247
const std::vector<uint32_t>& bind_buffers_segments,
@@ -292,7 +330,6 @@ class WebGpuContext final {
292330
bool is_profiling_ = false;
293331
bool preserve_device_;
294332
uint64_t max_storage_buffer_binding_size_;
295-
int power_preference_;
296333
GraphCaptureState graph_capture_state_{GraphCaptureState::Default};
297334

298335
// External vector to store captured commands, owned by EP

onnxruntime/core/providers/webgpu/webgpu_execution_provider.h

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,9 @@ struct CapturedCommandInfo;
2727
} // namespace webgpu
2828

2929
struct WebGpuExecutionProviderConfig {
30-
WebGpuExecutionProviderConfig(DataLayout data_layout, bool enable_graph_capture, bool enable_pix_capture)
31-
: data_layout{data_layout},
32-
enable_graph_capture{enable_graph_capture},
33-
enable_pix_capture{enable_pix_capture} {}
34-
WebGpuExecutionProviderConfig(WebGpuExecutionProviderConfig&&) = default;
35-
WebGpuExecutionProviderConfig& operator=(WebGpuExecutionProviderConfig&&) = default;
36-
ORT_DISALLOW_COPY_AND_ASSIGNMENT(WebGpuExecutionProviderConfig);
37-
38-
DataLayout data_layout;
39-
bool enable_graph_capture;
40-
bool enable_pix_capture;
41-
std::vector<std::string> force_cpu_node_names;
30+
DataLayout data_layout{DataLayout::NHWC}; // preferred layout is NHWC by default
31+
bool enable_graph_capture{false}; // graph capture feature is disabled by default
32+
std::vector<std::string> force_cpu_node_names{};
4233
};
4334

4435
class WebGpuExecutionProvider : public IExecutionProvider {

0 commit comments

Comments
 (0)