Skip to content

Commit b2dbfcd

Browse files
committed
Work on set rows
1 parent 6a6135c commit b2dbfcd

File tree

4 files changed

+83
-17
lines changed

4 files changed

+83
-17
lines changed

.github/workflows/build.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ jobs:
179179
- name: Test
180180
id: cmake_test
181181
run: |
182-
export LLAMA_SET_ROWS=0
183182
cd build
184183
ctest -L main --verbose --timeout 900
185184
@@ -438,7 +437,6 @@ jobs:
438437
- name: Test
439438
id: cmake_test
440439
run: |
441-
export LLAMA_SET_ROWS=0
442440
cd build
443441
# This is using llvmpipe and runs slower than other backends
444442
ctest -L main --verbose --timeout 3600

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -495,9 +495,9 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
495495
(uint32_t) src->ne[1],
496496
(uint32_t) src->ne[2],
497497
(uint32_t) src->ne[3],
498-
// broadcast shape of idx
499-
(uint32_t) (src->ne[2] / idx->ne[1]),
500-
(uint32_t) (src->ne[3] / idx->ne[2])
498+
// Shape of idx
499+
(uint32_t) (idx->ne[1]),
500+
(uint32_t) (idx->ne[2])
501501
};
502502

503503
std::vector<wgpu::BindGroupEntry> entries = {
@@ -512,18 +512,13 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
512512
{ .binding = 2,
513513
.buffer = ggml_backend_webgpu_tensor_buf(dst),
514514
.offset = ggml_backend_webgpu_tensor_offset(dst),
515-
.size = ggml_nbytes(dst) },
516-
{ .binding = 3,
517-
.buffer = ctx->debug_dev_buf,
518-
.offset = 0,
519-
.size = ctx->debug_dev_buf.GetSize() }
515+
.size = ggml_nbytes(dst) }
520516
};
521517

522518
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
523519
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
524520
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x);
525521
ggml_backend_webgpu_submit_queue(ctx);
526-
ggml_backend_webgpu_debug(ctx);
527522
}
528523

529524
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
enable f16;
2+
3+
@group(0) @binding(0)
4+
var<storage, read_write> src: array<f32>;
5+
6+
@group(0) @binding(1)
7+
var<storage, read_write> idx: array<u32>;
8+
9+
@group(0) @binding(2)
10+
var<storage, read_write> dst: array<f16>;
11+
12+
struct Params {
13+
offset_src: u32, // in elements
14+
offset_idx: u32, // in elements
15+
offset_dst: u32, // in elements
16+
17+
// Strides (in elements)
18+
stride_src1: u32,
19+
stride_src2: u32,
20+
stride_src3: u32,
21+
22+
stride_idx0: u32,
23+
stride_idx1: u32,
24+
stride_idx2: u32,
25+
26+
stride_dst1: u32,
27+
stride_dst2: u32,
28+
stride_dst3: u32,
29+
30+
// Shape of src
31+
ne0: u32,
32+
n_rows: u32,
33+
ne2: u32,
34+
ne3: u32,
35+
36+
// Shape of idx
37+
idx1: u32,
38+
idx2: u32,
39+
};
40+
41+
@group(0) @binding(3)
42+
var<uniform> params: Params;
43+
44+
override wg_size: u32;
45+
@compute @workgroup_size(wg_size)
46+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
47+
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
48+
return;
49+
}
50+
var i = gid.x;
51+
let i_src3 = i / (params.ne2 * params.n_rows);
52+
let i_dst3 = i / (params.ne2 * 3);
53+
54+
i = i % (params.ne2 * params.n_rows);
55+
let i_src2 = i / params.n_rows;
56+
let i_src1 = i % params.n_rows;
57+
58+
let i_idx2 = i_src3 % params.idx2;
59+
let i_idx1 = i_src2 % params.idx1;
60+
let i_idx0 = i_src1;
61+
62+
let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
63+
64+
let idx_high_val = idx[idx_high];
65+
let idx_low_val = idx[idx_high + 1];
66+
67+
let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
68+
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
69+
70+
for (var i: u32 = 0; i < params.ne0; i++) {
71+
dst[i_dst_row + i] = f16(src[i_src_row + i]);
72+
}
73+
}

tests/test-backend-ops.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,12 +1213,12 @@ struct test_case {
12131213
double err = nmse(f1.data(), f2.data(), f1.size());
12141214
if (err > ud->max_err) {
12151215
//printf("Backends %s and %s mismatch: ", bn1, bn2);
1216-
printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
1217-
for (int i = 0; i < (int) f1.size(); i++) {
1218-
printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
1219-
}
1220-
printf("\n");
1221-
exit(1);
1216+
//printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
1217+
//for (int i = 0; i < (int) f1.size(); i++) {
1218+
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
1219+
//}
1220+
//printf("\n");
1221+
//exit(1);
12221222
ud->ok = false;
12231223
}
12241224
return true;

0 commit comments

Comments
 (0)