@@ -125,8 +125,6 @@ struct webgpu_context_struct {
125125
126126 std::recursive_mutex mutex;
127127
128- bool device_init = false ;
129-
130128 webgpu_buf_pool param_buf_pool;
131129 webgpu_buf_pool set_rows_error_buf_pool;
132130
@@ -454,6 +452,7 @@ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
454452 ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer ->context ;
455453 return ctx->buffer ;
456454}
455+
457456static size_t ggml_webgpu_tensor_misalignment (webgpu_context & ctx, ggml_tensor * t) {
458457 size_t offset = ggml_webgpu_tensor_offset (t);
459458 return offset & (ctx->limits .minStorageBufferOffsetAlignment - 1 );
@@ -911,7 +910,7 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
911910}
912911
913912static void ggml_webgpu_init_mul_mat_pipeline (webgpu_context & webgpu_ctx) {
914- webgpu_pipeline_info pipeline_infos[5 ] = {
913+ webgpu_pipeline_info pipeline_infos[6 ] = {
915914 { .name = " mul_mat_f32_f32" ,
916915 .shader_code = wgsl_mul_mat_f32_f32,
917916 .src0_type = GGML_TYPE_F32,
@@ -928,10 +927,14 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
928927 .shader_code = wgsl_mul_mat_f16_f32,
929928 .src0_type = GGML_TYPE_F16,
930929 .src1_type = GGML_TYPE_F32 },
931- { .name = " mul_mat_q4_0_f32" ,
932- .shader_code = wgsl_mul_mat_q4_0_f32,
933- .src0_type = GGML_TYPE_Q4_0,
934- .src1_type = GGML_TYPE_F32 }
930+ { .name = " mul_mat_q4_0_f32" ,
931+ .shader_code = wgsl_mul_mat_q4_0_f32,
932+ .src0_type = GGML_TYPE_Q4_0,
933+ .src1_type = GGML_TYPE_F32 },
934+ { .name = " mul_mat_q4_0_f16" ,
935+ .shader_code = wgsl_mul_mat_q4_0_f16,
936+ .src0_type = GGML_TYPE_Q4_0,
937+ .src1_type = GGML_TYPE_F16 }
935938 };
936939
937940 for (auto & pipeline_info : pipeline_infos) {
@@ -965,79 +968,6 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
965968 ggml_backend_webgpu_device_context * dev_ctx = static_cast <ggml_backend_webgpu_device_context *>(dev->context );
966969 webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx ;
967970
968- // Multiple threads may try to initialize the device
969- std::lock_guard<std::recursive_mutex> lock (webgpu_ctx->mutex );
970- if (!webgpu_ctx->device_init ) {
971- // Initialize device
972- std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
973- wgpu::FeatureName::ImplicitDeviceSynchronization };
974- wgpu::DeviceDescriptor dev_desc;
975- dev_desc.requiredLimits = &webgpu_ctx->limits ;
976- dev_desc.requiredFeatures = required_features.data ();
977- dev_desc.requiredFeatureCount = required_features.size ();
978- dev_desc.SetDeviceLostCallback (
979- wgpu::CallbackMode::AllowSpontaneous,
980- [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
981- GGML_UNUSED (device);
982- GGML_LOG_ERROR (
983- " ggml_webgpu: Device lost! Reason: %d, Message: %s\n " , static_cast <int >(reason), message.data );
984- });
985- dev_desc.SetUncapturedErrorCallback (
986- [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
987- GGML_UNUSED (device);
988- GGML_LOG_ERROR (
989- " ggml_webgpu: Device error! Reason: %d, Message: %s\n " , static_cast <int >(reason), message.data );
990- });
991- webgpu_ctx->instance .WaitAny (
992- webgpu_ctx->adapter .RequestDevice (
993- &dev_desc,
994- wgpu::CallbackMode::AllowSpontaneous,
995- [webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
996- if (status != wgpu::RequestDeviceStatus::Success) {
997- GGML_LOG_ERROR (" ggml_webgpu: Failed to get a device: %s\n " , message.data );
998- return ;
999- }
1000- webgpu_ctx->device = std::move (device);
1001- }),
1002- UINT64_MAX);
1003- GGML_ASSERT (webgpu_ctx->device != nullptr );
1004-
1005- // Initialize (compute) queue
1006- webgpu_ctx->queue = webgpu_ctx->device .GetQueue ();
1007-
1008- // Create buffer pool for shader parameters
1009- webgpu_ctx->param_buf_pool .init (webgpu_ctx->device ,
1010- WEBGPU_NUM_PARAM_BUFS,
1011- WEBGPU_PARAMS_BUF_SIZE_BYTES,
1012- wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
1013- wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
1014- webgpu_ctx->set_rows_error_buf_pool .init (webgpu_ctx->device ,
1015- WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
1016- WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
1017- wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
1018- wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
1019-
1020- ggml_webgpu_init_memset_pipeline (webgpu_ctx);
1021- ggml_webgpu_init_mul_mat_pipeline (webgpu_ctx);
1022- ggml_webgpu_init_set_rows_pipeline (webgpu_ctx);
1023- ggml_webgpu_init_cpy_pipeline (webgpu_ctx);
1024-
1025- #ifdef GGML_WEBGPU_DEBUG
1026- // Initialize debug buffers
1027- ggml_webgpu_create_buffer (webgpu_ctx->device ,
1028- webgpu_ctx->debug_host_buf ,
1029- WEBGPU_DEBUG_BUF_ELEMS * sizeof (uint32_t ),
1030- wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
1031- " debug_host_buf" );
1032- ggml_webgpu_create_buffer (webgpu_ctx->device ,
1033- webgpu_ctx->debug_dev_buf ,
1034- WEBGPU_DEBUG_BUF_ELEMS * sizeof (uint32_t ),
1035- wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
1036- " debug_dev_buf" );
1037- #endif
1038- webgpu_ctx->device_init = true ;
1039- }
1040-
1041971 static ggml_backend_webgpu_context backend_ctx;
1042972 backend_ctx.name = GGML_WEBGPU_NAME + std::string (" : " ) + dev_ctx->device_name ;
1043973 backend_ctx.webgpu_ctx = webgpu_ctx;
@@ -1088,9 +1018,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
10881018 case GGML_OP_CPY | GGML_OP_SET_ROWS:
10891019 return op->type == GGML_TYPE_F16 && op->src [0 ]->type == GGML_TYPE_F32;
10901020 case GGML_OP_MUL_MAT:
1091- return (( op->src [0 ]->type == GGML_TYPE_F32 || op->src [0 ]->type == GGML_TYPE_F16) &&
1092- (op-> src [ 1 ]-> type == GGML_TYPE_F32 || op->src [1 ]->type == GGML_TYPE_F16)) ||
1093- (op->src [0 ]->type == GGML_TYPE_Q4_0 && op->src [1 ]->type == GGML_TYPE_F32 );
1021+ return (op->src [0 ]->type == GGML_TYPE_F32 || op->src [0 ]->type == GGML_TYPE_F16 ||
1022+ op->src [0 ]->type == GGML_TYPE_Q4_0) &&
1023+ (op->src [1 ]->type == GGML_TYPE_F32 || op->src [1 ]->type == GGML_TYPE_F16 );
10941024 default :
10951025 return false ;
10961026 }
@@ -1157,6 +1087,73 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
11571087 wgpu::AdapterInfo info{};
11581088 ctx->adapter .GetInfo (&info);
11591089
1090+ // Initialize device
1091+ std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
1092+ wgpu::FeatureName::ImplicitDeviceSynchronization };
1093+ wgpu::DeviceDescriptor dev_desc;
1094+ dev_desc.requiredLimits = &ctx->limits ;
1095+ dev_desc.requiredFeatures = required_features.data ();
1096+ dev_desc.requiredFeatureCount = required_features.size ();
1097+ dev_desc.SetDeviceLostCallback (
1098+ wgpu::CallbackMode::AllowSpontaneous,
1099+ [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
1100+ GGML_UNUSED (device);
1101+ GGML_LOG_ERROR (
1102+ " ggml_webgpu: Device lost! Reason: %d, Message: %s\n " , static_cast <int >(reason), message.data );
1103+ });
1104+ dev_desc.SetUncapturedErrorCallback (
1105+ [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
1106+ GGML_UNUSED (device);
1107+ GGML_LOG_ERROR (
1108+ " ggml_webgpu: Device error! Reason: %d, Message: %s\n " , static_cast <int >(reason), message.data );
1109+ });
1110+ ctx->instance .WaitAny (ctx->adapter .RequestDevice (
1111+ &dev_desc,
1112+ wgpu::CallbackMode::AllowSpontaneous,
1113+ [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
1114+ if (status != wgpu::RequestDeviceStatus::Success) {
1115+ GGML_LOG_ERROR (" ggml_webgpu: Failed to get a device: %s\n " , message.data );
1116+ return ;
1117+ }
1118+ ctx->device = std::move (device);
1119+ }),
1120+ UINT64_MAX);
1121+ GGML_ASSERT (ctx->device != nullptr );
1122+
1123+ // Initialize (compute) queue
1124+ ctx->queue = ctx->device .GetQueue ();
1125+
1126+ // Create buffer pool for shader parameters
1127+ ctx->param_buf_pool .init (ctx->device ,
1128+ WEBGPU_NUM_PARAM_BUFS,
1129+ WEBGPU_PARAMS_BUF_SIZE_BYTES,
1130+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
1131+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
1132+ ctx->set_rows_error_buf_pool .init (ctx->device ,
1133+ WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
1134+ WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
1135+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
1136+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
1137+
1138+ ggml_webgpu_init_memset_pipeline (ctx);
1139+ ggml_webgpu_init_mul_mat_pipeline (ctx);
1140+ ggml_webgpu_init_set_rows_pipeline (ctx);
1141+ ggml_webgpu_init_cpy_pipeline (ctx);
1142+
1143+ #ifdef GGML_WEBGPU_DEBUG
1144+ // Initialize debug buffers
1145+ ggml_webgpu_create_buffer (ctx->device ,
1146+ ctx->debug_host_buf ,
1147+ WEBGPU_DEBUG_BUF_ELEMS * sizeof (uint32_t ),
1148+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
1149+ " debug_host_buf" );
1150+ ggml_webgpu_create_buffer (ctx->device ,
1151+ ctx->debug_dev_buf ,
1152+ WEBGPU_DEBUG_BUF_ELEMS * sizeof (uint32_t ),
1153+ wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
1154+ " debug_dev_buf" );
1155+ #endif
1156+
11601157 static ggml_backend_webgpu_device_context device_ctx;
11611158 device_ctx.webgpu_ctx = ctx;
11621159 device_ctx.device_name = GGML_WEBGPU_NAME;
0 commit comments