Skip to content

Commit 43ad47b

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents 6ce873b + c1b1876 commit 43ad47b

16 files changed

+763
-305
lines changed

ggml/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ option(GGML_RV_ZFH "ggml: enable riscv zfh" ON)
168168
option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
169169
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
170170
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
171-
option(GGML_VXE "ggml: enable vxe" ON)
171+
option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE})
172172

173173
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
174174
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ if (CUDAToolkit_FOUND)
142142

143143
if (GGML_CUDA_DEBUG)
144144
list(APPEND CUDA_FLAGS -lineinfo)
145+
add_compile_definitions(GGML_CUDA_DEBUG)
145146
endif()
146147

147148
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3152,8 +3152,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31523152

31533153
for (int i = 0; i < cgraph->n_nodes; i++) {
31543154
ggml_tensor * node = cgraph->nodes[i];
3155-
3156-
31573155
#ifdef GGML_CUDA_DEBUG
31583156
const int nodes_fused = i - prev_i - 1;
31593157
prev_i = i;
@@ -3302,6 +3300,13 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
33023300
continue;
33033301
}
33043302

3303+
// we don't support repeating adds
3304+
if (bias_op == GGML_OP_ADD &&
3305+
(!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) ||
3306+
!ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) {
3307+
continue;
3308+
}
3309+
33053310
const ggml_tensor * src0 = up_n->src[0];
33063311
const ggml_tensor * src1 = up_n->src[1];
33073312
const ggml_tensor * ids = up_n->src[2];
@@ -3411,6 +3416,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
34113416
continue;
34123417
}
34133418

3419+
if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) {
3420+
continue;
3421+
}
3422+
34143423
ggml_cuda_mm_fusion_args_host fusion_data{};
34153424
fusion_data.x_bias = bias_tensor;
34163425

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 273 additions & 44 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
#include "rte.glsl"
55
#include "utils.glsl"
6+
#if RMS_NORM_ROPE_FUSION
7+
#include "rope_params.glsl"
8+
#endif
69

710
layout (push_constant) uniform parameter
811
{
@@ -12,11 +15,16 @@ layout (push_constant) uniform parameter
1215
uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
1316
uint misalign_offsets;
1417
float param1; float param2; int param3;
18+
#if RMS_NORM_ROPE_FUSION
19+
rope_params rope;
20+
#endif
1521
} p;
1622

23+
#if !RMS_NORM_ROPE_FUSION
1724
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
1825
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
1926
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
27+
#endif
2028

2129
// true if src0/src1 are the same shape and the indices can be reused without additional modulus
2230
layout(constant_id = 0) const bool norepeat = false;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ layout (push_constant) uniform parameter
100100
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
101101
layout (constant_id = 1) const uint BM = 64;
102102
layout (constant_id = 2) const uint BN = 64;
103-
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
104103
layout (constant_id = 4) const uint WM = 32;
105104
layout (constant_id = 5) const uint WN = 32;
106105
layout (constant_id = 6) const uint WMITER = 2;
@@ -109,6 +108,14 @@ layout (constant_id = 8) const uint TN = 2;
109108
layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
110109
layout (constant_id = 10) const uint WARP = 32;
111110

111+
#if defined(DATA_A_F32) || defined(DATA_A_F16)
112+
#define BK 32
113+
#define BK_STEP 4
114+
#else
115+
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
116+
#define BK_STEP 2
117+
#endif
118+
112119
#ifdef COOPMAT
113120
#define SHMEM_STRIDE (BK / 2 + 4)
114121
#else
@@ -244,8 +251,13 @@ void main() {
244251
}
245252
#else
246253
ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
254+
#if defined(DATA_A_F32) || defined(DATA_A_F16)
255+
FLOAT_TYPE_VEC4 cache_a[WMITER * TM];
256+
FLOAT_TYPE_VEC4 cache_b;
257+
#else
247258
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
248259
FLOAT_TYPE_VEC2 cache_b;
260+
#endif
249261

250262
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
251263
sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
@@ -283,24 +295,41 @@ void main() {
283295
}
284296
}
285297
#else
286-
[[unroll]] for (uint i = 0; i < BK / 2; i++) {
298+
[[unroll]] for (uint i = 0; i < BK / BK_STEP; i++) {
287299
// Load from shared into cache
288300
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
289301
[[unroll]] for (uint j = 0; j < TM; j++) {
302+
#if defined(DATA_A_F32) || defined(DATA_A_F16)
303+
cache_a[wsir * TM + j].xy = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i ];
304+
cache_a[wsir * TM + j].zw = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i + 1];
305+
#else
290306
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
307+
#endif
291308
}
292309
}
293310

294311
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
295312
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
313+
#if defined(DATA_A_F32) || defined(DATA_A_F16)
314+
cache_b.xy = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i ];
315+
cache_b.zw = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i + 1];
316+
#else
296317
cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
318+
#endif
297319

298320
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
299321
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
300322
// [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
301323
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
324+
#if defined(DATA_A_F32) || defined(DATA_A_F16)
325+
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y),
326+
fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].w), ACC_TYPE(cache_b.w), sums[sums_idx].x))));
327+
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y),
328+
fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].w), ACC_TYPE(cache_b.w), sums[sums_idx].y))));
329+
#else
302330
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
303331
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
332+
#endif
304333
}
305334
}
306335
}

ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,32 @@
33
#include "generic_binary_head.glsl"
44
#include "types.glsl"
55

6+
#if RMS_NORM_ROPE_FUSION
7+
8+
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
9+
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
10+
11+
// data is passed from rms_norm -> rope through shared memory.
12+
// rms_norm calls this data_d, rope calls this rope_data_a.
13+
// Binding 2 is not used
14+
shared FLOAT_TYPE rope_data_a[1024];
15+
#define data_d rope_data_a
16+
17+
layout (binding = 3) readonly buffer R_Y {int rope_data_pos[];};
18+
layout (binding = 4) readonly buffer R_Z {float rope_data_ff[];};
19+
layout (binding = 5) writeonly buffer R_D {ROPE_D_TYPE rope_data_d[];};
20+
layout (binding = 6) readonly buffer R_I {uvec2 rope_data_i[];}; // indices for set_rows
21+
22+
#include "rope_params.glsl"
23+
#include "rope_funcs.glsl"
24+
25+
#define GGML_ROPE_TYPE_NORMAL 0
26+
#define GGML_ROPE_TYPE_NEOX 2
27+
#define GGML_ROPE_TYPE_MROPE 8
28+
#define GGML_ROPE_TYPE_VISION 24
29+
30+
#endif
31+
632
#extension GL_EXT_control_flow_attributes : enable
733
#define BLOCK_SIZE 512
834

@@ -28,8 +54,12 @@ void rms_norm(uint num_iters) {
2854

2955
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
3056
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
57+
#if RMS_NORM_ROPE_FUSION
58+
// Per-row offset in shared memory
59+
uint32_t d_offset = 0;
60+
#else
3161
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
32-
62+
#endif
3363
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
3464

3565
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
@@ -79,6 +109,18 @@ void rms_norm(uint num_iters) {
79109
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
80110
}
81111
}
112+
#if RMS_NORM_ROPE_FUSION
113+
barrier();
114+
rope_params rp = p.rope;
115+
uint rope_row = (samp*nchannels + channel)*nrows + row;
116+
for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) {
117+
if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) {
118+
rope_neox(t, rope_row, rp);
119+
} else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) {
120+
rope_norm(t, rope_row, rp);
121+
}
122+
}
123+
#endif
82124
}
83125

84126
void main() {

0 commit comments

Comments
 (0)