Skip to content

Commit 1aa40f1

Browse files
committed
Add q4_0_f16 matmul and fix device init
1 parent 688b51d commit 1aa40f1

File tree

2 files changed

+91
-85
lines changed

2 files changed

+91
-85
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 80 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,6 @@ struct webgpu_context_struct {
125125

126126
std::recursive_mutex mutex;
127127

128-
bool device_init = false;
129-
130128
webgpu_buf_pool param_buf_pool;
131129
webgpu_buf_pool set_rows_error_buf_pool;
132130

@@ -454,6 +452,7 @@ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
454452
ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
455453
return ctx->buffer;
456454
}
455+
457456
static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, ggml_tensor * t) {
458457
size_t offset = ggml_webgpu_tensor_offset(t);
459458
return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
@@ -911,7 +910,7 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
911910
}
912911

913912
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
914-
webgpu_pipeline_info pipeline_infos[5] = {
913+
webgpu_pipeline_info pipeline_infos[6] = {
915914
{ .name = "mul_mat_f32_f32",
916915
.shader_code = wgsl_mul_mat_f32_f32,
917916
.src0_type = GGML_TYPE_F32,
@@ -928,10 +927,14 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
928927
.shader_code = wgsl_mul_mat_f16_f32,
929928
.src0_type = GGML_TYPE_F16,
930929
.src1_type = GGML_TYPE_F32 },
931-
{ .name = "mul_mat_q4_0_f32",
932-
.shader_code = wgsl_mul_mat_q4_0_f32,
933-
.src0_type = GGML_TYPE_Q4_0,
934-
.src1_type = GGML_TYPE_F32 }
930+
{ .name = "mul_mat_q4_0_f32",
931+
.shader_code = wgsl_mul_mat_q4_0_f32,
932+
.src0_type = GGML_TYPE_Q4_0,
933+
.src1_type = GGML_TYPE_F32 },
934+
{ .name = "mul_mat_q4_0_f16",
935+
.shader_code = wgsl_mul_mat_q4_0_f16,
936+
.src0_type = GGML_TYPE_Q4_0,
937+
.src1_type = GGML_TYPE_F16 }
935938
};
936939

937940
for (auto & pipeline_info : pipeline_infos) {
@@ -965,79 +968,6 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
965968
ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
966969
webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
967970

968-
// Multiple threads may try to initialize the device
969-
std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
970-
if (!webgpu_ctx->device_init) {
971-
// Initialize device
972-
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
973-
wgpu::FeatureName::ImplicitDeviceSynchronization };
974-
wgpu::DeviceDescriptor dev_desc;
975-
dev_desc.requiredLimits = &webgpu_ctx->limits;
976-
dev_desc.requiredFeatures = required_features.data();
977-
dev_desc.requiredFeatureCount = required_features.size();
978-
dev_desc.SetDeviceLostCallback(
979-
wgpu::CallbackMode::AllowSpontaneous,
980-
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
981-
GGML_UNUSED(device);
982-
GGML_LOG_ERROR(
983-
"ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
984-
});
985-
dev_desc.SetUncapturedErrorCallback(
986-
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
987-
GGML_UNUSED(device);
988-
GGML_LOG_ERROR(
989-
"ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
990-
});
991-
webgpu_ctx->instance.WaitAny(
992-
webgpu_ctx->adapter.RequestDevice(
993-
&dev_desc,
994-
wgpu::CallbackMode::AllowSpontaneous,
995-
[webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
996-
if (status != wgpu::RequestDeviceStatus::Success) {
997-
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data);
998-
return;
999-
}
1000-
webgpu_ctx->device = std::move(device);
1001-
}),
1002-
UINT64_MAX);
1003-
GGML_ASSERT(webgpu_ctx->device != nullptr);
1004-
1005-
// Initialize (compute) queue
1006-
webgpu_ctx->queue = webgpu_ctx->device.GetQueue();
1007-
1008-
// Create buffer pool for shader parameters
1009-
webgpu_ctx->param_buf_pool.init(webgpu_ctx->device,
1010-
WEBGPU_NUM_PARAM_BUFS,
1011-
WEBGPU_PARAMS_BUF_SIZE_BYTES,
1012-
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
1013-
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
1014-
webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->device,
1015-
WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
1016-
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
1017-
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
1018-
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
1019-
1020-
ggml_webgpu_init_memset_pipeline(webgpu_ctx);
1021-
ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
1022-
ggml_webgpu_init_set_rows_pipeline(webgpu_ctx);
1023-
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
1024-
1025-
#ifdef GGML_WEBGPU_DEBUG
1026-
// Initialize debug buffers
1027-
ggml_webgpu_create_buffer(webgpu_ctx->device,
1028-
webgpu_ctx->debug_host_buf,
1029-
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
1030-
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
1031-
"debug_host_buf");
1032-
ggml_webgpu_create_buffer(webgpu_ctx->device,
1033-
webgpu_ctx->debug_dev_buf,
1034-
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
1035-
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
1036-
"debug_dev_buf");
1037-
#endif
1038-
webgpu_ctx->device_init = true;
1039-
}
1040-
1041971
static ggml_backend_webgpu_context backend_ctx;
1042972
backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
1043973
backend_ctx.webgpu_ctx = webgpu_ctx;
@@ -1088,9 +1018,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
10881018
case GGML_OP_CPY | GGML_OP_SET_ROWS:
10891019
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
10901020
case GGML_OP_MUL_MAT:
1091-
return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1092-
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) ||
1093-
(op->src[0]->type == GGML_TYPE_Q4_0 && op->src[1]->type == GGML_TYPE_F32);
1021+
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 ||
1022+
op->src[0]->type == GGML_TYPE_Q4_0) &&
1023+
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16);
10941024
default:
10951025
return false;
10961026
}
@@ -1157,6 +1087,73 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
11571087
wgpu::AdapterInfo info{};
11581088
ctx->adapter.GetInfo(&info);
11591089

