@@ -138,6 +138,7 @@ struct webgpu_context_struct {
138138 wgpu::ComputePipeline rms_norm_pipeline[2 ]; // inplace
139139 wgpu::ComputePipeline rope_pipeline[2 ][2 ][2 ]; // type, ff, inplace
140140 wgpu::ComputePipeline glu_pipeline[7 ][2 ][2 ]; // glu-op, type, split
141+ wgpu::ComputePipeline scale_pipeline[2 ]; // inplace
141142
142143 size_t memset_bytes_per_thread;
143144
@@ -840,9 +841,9 @@ static void ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tenso
840841 (uint32_t ) dst->ne [0 ],
841842 (uint32_t ) dst->ne [1 ],
842843 (uint32_t ) dst->ne [2 ],
843- (uint32_t ) ((int32_t *) dst->op_params )[1 ], // swapped
844- *(uint32_t *) &dst->op_params [2 ], // alpha, for swiglu_oai
845- *(uint32_t *) &dst->op_params [3 ], // limit, for swiglu_oai
844+ (uint32_t ) ((int32_t *) dst->op_params )[1 ], // swapped
845+ *(uint32_t *) &dst->op_params [2 ], // alpha, for swiglu_oai
846+ *(uint32_t *) &dst->op_params [3 ], // limit, for swiglu_oai
846847 };
847848
848849 std::vector<wgpu::BindGroupEntry> entries = {
@@ -870,6 +871,45 @@ static void ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tenso
870871 ggml_backend_webgpu_build_and_enqueue (ctx, pipeline, params, entries, wg_x, ggml_op_name (dst->op ));
871872}
872873
874+ static void ggml_webgpu_scale (webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
875+ int inplace = ggml_webgpu_tensor_equal (src, dst);
876+
877+ std::vector<uint32_t > params = {
878+ (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, src) / ggml_type_size (src->type )),
879+ (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, dst) / ggml_type_size (dst->type )),
880+ (uint32_t ) (src->nb [1 ] / ggml_type_size (src->type )),
881+ (uint32_t ) (src->nb [2 ] / ggml_type_size (src->type )),
882+ (uint32_t ) (src->nb [3 ] / ggml_type_size (src->type )),
883+ (uint32_t ) (dst->nb [1 ] / ggml_type_size (dst->type )),
884+ (uint32_t ) (dst->nb [2 ] / ggml_type_size (dst->type )),
885+ (uint32_t ) (dst->nb [3 ] / ggml_type_size (dst->type )),
886+ (uint32_t ) ggml_nelements (dst),
887+ (uint32_t ) src->ne [0 ],
888+ (uint32_t ) src->ne [1 ],
889+ (uint32_t ) src->ne [2 ],
890+ *(uint32_t *) dst->op_params , // scale
891+ *(uint32_t *) &dst->op_params [1 ] // bias
892+ };
893+
894+ std::vector<wgpu::BindGroupEntry> entries = {
895+ { .binding = 0 ,
896+ .buffer = ggml_webgpu_tensor_buf (src),
897+ .offset = ggml_webgpu_tensor_align_offset (ctx, src),
898+ .size = ggml_webgpu_tensor_binding_size (ctx, src) }
899+ };
900+ if (!inplace) {
901+ entries.push_back ({ .binding = 1 ,
902+ .buffer = ggml_webgpu_tensor_buf (dst),
903+ .offset = ggml_webgpu_tensor_align_offset (ctx, dst),
904+ .size = ggml_webgpu_tensor_binding_size (ctx, dst) });
905+ }
906+
907+ size_t max_wg_size = ctx->max_wg_size_x ;
908+ uint32_t wg_x = (ggml_nelements (dst) + max_wg_size - 1 ) / max_wg_size;
909+ ggml_backend_webgpu_build_and_enqueue (ctx, ctx->scale_pipeline [inplace], params, entries, wg_x,
910+ ggml_op_name (dst->op ));
911+ }
912+
873913// Returns true if node has enqueued work into the queue, false otherwise
874914static bool ggml_webgpu_encode_node (webgpu_context ctx, ggml_tensor * node) {
875915 if (ggml_is_empty (node)) {
@@ -934,6 +974,9 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
934974 case GGML_OP_GLU:
935975 ggml_webgpu_glu (ctx, src0, src1, node);
936976 break ;
977+ case GGML_OP_SCALE:
978+ ggml_webgpu_scale (ctx, src0, node);
979+ break ;
937980 default :
938981 return false ;
939982 }
@@ -1449,7 +1492,14 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
14491492 wgsl_geglu_quick_f32_split, " geglu_quick_f32_split" , constants);
14501493 ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->glu_pipeline [GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1 ],
14511494 wgsl_geglu_quick_f16_split, " geglu_quick_f16_split" , constants);
1495+ }
14521496
1497+ static void ggml_webgpu_init_scale_pipeline (webgpu_context & webgpu_ctx) {
1498+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry (webgpu_ctx);
1499+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->scale_pipeline [0 ], wgsl_scale_f32, " scale_f32" ,
1500+ constants);
1501+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->scale_pipeline [1 ], wgsl_scale_f32_inplace,
1502+ " scale_f32_inplace" , constants);
14531503}
14541504
14551505static ggml_backend_t ggml_backend_webgpu_device_init (ggml_backend_dev_t dev, const char * params) {
@@ -1628,6 +1678,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
16281678 break ;
16291679 }
16301680 break ;
1681+ case GGML_OP_SCALE:
1682+ supports_op = op->type == GGML_TYPE_F32;
1683+ break ;
16311684 default :
16321685 break ;
16331686 }
@@ -1758,6 +1811,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
17581811 ggml_webgpu_init_rms_norm_pipeline (ctx);
17591812 ggml_webgpu_init_rope_pipeline (ctx);
17601813 ggml_webgpu_init_glu_pipeline (ctx);
1814+ ggml_webgpu_init_scale_pipeline (ctx);
17611815
17621816#ifdef GGML_WEBGPU_DEBUG
17631817 // Initialize debug buffers
0 commit comments