@@ -696,23 +696,19 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t
696696
697697 std::vector<uint32_t > params = {
698698 (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, src) / ggml_type_size (src->type )),
699+ (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, dst) / ggml_type_size (dst->type )),
700+ (uint32_t ) (src->nb [1 ] / ggml_type_size (src->type )),
701+ (uint32_t ) (src->nb [2 ] / ggml_type_size (src->type )),
702+ (uint32_t ) (src->nb [3 ] / ggml_type_size (src->type )),
703+ (uint32_t ) (dst->nb [1 ] / ggml_type_size (dst->type )),
704+ (uint32_t ) (dst->nb [2 ] / ggml_type_size (dst->type )),
705+ (uint32_t ) (dst->nb [3 ] / ggml_type_size (dst->type )),
706+ (uint32_t ) src->ne [0 ],
707+ (uint32_t ) src->ne [1 ],
708+ (uint32_t ) src->ne [2 ],
709+ (uint32_t ) src->ne [3 ],
710+ *(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader
699711 };
700- if (!in_place) {
701- params.push_back ((uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, dst) / ggml_type_size (dst->type )));
702- }
703- params.push_back ((uint32_t ) (src->nb [1 ] / ggml_type_size (src->type )));
704- params.push_back ((uint32_t ) (src->nb [2 ] / ggml_type_size (src->type )));
705- params.push_back ((uint32_t ) (src->nb [3 ] / ggml_type_size (src->type )));
706- if (!in_place) {
707- params.push_back ((uint32_t ) (dst->nb [1 ] / ggml_type_size (dst->type )));
708- params.push_back ((uint32_t ) (dst->nb [2 ] / ggml_type_size (dst->type )));
709- params.push_back ((uint32_t ) (dst->nb [3 ] / ggml_type_size (dst->type )));
710- }
711- params.push_back ((uint32_t ) src->ne [0 ]);
712- params.push_back ((uint32_t ) src->ne [1 ]);
713- params.push_back ((uint32_t ) src->ne [2 ]);
714- params.push_back ((uint32_t ) src->ne [3 ]);
715- params.push_back (*(uint32_t *) dst->op_params ); // epsilon, treated as f32 in the shader
716712
717713 std::vector<wgpu::BindGroupEntry> entries = {
718714 { .binding = 0 ,
@@ -1266,10 +1262,10 @@ static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
12661262 constants);
12671263 ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F16], wgsl_add_f16, " add_f16" ,
12681264 constants);
1269- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_ip_pipeline [GGML_TYPE_F32], wgsl_add_in_place_f32 ,
1270- " add_in_place_f32 " , constants);
1271- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_ip_pipeline [GGML_TYPE_F16], wgsl_add_in_place_f16 ,
1272- " add_in_place_f16 " , constants);
1265+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_ip_pipeline [GGML_TYPE_F32], wgsl_add_f32_inplace ,
1266+ " add_f32_inplace " , constants);
1267+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_ip_pipeline [GGML_TYPE_F16], wgsl_add_f16_inplace ,
1268+ " add_f16_inplace " , constants);
12731269}
12741270
12751271static void ggml_webgpu_init_mul_pipeline (webgpu_context & webgpu_ctx) {
@@ -1278,18 +1274,18 @@ static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
12781274 constants);
12791275 ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F16], wgsl_mul_f16, " mul_f16" ,
12801276 constants);
1281- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_ip_pipeline [GGML_TYPE_F32], wgsl_mul_in_place_f32 ,
1282- " mul_in_place_f32 " , constants);
1283- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_ip_pipeline [GGML_TYPE_F16], wgsl_mul_in_place_f16 ,
1284- " mul_in_place_f16 " , constants);
1277+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_ip_pipeline [GGML_TYPE_F32], wgsl_mul_f32_inplace ,
1278+ " mul_f32_inplace " , constants);
1279+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_ip_pipeline [GGML_TYPE_F16], wgsl_mul_f16_inplace ,
1280+ " mul_f16_inplace " , constants);
12851281}
12861282
12871283static void ggml_webgpu_init_rms_norm_pipeline (webgpu_context & webgpu_ctx) {
12881284 std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry (webgpu_ctx);
12891285 ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->rms_norm_pipeline , wgsl_rms_norm, " rms_norm" ,
12901286 constants);
1291- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->rms_norm_ip_pipeline , wgsl_rms_norm_in_place ,
1292- " rms_norm_in_place " , constants);
1287+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->rms_norm_ip_pipeline , wgsl_rms_norm_inplace ,
1288+ " rms_norm_inplace " , constants);
12931289}
12941290
12951291static void ggml_webgpu_init_rope_pipeline (webgpu_context & webgpu_ctx) {
0 commit comments