diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 80e3b89d58054..7a9990ab54f61 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -507,6 +507,7 @@ struct vk_device_struct { vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; vk_pipeline pipeline_opt_step_adamw_f32; + vk_pipeline pipeline_opt_step_sgd_f32; vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT]; vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT]; vk_pipeline pipeline_conv2d_dw_whcn_f32; @@ -3085,6 +3086,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + // conv2d for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) { uint32_t conv2d_WG_SIZE = 256; @@ -7120,7 +7123,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_OPT_STEP_SGD: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - // TODO + return ctx->device->pipeline_opt_step_sgd_f32; } return nullptr; case GGML_OP_LEAKY_RELU: @@ -7599,6 +7602,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz); ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (op == GGML_OP_OPT_STEP_SGD) { + // OPT_STEP_SGD works on src0, it does not need dst + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements); } else if (use_src2) { ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); @@ -7937,18 +7944,10 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su ); } -static void ggml_vk_op_f32_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) { - GGML_ASSERT(0 && "SGD vulkan unimplemented"); // TODO -} - -static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { +static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { const size_t n = ggml_nelements(dst->src[0]); - ggml_vk_op_f32_opt_step_sgd( - ctx, subctx, dst, - { (uint32_t)n, 0, 0.0f, 0.0f }, - dryrun - ); + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun); } static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -9489,6 +9488,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_LEAKY_RELU: case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: break; default: std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl; @@ -9553,6 +9553,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_CONV_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_LEAKY_RELU: + case GGML_OP_OPT_STEP_SGD: { // These operations all go through ggml_vk_op_f32, so short-circuit and // do the only thing needed for the dryrun. @@ -9800,8 +9801,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_OPT_STEP_SGD: - return false; // TODO - ggml_vk_opt_step_sgd(ctx, compute_ctx, node, dryrun); + ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node, dryrun); break; default: @@ -9905,10 +9905,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: buf = tensor->buffer; break; - case GGML_OP_OPT_STEP_SGD: - return false; case GGML_OP_UNARY: switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_SILU: @@ -11036,6 +11035,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: + case GGML_OP_LEAKY_RELU: + case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_UPSCALE: case GGML_OP_ACC: @@ -11057,11 +11059,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_POOL_2D: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: - case GGML_OP_LEAKY_RELU: - case GGML_OP_OPT_STEP_ADAMW: return true; - case GGML_OP_OPT_STEP_SGD: - return false; case GGML_OP_CONV_TRANSPOSE_1D: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_CONV_2D: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp new file mode 100644 index 0000000000000..3d5e1d98fda48 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp @@ -0,0 +1,25 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) buffer X {A_TYPE data_x[];}; +layout (binding = 1) readonly buffer G {A_TYPE data_grad[];}; +layout (binding = 2) readonly buffer P {float data_params[2];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float alpha = data_params[0]; + const float keep = data_params[1]; + + data_x[i] = data_x[i] * keep - alpha * data_grad[i]; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index c6aa3ea4c79e0..b32c440d23740 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -654,6 +654,7 @@ void process_shaders() { string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}}); string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}}); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 688b57c1a00dc..244319a48a108 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1006,8 +1006,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", "OPT_STEP_ADAMW", - "GLU", "OPT_STEP_SGD", + + "GLU", }; static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); @@ -1106,8 +1107,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss(x,y)", "cross_entropy_loss_back(x,y)", "adamw(x)", - "glu(x)", "sgd(x)", + + "glu(x)", }; static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2b1bbb2b34549..11ad193f0c0ce 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5110,7 +5110,7 @@ static const ggml_type other_types[] = { }; // Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low -static std::vector> make_test_cases_eval(bool test_sgd = true) { +static std::vector> make_test_cases_eval() { std::vector> test_cases; std::default_random_engine rng(0); @@ -5912,8 +5912,7 @@ static std::vector> make_test_cases_eval(bool test_sg test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {30000, 1, 1, 1})); test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3})); - if (test_sgd) - test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, { 10, 5, 4, 3 })); + test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, { 10, 5, 4, 3 })); #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging @@ -6051,10 +6050,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } }; - char const* name = ggml_backend_name(backend); - bool const vulkan = strstr(name, "ulkan"); - bool const sgd = !vulkan; - if (mode == MODE_TEST) { auto test_cases = make_test_cases_eval(); filter_test_cases(test_cases, params_filter); @@ -6080,7 +6075,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } if (mode == MODE_GRAD) { - auto test_cases = make_test_cases_eval(sgd); + auto test_cases = make_test_cases_eval(); filter_test_cases(test_cases, params_filter); size_t n_ok = 0; for (auto & test : test_cases) {