Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 17 additions & 19 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ struct vk_device_struct {
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
vk_pipeline pipeline_opt_step_sgd_f32;
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv2d_dw_whcn_f32;
Expand Down Expand Up @@ -3085,6 +3086,8 @@ static void ggml_vk_load_shaders(vk_device& device) {

ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);

// conv2d
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
uint32_t conv2d_WG_SIZE = 256;
Expand Down Expand Up @@ -7120,7 +7123,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
case GGML_OP_OPT_STEP_SGD:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
// TODO
return ctx->device->pipeline_opt_step_sgd_f32;
}
return nullptr;
case GGML_OP_LEAKY_RELU:
Expand Down Expand Up @@ -7599,6 +7602,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_OPT_STEP_SGD) {
// OPT_STEP_SGD works on src0, it does not need dst
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
} else if (use_src2) {
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
Expand Down Expand Up @@ -7937,18 +7944,10 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
);
}

static void ggml_vk_op_f32_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
GGML_ASSERT(0 && "SGD vulkan unimplemented"); // TODO
}

static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
const size_t n = ggml_nelements(dst->src[0]);

ggml_vk_op_f32_opt_step_sgd(
ctx, subctx, dst,
{ (uint32_t)n, 0, 0.0f, 0.0f },
dryrun
);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
}

static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
Expand Down Expand Up @@ -9489,6 +9488,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
break;
default:
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
Expand Down Expand Up @@ -9553,6 +9553,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_LEAKY_RELU:
case GGML_OP_OPT_STEP_SGD:
{
// These operations all go through ggml_vk_op_f32, so short-circuit and
// do the only thing needed for the dryrun.
Expand Down Expand Up @@ -9800,8 +9801,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;

case GGML_OP_OPT_STEP_SGD:
return false; // TODO
ggml_vk_opt_step_sgd(ctx, compute_ctx, node, dryrun);
ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node, dryrun);

break;
default:
Expand Down Expand Up @@ -9905,10 +9905,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
buf = tensor->buffer;
break;
case GGML_OP_OPT_STEP_SGD:
return false;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(tensor)) {
case GGML_UNARY_OP_SILU:
Expand Down Expand Up @@ -11036,6 +11035,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_LEAKY_RELU:
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_UPSCALE:
case GGML_OP_ACC:
Expand All @@ -11057,11 +11059,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU:
case GGML_OP_OPT_STEP_ADAMW:
return true;
case GGML_OP_OPT_STEP_SGD:
return false;
case GGML_OP_CONV_TRANSPOSE_1D:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_CONV_2D:
Expand Down
25 changes: 25 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#version 450

#include "generic_head.comp"
#include "types.comp"

#extension GL_EXT_control_flow_attributes : enable

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) buffer X {A_TYPE data_x[];};
layout (binding = 1) readonly buffer G {A_TYPE data_grad[];};
layout (binding = 2) readonly buffer P {float data_params[2];};

void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;

if (i >= p.KX) {
return;
}

const float alpha = data_params[0];
const float keep = data_params[1];

data_x[i] = data_x[i] * keep - alpha * data_grad[i];
}
1 change: 1 addition & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ void process_shaders() {
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
Expand Down
6 changes: 4 additions & 2 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1006,8 +1006,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS",
"CROSS_ENTROPY_LOSS_BACK",
"OPT_STEP_ADAMW",
"GLU",
"OPT_STEP_SGD",

"GLU",
};

static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
Expand Down Expand Up @@ -1106,8 +1107,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss(x,y)",
"cross_entropy_loss_back(x,y)",
"adamw(x)",
"glu(x)",
"sgd(x)",

"glu(x)",
};

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
Expand Down
11 changes: 3 additions & 8 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5110,7 +5110,7 @@ static const ggml_type other_types[] = {
};

// Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low
static std::vector<std::unique_ptr<test_case>> make_test_cases_eval(bool test_sgd = true) {
static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
std::vector<std::unique_ptr<test_case>> test_cases;
std::default_random_engine rng(0);

Expand Down Expand Up @@ -5912,8 +5912,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval(bool test_sg
test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {30000, 1, 1, 1}));

test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
if (test_sgd)
test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, { 10, 5, 4, 3 }));
test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, { 10, 5, 4, 3 }));

#if 0
// these tests are disabled to save execution time, sbut they can be handy for debugging
Expand Down Expand Up @@ -6051,10 +6050,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
};

char const* name = ggml_backend_name(backend);
bool const vulkan = strstr(name, "ulkan");
bool const sgd = !vulkan;

if (mode == MODE_TEST) {
auto test_cases = make_test_cases_eval();
filter_test_cases(test_cases, params_filter);
Expand All @@ -6080,7 +6075,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}

if (mode == MODE_GRAD) {
auto test_cases = make_test_cases_eval(sgd);
auto test_cases = make_test_cases_eval();
filter_test_cases(test_cases, params_filter);
size_t n_ok = 0;
for (auto & test : test_cases) {
Expand Down
Loading