Skip to content

Commit b8da8d4

Browse files
authored
Merge branch 'ggml-org:master' into tr/qwen3-vl-2
2 parents 0171861 + 4b2dae3 commit b8da8d4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1812
-1043
lines changed

common/arg.cpp

Lines changed: 162 additions & 134 deletions
Large diffs are not rendered by default.

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ struct common_params {
426426
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
427427
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
428428
int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
429-
int32_t cache_ram_mib = 8192; // 0 = no limit, 1 = 1 MiB, etc.
429+
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
430430

431431
std::string hostname = "127.0.0.1";
432432
std::string public_path = ""; // NOLINT

convert_hf_to_gguf.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6145,20 +6145,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
61456145
class JambaModel(TextModel):
61466146
model_arch = gguf.MODEL_ARCH.JAMBA
61476147

6148-
def get_vocab_base_pre(self, tokenizer) -> str:
6149-
del tokenizer # unused
6150-
6151-
return "gpt-2"
6152-
61536148
def set_vocab(self):
61546149
if (self.dir_model / "tokenizer.model").is_file():
6155-
# Using Jamba's tokenizer.json causes errors on model load
6156-
# (something about "byte not found in vocab"),
6157-
# but there's a working tokenizer.model
61586150
self._set_vocab_sentencepiece()
61596151
else:
6160-
# Some Jamba models only have a tokenizer.json, which works.
6161-
self._set_vocab_gpt2()
6152+
self._set_vocab_llama_hf()
6153+
self.gguf_writer.add_add_space_prefix(False)
61626154

61636155
def set_gguf_parameters(self):
61646156
d_model = self.find_hparam(["hidden_size", "mamba_d_model"])

ggml/src/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ endif()
145145
# which was introduced in POSIX.1-2008, forcing us to go higher
146146
if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
147147
add_compile_definitions(_XOPEN_SOURCE=700)
148+
elseif (CMAKE_SYSTEM_NAME MATCHES "AIX")
149+
# Don't define _XOPEN_SOURCE. We need _ALL_SOURCE, which is the default,
150+
# in order to define _SC_PHYS_PAGES.
148151
else()
149152
add_compile_definitions(_XOPEN_SOURCE=600)
150153
endif()

ggml/src/ggml-cpu/ops.cpp

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3467,31 +3467,27 @@ static void ggml_compute_forward_norm_f32(
34673467

34683468
GGML_ASSERT(eps >= 0.0f);
34693469

3470-
// TODO: optimize
34713470
for (int64_t i03 = 0; i03 < ne03; i03++) {
34723471
for (int64_t i02 = 0; i02 < ne02; i02++) {
34733472
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
34743473
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
34753474

3476-
ggml_float sum = 0.0;
3477-
for (int64_t i00 = 0; i00 < ne00; i00++) {
3478-
sum += (ggml_float)x[i00];
3479-
}
3480-
3475+
float sum = 0.0;
3476+
ggml_vec_sum_f32(ne00, &sum, x);
34813477
float mean = sum/ne00;
34823478

34833479
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3480+
float variance = 0;
34843481

3485-
ggml_float sum2 = 0.0;
3486-
for (int64_t i00 = 0; i00 < ne00; i00++) {
3487-
float v = x[i00] - mean;
3488-
y[i00] = v;
3489-
sum2 += (ggml_float)(v*v);
3490-
}
3482+
#ifdef GGML_USE_ACCELERATE
3483+
mean = -mean;
3484+
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
3485+
vDSP_measqv(y, 1, &variance, ne00);
3486+
#else
3487+
variance = ggml_vec_cvar_f32(ne00, y, x, mean);
3488+
#endif //GGML_USE_ACCELERATE
34913489

3492-
float variance = sum2/ne00;
34933490
const float scale = 1.0f/sqrtf(variance + eps);
3494-
34953491
ggml_vec_scale_f32(ne00, y, scale);
34963492
}
34973493
}

ggml/src/ggml-cpu/vec.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,72 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
404404
}
405405
}
406406

