@@ -130,7 +130,7 @@ struct webgpu_context_struct {
130130    wgpu::ComputePipeline set_rows_pipeline;
131131    wgpu::ComputePipeline get_rows_pipeline[30 ];
132132    wgpu::ComputePipeline get_rows_f32_no_vec_pipeline;
133-     wgpu::ComputePipeline cpy_pipeline; 
133+     wgpu::ComputePipeline cpy_pipeline[ 2 ][ 2 ];       //  src type, dst type 
134134    wgpu::ComputePipeline add_pipeline[2 ][2 ];      //  type, inplace
135135    wgpu::ComputePipeline sub_pipeline[2 ][2 ];      //  type, inplace
136136    wgpu::ComputePipeline mul_pipeline[2 ][2 ];      //  type, inplace
@@ -491,8 +491,9 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
491491        (uint32_t ) (src->nb [2 ] / ggml_type_size (src->type )), (uint32_t ) (src->nb [3 ] / ggml_type_size (src->type )),
492492        (uint32_t ) (dst->nb [0 ] / ggml_type_size (dst->type )), (uint32_t ) (dst->nb [1 ] / ggml_type_size (dst->type )),
493493        (uint32_t ) (dst->nb [2 ] / ggml_type_size (dst->type )), (uint32_t ) (dst->nb [3 ] / ggml_type_size (dst->type )),
494-         //  Logical shape — same for both tensors even if permuted
495-         (uint32_t ) src->ne [0 ], (uint32_t ) src->ne [1 ], (uint32_t ) src->ne [2 ], (uint32_t ) src->ne [3 ]
494+         //  Logical shapes
495+         (uint32_t ) src->ne [0 ], (uint32_t ) src->ne [1 ], (uint32_t ) src->ne [2 ], (uint32_t ) dst->ne [0 ],
496+         (uint32_t ) dst->ne [1 ], (uint32_t ) dst->ne [2 ]
496497    };
497498
498499    std::vector<wgpu::BindGroupEntry> entries = {
@@ -508,7 +509,8 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
508509
509510    size_t    max_wg_size = ctx->max_wg_size_x ;
510511    uint32_t  wg_x        = (ne + max_wg_size - 1 ) / max_wg_size;
511-     ggml_backend_webgpu_build_and_enqueue (ctx, ctx->cpy_pipeline , params, entries, wg_x, ggml_op_name (dst->op ));
512+     ggml_backend_webgpu_build_and_enqueue (ctx, ctx->cpy_pipeline [src->type ][dst->type ], params, entries, wg_x,
513+                                           ggml_op_name (dst->op ));
512514}
513515
514516static  void  ggml_webgpu_set_rows (webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
@@ -930,6 +932,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
930932        case  GGML_OP_RESHAPE:
931933            return  false ;
932934        case  GGML_OP_CPY:
935+         case  GGML_OP_CONT:
933936            ggml_webgpu_cpy (ctx, src0, node);
934937            break ;
935938        case  GGML_OP_SET_ROWS:
@@ -1360,8 +1363,15 @@ static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
13601363}
13611364
13621365static  void  ggml_webgpu_init_cpy_pipeline (webgpu_context & webgpu_ctx) {
1363-     ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->cpy_pipeline , wgsl_cpy, " cpy"  ,
1364-                                 ggml_webgpu_max_wg_size_entry (webgpu_ctx));
1366+     std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry (webgpu_ctx);
1367+     ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->cpy_pipeline [GGML_TYPE_F32][GGML_TYPE_F32],
1368+                                 wgsl_cpy_f32_f32, " cpy_f32_f32"  , constants);
1369+     ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->cpy_pipeline [GGML_TYPE_F32][GGML_TYPE_F16],
1370+                                 wgsl_cpy_f32_f16, " cpy_f32_f16"  , constants);
1371+     ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->cpy_pipeline [GGML_TYPE_F16][GGML_TYPE_F32],
1372+                                 wgsl_cpy_f16_f32, " cpy_f16_f32"  , constants);
1373+     ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->cpy_pipeline [GGML_TYPE_F16][GGML_TYPE_F16],
1374+                                 wgsl_cpy_f16_f16, " cpy_f16_f16"  , constants);
13651375}
13661376
13671377static  void  ggml_webgpu_init_add_pipeline (webgpu_context & webgpu_ctx) {
@@ -1608,6 +1618,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
16081618                          (src1->type  == op->type );
16091619            break ;
16101620        case  GGML_OP_CPY:
1621+         case  GGML_OP_CONT:
1622+             supports_op = (op->type  == GGML_TYPE_F32 || op->type  == GGML_TYPE_F16) &&
1623+                           (src0->type  == GGML_TYPE_F32 || src0->type  == GGML_TYPE_F16);
1624+             break ;
16111625        case  GGML_OP_SET_ROWS:
16121626            supports_op = (op->type  == GGML_TYPE_F16 && src0->type  == GGML_TYPE_F32);
16131627            break ;
0 commit comments