@@ -483,6 +483,7 @@ struct vk_device_struct {
483483 vk_pipeline pipeline_rwkv_wkv6_f32;
484484 vk_pipeline pipeline_rwkv_wkv7_f32;
485485 vk_pipeline pipeline_opt_step_adamw_f32;
486+ vk_pipeline pipeline_opt_step_sgd_f32;
486487 vk_pipeline pipeline_conv2d_f32;
487488 vk_pipeline pipeline_conv2d_dw_whcn_f32;
488489 vk_pipeline pipeline_conv2d_dw_cwhn_f32;
@@ -3046,6 +3047,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
30463047
30473048 ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
30483049
3050+ ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3051+
30493052 // conv2d
30503053 uint32_t conv2d_WG_SIZE = 256;
30513054 uint32_t conv2d_BS_K = 128;
@@ -6954,7 +6957,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
69546957 return nullptr;
69556958 case GGML_OP_OPT_STEP_SGD:
69566959 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6957- // TODO
6960+ return ctx->device->pipeline_opt_step_sgd_f32;
69586961 }
69596962 return nullptr;
69606963 case GGML_OP_LEAKY_RELU:
@@ -7430,6 +7433,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
74307433 ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
74317434 ggml_vk_sync_buffers(subctx);
74327435 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
7436+ } else if (op == GGML_OP_OPT_STEP_SGD) {
7437+ // OPT_STEP_SGD works on src0, it does not need dst
7438+ ggml_vk_sync_buffers(subctx);
7439+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
74337440 } else if (use_src2) {
74347441 ggml_vk_sync_buffers(subctx);
74357442 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
@@ -7768,18 +7775,10 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
77687775 );
77697776}
77707777
7771- static void ggml_vk_op_f32_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
7772- GGML_ASSERT(0 && "SGD vulkan unimplemented"); // TODO
7773- }
7774-
7775- static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
7778+ static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
77767779 const size_t n = ggml_nelements(dst->src[0]);
77777780
7778- ggml_vk_op_f32_opt_step_sgd(
7779- ctx, subctx, dst,
7780- { (uint32_t)n, 0, 0.0f, 0.0f },
7781- dryrun
7782- );
7781+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
77837782}
77847783
77857784static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9313,6 +9312,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
93139312 case GGML_OP_LEAKY_RELU:
93149313 case GGML_OP_FLASH_ATTN_EXT:
93159314 case GGML_OP_OPT_STEP_ADAMW:
9315+ case GGML_OP_OPT_STEP_SGD:
93169316 break;
93179317 default:
93189318 std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@@ -9377,6 +9377,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
93779377 case GGML_OP_CONV_2D:
93789378 case GGML_OP_CONV_2D_DW:
93799379 case GGML_OP_LEAKY_RELU:
9380+ case GGML_OP_OPT_STEP_SGD:
93809381 {
93819382 // These operations all go through ggml_vk_op_f32, so short-circuit and
93829383 // do the only thing needed for the dryrun.
@@ -9624,8 +9625,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
96249625 break;
96259626
96269627 case GGML_OP_OPT_STEP_SGD:
9627- return false; // TODO
9628- ggml_vk_opt_step_sgd(ctx, compute_ctx, node, dryrun);
9628+ ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node, dryrun);
96299629
96309630 break;
96319631 default:
@@ -9729,10 +9729,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
97299729 case GGML_OP_REPEAT:
97309730 case GGML_OP_REPEAT_BACK:
97319731 case GGML_OP_OPT_STEP_ADAMW:
9732+ case GGML_OP_OPT_STEP_SGD:
97329733 buf = tensor->buffer;
97339734 break;
9734- case GGML_OP_OPT_STEP_SGD:
9735- return false;
97369735 case GGML_OP_UNARY:
97379736 switch (ggml_get_unary_op(tensor)) {
97389737 case GGML_UNARY_OP_SILU:
@@ -10860,6 +10859,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1086010859 case GGML_OP_SIN:
1086110860 case GGML_OP_COS:
1086210861 case GGML_OP_CLAMP:
10862+ case GGML_OP_LEAKY_RELU:
10863+ case GGML_OP_OPT_STEP_ADAMW:
10864+ case GGML_OP_OPT_STEP_SGD:
1086310865 return op->src[0]->type == GGML_TYPE_F32;
1086410866 case GGML_OP_UPSCALE:
1086510867 case GGML_OP_ACC:
@@ -10881,11 +10883,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1088110883 case GGML_OP_POOL_2D:
1088210884 case GGML_OP_RWKV_WKV6:
1088310885 case GGML_OP_RWKV_WKV7:
10884- case GGML_OP_LEAKY_RELU:
10885- case GGML_OP_OPT_STEP_ADAMW:
1088610886 return true;
10887- case GGML_OP_OPT_STEP_SGD:
10888- return false;
1088910887 case GGML_OP_CONV_TRANSPOSE_1D:
1089010888 return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
1089110889 case GGML_OP_CONV_2D:
0 commit comments