@@ -228,6 +228,8 @@ struct vk_device_struct {
228228    vk_pipeline pipeline_repeat_f32;
229229    vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
230230    vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
231+     vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
232+     vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
231233    vk_pipeline pipeline_norm_f32;
232234    vk_pipeline pipeline_group_norm_f32;
233235    vk_pipeline pipeline_rms_norm_f32;
@@ -1965,6 +1967,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
19651967    ggml_vk_create_pipeline (device, device->pipeline_contig_cpy_f32_f16 , " contig_cpy_f32_f16" " main" 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
19661968    ggml_vk_create_pipeline (device, device->pipeline_contig_cpy_f16_f16 , " contig_cpy_f16_f16" " main" 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
19671969
1970+     ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_quant [GGML_TYPE_Q4_0], " cpy_f32_q4_0" " main" 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q4_0), 1 , 1 }, {}, 1 );
1971+     ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_quant [GGML_TYPE_Q4_1], " cpy_f32_q4_1" " main" 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q4_1), 1 , 1 }, {}, 1 );
1972+     ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_quant [GGML_TYPE_Q5_0], " cpy_f32_q5_0" " main" 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q5_0), 1 , 1 }, {}, 1 );
1973+     ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_quant [GGML_TYPE_Q5_1], " cpy_f32_q5_1" " main" 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q5_1), 1 , 1 }, {}, 1 );
1974+     ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_quant [GGML_TYPE_Q8_0], " cpy_f32_q8_0" " main" 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q8_0), 1 , 1 }, {}, 1 );
1975+     ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_quant [GGML_TYPE_IQ4_NL], " cpy_f32_iq4_nl" " main" 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_IQ4_NL), 1 , 1 }, {}, 1 );
1976+ 
1977+     ggml_vk_create_pipeline (device, device->pipeline_cpy_quant_f32 [GGML_TYPE_Q4_0], " cpy_q4_0_f32" " main" 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q4_0), 1 , 1 }, {}, 1 );
1978+     ggml_vk_create_pipeline (device, device->pipeline_cpy_quant_f32 [GGML_TYPE_Q4_1], " cpy_q4_1_f32" " main" 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q4_1), 1 , 1 }, {}, 1 );
1979+     ggml_vk_create_pipeline (device, device->pipeline_cpy_quant_f32 [GGML_TYPE_Q5_0], " cpy_q5_0_f32" " main" 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q5_0), 1 , 1 }, {}, 1 );
1980+     ggml_vk_create_pipeline (device, device->pipeline_cpy_quant_f32 [GGML_TYPE_Q5_1], " cpy_q5_1_f32" " main" 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q5_1), 1 , 1 }, {}, 1 );
1981+     ggml_vk_create_pipeline (device, device->pipeline_cpy_quant_f32 [GGML_TYPE_Q8_0], " cpy_q8_0_f32" " main" 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q8_0), 1 , 1 }, {}, 1 );
1982+     ggml_vk_create_pipeline (device, device->pipeline_cpy_quant_f32 [GGML_TYPE_IQ4_NL], " cpy_iq4_nl_f32" " main" 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_IQ4_NL), 1 , 1 }, {}, 1 );
1983+ 
19681984    ggml_vk_create_pipeline (device, device->pipeline_add_f32 , " add_f32" " main" 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {0 }, 1 );
19691985    ggml_vk_create_pipeline (device, device->pipeline_add_f32_norepeat , " add_f32_norepeat" " main" 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {1 }, 1 );
19701986    ggml_vk_create_pipeline (device, device->pipeline_add_f16_f32_f16 , " add_f16_f32_f16" " main" 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {0 }, 1 );
@@ -3689,6 +3705,33 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
36893705            return  ctx->device ->pipeline_cpy_f16_f16 ;
36903706        }
36913707    }
3708+     if  (src->type  == GGML_TYPE_F32) {
3709+         switch  (to) {
3710+         case  GGML_TYPE_Q4_0:
3711+         case  GGML_TYPE_Q4_1:
3712+         case  GGML_TYPE_Q5_0:
3713+         case  GGML_TYPE_Q5_1:
3714+         case  GGML_TYPE_Q8_0:
3715+         case  GGML_TYPE_IQ4_NL:
3716+             return  ctx->device ->pipeline_cpy_f32_quant [to];
3717+         default :
3718+             break ;
3719+         }
3720+     }
3721+ 
3722+     if  (to == GGML_TYPE_F32) {
3723+         switch  (src->type ) {
3724+         case  GGML_TYPE_Q4_0:
3725+         case  GGML_TYPE_Q4_1:
3726+         case  GGML_TYPE_Q5_0:
3727+         case  GGML_TYPE_Q5_1:
3728+         case  GGML_TYPE_Q8_0:
3729+         case  GGML_TYPE_IQ4_NL:
3730+             return  ctx->device ->pipeline_cpy_quant_f32 [src->type ];
3731+         default :
3732+             break ;
3733+         }
3734+     }
36923735
36933736    std::cerr << " Missing CPY op for types: " ggml_type_name (src->type ) << "  " ggml_type_name (to) << std::endl;
36943737    GGML_ABORT (" fatal error" 
@@ -5160,7 +5203,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
51605203    }
51615204    std::cerr << " ), (" " , name=" name  << " , type=" type  << " , ne0=" ne [0 ] << " , ne1=" ne [1 ] << " , ne2=" ne [2 ] << " , ne3=" ne [3 ] << " , nb0=" nb [0 ] << " , nb1=" nb [1 ] << " , nb2=" nb [2 ] << " , nb3=" nb [3 ];
51625205    std::cerr << " ), " ggml_op_name (op) << " , " " dryrun" " " " )" 
5163-     GGML_ASSERT (op == GGML_OP_GET_ROWS || (!ggml_is_quantized (src0->type ) && (src1 == nullptr  || !ggml_is_quantized (src1->type ))));  //  NOLINT
5206+     GGML_ASSERT (op == GGML_OP_GET_ROWS || op == GGML_OP_CPY ||  (!ggml_is_quantized (src0->type ) && (src1 == nullptr  || !ggml_is_quantized (src1->type ))));  //  NOLINT
51645207    GGML_ASSERT (ggml_vk_op_supports_incontiguous (op) || ggml_vk_dim01_contiguous (src0));  //  NOLINT
51655208    GGML_ASSERT (dst->buffer  != nullptr );
51665209    const  uint64_t  ne00 = src0->ne [0 ];
@@ -7905,12 +7948,36 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
79057948            {
79067949                ggml_type src0_type = op->src [0 ]->type ;
79077950                ggml_type src1_type = op->src [1 ] != nullptr  ? op->src [1 ]->type  : src0_type;
7908-                 if  (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
7909-                     return  true ;
7951+ 
7952+                 if  (src0_type == GGML_TYPE_F32) {
7953+                     switch  (src1_type) {
7954+                     case  GGML_TYPE_F32:
7955+                     case  GGML_TYPE_F16:
7956+                     case  GGML_TYPE_Q4_0:
7957+                     case  GGML_TYPE_Q4_1:
7958+                     case  GGML_TYPE_Q5_0:
7959+                     case  GGML_TYPE_Q5_1:
7960+                     case  GGML_TYPE_Q8_0:
7961+                     case  GGML_TYPE_IQ4_NL:
7962+                         return  true ;
7963+                     default :
7964+                         break ;
7965+                     }
79107966                }
7911-                 if  (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
7912-                     return  true ;
7967+                 if  (src1_type == GGML_TYPE_F32) {
7968+                     switch  (src0_type) {
7969+                     case  GGML_TYPE_Q4_0:
7970+                     case  GGML_TYPE_Q4_1:
7971+                     case  GGML_TYPE_Q5_0:
7972+                     case  GGML_TYPE_Q5_1:
7973+                     case  GGML_TYPE_Q8_0:
7974+                     case  GGML_TYPE_IQ4_NL:
7975+                         return  true ;
7976+                     default :
7977+                         break ;
7978+                     }
79137979                }
7980+ 
79147981                if  (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
79157982                    return  true ;
79167983                }
0 commit comments