Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,9 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
case GGML_OP_SCALE:
ggml_webgpu_scale(ctx, src0, node);
break;
case GGML_OP_SOFT_MAX:
ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
break;
default:
return false;
}
Expand Down Expand Up @@ -1806,6 +1809,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_SCALE:
supports_op = op->type == GGML_TYPE_F32;
break;
case GGML_OP_SOFT_MAX:
supports_op = op->type == GGML_TYPE_F32;
break;
default:
break;
}
Expand Down Expand Up @@ -1949,6 +1955,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
ggml_webgpu_init_rope_pipeline(ctx);
ggml_webgpu_init_glu_pipeline(ctx);
ggml_webgpu_init_scale_pipeline(ctx);
ggml_webgpu_init_soft_max_pipeline(ctx);

#ifdef GGML_WEBGPU_DEBUG
// Initialize debug buffers
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
let i2 = i / params.ne1;
let i1 = i % params.ne1;
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;

let elems = (params.ne0 + wg_size - 1) / wg_size;

Expand Down
Loading