Skip to content

Commit d006858

Browse files
authored
ggml-webgpu: move from parameter buffer pool to single buffer with offsets (#21278)
* Work towards removing bitcast * Move rest of existing types over * Add timeout back to wait and remove synchronous set_tensor/memset_tensor * move to unpackf16 for wider compatibility * cleanup * Remove deadlock condition in free_bufs * Start work on removing parameter buffer pools * Simplify and optimize further * simplify profile futures * Fix stride * Try using a single command buffer per batch * formatting
1 parent e439700 commit d006858

File tree

2 files changed

+379
-422
lines changed

2 files changed

+379
-422
lines changed

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -437,12 +437,18 @@ inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_
437437

438438
// Head-dim specializations used by the tuned vec f16 path.
439439
switch (key.head_dim_qk) {
440-
case 64: return 2u;
441-
case 96: return 4u;
442-
case 128: return 1u;
443-
case 192: return 2u;
444-
case 576: return 2u;
445-
default: return 1u;
440+
case 64:
441+
return 2u;
442+
case 96:
443+
return 4u;
444+
case 128:
445+
return 1u;
446+
case 192:
447+
return 2u;
448+
case 576:
449+
return 2u;
450+
default:
451+
return 1u;
446452
}
447453
}
448454

@@ -513,9 +519,9 @@ struct ggml_webgpu_flash_attn_blk_shader_lib_context {
513519
};
514520

515521
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader(
516-
pre_wgsl::Preprocessor & preprocessor,
517-
const char * shader_src,
518-
const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
522+
pre_wgsl::Preprocessor & preprocessor,
523+
const char * shader_src,
524+
const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
519525
std::vector<std::string> defines;
520526
std::string variant = "flash_attn_vec_blk";
521527

@@ -1857,9 +1863,8 @@ class ggml_webgpu_shader_lib {
18571863
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
18581864

18591865
uint32_t q_tile = context.sg_mat_m;
1860-
uint32_t kv_tile =
1861-
std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
1862-
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
1866+
uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
1867+
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
18631868
if (context.key.use_vec) {
18641869
q_tile = 1;
18651870
kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context)));
@@ -1885,14 +1890,14 @@ class ggml_webgpu_shader_lib {
18851890
}
18861891
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
18871892

1888-
const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn;
1893+
const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn;
18891894
webgpu_pipeline pipeline =
18901895
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
1891-
auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
1892-
decisions->q_tile = q_tile;
1893-
decisions->kv_tile = kv_tile;
1894-
decisions->wg_size = wg_size;
1895-
pipeline.context = decisions;
1896+
auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
1897+
decisions->q_tile = q_tile;
1898+
decisions->kv_tile = kv_tile;
1899+
decisions->wg_size = wg_size;
1900+
pipeline.context = decisions;
18961901
flash_attn_pipelines[context.key] = pipeline;
18971902
return flash_attn_pipelines[context.key];
18981903
}
@@ -1905,7 +1910,7 @@ class ggml_webgpu_shader_lib {
19051910

19061911
ggml_webgpu_processed_shader processed =
19071912
ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context);
1908-
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
1913+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
19091914
flash_attn_blk_pipelines[context.key] = pipeline;
19101915
return flash_attn_blk_pipelines[context.key];
19111916
}

0 commit comments

Comments
 (0)