Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ struct webgpu_context_struct {
webgpu_pipeline memset_pipeline;
webgpu_pipeline mul_mat_pipeline[30][2];
webgpu_pipeline set_rows_pipeline;
webgpu_pipeline set_rows_f32_no_vec_pipeline;
webgpu_pipeline get_rows_pipeline[30];
webgpu_pipeline get_rows_f32_no_vec_pipeline;
webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type
Expand Down Expand Up @@ -767,7 +768,12 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
size_t max_wg_size = ctx->max_wg_size_x;
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;

return ggml_backend_webgpu_build(ctx, ctx->set_rows_pipeline, params, entries, wg_x, error_bufs);
webgpu_pipeline pipeline = ctx->set_rows_pipeline;
// if not evenly divisble by 4, use the non-vectorized version
if (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 != 0) {
pipeline = ctx->set_rows_f32_no_vec_pipeline;
}
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, error_bufs);
}

static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
Expand Down Expand Up @@ -1613,7 +1619,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
}

static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
// create_pipeline(device, pipeline, shader_code, label, constants)
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_f32_no_vec_pipeline, wgsl_set_rows_f32, "set_rows_f32",
ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows_f32_vec, "set_rows_f32_vec",
ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,58 @@

#define(VARIANTS)

[
{
"SHADER_SUFFIX": "f32_vec",
"REPLS": {
"TYPE" : "vec4<f32>",
"DST_TYPE": "vec4<f16>",
"BLOCK_SIZE": 4
},
"DECLS": ["F32_VEC"]
},
{
"REPLS": {
"TYPE" : "f32",
"DST_TYPE": "f16",
"BLOCK_SIZE": 1
},
"DECLS": ["F32"]
}
]

#end(VARIANTS)

#define(DECLS)

#decl(F32_VEC)
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[(dst_base / 4) + offset] = vec4<f16>(src[(src_base / 4) + offset]);
}
#enddecl(F32_VEC)

#decl(F32)
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[dst_base + offset] = f16(src[src_base + offset]);
}
#enddecl(F32)

#end(DECLS)

#define(SHADER)

enable f16;

DECLS

@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
var<storage, read_write> src: array<{{TYPE}}>;

@group(0) @binding(1)
var<storage, read_write> idx: array<u32>;

@group(0) @binding(2)
var<storage, read_write> dst: array<f16>;
var<storage, read_write> dst: array<{{DST_TYPE}}>;

@group(0) @binding(3)
var<storage, read_write> error: atomic<u32>;
Expand Down Expand Up @@ -75,7 +120,10 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;

for (var i: u32 = 0; i < params.ne0; i++) {
dst[i_dst_row + i] = f16(src[i_src_row + i]);
for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) {
copy_elements(i_src_row, i_dst_row, i);
}
}

#end(SHADER)