Skip to content

Commit 03ea041

Browse files
reeselevineneha-haNeha AbbasNeha Abbas
authored
ggml webgpu: minor set rows optimization (#16810)
* Add buffer label and enable dawn-specific toggles to turn off some checks * Minor set_rows optimization (#4) * updated optimization, fixed errors * non vectorized version now dispatches one thread per element * Simplify * Change logic for set_rows pipelines --------- Co-authored-by: Neha Abbas <[email protected]> Co-authored-by: Neha Abbas <[email protected]> Co-authored-by: Reese Levine <[email protected]> * Comment on dawn toggles * Remove some comments * Implement overlap binary operators * Revert "Implement overlap binary operators" This reverts commit ed710b36f51ab3f53fa13db15c1685dc8678a32a. * Disable support for non-contiguous binary_op tensors and leave note for future support --------- Co-authored-by: neha-ha <[email protected]> Co-authored-by: Neha Abbas <[email protected]> Co-authored-by: Neha Abbas <[email protected]>
1 parent cdabeb2 commit 03ea041

File tree

2 files changed

+103
-30
lines changed

2 files changed

+103
-30
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 65 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ struct webgpu_context_struct {
248248

249249
webgpu_pipeline memset_pipeline;
250250
webgpu_pipeline mul_mat_pipeline[30][2];
251-
webgpu_pipeline set_rows_pipeline;
251+
webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized
252252
webgpu_pipeline get_rows_pipeline[30];
253253
webgpu_pipeline get_rows_f32_no_vec_pipeline;
254254
webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type
@@ -309,10 +309,12 @@ struct ggml_backend_webgpu_context {
309309
struct ggml_backend_webgpu_buffer_context {
310310
webgpu_context webgpu_ctx;
311311
wgpu::Buffer buffer;
312+
std::string label;
312313

313-
ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) :
314+
ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf, std::string lbl) :
314315
webgpu_ctx(std::move(ctx)),
315-
buffer(std::move(buf)) {}
316+
buffer(std::move(buf)),
317+
label(std::move(lbl)) {}
316318
};
317319

318320
/* End struct definitions */
@@ -764,10 +766,20 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
764766
{ .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
765767
};
766768

767-
size_t max_wg_size = ctx->max_wg_size_x;
768-
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
769+
size_t max_wg_size = ctx->max_wg_size_x;
770+
771+
int vectorized = src->ne[0] % 4 == 0;
772+
webgpu_pipeline pipeline = ctx->set_rows_pipeline[0][vectorized];
773+
uint32_t threads;
774+
if (vectorized) {
775+
threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
776+
} else {
777+
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
778+
}
769779

770-
return ggml_backend_webgpu_build(ctx, ctx->set_rows_pipeline, params, entries, wg_x, error_bufs);
780+
uint32_t wg_x = (threads + max_wg_size - 1) / max_wg_size;
781+
782+
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, error_bufs);
771783
}
772784

773785
static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
@@ -1336,11 +1348,11 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
13361348

13371349
WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor);
13381350

1339-
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", "
1340-
<< offset << ", " << size << ")");
1341-
13421351
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
13431352

1353+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value
1354+
<< ", " << offset << ", " << size << ")");
1355+
13441356
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
13451357

13461358
// This is a trick to set all bytes of a u32 to the same 1 byte value.
@@ -1354,12 +1366,13 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
13541366
const void * data,
13551367
size_t offset,
13561368
size_t size) {
1357-
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", "
1358-
<< offset << ", " << size << ")");
13591369
WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor);
13601370
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
13611371
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
13621372

1373+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
1374+
<< ", " << offset << ", " << size << ")");
1375+
13631376
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
13641377

13651378
webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
@@ -1397,12 +1410,12 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
13971410
void * data,
13981411
size_t offset,
13991412
size_t size) {
1400-
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", "
1401-
<< offset << ", " << size << ")");
14021413
WEBGPU_CPU_PROFILE_TOTAL_START(get_tensor);
1403-
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
1404-
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
1405-
wgpu::Device device = webgpu_ctx->device;
1414+
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
1415+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
1416+
<< ", " << offset << ", " << size << ")");
1417+
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
1418+
wgpu::Device device = webgpu_ctx->device;
14061419

14071420
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
14081421

@@ -1473,16 +1486,20 @@ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer
14731486

14741487
static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
14751488
size_t size) {
1476-
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")");
1489+
static std::atomic<int> buffer_count;
1490+
int buffer_id = buffer_count++;
1491+
std::string buf_name = "tensor_buf" + std::to_string(buffer_id);
1492+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes");
14771493
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
14781494

14791495
wgpu::Buffer buf;
14801496
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf,
14811497
(size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1),
14821498
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
1483-
"allocated_buffer");
1499+
buf_name.c_str());
14841500

1485-
ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf);
1501+
ggml_backend_webgpu_buffer_context * buf_ctx =
1502+
new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf, buf_name);
14861503

14871504
return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
14881505
}
@@ -1613,8 +1630,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
16131630
}
16141631

