Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
488 changes: 414 additions & 74 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp

Large diffs are not rendered by default.

44 changes: 0 additions & 44 deletions ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl

This file was deleted.

41 changes: 0 additions & 41 deletions ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl

This file was deleted.

188 changes: 188 additions & 0 deletions ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
#define(VARIANTS)

[
{
"SHADER_NAME": "add_f32",
"REPLS": {
"TYPE" : "f32",
"OP": "+"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "add_f16",
"REPLS": {
"TYPE" : "f16",
"OP": "+"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "add_f32_inplace",
"REPLS": {
"TYPE" : "f32",
"OP": "+"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "add_f16_inplace",
"REPLS": {
"TYPE" : "f16",
"OP": "+"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "mul_f32",
"REPLS": {
"TYPE" : "f32",
"OP": "*"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "mul_f16",
"REPLS": {
"TYPE" : "f16",
"OP": "*"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "mul_f32_inplace",
"REPLS": {
"TYPE" : "f32",
"OP": "*"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "mul_f16_inplace",
"REPLS": {
"TYPE" : "f16",
"OP": "*"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sub_f32",
"REPLS": {
"TYPE" : "f32",
"OP": "-"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sub_f16",
"REPLS": {
"TYPE" : "f16",
"OP": "-"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sub_f32_inplace",
"REPLS": {
"TYPE" : "f32",
"OP": "-"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sub_f16_inplace",
"REPLS": {
"TYPE" : "f16",
"OP": "-"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "div_f32",
"REPLS": {
"TYPE" : "f32",
"OP": "/"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "div_f16",
"REPLS": {
"TYPE" : "f16",
"OP": "/"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "div_f32_inplace",
"REPLS": {
"TYPE" : "f32",
"OP": "/"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "div_f16_inplace",
"REPLS": {
"TYPE" : "f16",
"OP": "/"
},
"DECLS": ["INPLACE"]
}
]

#end(VARIANTS)

#define(DECLS)

#decl(NOT_INPLACE)

fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
}

@group(0) @binding(2)
var<storage, read_write> dst: array<{{TYPE}}>;

@group(0) @binding(3)
var<uniform> params: Params;

#enddecl(NOT_INPLACE)

#decl(INPLACE)

fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
}

@group(0) @binding(2)
var<uniform> params: Params;

#enddecl(INPLACE)

#end(DECLS)


#define(SHADER)

enable f16;

#include "binary_head.tmpl"

@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;

@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;

DECLS

override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
}
}

#end(SHADER)
101 changes: 101 additions & 0 deletions ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#define(VARIANTS)

[
{
"REPLS": {
"SRC_TYPE": "f32",
"DST_TYPE": "f32"
}
},
{
"REPLS": {
"SRC_TYPE": "f32",
"DST_TYPE": "f16"
}
},
{
"REPLS": {
"SRC_TYPE": "f16",
"DST_TYPE": "f16"
}
},
{
"REPLS": {
"SRC_TYPE": "f16",
"DST_TYPE": "f32"
}
}
]

#end(VARIANTS)

#define(SHADER)
enable f16;

@group(0) @binding(0)
var<storage, read_write> src: array<{{SRC_TYPE}}>;

@group(0) @binding(1)
var<storage, read_write> dst: array<{{DST_TYPE}}>;

struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements

// Strides (in elements) — may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,

stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,

// Logical shapes
src_ne0: u32,
src_ne1: u32,
src_ne2: u32,

dst_ne0: u32,
dst_ne1: u32,
dst_ne2: u32
};

@group(0) @binding(2)
var<uniform> params: Params;

override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}

var i = gid.x;
let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
let i2 = i / (params.src_ne1 * params.src_ne0);
i = i % (params.src_ne1 * params.src_ne0);
let i1 = i / params.src_ne0;
let i0 = i % params.src_ne0;

var j = gid.x;
let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
let j2 = j / (params.dst_ne1 * params.dst_ne0);
j = j % (params.dst_ne1 * params.dst_ne0);
let j1 = j / params.dst_ne0;
let j0 = j % params.dst_ne0;

let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
i2 * params.stride_src2 + i3 * params.stride_src3;

let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
j2 * params.stride_dst2 + j3 * params.stride_dst3;

dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx]));
}
#end(SHADER)
Loading
Loading