1090+
// Initialize device
1091+
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
1092+
wgpu::FeatureName::ImplicitDeviceSynchronization };
1093+
wgpu::DeviceDescriptor dev_desc;
1094+
dev_desc.requiredLimits = &ctx->limits;
1095+
dev_desc.requiredFeatures = required_features.data();
1096+
dev_desc.requiredFeatureCount = required_features.size();
1097+
dev_desc.SetDeviceLostCallback(
1098+
wgpu::CallbackMode::AllowSpontaneous,
1099+
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
1100+
GGML_UNUSED(device);
1101+
GGML_LOG_ERROR(
1102+
"ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
1103+
});
1104+
dev_desc.SetUncapturedErrorCallback(
1105+
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
1106+
GGML_UNUSED(device);
1107+
GGML_LOG_ERROR(
1108+
"ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
1109+
});
1110+
ctx->instance.WaitAny(ctx->adapter.RequestDevice(
1111+
&dev_desc,
1112+
wgpu::CallbackMode::AllowSpontaneous,
1113+
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
1114+
if (status != wgpu::RequestDeviceStatus::Success) {
1115+
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data);
1116+
return;
1117+
}
1118+
ctx->device = std::move(device);
1119+
}),
1120+
UINT64_MAX);
1121+
GGML_ASSERT(ctx->device != nullptr);
1122+
1123+
// Initialize (compute) queue
1124+
ctx->queue = ctx->device.GetQueue();
1125+
1126+
// Create buffer pool for shader parameters
1127+
ctx->param_buf_pool.init(ctx->device,
1128+
WEBGPU_NUM_PARAM_BUFS,
1129+
WEBGPU_PARAMS_BUF_SIZE_BYTES,
1130+
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
1131+
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
1132+
ctx->set_rows_error_buf_pool.init(ctx->device,
1133+
WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
1134+
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
1135+
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
1136+
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
1137+
1138+
ggml_webgpu_init_memset_pipeline(ctx);
1139+
ggml_webgpu_init_mul_mat_pipeline(ctx);
1140+
ggml_webgpu_init_set_rows_pipeline(ctx);
1141+
ggml_webgpu_init_cpy_pipeline(ctx);
1142+
1143+
#ifdef GGML_WEBGPU_DEBUG
1144+
// Initialize debug buffers
1145+
ggml_webgpu_create_buffer(ctx->device,
1146+
ctx->debug_host_buf,
1147+
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
1148+
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
1149+
"debug_host_buf");
1150+
ggml_webgpu_create_buffer(ctx->device,
1151+
ctx->debug_dev_buf,
1152+
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
1153+
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
1154+
"debug_dev_buf");
1155+
#endif
1156+
11601157
static ggml_backend_webgpu_device_context device_ctx;
11611158
device_ctx.webgpu_ctx = ctx;
11621159
device_ctx.device_name = GGML_WEBGPU_NAME;

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,16 @@
4040
"BLOCK_SIZE": 32
4141
},
4242
"DECLS": "Q4_0"
43+
},
44+
{
45+
"REPLS": {
46+
"SRC0_TYPE": "q4_0",
47+
"SRC1_TYPE": "f16",
48+
"BLOCK_SIZE": 32
49+
},
50+
"DECLS": "Q4_0"
4351
}
52+
4453
]
4554

4655
#end(VARIANTS)
@@ -70,8 +79,8 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
7079
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
7180
let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;
7281
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
73-
sum += q_lo * src1[src1_offset];
74-
sum += q_hi * src1[src1_offset + 16];
82+
sum += q_lo * f32(src1[src1_offset]);
83+
sum += q_hi * f32(src1[src1_offset + 16]);
7584
}
7685
}
7786
return sum;

0 commit comments

Comments
 (0)