diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 36b23dc6d0d82..5028a9cebf260 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -237,6 +237,8 @@ #define GGML_EXIT_SUCCESS 0 #define GGML_EXIT_ABORTED 1 +// TODO: convert to enum https://github.com/ggml-org/llama.cpp/pull/16187#discussion_r2388538726 +#define GGML_ROPE_TYPE_NORMAL 0 #define GGML_ROPE_TYPE_NEOX 2 #define GGML_ROPE_TYPE_MROPE 8 #define GGML_ROPE_TYPE_VISION 24 diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index cee4b08366d79..93200a4d29f53 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -130,13 +130,15 @@ struct webgpu_context_struct { wgpu::ComputePipeline set_rows_pipeline; wgpu::ComputePipeline get_rows_pipeline[30]; wgpu::ComputePipeline get_rows_f32_no_vec_pipeline; - wgpu::ComputePipeline cpy_pipeline; - wgpu::ComputePipeline add_pipeline[2]; - wgpu::ComputePipeline add_ip_pipeline[2]; - wgpu::ComputePipeline mul_pipeline[2]; - wgpu::ComputePipeline mul_ip_pipeline[2]; - wgpu::ComputePipeline rms_norm_pipeline; - wgpu::ComputePipeline rms_norm_ip_pipeline; + wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type + wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace + wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace + wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split + wgpu::ComputePipeline scale_pipeline[2]; // inplace size_t memset_bytes_per_thread; @@ -489,8 +491,9 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), - // Logical shape — same for both tensors even if permuted - (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3] + // Logical shapes + (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], (uint32_t) dst->ne[2] }; std::vector entries = { @@ -506,7 +509,8 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x, + ggml_op_name(dst->op)); } static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { @@ -649,7 +653,7 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * dst, wgpu::ComputePipeline & pipeline, - bool in_place) { + bool inplace) { std::vector params = { (uint32_t) ggml_nelements(dst), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), @@ -678,7 +682,7 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx, .offset = ggml_webgpu_tensor_align_offset(ctx, src1), .size = ggml_webgpu_tensor_binding_size(ctx, src1) } }; - if (!in_place) { + if (!inplace) { entries.push_back({ .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst), @@ -691,30 +695,23 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx, } static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - bool in_place = ggml_webgpu_tensor_equal(src, dst); - - uint32_t eps; - memcpy(&eps, dst->op_params, sizeof(float)); + int inplace = ggml_webgpu_tensor_equal(src, dst); std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + (uint32_t) src->ne[3], + *(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader }; - if (!in_place) { - params.push_back((uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type))); - } - params.push_back((uint32_t) (src->nb[1] / ggml_type_size(src->type))); - params.push_back((uint32_t) (src->nb[2] / ggml_type_size(src->type))); - params.push_back((uint32_t) (src->nb[3] / ggml_type_size(src->type))); - if (!in_place) { - params.push_back((uint32_t) (dst->nb[1] / ggml_type_size(dst->type))); - params.push_back((uint32_t) (dst->nb[2] / ggml_type_size(dst->type))); - params.push_back((uint32_t) (dst->nb[3] / ggml_type_size(dst->type))); - } - params.push_back((uint32_t) src->ne[0]); - params.push_back((uint32_t) src->ne[1]); - params.push_back((uint32_t) src->ne[2]); - params.push_back((uint32_t) src->ne[3]); - params.push_back(eps); // epsilon, will be bitcast to float in shader std::vector entries = { { .binding = 0, @@ -722,24 +719,199 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t .offset = ggml_webgpu_tensor_align_offset(ctx, src), .size = ggml_webgpu_tensor_binding_size(ctx, src) } }; - if (!in_place) { + if (!inplace) { entries.push_back({ .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - wgpu::ComputePipeline pipeline; - if (in_place) { - pipeline = ctx->rms_norm_ip_pipeline; - } else { - pipeline = ctx->rms_norm_pipeline; - } size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size; + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, wg_x, + ggml_op_name(dst->op)); +} + +static void ggml_webgpu_rope(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + const int inplace = ggml_webgpu_tensor_equal(src0, dst); + const int has_freq_factor = (src2 != nullptr); + + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + + int sections[4]; + memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int)); + + float theta_scale = powf(freq_base, -2.0f / n_dims); + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + src2 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) ggml_nelements(src0) / 2, + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], + (uint32_t) n_dims, + (uint32_t) mode, + *(uint32_t *) &theta_scale, + *(uint32_t *) &attn_factor, + *(uint32_t *) &freq_scale, + *(uint32_t *) &ext_factor, + *(uint32_t *) &corr_dims[0], + *(uint32_t *) &corr_dims[1], + (uint32_t) sections[0], + (uint32_t) sections[1], + (uint32_t) sections[2], + (uint32_t) sections[3] + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) } + }; + uint32_t dst_binding = 2; + if (has_freq_factor) { + dst_binding = 3; + entries.push_back({ .binding = 2, + .buffer = ggml_webgpu_tensor_buf(src2), + .offset = ggml_webgpu_tensor_align_offset(ctx, src2), + .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); + } + if (!inplace) { + entries.push_back({ .binding = dst_binding, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + } + + wgpu::ComputePipeline pipeline = ctx->rope_pipeline[dst->type][has_freq_factor][inplace]; + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(src0) / 2 + max_wg_size - 1) / max_wg_size; ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); } +static void ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { + const int split = (src1 != nullptr); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + src1 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + src1 != nullptr ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + src1 != nullptr ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + src1 != nullptr ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) ggml_nelements(dst), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) ((int32_t *) dst->op_params)[1], // swapped + *(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai + *(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + }; + uint32_t dst_binding = 1; + if (split) { + dst_binding = 2; + entries.push_back({ .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + } + entries.push_back({ .binding = dst_binding, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + + wgpu::ComputePipeline pipeline = ctx->glu_pipeline[ggml_get_glu_op(dst)][dst->type][split]; + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); +} + +static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + int inplace = ggml_webgpu_tensor_equal(src, dst); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) ggml_nelements(dst), + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + *(uint32_t *) dst->op_params, // scale + *(uint32_t *) &dst->op_params[1] // bias + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) } + }; + if (!inplace) { + entries.push_back({ .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + } + + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->scale_pipeline[inplace], params, entries, wg_x, + ggml_op_name(dst->op)); +} + // Returns true if node has enqueued work into the queue, false otherwise static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { if (ggml_is_empty(node)) { @@ -749,6 +921,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { ggml_tensor * src0 = node->src[0]; ggml_tensor * src1 = node->src[1]; + ggml_tensor * src2 = node->src[2]; switch (node->op) { // no-ops @@ -759,6 +932,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { case GGML_OP_RESHAPE: return false; case GGML_OP_CPY: + case GGML_OP_CONT: ggml_webgpu_cpy(ctx, src0, node); break; case GGML_OP_SET_ROWS: @@ -771,22 +945,41 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { ggml_webgpu_mul_mat(ctx, src0, src1, node); break; case GGML_OP_ADD: - if (ggml_webgpu_tensor_equal(src0, node)) { - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true); - } else { - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false); + { + int inplace = ggml_webgpu_tensor_equal(src0, node); + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace); + break; + } + case GGML_OP_SUB: + { + int inplace = ggml_webgpu_tensor_equal(src0, node); + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace); + break; } - break; case GGML_OP_MUL: - if (ggml_webgpu_tensor_equal(src0, node)) { - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true); - } else { - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false); + { + int inplace = ggml_webgpu_tensor_equal(src0, node); + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace); + break; + } + case GGML_OP_DIV: + { + int inplace = ggml_webgpu_tensor_equal(src0, node); + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace); + break; } - break; case GGML_OP_RMS_NORM: ggml_webgpu_rms_norm(ctx, src0, node); break; + case GGML_OP_ROPE: + ggml_webgpu_rope(ctx, src0, src1, src2, node); + break; + case GGML_OP_GLU: + ggml_webgpu_glu(ctx, src0, src1, node); + break; + case GGML_OP_SCALE: + ggml_webgpu_scale(ctx, src0, node); + break; default: return false; } @@ -1170,40 +1363,153 @@ static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", - ggml_webgpu_max_wg_size_entry(webgpu_ctx)); + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32], + wgsl_cpy_f32_f32, "cpy_f32_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16], + wgsl_cpy_f32_f16, "cpy_f32_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F32], + wgsl_cpy_f16_f32, "cpy_f16_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F16], + wgsl_cpy_f16_f16, "cpy_f16_f16", constants); } static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1], wgsl_add_f32_inplace, + "add_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1], wgsl_add_f16_inplace, + "add_f16_inplace", constants); +} + +static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_in_place_f32, - "add_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16, - "add_in_place_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1], wgsl_sub_f32_inplace, + "sub_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1], wgsl_sub_f16_inplace, + "sub_f16_inplace", constants); } static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1], wgsl_mul_f32_inplace, + "mul_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1], wgsl_mul_f16_inplace, + "mul_f16_inplace", constants); +} + +static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32, - "mul_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16, - "mul_in_place_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1], wgsl_div_f32_inplace, + "div_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1], wgsl_div_f16_inplace, + "div_f16_inplace", constants); } static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline, wgsl_rms_norm, "rms_norm", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[0], wgsl_rms_norm, "rms_norm", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_ip_pipeline, wgsl_rms_norm_in_place, - "rms_norm_in_place", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[1], wgsl_rms_norm_inplace, + "rms_norm_inplace", constants); +} + +static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0], wgsl_rope_f32, + "rope_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1], + wgsl_rope_f32_inplace, "rope_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][0], wgsl_rope_f32_ff, + "rope_f32_ff", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][1], + wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][0], wgsl_rope_f16, + "rope_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][1], + wgsl_rope_f16_inplace, "rope_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][0], wgsl_rope_f16_ff, + "rope_f16_ff", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][1], + wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); +} + +static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + // reglu + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0], + wgsl_reglu_f32, "reglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0], + wgsl_reglu_f16, "reglu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1], + wgsl_reglu_f32_split, "reglu_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1], + wgsl_reglu_f16_split, "reglu_f16_split", constants); + // geglu + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0], + wgsl_geglu_f32, "geglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0], + wgsl_geglu_f16, "geglu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1], + wgsl_geglu_f32_split, "geglu_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1], + wgsl_geglu_f16_split, "geglu_f16_split", constants); + // swiglu + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0], + wgsl_swiglu_f32, "swiglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0], + wgsl_swiglu_f16, "swiglu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1], + wgsl_swiglu_f32_split, "swiglu_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1], + wgsl_swiglu_f16_split, "swiglu_f16_split", constants); + // swiglu_oai + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0], + wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1], + wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); + // geglu_erf + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0], + wgsl_geglu_erf_f32, "geglu_erf_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0], + wgsl_geglu_erf_f16, "geglu_erf_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1], + wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1], + wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); + // geglu_quick + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0], + wgsl_geglu_quick_f32, "geglu_quick_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0], + wgsl_geglu_quick_f16, "geglu_quick_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1], + wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1], + wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); +} + +static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace, + "scale_f32_inplace", constants); } static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { @@ -1287,6 +1593,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * src0 = op->src[0]; ggml_tensor * src1 = op->src[1]; + // on smaller devices (or CI), tensors may be larger than the max storage buffer size if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize || (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) || @@ -1304,28 +1611,34 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const supports_op = true; break; case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_MUL: - supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type) && - (op->src[1]->type == op->type); + case GGML_OP_DIV: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) && + (src1->type == op->type); break; case GGML_OP_CPY: + case GGML_OP_CONT: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + break; case GGML_OP_SET_ROWS: supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64); break; case GGML_OP_GET_ROWS: - if (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || - op->src[0]->type == GGML_TYPE_I32 || ggml_webgpu_supported_qtype(op->src[0]->type)) { + if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 || + ggml_webgpu_supported_qtype(src0->type)) { supports_op = (op->type == GGML_TYPE_F32); } break; case GGML_OP_MUL_MAT: { - switch (op->src[1]->type) { + switch (src1->type) { case GGML_TYPE_F16: - supports_op = (op->src[0]->type == GGML_TYPE_F16); + supports_op |= (src0->type == GGML_TYPE_F16); break; case GGML_TYPE_F32: - switch (op->src[0]->type) { + switch (src0->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_Q4_0: @@ -1358,7 +1671,29 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } case GGML_OP_RMS_NORM: - supports_op = op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_ROPE: + supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16; + break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16; + break; + case GGML_GLU_OP_SWIGLU_OAI: + supports_op = op->type == GGML_TYPE_F32; + break; + default: + break; + } + break; + case GGML_OP_SCALE: + supports_op = op->type == GGML_TYPE_F32; break; default: break; @@ -1484,8 +1819,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_webgpu_init_get_rows_pipeline(ctx); ggml_webgpu_init_cpy_pipeline(ctx); ggml_webgpu_init_add_pipeline(ctx); + ggml_webgpu_init_sub_pipeline(ctx); ggml_webgpu_init_mul_pipeline(ctx); + ggml_webgpu_init_div_pipeline(ctx); ggml_webgpu_init_rms_norm_pipeline(ctx); + ggml_webgpu_init_rope_pipeline(ctx); + ggml_webgpu_init_glu_pipeline(ctx); + ggml_webgpu_init_scale_pipeline(ctx); #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl deleted file mode 100644 index f261cbb553a82..0000000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +++ /dev/null @@ -1,44 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - } - }, - { - "REPLS": { - "TYPE" : "f16", - } - } -] - -#end(VARIANTS) - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)]; - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl deleted file mode 100644 index 903f7bdbcc5f3..0000000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +++ /dev/null @@ -1,41 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - } - }, - { - "REPLS": { - "TYPE" : "f16", - } - } -] - -#end(VARIANTS) - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)]; - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl new file mode 100644 index 0000000000000..1ce4d83fa8e50 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl @@ -0,0 +1,188 @@ +#define(VARIANTS) + +[ + { + "SHADER_NAME": "add_f32", + "REPLS": { + "TYPE" : "f32", + "OP": "+" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "add_f16", + "REPLS": { + "TYPE" : "f16", + "OP": "+" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "add_f32_inplace", + "REPLS": { + "TYPE" : "f32", + "OP": "+" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "add_f16_inplace", + "REPLS": { + "TYPE" : "f16", + "OP": "+" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "mul_f32", + "REPLS": { + "TYPE" : "f32", + "OP": "*" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "mul_f16", + "REPLS": { + "TYPE" : "f16", + "OP": "*" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "mul_f32_inplace", + "REPLS": { + "TYPE" : "f32", + "OP": "*" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "mul_f16_inplace", + "REPLS": { + "TYPE" : "f16", + "OP": "*" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "sub_f32", + "REPLS": { + "TYPE" : "f32", + "OP": "-" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sub_f16", + "REPLS": { + "TYPE" : "f16", + "OP": "-" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sub_f32_inplace", + "REPLS": { + "TYPE" : "f32", + "OP": "-" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "sub_f16_inplace", + "REPLS": { + "TYPE" : "f16", + "OP": "-" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "div_f32", + "REPLS": { + "TYPE" : "f32", + "OP": "/" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "div_f16", + "REPLS": { + "TYPE" : "f16", + "OP": "/" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "div_f32_inplace", + "REPLS": { + "TYPE" : "f32", + "OP": "/" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "div_f16_inplace", + "REPLS": { + "TYPE" : "f16", + "OP": "/" + }, + "DECLS": ["INPLACE"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(NOT_INPLACE) + +fn update(dst_i: u32, src0_i: u32, src1_i: u32) { + dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i]; +} + +@group(0) @binding(2) +var dst: array<{{TYPE}}>; + +@group(0) @binding(3) +var params: Params; + +#enddecl(NOT_INPLACE) + +#decl(INPLACE) + +fn update(dst_i: u32, src0_i: u32, src1_i: u32) { + src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i]; +} + +@group(0) @binding(2) +var params: Params; + +#enddecl(INPLACE) + +#end(DECLS) + + +#define(SHADER) + +enable f16; + +#include "binary_head.tmpl" + +@group(0) @binding(0) +var src0: array<{{TYPE}}>; + +@group(0) @binding(1) +var src1: array<{{TYPE}}>; + +DECLS + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x < params.ne) { + update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x)); + } +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl new file mode 100644 index 0000000000000..db1aa34903b5d --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl @@ -0,0 +1,101 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "SRC_TYPE": "f32", + "DST_TYPE": "f32" + } + }, + { + "REPLS": { + "SRC_TYPE": "f32", + "DST_TYPE": "f16" + } + }, + { + "REPLS": { + "SRC_TYPE": "f16", + "DST_TYPE": "f16" + } + }, + { + "REPLS": { + "SRC_TYPE": "f16", + "DST_TYPE": "f32" + } + } +] + +#end(VARIANTS) + +#define(SHADER) +enable f16; + +@group(0) @binding(0) +var src: array<{{SRC_TYPE}}>; + +@group(0) @binding(1) +var dst: array<{{DST_TYPE}}>; + +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) — may be permuted + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Logical shapes + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32 +}; + +@group(0) @binding(2) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); + i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); + let i2 = i / (params.src_ne1 * params.src_ne0); + i = i % (params.src_ne1 * params.src_ne0); + let i1 = i / params.src_ne0; + let i0 = i % params.src_ne0; + + var j = gid.x; + let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + let j2 = j / (params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne1 * params.dst_ne0); + let j1 = j / params.dst_ne0; + let j0 = j % params.dst_ne0; + + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + + i2 * params.stride_src2 + i3 * params.stride_src3; + + let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + + j2 * params.stride_dst2 + j3 * params.stride_dst3; + + dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx])); +} +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl deleted file mode 100644 index 6fe924c554cc3..0000000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +++ /dev/null @@ -1,60 +0,0 @@ -enable f16; - -@group(0) @binding(0) -var src: array; - -@group(0) @binding(1) -var dst: array; - -struct Params { - ne: u32, // total number of elements - offset_src: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) — may be permuted - stride_src0: u32, - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_dst0: u32, - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Logical shape (same for both tensors) - ne0: u32, - ne1: u32, - ne2: u32, - ne3: u32, -}; - -@group(0) @binding(2) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { - return; - } - - var i = gid.x; - - let i3 = i / (params.ne2 * params.ne1 * params.ne0); - i = i % (params.ne2 * params.ne1 * params.ne0); - - let i2 = i / (params.ne1 * params.ne0); - i = i % (params.ne1 * params.ne0); - - let i1 = i / params.ne0; - let i0 = i % params.ne0; - - let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + - i2 * params.stride_src2 + i3 * params.stride_src3; - - let dst_idx = i0 * params.stride_dst0 + i1 * params.stride_dst1 + - i2 * params.stride_dst2 + i3 * params.stride_dst3; - - dst[params.offset_dst + dst_idx] = f16(src[params.offset_src + src_idx]); -} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index d9dfd7d6f4ffc..251051eaeca0f 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -88,15 +88,20 @@ def generate_variants(fname, input_dir, output_dir, outfile): raise ValueError(f"DECLS key '{key}' not found.") decls_code += decls_map[key] + "\n\n" - shader_variant = replace_placeholders(shader_template, variant["REPLS"]) - final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant) + final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template) + if "REPLS" in variant: + final_shader = replace_placeholders(final_shader, variant["REPLS"]) final_shader = expand_includes(final_shader, input_dir) - if "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]: + if "SHADER_NAME" in variant: + output_name = variant["SHADER_NAME"] + elif "SHADER_SUFFIX" in variant: + output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"] + elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]: output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]]) - elif "TYPE_SUFFIX" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE_SUFFIX"] - elif "TYPE" in variant["REPLS"]: + elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]: + output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]]) + elif "REPLS" in variant and "TYPE" in variant["REPLS"]: output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"] else: output_name = shader_base_name diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl index e3fe311b26b83..f80ce1fc55060 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl @@ -2,9 +2,9 @@ [ { + "SHADER_SUFFIX": "f32_vec", "REPLS": { "TYPE" : "vec4", - "TYPE_SUFFIX": "f32_vec", "DST_TYPE": "vec4", "BLOCK_SIZE": 4 }, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl new file mode 100644 index 0000000000000..03fcd54868933 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl @@ -0,0 +1,323 @@ +#define(VARIANTS) + +[ + { + "SHADER_NAME": "reglu_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "REGLU"] + }, + { + "SHADER_NAME": "reglu_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "REGLU"] + }, + { + "SHADER_NAME": "reglu_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "REGLU"] + }, + { + "SHADER_NAME": "reglu_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "REGLU"] + }, + { + "SHADER_NAME": "geglu_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "GEGLU"] + }, + { + "SHADER_NAME": "geglu_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "GEGLU"] + }, + { + "SHADER_NAME": "geglu_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "GEGLU"] + }, + { + "SHADER_NAME": "geglu_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "GEGLU"] + }, + { + "SHADER_NAME": "swiglu_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "SWIGLU"] + }, + { + "SHADER_NAME": "swiglu_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "SWIGLU"] + }, + { + "SHADER_NAME": "swiglu_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "SWIGLU"] + }, + { + "SHADER_NAME": "swiglu_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "SWIGLU"] + }, + { + "SHADER_NAME": "swiglu_oai_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "SWIGLU_OAI"] + }, + { + "SHADER_NAME": "swiglu_oai_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "SWIGLU_OAI"] + }, + { + "SHADER_NAME": "geglu_erf_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "GEGLU_ERF"] + }, + { + "SHADER_NAME": "geglu_erf_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "GEGLU_ERF"] + }, + { + "SHADER_NAME": "geglu_erf_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "GEGLU_ERF"] + }, + { + "SHADER_NAME": "geglu_erf_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "GEGLU_ERF"] + }, + { + "SHADER_NAME": "geglu_quick_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] + }, + { + "SHADER_NAME": "geglu_quick_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "GEGLU_QUICK"] + }, + { + "SHADER_NAME": "geglu_quick_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] + }, + { + "SHADER_NAME": "geglu_quick_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "GEGLU_QUICK"] + }, +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(REGLU) +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + return max(a, 0) * b; +} +#enddecl(REGLU) + +#decl(GEGLU) +const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876; +const GELU_COEF_A: {{TYPE}} = 0.044715; + +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a); + return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b; +} +#enddecl(GEGLU) + +#decl(SWIGLU) +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + return a / (1.0 + exp(-a)) * b; +} +#enddecl(SWIGLU) + +#decl(SWIGLU_OAI) +fn op(a: f32, b: f32) -> f32 { + let xi = min(a, params.limit); + let gi = max(min(b, params.limit), -params.limit); + var out_glu = xi / (1.0 + exp(-xi * params.alpha)); + out_glu = out_glu * (1.0 + gi); + return out_glu; +} +#enddecl(SWIGLU_OAI) + +#decl(GEGLU_ERF) +const p_erf: {{TYPE}} = 0.3275911; +const a1_erf: {{TYPE}} = 0.254829592; +const a2_erf: {{TYPE}} = -0.284496736; +const a3_erf: {{TYPE}} = 1.421413741; +const a4_erf: {{TYPE}} = -1.453152027; +const a5_erf: {{TYPE}} = 1.061405429; +const SQRT_2_INV: {{TYPE}} = 0.7071067811865476; + +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + let a_div_sqr2 = a * SQRT_2_INV; + let sign_x = sign(a_div_sqr2); + let x = abs(a_div_sqr2); + let t = 1.0 / (1.0 + p_erf * x); + let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x)); + let erf_approx = sign_x * y; + return 0.5 * a * (1.0 + erf_approx) * b; +} +#enddecl(GEGLU_ERF) + +#decl(GEGLU_QUICK) +const GELU_QUICK_COEF: {{TYPE}} = -1.702; + +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b; +} +#enddecl(GEGLU_QUICK) + +#decl(NO_SPLIT) +@group(0) @binding(1) +var dst: array<{{TYPE}}>; + +@group(0) @binding(2) +var params: Params; + +fn a_value(base: u32) -> {{TYPE}} { + let offset: u32 = select(0, params.ne0, params.swapped != 0); + return src0[base + offset]; +} + +fn b_value(base: u32) -> {{TYPE}} { + let offset: u32 = select(params.ne0, 0, params.swapped != 0); + return src0[base + offset]; +} +#enddecl(NO_SPLIT) + +#decl(SPLIT) +@group(0) @binding(1) +var src1: array<{{TYPE}}>; + +@group(0) @binding(2) +var dst: array<{{TYPE}}>; + +@group(0) @binding(3) +var params: Params; + +fn a_value(base: u32) -> {{TYPE}} { + return src0[base]; +} + +fn b_value(base: u32) -> {{TYPE}} { + return src1[base]; +} +#enddecl(SPLIT) + +#end(DECLS) + +#define(SHADER) + +enable f16; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // shape of dst + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + swapped: u32, + alpha: f32, + limit: f32, +} + +@group(0) @binding(0) +var src0: array<{{TYPE}}>; + +DECLS + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0; + let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0; + let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; + + dst[i_dst] = op(a_value(i_a), b_value(i_b)); +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl deleted file mode 100644 index 12506e1420efe..0000000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +++ /dev/null @@ -1,44 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - } - }, - { - "REPLS": { - "TYPE" : "f16", - } - } -] - -#end(VARIANTS) - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)]; - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl deleted file mode 100644 index e467e59edb48e..0000000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +++ /dev/null @@ -1,41 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - } - }, - { - "REPLS": { - "TYPE" : "f16", - } - } -] - -#end(VARIANTS) - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)]; - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl index f919a51336343..a275eeb9783da 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl @@ -1,9 +1,48 @@ -@group(0) @binding(0) -var src: array; +#define(VARIANTS) + +[ + { + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_SUFFIX": "inplace", + "DECLS": ["INPLACE"] + }, +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(NOT_INPLACE) + +fn update(src_offset: u32, dst_offset: u32, scale: f32) { + dst[dst_offset] = scale * src[src_offset]; +} @group(0) @binding(1) var dst: array; +@group(0) @binding(2) +var params: Params; + +#enddecl(NOT_INPLACE) + +#decl(INPLACE) + +fn update(src_offset: u32, dst_offset: u32, scale: f32) { + src[dst_offset] = scale * src[src_offset]; +} + +@group(0) @binding(1) +var params: Params; + +#enddecl(INPLACE) + +#end(DECLS) + +#define(SHADER) + struct Params { offset_src: u32, // in elements offset_dst: u32, // in elements @@ -23,11 +62,13 @@ struct Params { ne2: u32, ne3: u32, - eps: u32 + eps: f32 }; -@group(0) @binding(2) -var params: Params; +@group(0) @binding(0) +var src: array; + +DECLS override wg_size: u32; @compute @workgroup_size(wg_size) @@ -49,9 +90,9 @@ fn main(@builtin(global_invocation_id) gid: vec3) { for (var j: u32 = 0; j < params.ne0; j++) { sum += src[i_src_row + j] * src[i_src_row + j]; } - let eps = bitcast(params.eps); - let scale = 1.0/sqrt(sum/f32(params.ne0) + eps); + let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); for (var j: u32 = 0; j < params.ne0; j++) { - dst[i_dst_row + j] = scale * src[i_src_row + j]; + update(i_src_row + j, i_dst_row + j, scale); } } +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl deleted file mode 100644 index ae84f556d6013..0000000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +++ /dev/null @@ -1,48 +0,0 @@ -@group(0) @binding(0) -var a: array; - -struct Params { - offset: u32, // in elements - - // Strides (in elements) - stride1: u32, - stride2: u32, - stride3: u32, - - // Shape - ne0: u32, - ne1: u32, - ne2: u32, - ne3: u32, - - eps: u32 -}; - -@group(0) @binding(1) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne1 * params.ne2 * params.ne3) { - return; - } - - // one thread per row - var i = gid.x; - let i3 = i / (params.ne2 * params.ne1); - i = i % (params.ne2 * params.ne1); - let i2 = i / params.ne1; - let i1 = i % params.ne1; - let i_row = params.offset + i3 * params.stride3 + i2 * params.stride2 + i1 * params.stride1; - - var sum = 0.0f; - for (var j: u32 = 0; j < params.ne0; j++) { - sum += a[i_row + j] * a[i_row + j]; - } - let eps = bitcast(params.eps); - let scale = 1.0/sqrt(sum/f32(params.ne0) + eps); - for (var j: u32 = 0; j < params.ne0; j++) { - a[i_row + j] = scale * a[i_row + j]; - } -} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl new file mode 100644 index 0000000000000..9a6ff41128b6d --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl @@ -0,0 +1,282 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] + }, + { + "SHADER_SUFFIX": "f32_inplace", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] + }, + { + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] + }, + { + "SHADER_SUFFIX": "f16_inplace", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] + }, + { + "SHADER_SUFFIX": "f32_ff", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] + }, + { + "SHADER_SUFFIX": "f32_ff_inplace", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] + }, + { + "SHADER_SUFFIX": "f16_ff", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] + }, + { + "SHADER_SUFFIX": "f16_ff_inplace", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(ROTATE) +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + dst[i_dst0] = {{TYPE}}(out0); + dst[i_dst1] = {{TYPE}}(out1); +} +#enddecl(ROTATE) + +#decl(ROTATE_INPLACE) +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + src0[i_dst0] = {{TYPE}}(out0); + src0[i_dst1] = {{TYPE}}(out1); +} +#enddecl(ROTATE_INPLACE) + +#decl(NO_FF_FUNC) +fn freq_factor(i: u32) -> f32 { + return 1.0f; +} +#enddecl(NO_FF_FUNC) + +#decl(FF_FUNC) +fn freq_factor(i: u32) -> f32 { + return src2[params.offset_src2 + i/2]; +} +#enddecl(FF_FUNC) + +#decl(NO_FF_BINDINGS) + +@group(0) @binding(2) +var dst: array<{{TYPE}}>; + +@group(0) @binding(3) +var params: Params; + +#enddecl(NO_FF_BINDINGS) + +#decl(NO_FF_BINDINGS_INPLACE) + +@group(0) @binding(2) +var params: Params; + +#enddecl(NO_FF_BINDINGS_INPLACE) + +#decl(FF_BINDINGS) + +@group(0) @binding(2) +var src2: array; + +@group(0) @binding(3) +var dst: array<{{TYPE}}>; + +@group(0) @binding(4) +var params: Params; + +#enddecl(FF_BINDINGS) + +#decl(FF_BINDINGS_INPLACE) + +@group(0) @binding(2) +var src2: array; + +@group(0) @binding(3) +var params: Params; + +#enddecl(FF_BINDINGS_INPLACE) + +#end(DECLS) + +#define(SHADER) + +enable f16; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_src2: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + n_threads: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + n_dims: u32, + mode: u32, + theta_scale: f32, + attn_factor: f32, + freq_scale: f32, + ext_factor: f32, + corr_dim0: f32, + corr_dim1: f32, + sections0: u32, + sections1: u32, + sections2: u32, + sections3: u32 +}; + +@group(0) @binding(0) +var src0: array<{{TYPE}}>; + +@group(0) @binding(1) +var src1: array; + +DECLS + +fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 { + let y = (f32(i / 2) - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// returns vector of (cos_theta, sin_theta) +// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row +fn rope_yarn(theta_extrap: f32, i: u32) -> vec2 { + var mscale = params.attn_factor; + var theta = params.freq_scale * theta_extrap; + if (params.ext_factor != 0.0f) { + let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor; + theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix; + mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale); + } + return vec2(cos(theta) * mscale, sin(theta) * mscale); +} + +fn pair_base(i0: u32, div_2: bool) -> u32 { + if (div_2) { + return i0 / 2; + } else { + return i0; + } +} + +fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 { + if (is_vision) { + return params.n_dims; + } else if (is_neox || is_mrope) { + return params.n_dims / 2; + } else { + return 1; + } +} + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + // two elements per thread + if (gid.x >= params.n_threads) { + return; + } + + let is_neox = bool(params.mode & 2); + let is_mrope = bool(params.mode & 8); + let is_vision = params.mode == 24; + + var i = gid.x * 2; // start index for this thread + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + + if (i0 >= params.n_dims && !is_vision) { + let i_src = i_src_row + i0; + let i_dst = i_dst_row + i0; + rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1])); + return; + } + + var theta_base_mult: u32 = 0; + var theta_scale_pwr: u32 = i0 / 2; + if (is_mrope) { + let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3; + let sec_w = params.sections1 + params.sections0; + let sec_e = params.sections2 + sec_w; + let sector = (i0 / 2) % sect_dims; + if (sector >= params.sections0 && sector < sec_w) { + theta_base_mult = 1; + if (is_vision) { + theta_scale_pwr = sector - params.sections0; + } + } else if (sector >= sec_w && sector < sec_e) { + theta_base_mult = 2; + if (is_vision) { + theta_scale_pwr = sector - sec_w; + } + } else if (sector >= sec_e) { + if (is_vision) { + theta_scale_pwr = sector - sec_e; + theta_scale_pwr = (i0 / 2) % sec_e; + } + theta_base_mult = 3; + } else if (is_vision) { + theta_scale_pwr = sector; + } + } + let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr)); + let thetas = rope_yarn(theta_base/freq_factor(i0), i0); + + let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision); + let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision); + + let x0 = f32(src0[i_src]); + let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]); + rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x); +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl new file mode 100644 index 0000000000000..040e80dfea24a --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl @@ -0,0 +1,90 @@ +#define(VARIANTS) + +[ + { + "SHADER_NAME": "scale_f32", + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "scale_f32_inplace", + "DECLS": ["INPLACE"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(NOT_INPLACE) +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; + +fn store_scale(val: f32, offset: u32) { + dst[offset] = val; +} +#enddecl(NOT_INPLACE) + +#decl(INPLACE) +@group(0) @binding(1) +var params: Params; + +fn store_scale(val: f32, offset: u32) { + src[offset] = val; +} +#enddecl(INPLACE) + +#end(DECLS) + +#define(SHADER) + +struct Params { + offset_src: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + scale: f32, + bias: f32 +}; + +@group(0) @binding(0) +var src: array; + +DECLS + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0; + let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; + + store_scale(src[i_src] * params.scale + params.bias, i_dst); +} +#end(SHADER) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 5ab42f59e0635..d4b6e0c7c48dd 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2688,23 +2688,30 @@ struct test_scale : public test_case { const std::array ne; float scale; float bias; + bool inplace; std::string vars() override { - return VARS_TO_STR4(type, ne, scale, bias); + return VARS_TO_STR5(type, ne, scale, bias, inplace); } test_scale(ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 10, 10}, float scale = 2.0f, - float bias = 0.0f) - : type(type), ne(ne), scale(scale), bias(bias) {} + float bias = 0.0f, + bool inplace = false) + : type(type), ne(ne), scale(scale), bias(bias), inplace(inplace) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_param(a); ggml_set_name(a, "a"); - ggml_tensor * out = ggml_scale_bias(ctx, a, scale, bias); + ggml_tensor * out; + if (inplace) { + out = ggml_scale_bias_inplace(ctx, a, scale, bias); + } else { + out = ggml_scale_bias(ctx, a, scale, bias); + } ggml_set_name(out, "out"); return out; @@ -2861,16 +2868,18 @@ struct test_rms_norm : public test_case { const std::array ne; const bool v; // whether a is a non-contiguous view const float eps; + const bool inplace; // whether to do the operation inplace std::string vars() override { - return VARS_TO_STR4(type, ne, v, eps); + return VARS_TO_STR5(type, ne, v, eps, inplace); } test_rms_norm(ggml_type type = GGML_TYPE_F32, std::array ne = {64, 5, 4, 3}, bool v = false, - float eps = 1e-6f) - : type(type), ne(ne), v(v), eps(eps) {} + float eps = 1e-6f, + bool inplace = false) + : type(type), ne(ne), v(v), eps(eps), inplace(inplace) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); @@ -2882,7 +2891,12 @@ struct test_rms_norm : public test_case { ggml_set_name(a, "view of a"); } - ggml_tensor * out = ggml_rms_norm(ctx, a, eps); + ggml_tensor * out; + if (inplace) { + out = ggml_rms_norm_inplace(ctx, a, eps); + } else { + out = ggml_rms_norm(ctx, a, eps); + } ggml_set_name(out, "out"); return out; @@ -3787,17 +3801,18 @@ struct test_rope : public test_case { bool ff; int v; // view (1 : non-contiguous a) bool forward; + bool inplace; std::string vars() override { // forward can be inferred from the op, does not need to be printed - return VARS_TO_STR10(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v); + return VARS_TO_STR11(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v, inplace); } test_rope(ggml_type type = GGML_TYPE_F32, std::array ne_a = {10, 5, 3, 1}, - int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f, - float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true) - : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward) {} + int n_dims = 10, int mode = GGML_ROPE_TYPE_NORMAL, int n_ctx = 512, float fs = 1.0f, + float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true, bool inplace = false) + : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward), inplace(inplace) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a; @@ -3842,7 +3857,11 @@ struct test_rope : public test_case { GGML_ASSERT(n_dims/4 > 0); int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate if (forward) { - out = ggml_rope_multi (ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + if (inplace) { + out = ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } else { + out = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } } else { out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); } @@ -3850,14 +3869,22 @@ struct test_rope : public test_case { GGML_ASSERT(n_dims/3 > 0); int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0}; if (forward) { - out = ggml_rope_multi (ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + if (inplace) { + out = ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } else { + out = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } } else { out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); } } } else { if (forward) { - out = ggml_rope_ext (ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + if (inplace) { + out = ggml_rope_ext_inplace(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } else { + out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } } else { out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); } @@ -6138,9 +6165,11 @@ static std::vector> make_test_cases_eval() { //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1}); } - // single in-place tests, especially important for WebGPU backend since kernels for in-place vs. not are different + // single inplace tests, especially important for WebGPU backend since kernels for inplace vs. not are different test_cases.emplace_back(new test_bin_bcast(ggml_add_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16)); test_cases.emplace_back(new test_bin_bcast(ggml_mul_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16)); + test_cases.emplace_back(new test_bin_bcast(ggml_sub_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16)); + test_cases.emplace_back(new test_bin_bcast(ggml_div_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16)); // fusion test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}, 2)); @@ -6155,6 +6184,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_add1()); test_cases.emplace_back(new test_scale()); test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f)); + test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f, true)); // inplace test test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {100, 10, 10, 10}, 2.0f, 1.0f)); test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f)); test_cases.emplace_back(new test_silu_back()); @@ -6167,6 +6197,10 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps)); } + + // in-place tests + test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, false, 1e-6f, true)); + for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) { test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false)); test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true)); @@ -6514,26 +6548,26 @@ static std::vector> make_test_cases_eval() { for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { for (bool ff : {false, true}) { // freq_factors for (float v : { 0, 1 }) { - test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B + test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 7B if (all) { - test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 13B - test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 30B - test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 65B + test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 13B + test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 30B + test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 65B } if (all) { - test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) - test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) - test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) + test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) + test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) + test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw)); - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); - test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) - test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) + test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) } if (all) { @@ -6544,7 +6578,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT) } - test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) + test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) } } @@ -6555,6 +6589,15 @@ static std::vector> make_test_cases_eval() { } } + // single inplace test per type/mode/ff + for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (int mode : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION}) { + for (bool ff : {false, true}) { + test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 0, true, true)); + } + } + } + for (int v : { 0, 1, 2, 3 }) { for (int dim : { 0, 1, 2, 3, }) { test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v));