Skip to content

Commit 6ce873b

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents 2c8b338 + e14e842 commit 6ce873b

File tree

11 files changed

+1250
-33
lines changed

11 files changed

+1250
-33
lines changed

.github/workflows/build.yml

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,15 +161,16 @@ jobs:
161161
- name: Dawn Dependency
162162
id: dawn-depends
163163
run: |
164-
DAWN_VERSION="v1.0.0"
164+
DAWN_VERSION="v2.0.0"
165165
DAWN_OWNER="reeselevine"
166166
DAWN_REPO="dawn"
167-
DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-macos-latest-Release.tar.gz"
167+
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.zip"
168168
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
169-
curl -L -o artifact.tar.gz \
169+
curl -L -o artifact.zip \
170170
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
171171
mkdir dawn
172-
tar -xvf artifact.tar.gz -C dawn --strip-components=1
172+
unzip artifact.zip
173+
tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.tar.gz -C dawn --strip-components=1
173174
174175
- name: Build
175176
id: cmake_build
@@ -521,15 +522,16 @@ jobs:
521522
id: dawn-depends
522523
run: |
523524
sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev
524-
DAWN_VERSION="v1.0.0"
525+
DAWN_VERSION="v2.0.0"
525526
DAWN_OWNER="reeselevine"
526527
DAWN_REPO="dawn"
527-
DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-ubuntu-latest-Release.tar.gz"
528+
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.zip"
528529
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
529-
curl -L -o artifact.tar.gz \
530+
curl -L -o artifact.zip \
530531
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
531532
mkdir dawn
532-
tar -xvf artifact.tar.gz -C dawn --strip-components=1
533+
unzip artifact.zip
534+
tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.tar.gz -C dawn --strip-components=1
533535
534536
- name: Build
535537
id: cmake_build

ggml/src/ggml-cuda/cpy.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,16 +198,14 @@ static void ggml_cpy_flt_cuda(
198198
if (transposed) {
199199
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
200200
int ne00n, ne01n, ne02n;
201-
if (nb00 < nb02) {
201+
if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here
202202
ne00n = ne00;
203203
ne01n = ne01;
204204
ne02n = ne02;
205205
} else if (nb00 > nb02) {
206206
ne00n = ne00;
207207
ne01n = ne01*ne02;
208208
ne02n = 1;
209-
} else {
210-
GGML_ASSERT(false);
211209
}
212210

213211
dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3494,7 +3494,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
34943494
const int col_diff = col_high - col_low;
34953495

34963496
for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
3497-
ids_dst_shared[j] = ids_dst[col_low + j];
3497+
ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
34983498
}
34993499
__syncthreads();
35003500

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 313 additions & 13 deletions
Large diffs are not rendered by default.

ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,12 @@ def generate_variants(fname, input_dir, output_dir, outfile):
7272
except ValueError:
7373
decls_map = {}
7474

75-
with open(os.path.join(input_dir, "common_decls.tmpl"), "r", encoding="utf-8") as f:
76-
common_decls = f.read()
77-
decls_map.update(parse_decls(common_decls))
75+
for fname in sorted(os.listdir(input_dir)):
76+
if fname.endswith(".tmpl"):
77+
tmpl_path = os.path.join(input_dir, fname)
78+
with open(tmpl_path, "r", encoding="utf-8") as f_tmpl:
79+
decls = f_tmpl.read()
80+
decls_map.update(parse_decls(decls))
7881

7982
shader_template = extract_block(text, "SHADER")
8083
for variant in variants:

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -864,8 +864,8 @@ struct MulMatParams {
864864
broadcast3: u32
865865
};
866866

867-
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // N rows, K columns
868-
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // M rows, K columns (transposed)
867+
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
868+
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
869869
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
870870

871871
@group(0) @binding(3) var<uniform> params: MulMatParams;
@@ -891,8 +891,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
891891

892892
let dst2_rem = dst3_rem % dst2_stride;
893893

894-
let row = dst2_rem / params.n; // output row
895-
let col = dst2_rem % params.n; // output column
894+
let row = dst2_rem / params.m; // output row
895+
let col = dst2_rem % params.m; // output column
896896

897897
let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01;
898898
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
@@ -901,7 +901,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
901901
for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) {
902902
sum += multiply_add(src0_idx_base, src1_idx_base, i);
903903
}
904-
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum;
904+
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
905905
}
906906

