@@ -510,6 +510,7 @@ struct vk_device_struct {
510510    vk_pipeline pipeline_rwkv_wkv6_f32;
511511    vk_pipeline pipeline_rwkv_wkv7_f32;
512512    vk_pipeline pipeline_opt_step_adamw_f32;
513+     vk_pipeline pipeline_opt_step_sgd_f32;
513514    vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
514515    vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
515516    vk_pipeline pipeline_conv2d_dw_whcn_f32;
@@ -3120,6 +3121,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
31203121
31213122    ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
31223123
3124+     ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3125+ 
31233126    // conv2d
31243127    for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
31253128        uint32_t conv2d_WG_SIZE  = 256;
@@ -7169,7 +7172,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
71697172        return nullptr;
71707173    case GGML_OP_OPT_STEP_SGD:
71717174        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7172-             // TODO 
7175+             return ctx->device->pipeline_opt_step_sgd_f32; 
71737176        }
71747177        return nullptr;
71757178    case GGML_OP_LEAKY_RELU:
@@ -7671,6 +7674,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
76717674        ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
76727675        ggml_vk_sync_buffers(subctx);
76737676        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
7677+     } else if (op == GGML_OP_OPT_STEP_SGD) {
7678+         // OPT_STEP_SGD works on src0, it does not need dst
7679+         ggml_vk_sync_buffers(subctx);
7680+         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
76747681    } else if (use_src2) {
76757682        ggml_vk_sync_buffers(subctx);
76767683        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
@@ -8024,18 +8031,10 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
80248031    );
80258032}
80268033
8027- static void ggml_vk_op_f32_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
8028-     GGML_ASSERT(0 && "SGD vulkan unimplemented"); // TODO
8029- }
8030- 
8031- static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
8034+ static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
80328035    const size_t n = ggml_nelements(dst->src[0]);
80338036
8034-     ggml_vk_op_f32_opt_step_sgd(
8035-         ctx, subctx, dst,
8036-         { (uint32_t)n, 0, 0.0f, 0.0f },
8037-         dryrun
8038-     );
8037+     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
80398038}
80408039
80418040static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9591,6 +9590,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
95919590    case GGML_OP_LEAKY_RELU:
95929591    case GGML_OP_FLASH_ATTN_EXT:
95939592    case GGML_OP_OPT_STEP_ADAMW:
9593+     case GGML_OP_OPT_STEP_SGD:
95949594        break;
95959595    default:
95969596        std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@@ -9655,6 +9655,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
96559655        case GGML_OP_CONV_2D:
96569656        case GGML_OP_CONV_2D_DW:
96579657        case GGML_OP_LEAKY_RELU:
9658+         case GGML_OP_OPT_STEP_SGD:
96589659            {
96599660                // These operations all go through ggml_vk_op_f32, so short-circuit and
96609661                // do the only thing needed for the dryrun.
@@ -9907,8 +9908,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
99079908        break;
99089909
99099910    case GGML_OP_OPT_STEP_SGD:
9910-         return false; // TODO
9911-         ggml_vk_opt_step_sgd(ctx, compute_ctx, node, dryrun);
9911+         ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node, dryrun);
99129912
99139913        break;
99149914    default:
@@ -10013,10 +10013,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1001310013    case GGML_OP_REPEAT:
1001410014    case GGML_OP_REPEAT_BACK:
1001510015    case GGML_OP_OPT_STEP_ADAMW:
10016+     case GGML_OP_OPT_STEP_SGD:
1001610017        buf = tensor->buffer;
1001710018        break;
10018-     case GGML_OP_OPT_STEP_SGD:
10019-         return false;
1002010019    case GGML_OP_UNARY:
1002110020        switch (ggml_get_unary_op(tensor)) {
1002210021        case GGML_UNARY_OP_SILU:
@@ -11155,6 +11154,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1115511154        case GGML_OP_SIN:
1115611155        case GGML_OP_COS:
1115711156        case GGML_OP_CLAMP:
11157+         case GGML_OP_LEAKY_RELU:
11158+         case GGML_OP_OPT_STEP_ADAMW:
11159+         case GGML_OP_OPT_STEP_SGD:
1115811160            return op->src[0]->type == GGML_TYPE_F32;
1115911161        case GGML_OP_UPSCALE:
1116011162        case GGML_OP_ACC:
@@ -11176,11 +11178,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1117611178        case GGML_OP_POOL_2D:
1117711179        case GGML_OP_RWKV_WKV6:
1117811180        case GGML_OP_RWKV_WKV7:
11179-         case GGML_OP_LEAKY_RELU:
11180-         case GGML_OP_OPT_STEP_ADAMW:
1118111181            return true;
11182-         case GGML_OP_OPT_STEP_SGD:
11183-             return false;
1118411182        case GGML_OP_CONV_TRANSPOSE_1D:
1118511183            return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
1118611184        case GGML_OP_CONV_2D:
0 commit comments