Skip to content

Commit 418807c

Browse files
committed
Fix WebGPU EP crash on exit
1 parent 65fb61b commit 418807c

File tree

2 files changed

+65
-15
lines changed

2 files changed

+65
-15
lines changed

onnxruntime/core/providers/webgpu/webgpu_context.cc

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -924,10 +924,12 @@ void WebGpuContext::ReleaseGraphResources(std::vector<webgpu::CapturedCommandInf
924924
}
925925
}
926926

927-
std::unordered_map<int32_t, WebGpuContextFactory::WebGpuContextInfo>* WebGpuContextFactory::contexts_ = nullptr;
928927
std::mutex WebGpuContextFactory::mutex_;
929928
std::once_flag WebGpuContextFactory::init_default_flag_;
930-
wgpu::Instance WebGpuContextFactory::default_instance_;
929+
930+
std::unordered_map<int32_t, WebGpuContextFactory::WebGpuContextInfo>* WebGpuContextFactory::contexts_ = nullptr;
931+
WGPUInstance WebGpuContextFactory::default_instance_ = nullptr;
932+
bool WebGpuContextFactory::modules_dxc_loaded_ = false;
931933

932934
WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& config) {
933935
const int context_id = config.context_id;
@@ -960,28 +962,60 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co
960962

961963
std::lock_guard<std::mutex> lock(mutex_);
962964

963-
// Lazy-allocate the contexts map on first use (heap-allocated to avoid static destruction crash).
964-
if (contexts_ == nullptr) {
965-
contexts_ = new std::unordered_map<int32_t, WebGpuContextInfo>();
966-
}
967-
968965
if (default_instance_ == nullptr) {
969966
// Create wgpu::Instance
970967
wgpu::InstanceFeatureName required_instance_features[] = {wgpu::InstanceFeatureName::TimedWaitAny};
971968
wgpu::InstanceDescriptor instance_desc{};
972969
instance_desc.requiredFeatures = required_instance_features;
973970
instance_desc.requiredFeatureCount = sizeof(required_instance_features) / sizeof(required_instance_features[0]);
974-
default_instance_ = wgpu::CreateInstance(&instance_desc);
971+
default_instance_ = wgpuCreateInstance(&instance_desc);
975972

976973
ORT_ENFORCE(default_instance_ != nullptr, "Failed to create wgpu::Instance.");
974+
975+
// Lazy-allocate the contexts map on first use (heap-allocated to avoid static destruction crash).
976+
if (contexts_ == nullptr) {
977+
contexts_ = new std::unordered_map<int32_t, WebGpuContextInfo>();
978+
}
979+
980+
// If we are on Windows and the build does not use external Dawn, for D3D12 backend dxil.dll and dxcompiler.dll are required.
981+
//
982+
// Dawn will try to load them later, but if the DLLs are loaded by Dawn, it could cause them to be unloaded earlier
983+
// than the descruction of WebGpuContextFactory, which will cause crash because the resource release can potentially
984+
// call into dxcompiler.dll. By loading them here, we can make sure they are not unloaded too early.
985+
#if !defined(__wasm__) && defined(_WIN32) && !defined(USE_EXTERNAL_DAWN)
986+
if (config.backend_type == static_cast<int>(WebGpuBackendType::D3D12) && !modules_dxc_loaded_) {
987+
auto runtime_path = Env::Default().GetRuntimePath();
988+
if (!runtime_path.empty()) {
989+
if (modules_ == nullptr) {
990+
modules_ = new LibraryHandles();
991+
}
992+
993+
Status status;
994+
void* module_handle = nullptr;
995+
996+
PathString dxil_path = runtime_path + ToPathString(L"dxil.dll");
997+
status = Env::Default().LoadDynamicLibrary(dxil_path, false, &module_handle);
998+
if (status.IsOK() && module_handle != nullptr) {
999+
modules_->Add(dxil_path, module_handle);
1000+
}
1001+
1002+
PathString dxcompiler_path = runtime_path + ToPathString(L"dxcompiler.dll");
1003+
status = Env::Default().LoadDynamicLibrary(dxcompiler_path, false, &module_handle);
1004+
if (status.IsOK() && module_handle != nullptr) {
1005+
modules_->Add(dxcompiler_path, module_handle);
1006+
}
1007+
modules_dxc_loaded_ = true;
1008+
}
1009+
}
1010+
#endif
9771011
}
9781012

9791013
if (context_id == 0) {
9801014
// context ID is preserved for the default context. User cannot use context ID 0 as a custom context.
9811015
ORT_ENFORCE(instance == nullptr && device == nullptr,
9821016
"WebGPU EP default context (contextId=0) must not have custom WebGPU instance or device.");
9831017

984-
instance = default_instance_.Get();
1018+
instance = default_instance_;
9851019
} else {
9861020
// for context ID > 0, user must provide custom WebGPU instance and device.
9871021
ORT_ENFORCE(instance != nullptr && device != nullptr,
@@ -1034,9 +1068,18 @@ void WebGpuContextFactory::ReleaseContext(int context_id) {
10341068

10351069
void WebGpuContextFactory::Cleanup() {
10361070
std::lock_guard<std::mutex> lock(mutex_);
1037-
delete contexts_;
1038-
contexts_ = nullptr;
1039-
default_instance_ = nullptr;
1071+
if (contexts_ != nullptr) {
1072+
delete contexts_;
1073+
contexts_ = nullptr;
1074+
}
1075+
if (default_instance_ != nullptr) {
1076+
wgpuReleaseInstance(default_instance_);
1077+
default_instance_ = nullptr;
1078+
}
1079+
if (modules_ != nullptr) {
1080+
delete modules_;
1081+
modules_ = nullptr;
1082+
}
10401083
}
10411084

10421085
WebGpuContext& WebGpuContextFactory::DefaultContext() {

onnxruntime/core/providers/webgpu/webgpu_context.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "core/providers/webgpu/webgpu_external_header.h"
1111

1212
#include "core/common/common.h"
13+
#include "core/framework/library_handles.h"
1314
#include "core/providers/webgpu/buffer_manager.h"
1415
#include "core/providers/webgpu/program_manager.h"
1516
#include "core/providers/webgpu/webgpu_utils.h"
@@ -147,6 +148,9 @@ class WebGpuContextFactory {
147148
private:
148149
WebGpuContextFactory() {}
149150

151+
static std::mutex mutex_;
152+
static std::once_flag init_default_flag_;
153+
150154
// Use pointers to heap-allocated objects so that their destructors do NOT run
151155
// during static destruction at process exit. This avoids crashes when dependent
152156
// DLLs (e.g. dxcompiler.dll) have already been unloaded by the OS.
@@ -155,9 +159,12 @@ class WebGpuContextFactory {
155159
// it is reached from OrtEnv::~OrtEnv via CleanupWebGpuContexts().
156160
// On abnormal/process termination they simply leak, which is safe.
157161
static std::unordered_map<int32_t, WebGpuContextInfo>* contexts_;
158-
static std::mutex mutex_;
159-
static std::once_flag init_default_flag_;
160-
static wgpu::Instance default_instance_;
162+
static WGPUInstance default_instance_;
163+
164+
// Use a module manager to ensure that dependent DLLs (e.g. dxcompiler.dll) are not unloaded while WebGPU contexts are
165+
// still alive, which would cause crashes if the contexts try to call into those DLLs during their destruction.
166+
static LibraryHandles* modules_;
167+
static bool modules_dxc_loaded_;
161168
};
162169

163170
// Class WebGpuContext includes all necessary resources for the context.

0 commit comments

Comments
 (0)