907907
#end(SHADER)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#decl(SHMEM_VEC)
2+
fn store_shmem(val: vec4<f16>, idx: u32) {
3+
shmem[idx] = val.x;
4+
shmem[idx + 1] = val.y;
5+
shmem[idx + 2] = val.z;
6+
shmem[idx + 3] = val.w;
7+
}
8+
#enddecl(SHMEM_VEC)
9+
10+
#decl(SHMEM_SCALAR)
11+
fn store_shmem(val: f16, idx: u32) {
12+
shmem[idx] = val;
13+
}
14+
#enddecl(SHMEM_SCALAR)
15+
16+
#decl(INIT_SRC0_SHMEM_FLOAT)
17+
18+
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
19+
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
20+
let tile_m = elem_idx / TILE_K;
21+
let tile_k = elem_idx % TILE_K;
22+
let global_m = offset_m + tile_m;
23+
let global_k = k_outer + tile_k;
24+
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
25+
let src0_val = select( // taking a slight performance hit to avoid oob
26+
{{SRC0_TYPE}}(0.0),
27+
src0[src0_idx/{{VEC_SIZE}}],
28+
global_m < params.m && global_k < params.k);
29+
store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx);
30+
}
31+
}
32+
33+
#enddecl(INIT_SRC0_SHMEM_FLOAT)
34+
35+
#decl(INIT_SRC1_SHMEM)
36+
37+
fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
38+
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
39+
let tile_n = elem_idx / TILE_K;
40+
let tile_k = elem_idx % TILE_K;
41+
let global_n = offset_n + tile_n;
42+
let global_k = k_outer + tile_k;
43+
let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
44+
let src1_val = select(
45+
{{SRC1_TYPE}}(0.0),
46+
src1[src1_idx/{{VEC_SIZE}}],
47+
global_n < params.n && global_k < params.k);
48+
store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx);
49+
}
50+
}
51+
52+
#enddecl(INIT_SRC1_SHMEM)
53+
54+
#decl(INIT_SRC0_SHMEM_Q4_0)
55+
56+
const BLOCK_SIZE = 32u;
57+
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
58+
override BLOCKS_K = TILE_K/BLOCK_SIZE;
59+
const NQ = 16u;
60+
const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
61+
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
62+
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
63+
64+
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
65+
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
66+
let blck_idx = i / BLOCK_SIZE;
67+
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
68+
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
69+
70+
let tile_m = blck_idx / BLOCKS_K;
71+
let global_m = offset_m + tile_m;
72+
let block_k = blck_idx % BLOCKS_K;
73+
let global_k = k_outer / BLOCK_SIZE + block_k;
74+
75+
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
76+
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
77+
let scale_idx = src0_idx * F16_PER_BLOCK;
78+
let d = src0[scale_idx];
79+
80+
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
81+
let q_0 = src0[scale_idx + 1u + block_offset + j];
82+
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
83+
84+
let q_packed = bitcast<u32>(vec2(q_0, q_1));
85+
for (var k = 0u; k < 4u; k++) {
86+
let q_byte = get_byte(q_packed, k);
87+
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
88+
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
89+
shmem[shmem_idx + j * 2 + k] = q_lo;
90+
shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
91+
}
92+
}
93+
}
94+
}
95+
}
96+
97+
#enddecl(INIT_SRC0_SHMEM_Q4_0)

0 commit comments

Comments
 (0)