@@ -437,12 +437,18 @@ inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_
437437
438438 // Head-dim specializations used by the tuned vec f16 path.
439439 switch (key.head_dim_qk ) {
440- case 64 : return 2u ;
441- case 96 : return 4u ;
442- case 128 : return 1u ;
443- case 192 : return 2u ;
444- case 576 : return 2u ;
445- default : return 1u ;
440+ case 64 :
441+ return 2u ;
442+ case 96 :
443+ return 4u ;
444+ case 128 :
445+ return 1u ;
446+ case 192 :
447+ return 2u ;
448+ case 576 :
449+ return 2u ;
450+ default :
451+ return 1u ;
446452 }
447453}
448454
@@ -513,9 +519,9 @@ struct ggml_webgpu_flash_attn_blk_shader_lib_context {
513519};
514520
515521inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader (
516- pre_wgsl::Preprocessor & preprocessor,
517- const char * shader_src,
518- const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
522+ pre_wgsl::Preprocessor & preprocessor,
523+ const char * shader_src,
524+ const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
519525 std::vector<std::string> defines;
520526 std::string variant = " flash_attn_vec_blk" ;
521527
@@ -1857,9 +1863,8 @@ class ggml_webgpu_shader_lib {
18571863 defines.push_back (std::string (" SG_MAT_K=" ) + std::to_string (context.sg_mat_k ));
18581864
18591865 uint32_t q_tile = context.sg_mat_m ;
1860- uint32_t kv_tile =
1861- std::min (ggml_webgpu_flash_attn_max_kv_tile (context),
1862- context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
1866+ uint32_t kv_tile = std::min (ggml_webgpu_flash_attn_max_kv_tile (context),
1867+ context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
18631868 if (context.key .use_vec ) {
18641869 q_tile = 1 ;
18651870 kv_tile = std::max (context.sg_mat_n , std::min (32u , ggml_webgpu_flash_attn_max_kv_tile (context)));
@@ -1885,14 +1890,14 @@ class ggml_webgpu_shader_lib {
18851890 }
18861891 defines.push_back (std::string (" WG_SIZE=" ) + std::to_string (wg_size));
18871892
1888- const char * shader_src = context.key .use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn;
1893+ const char * shader_src = context.key .use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn;
18891894 webgpu_pipeline pipeline =
18901895 ggml_webgpu_create_pipeline (device, preprocessor.preprocess (shader_src, defines), variant);
1891- auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
1892- decisions->q_tile = q_tile;
1893- decisions->kv_tile = kv_tile;
1894- decisions->wg_size = wg_size;
1895- pipeline.context = decisions;
1896+ auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
1897+ decisions->q_tile = q_tile;
1898+ decisions->kv_tile = kv_tile;
1899+ decisions->wg_size = wg_size;
1900+ pipeline.context = decisions;
18961901 flash_attn_pipelines[context.key ] = pipeline;
18971902 return flash_attn_pipelines[context.key ];
18981903 }
@@ -1905,7 +1910,7 @@ class ggml_webgpu_shader_lib {
19051910
19061911 ggml_webgpu_processed_shader processed =
19071912 ggml_webgpu_preprocess_flash_attn_blk_shader (preprocessor, wgsl_flash_attn_vec_blk, context);
1908- webgpu_pipeline pipeline = ggml_webgpu_create_pipeline (device, processed.wgsl , processed.variant );
1913+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline (device, processed.wgsl , processed.variant );
19091914 flash_attn_blk_pipelines[context.key ] = pipeline;
19101915 return flash_attn_blk_pipelines[context.key ];
19111916 }
0 commit comments