@@ -180,6 +180,7 @@ struct vk_device_struct {
180180 vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
181181 vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
182182 vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
183+ vk_pipeline pipeline_acc_f32;
183184 vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16;
184185 vk_pipeline pipeline_mul_f32;
185186 vk_pipeline pipeline_div_f32;
@@ -1687,6 +1688,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
16871688 ggml_vk_create_pipeline (device, device->pipeline_add_f32 , " add_f32" , add_f32_len, add_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
16881689 ggml_vk_create_pipeline (device, device->pipeline_add_f16_f32_f16 , " add_f16_f32_f16" , add_f16_f32_f16_len, add_f16_f32_f16_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
16891690
1691+ ggml_vk_create_pipeline (device, device->pipeline_acc_f32 , " acc_f32" , acc_f32_len, acc_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
1692+
16901693 ggml_vk_create_pipeline (device, device->pipeline_mul_f32 , " mul_f32" , mul_f32_len, mul_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
16911694 ggml_vk_create_pipeline (device, device->pipeline_div_f32 , " div_f32" , div_f32_len, div_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
16921695
@@ -3971,6 +3974,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
39713974 return ctx->device ->pipeline_get_rows_f32 [src0->type ];
39723975 }
39733976 return nullptr ;
3977+ case GGML_OP_ACC:
3978+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3979+ return ctx->device ->pipeline_acc_f32 ;
3980+ }
3981+ return nullptr ;
39743982 case GGML_OP_ADD:
39753983 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
39763984 return ctx->device ->pipeline_add_f32 ;
@@ -4463,6 +4471,28 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
44634471 }, dryrun);
44644472}
44654473
4474+ static void ggml_vk_acc (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
4475+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra ;
4476+ const uint32_t src0_type_size = ggml_type_size (src0->type );
4477+ const uint32_t src1_type_size = ggml_type_size (src1->type );
4478+ const uint32_t dst_type_size = ggml_type_size (dst->type );
4479+ const uint32_t d_offset = ((extra->offset + dst->view_offs ) % ctx->device ->properties .limits .minStorageBufferOffsetAlignment ) / dst_type_size;
4480+
4481+ int nb1 = dst->op_params [0 ] / 4 ; // 4 bytes of float32
4482+ int nb2 = dst->op_params [1 ] / 4 ; // 4 bytes of float32
4483+ // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
4484+ int offset = dst->op_params [3 ] / 4 ; // offset in bytes
4485+
4486+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr , dst, GGML_OP_ACC, {
4487+ (uint32_t )ggml_nelements (src0),
4488+ (uint32_t )src0->ne [0 ], (uint32_t )src0->ne [1 ], (uint32_t )src0->ne [2 ],(uint32_t )src0->ne [3 ], (uint32_t )src0->nb [0 ] / src0_type_size, (uint32_t )nb1, (uint32_t )nb2, (uint32_t )src0->nb [3 ] / src0_type_size,
4489+ (uint32_t )src1->ne [0 ], (uint32_t )src1->ne [1 ], (uint32_t )src1->ne [2 ],(uint32_t )src1->ne [3 ], (uint32_t )src1->nb [0 ] / src1_type_size, (uint32_t )src1->nb [1 ] / src1_type_size, (uint32_t )src1->nb [2 ] / src1_type_size, (uint32_t )src1->nb [3 ] / src1_type_size,
4490+ (uint32_t ) dst->ne [0 ], (uint32_t ) dst->ne [1 ], (uint32_t ) dst->ne [2 ],(uint32_t ) dst->ne [3 ], (uint32_t ) dst->nb [0 ] / dst_type_size, (uint32_t )nb1, (uint32_t )nb2, (uint32_t ) dst->nb [3 ] / dst_type_size,
4491+ d_offset,
4492+ 0 .0f , 0 .0f , offset,
4493+ }, dryrun);
4494+ }
4495+
44664496static void ggml_vk_add (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
44674497 const uint32_t src0_type_size = ggml_type_size (src0->type );
44684498 const uint32_t src1_type_size = ggml_type_size (src1->type );
@@ -5621,6 +5651,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
56215651 case GGML_OP_REPEAT:
56225652 case GGML_OP_GET_ROWS:
56235653 case GGML_OP_ADD:
5654+ case GGML_OP_ACC:
56245655 case GGML_OP_MUL:
56255656 case GGML_OP_DIV:
56265657 case GGML_OP_CONCAT:
@@ -5668,6 +5699,10 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
56685699 case GGML_OP_REPEAT:
56695700 ggml_vk_repeat (ctx, compute_ctx, src0, node, dryrun);
56705701
5702+ break ;
5703+ case GGML_OP_ACC:
5704+ ggml_vk_acc (ctx, compute_ctx, src0, src1, node, dryrun);
5705+
56715706 break ;
56725707 case GGML_OP_GET_ROWS:
56735708 ggml_vk_get_rows (ctx, compute_ctx, src0, src1, node, dryrun);
@@ -5808,6 +5843,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
58085843
58095844 switch (tensor->op ) {
58105845 case GGML_OP_ADD:
5846+ case GGML_OP_ACC:
58115847 case GGML_OP_GET_ROWS:
58125848 case GGML_OP_MUL:
58135849 case GGML_OP_DIV:
@@ -6539,6 +6575,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
65396575 case GGML_OP_GROUP_NORM:
65406576 case GGML_OP_RMS_NORM:
65416577 case GGML_OP_ADD:
6578+ case GGML_OP_ACC:
65426579 case GGML_OP_MUL:
65436580 case GGML_OP_DIV:
65446581 case GGML_OP_CONCAT:
@@ -6995,6 +7032,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
69957032 tensor_clone = ggml_repeat (ggml_ctx, src0_clone, src1_clone);
69967033 } else if (tensor->op == GGML_OP_ADD) {
69977034 tensor_clone = ggml_add (ggml_ctx, src0_clone, src1_clone);
7035+ } else if (tensor->op == GGML_OP_ACC) {
7036+ tensor_clone = ggml_acc (ggml_ctx, src0_clone, src1_clone, tensor->op_params [0 ], tensor->op_params [1 ], tensor->op_params [2 ], tensor->op_params [3 ]);
69987037 } else if (tensor->op == GGML_OP_NORM) {
69997038 tensor_clone = ggml_norm (ggml_ctx, src0_clone, *(float *)tensor->op_params );
70007039 } else if (tensor->op == GGML_OP_GROUP_NORM) {
0 commit comments