407+
ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) {
408+
int i = 0;
409+
ggml_float sum = 0;
410+
// TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE
411+
// ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344
412+
#if defined(__AVX512F__) && defined(__AVX512DQ__)
413+
for (; i + 15 < n; i += 16) {
414+
__m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i),
415+
_mm512_set1_ps(mean));
416+
_mm512_storeu_ps(y + i, val);
417+
sum += (ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val));
418+
}
419+
#elif defined(__AVX2__) && defined(__FMA__)
420+
for (; i + 7 < n; i += 8) {
421+
__m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i),
422+
_mm256_set1_ps(mean));
423+
_mm256_storeu_ps(y + i, val);
424+
val = _mm256_mul_ps(val,val);
425+
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
426+
_mm256_castps256_ps128(val));
427+
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
428+
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
429+
sum += (ggml_float)_mm_cvtss_f32(val2);
430+
}
431+
#elif defined(__SSE2__)
432+
for (; i + 3 < n; i += 4) {
433+
__m128 val = _mm_sub_ps(_mm_loadu_ps(x + i),
434+
_mm_set1_ps(mean));
435+
_mm_storeu_ps(y + i, val);
436+
val = _mm_mul_ps(val, val);
437+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
438+
val = _mm_add_ps(val, _mm_movehl_ps(val, val));
439+
val = _mm_add_ss(val, _mm_movehdup_ps(val));
440+
#else
441+
__m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
442+
val = _mm_add_ps(val, tmp);
443+
tmp = _mm_movehl_ps(tmp, val);
444+
val = _mm_add_ss(val, tmp);
445+
#endif // __AVX__ || __AVX2__ || __AVX512F__
446+
sum += (ggml_float)_mm_cvtss_f32(val);
447+
}
448+
#elif defined(__ARM_NEON) && defined(__aarch64__)
449+
for (; i + 3 < n; i += 4) {
450+
float32x4_t val = vsubq_f32(vld1q_f32(x + i),
451+
vdupq_n_f32(mean));
452+
vst1q_f32(y + i, val);
453+
val = vmulq_f32(val, val);
454+
sum += (ggml_float)vaddvq_f32(val);
455+
}
456+
#elif defined(__VXE__) || defined(__VXE2__)
457+
for (; i + 3 < n; i += 4) {
458+
float32x4_t val = vec_sub(vec_xl(0, x + i), vec_splats(mean));
459+
vec_xst(val, 0, y + i);
460+
val = vec_mul(val, val);
461+
sum += (ggml_float)vec_hsum_f32x4(val);
462+
}
463+
#endif
464+
for (; i < n; ++i) {
465+
float val = x[i] - mean;
466+
val *= val;
467+
sum += (ggml_float)val;
468+
y[i] = val;
469+
}
470+
return sum/n;
471+
}
472+
407473
ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
408474
int i = 0;
409475
ggml_float sum = 0;

ggml/src/ggml-cpu/vec.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t *
4444
void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
4545

4646
void ggml_vec_silu_f32(const int n, float * y, const float * x);
47+
ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean); //it will also center y ( y = y - mean )
4748
ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max);
4849
ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max);
4950

@@ -143,14 +144,14 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
143144
for (int i = 0; i < np; i += ggml_f16_step) {
144145
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements
145146

146-
ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elemnst
147+
ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements
147148
sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1
148149
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements
149150
sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1);
150151

151152
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements
152153

153-
ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 ekements
154+
ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements
154155
sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2);
155156
ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1);
156157
sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2);
@@ -159,7 +160,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
159160

160161
ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2);
161162
sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3);
162-
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
163+
ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
163164
sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3);
164165

165166
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
@@ -819,7 +820,8 @@ inline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_f
819820
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
820821
inline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
821822
for (int i = 0; i < n; ++i) {
822-
y[i] = GGML_CPU_FP32_TO_FP16(expm1f(GGML_CPU_FP16_TO_FP32(x[i])));
823+
const float v = GGML_CPU_FP16_TO_FP32(x[i]);
824+
y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v : expm1f(v));
823825
}
824826
}
825827
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ if (CUDAToolkit_FOUND)
6262
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
6363

6464
file(GLOB GGML_SOURCES_CUDA "*.cu")
65+
file(GLOB SRCS "template-instances/fattn-tile*.cu")
66+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
6567
file(GLOB SRCS "template-instances/fattn-mma*.cu")
6668
list(APPEND GGML_SOURCES_CUDA ${SRCS})
6769
file(GLOB SRCS "template-instances/mmq*.cu")

ggml/src/ggml-cuda/common.cuh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,8 @@ static bool fp16_available(const int cc) {
245245
}
246246

247247
static bool fast_fp16_available(const int cc) {
248-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
248+
return GGML_CUDA_CC_IS_AMD(cc) ||
249+
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610);
249250
}
250251

251252
// To be used for feature selection of external libraries, e.g. cuBLAS.
@@ -571,6 +572,10 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v,
571572
}
572573

573574
// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
575+
// Important: do not use this function if dst and src both point at registers.
576+
// Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types.
577+
// The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions.
578+
// If dst and src point at different address spaces then they are guaranteed to not be aliased.
574579
template <int nbytes, int alignment = 0>
575580
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
576581
if constexpr (alignment != 0) {

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -793,8 +793,6 @@ void launch_fattn(
793793
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
794794
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
795795

796-
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
797-
798796
ggml_cuda_pool & pool = ctx.pool();
799797
cudaStream_t main_stream = ctx.stream();
800798
const int id = ggml_cuda_get_device();
@@ -878,7 +876,7 @@ void launch_fattn(
878876
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
879877
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
880878
// multiple sequences of possibly different lengths.
881-
if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
879+
if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
882880
const int s31 = mask->nb[1] / sizeof(half2);
883881
const int s33 = mask->nb[3] / sizeof(half2);
884882

@@ -916,8 +914,7 @@ void launch_fattn(
916914

917915
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
918916
} else {
919-
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
920-
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
917+
const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
921918

922919
// parallel_blocks must not be larger than what the tensor size allows:
923920
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
@@ -946,7 +943,7 @@ void launch_fattn(
946943

947944
blocks_num.x = ntiles_x;
948945
blocks_num.y = parallel_blocks;
949-
blocks_num.z = Q->ne[2]*Q->ne[3];
946+
blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
950947

951948
if (parallel_blocks > 1) {
952949
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));

0 commit comments

Comments
 (0)