Skip to content

Commit 8c29230

Browse files
authored
Merge branch 'ggerganov:master' into avx_opt
2 parents b0e9b96 + b8deef0 commit 8c29230

File tree

4 files changed

+45
-53
lines changed

4 files changed

+45
-53
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ jobs:
9292
name: llama-bin-macos-arm64.zip
9393

9494
macOS-latest-cmake-x64:
95-
runs-on: macos-12
95+
runs-on: macos-13
9696

9797
steps:
9898
- name: Clone

ggml/src/ggml-quants.c

Lines changed: 37 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -9133,10 +9133,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91339133

91349134
#elif defined __AVX__
91359135

9136-
const __m128i m4 = _mm_set1_epi8(0xF);
91379136
const __m128i m3 = _mm_set1_epi8(3);
9138-
const __m128i m32s = _mm_set1_epi8(32);
9139-
const __m128i m2 = _mm_set1_epi8(2);
9137+
const __m128i m15 = _mm_set1_epi8(15);
91409138

91419139
__m256 acc = _mm256_setzero_ps();
91429140

@@ -9148,39 +9146,47 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91489146
const uint8_t * restrict qh = x[i].qh;
91499147
const int8_t * restrict q8 = y[i].qs;
91509148

9149+
// handle the q6_k -32 offset separately using bsums
9150+
const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);
9151+
const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1);
91519152
const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
9153+
const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
9154+
const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
9155+
const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);
9156+
const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);
91529157

91539158
__m128i sumi_0 = _mm_setzero_si128();
91549159
__m128i sumi_1 = _mm_setzero_si128();
91559160

9156-
__m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
9161+
int is = 0;
9162+
91579163
for (int j = 0; j < QK_K/128; ++j) {
91589164

91599165
const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
91609166
const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
91619167

91629168
const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
91639169
const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
9164-
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
9165-
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
9166-
const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4);
9167-
const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4);
9168-
const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
9169-
const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
9170+
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2);
9171+
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2);
9172+
const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48));
9173+
const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48));
9174+
const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2);
9175+
const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2);
91709176

91719177
const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
91729178
const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
91739179
const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
91749180
const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
91759181

9176-
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0);
9177-
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1);
9178-
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2);
9179-
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3);
9180-
const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4);
9181-
const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5);
9182-
const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6);
9183-
const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7);
9182+
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
9183+
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
9184+
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
9185+
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
9186+
const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
9187+
const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
9188+
const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
9189+
const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
91849190

91859191
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
91869192
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
@@ -9191,15 +9197,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91919197
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
91929198
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
91939199

9194-
__m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
9195-
__m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
9196-
__m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
9197-
__m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
9198-
__m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
9199-
__m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
9200-
__m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
9201-
__m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
9202-
92039200
__m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
92049201
__m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
92059202
__m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
@@ -9209,32 +9206,20 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
92099206
__m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
92109207
__m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
92119208

9212-
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
9213-
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
9214-
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
9215-
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
9216-
p16_4 = _mm_sub_epi16(p16_4, q8s_4);
9217-
p16_5 = _mm_sub_epi16(p16_5, q8s_5);
9218-
p16_6 = _mm_sub_epi16(p16_6, q8s_6);
9219-
p16_7 = _mm_sub_epi16(p16_7, q8s_7);
9220-
9221-
const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
9222-
shuffle = _mm_add_epi8(shuffle, m2);
9223-
const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
9224-
shuffle = _mm_add_epi8(shuffle, m2);
9225-
const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
9226-
shuffle = _mm_add_epi8(shuffle, m2);
9227-
const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
9228-
shuffle = _mm_add_epi8(shuffle, m2);
9209+
const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
9210+
const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
9211+
const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
9212+
const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
9213+
is += 4;
92299214

92309215
p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
9231-
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
9216+
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1);
92329217
p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
9233-
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
9218+
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3);
92349219
p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
9235-
p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5);
9220+
p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5);
92369221
p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
9237-
p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7);
9222+
p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7);
92389223

92399224
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
92409225
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
@@ -9243,8 +9228,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
92439228

92449229
}
92459230

9246-
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
9247-
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
9231+
sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);
9232+
sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);
9233+
const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
9234+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc);
92489235
}
92499236

92509237
*s = hsum_float_8(acc);

ggml/src/ggml.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,8 @@ void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
395395
16)));
396396
}
397397
}
398+
#endif
399+
#if defined(__AVX2__)
398400
if (ggml_cpu_has_avx2()) {
399401
for (; i + 8 <= n; i += 8) {
400402
_mm256_storeu_ps(y + i,

src/llama.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21799,8 +21799,11 @@ static int32_t llama_chat_apply_template_internal(
2179921799
// IBM Granite template
2180021800
for (const auto & message : chat) {
2180121801
std::string role(message->role);
21802-
ss << "<|start_of_role|>" << role << "<|end_of_role|>"
21803-
<< message->content << "<|end_of_text|>\n";
21802+
ss << "<|start_of_role|>" << role << "<|end_of_role|>";
21803+
if (role == "assistant_tool_call") {
21804+
ss << "<|tool_call|>";
21805+
}
21806+
ss << message->content << "<|end_of_text|>\n";
2180421807
}
2180521808
if (add_ass) {
2180621809
ss << "<|start_of_role|>assistant<|end_of_role|>\n";

0 commit comments

Comments
 (0)