16151632
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
1616-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
1617-
ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
1633+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][0], wgsl_set_rows_f16,
1634+
"set_rows_f16", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
1635+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][1], wgsl_set_rows_f16_vec,
1636+
"set_rows_f16_vec", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
16181637
}
16191638

16201639
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
@@ -1950,8 +1969,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
19501969
case GGML_OP_SUB:
19511970
case GGML_OP_MUL:
19521971
case GGML_OP_DIV:
1972+
// TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE
1973+
// see https://github.com/ggml-org/llama.cpp/pull/16857
19531974
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
1954-
(src1->type == op->type);
1975+
(src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
19551976
break;
19561977
case GGML_OP_CPY:
19571978
case GGML_OP_CONT:
@@ -2129,6 +2150,19 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
21292150
required_features.push_back(wgpu::FeatureName::TimestampQuery);
21302151
#endif
21312152

2153+
// Enable Dawn-specific toggles to increase native performance
2154+
// TODO: Don't enable for WASM builds, they won't have an effect anyways
2155+
// TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
2156+
// only for native performance?
2157+
const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
2158+
"disable_polyfills_on_integer_div_and_mod" };
2159+
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
2160+
wgpu::DawnTogglesDescriptor deviceTogglesDesc;
2161+
deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
2162+
deviceTogglesDesc.enabledToggleCount = 4;
2163+
deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
2164+
deviceTogglesDesc.disabledToggleCount = 1;
2165+
21322166
wgpu::DeviceDescriptor dev_desc;
21332167
dev_desc.requiredLimits = &ctx->limits;
21342168
dev_desc.requiredFeatures = required_features.data();
@@ -2146,6 +2180,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
21462180
GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
21472181
std::string(message).c_str());
21482182
});
2183+
dev_desc.nextInChain = &deviceTogglesDesc;
21492184
ctx->instance.WaitAny(ctx->adapter.RequestDevice(
21502185
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
21512186
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
@@ -2243,11 +2278,18 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
22432278
ctx.name = GGML_WEBGPU_NAME;
22442279
ctx.device_count = 1;
22452280

2281+
const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" };
2282+
2283+
wgpu::DawnTogglesDescriptor instanceTogglesDesc;
2284+
instanceTogglesDesc.enabledToggles = instanceEnabledToggles;
2285+
instanceTogglesDesc.enabledToggleCount = 1;
22462286
wgpu::InstanceDescriptor instance_descriptor{};
22472287
std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
22482288
instance_descriptor.requiredFeatures = instance_features.data();
22492289
instance_descriptor.requiredFeatureCount = instance_features.size();
2250-
webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
2290+
instance_descriptor.nextInChain = &instanceTogglesDesc;
2291+
2292+
webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
22512293
GGML_ASSERT(webgpu_ctx->instance != nullptr);
22522294

22532295
static ggml_backend_reg reg = {

ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl renamed to ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,38 @@
1+
#define(VARIANTS)
2+
3+
[
4+
{
5+
"SHADER_SUFFIX": "f16_vec",
6+
"REPLS": {
7+
"TYPE" : "vec4<f32>",
8+
"DST_TYPE": "vec4<f16>",
9+
"VEC_SIZE": 4
10+
}
11+
},
12+
{
13+
"SHADER_SUFFIX": "f16",
14+
"REPLS": {
15+
"TYPE" : "f32",
16+
"DST_TYPE": "f16",
17+
"VEC_SIZE": 1
18+
}
19+
}
20+
]
21+
22+
#end(VARIANTS)
23+
24+
#define(SHADER)
25+
126
enable f16;
227

328
@group(0) @binding(0)
4-
var<storage, read_write> src: array<f32>;
29+
var<storage, read_write> src: array<{{TYPE}}>;
530

631
@group(0) @binding(1)
732
var<storage, read_write> idx: array<u32>;
833

934
@group(0) @binding(2)
10-
var<storage, read_write> dst: array<f16>;
35+
var<storage, read_write> dst: array<{{DST_TYPE}}>;
1136

1237
@group(0) @binding(3)
1338
var<storage, read_write> error: atomic<u32>;
@@ -47,10 +72,14 @@ var<uniform> params: Params;
4772
override wg_size: u32;
4873
@compute @workgroup_size(wg_size)
4974
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
50-
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
75+
if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / {{VEC_SIZE}}) {
5176
return;
5277
}
53-
var i = gid.x;
78+
79+
// getting the row from gid
80+
let elems_per_row = params.ne0 / {{VEC_SIZE}};
81+
var i = gid.x / elems_per_row;
82+
5483
let i_src3 = i / (params.ne2 * params.n_rows);
5584

5685
i = i % (params.ne2 * params.n_rows);
@@ -75,7 +104,9 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
75104
let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
76105
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
77106

78-
for (var i: u32 = 0; i < params.ne0; i++) {
79-
dst[i_dst_row + i] = f16(src[i_src_row + i]);
80-
}
107+
let col_idx = (gid.x % elems_per_row);
108+
dst[i_dst_row/{{VEC_SIZE}} + col_idx] = {{DST_TYPE}}(src[i_src_row/{{VEC_SIZE}} + col_idx]);
81109
}
110+
111+
#end(SHADER)
112+

0 commit comments

Comments
 (0)