Skip to content

Commit 6464269

Browse files
committed
implement soft_max
1 parent d64c810 commit 6464269

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,9 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
10601060
case GGML_OP_SCALE:
10611061
ggml_webgpu_scale(ctx, src0, node);
10621062
break;
1063+
case GGML_OP_SOFT_MAX:
1064+
ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
1065+
break;
10631066
default:
10641067
return false;
10651068
}
@@ -1806,6 +1809,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
18061809
case GGML_OP_SCALE:
18071810
supports_op = op->type == GGML_TYPE_F32;
18081811
break;
1812+
case GGML_OP_SOFT_MAX:
1813+
supports_op = op->type == GGML_TYPE_F32;
1814+
break;
18091815
default:
18101816
break;
18111817
}
@@ -1949,6 +1955,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
19491955
ggml_webgpu_init_rope_pipeline(ctx);
19501956
ggml_webgpu_init_glu_pipeline(ctx);
19511957
ggml_webgpu_init_scale_pipeline(ctx);
1958+
ggml_webgpu_init_soft_max_pipeline(ctx);
19521959

19531960
#ifdef GGML_WEBGPU_DEBUG
19541961
// Initialize debug buffers

ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
8484
let i2 = i / params.ne1;
8585
let i1 = i % params.ne1;
8686
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
87-
let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
87+
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
8888

8989
let elems = (params.ne0 + wg_size - 1) / wg_size;
9090

0 commit comments

Comments
 (0)