|
23 | 23 |
|
24 | 24 | #define UNUSED GGML_UNUSED
|
25 | 25 |
|
| 26 | +#if defined(__VXE__) || defined(__VXE2__) |
| 27 | +#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s |
| 28 | +#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) |
| 29 | +#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) |
| 30 | +#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) |
| 31 | +#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) |
| 32 | +#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) |
| 33 | +#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) |
| 34 | +#define B8(c,s ) B7(c,s, c), B7(c,s, s) |
| 35 | + |
| 36 | +// precomputed tables for expanding 8bits to 8 bytes: |
| 37 | +static const __attribute__((aligned(16))) uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b ) << 4 |
| 38 | +static const __attribute__((aligned(16))) uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 |
| 39 | + |
| 40 | +// permute mask for byteswapping |
| 41 | +static const uint8x16_t v_kperm = (const uint8x16_t){ |
| 42 | + 7, 6, 5, 4, 3, 2, 1, 0, |
| 43 | + 15, 14, 13, 12, 11, 10, 9, 8 |
| 44 | +}; |
| 45 | +#endif |
| 46 | + |
26 | 47 | void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
27 | 48 | assert(QK8_0 == 32);
|
28 | 49 | assert(k % QK8_0 == 0);
|
@@ -241,6 +262,301 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
241 | 262 | #endif
|
242 | 263 | }
|
243 | 264 |
|
| 265 | +void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { |
| 266 | + const int qk = QK8_0; |
| 267 | + const int nb = n / qk; |
| 268 | + |
| 269 | + assert(n % qk == 0); |
| 270 | + assert(qk == QK5_0); |
| 271 | + assert(nrc == 1); |
| 272 | + UNUSED(nrc); |
| 273 | + UNUSED(bx); |
| 274 | + UNUSED(by); |
| 275 | + UNUSED(bs); |
| 276 | + |
| 277 | + const block_q5_0 * GGML_RESTRICT x = vx; |
| 278 | + const block_q8_0 * GGML_RESTRICT y = vy; |
| 279 | + |
| 280 | + int ib = 0; |
| 281 | + float sumf = 0.0f; |
| 282 | + |
| 283 | +#if defined(__VXE__) || defined(__VXE2__) |
| 284 | + float32x4_t v_sum0 = vec_splats(0.0f); |
| 285 | + float32x4_t v_sum1 = vec_splats(0.0f); |
| 286 | + |
| 287 | + uint32_t qh0, qh1; |
| 288 | + uint64_t tmp0[4], tmp1[4]; |
| 289 | + |
| 290 | + const uint8x16_t v_m = vec_splats((uint8_t)0x0F); |
| 291 | + |
| 292 | + #pragma GCC unroll 4 |
| 293 | + for (; ib + 1 < nb; ib += 2) { |
| 294 | + const block_q5_0 * GGML_RESTRICT x0 = &x[ib + 0]; |
| 295 | + const block_q5_0 * GGML_RESTRICT x1 = &x[ib + 1]; |
| 296 | + const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; |
| 297 | + const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; |
| 298 | + |
| 299 | + memcpy(&qh0, x0->qh, sizeof(qh0)); |
| 300 | + memcpy(&qh1, x1->qh, sizeof(qh1)); |
| 301 | + |
| 302 | + tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; |
| 303 | + tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; |
| 304 | + tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; |
| 305 | + tmp0[3] = table_b2b_1[(qh0 >> 24) ]; |
| 306 | + |
| 307 | + tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; |
| 308 | + tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; |
| 309 | + tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; |
| 310 | + tmp1[3] = table_b2b_1[(qh1 >> 24) ]; |
| 311 | + |
| 312 | + int8x16_t v_qh0l = vec_xl(0, (const int8_t *)(tmp0 + 0)); |
| 313 | + int8x16_t v_qh0h = vec_xl(0, (const int8_t *)(tmp0 + 2)); |
| 314 | + int8x16_t v_qh1l = vec_xl(0, (const int8_t *)(tmp1 + 0)); |
| 315 | + int8x16_t v_qh1h = vec_xl(0, (const int8_t *)(tmp1 + 2)); |
| 316 | + |
| 317 | + // required for fixing the byteorder |
| 318 | + v_qh0l = vec_perm(v_qh0l, v_qh0l, v_kperm); |
| 319 | + v_qh0h = vec_perm(v_qh0h, v_qh0h, v_kperm); |
| 320 | + v_qh1l = vec_perm(v_qh1l, v_qh1l, v_kperm); |
| 321 | + v_qh1h = vec_perm(v_qh1h, v_qh1h, v_kperm); |
| 322 | + |
| 323 | + const uint8x16_t v_x0 = vec_xl(0, (const uint8_t *)x0->qs); |
| 324 | + const uint8x16_t v_x1 = vec_xl(0, (const uint8_t *)x1->qs); |
| 325 | + |
| 326 | + int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m); |
| 327 | + int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4); |
| 328 | + int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m); |
| 329 | + int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4); |
| 330 | + |
| 331 | + const int8x16_t v_x0lf = vec_sub(v_x0l, v_qh0l); |
| 332 | + const int8x16_t v_x0hf = vec_sub(v_x0h, v_qh0h); |
| 333 | + const int8x16_t v_x1lf = vec_sub(v_x1l, v_qh1l); |
| 334 | + const int8x16_t v_x1hf = vec_sub(v_x1h, v_qh1h); |
| 335 | + |
| 336 | + const int8x16_t v_y0l = vec_xl(0, (const int8_t *)y0->qs); |
| 337 | + const int8x16_t v_y0h = vec_xl(QK8_0/2, (const int8_t *)y0->qs); |
| 338 | + const int8x16_t v_y1l = vec_xl(0, (const int8_t *)y1->qs); |
| 339 | + const int8x16_t v_y1h = vec_xl(QK8_0/2, (const int8_t *)y1->qs); |
| 340 | + |
| 341 | + const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0lf, v_y0l), v_x0hf, v_y0h); |
| 342 | + const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1lf, v_y1l), v_x1hf, v_y1h); |
| 343 | + |
| 344 | + const float32x4_t v_xy0f = vec_float(v_xy0); |
| 345 | + const float32x4_t v_xy1f = vec_float(v_xy1); |
| 346 | + |
| 347 | + const float32x4_t v_d0 = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)); |
| 348 | + const float32x4_t v_d1 = vec_splats(GGML_CPU_FP16_TO_FP32(x1->d) * GGML_CPU_FP16_TO_FP32(y1->d)); |
| 349 | + |
| 350 | + v_sum0 = vec_madd(v_xy0f, v_d0, v_sum0); |
| 351 | + v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1); |
| 352 | + } |
| 353 | + |
| 354 | + sumf += vec_hsum(v_sum0) + vec_hsum(v_sum1); |
| 355 | + |
| 356 | + #pragma GCC unroll 4 |
| 357 | + for (; ib < nb; ++ib) { |
| 358 | + const block_q5_0 * GGML_RESTRICT x0 = &x[ib]; |
| 359 | + const block_q8_0 * GGML_RESTRICT y0 = &y[ib]; |
| 360 | + |
| 361 | + uint32_t qh; |
| 362 | + memcpy(&qh, x0->qh, sizeof(qh)); |
| 363 | + |
| 364 | + uint64_t tmp[4]; |
| 365 | + tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; |
| 366 | + tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; |
| 367 | + tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; |
| 368 | + tmp[3] = table_b2b_1[(qh >> 24) ]; |
| 369 | + |
| 370 | + int8x16_t v_qhl = vec_xl(0, (const int8_t *)(tmp + 0)); |
| 371 | + int8x16_t v_qhh = vec_xl(0, (const int8_t *)(tmp + 2)); |
| 372 | + |
| 373 | + // required for fixing the byteorder |
| 374 | + v_qhl = vec_perm(v_qhl, v_qhl, v_kperm); |
| 375 | + v_qhh = vec_perm(v_qhh, v_qhh, v_kperm); |
| 376 | + |
| 377 | + const uint8x16_t v_x = vec_xl(0, (const uint8_t *)x0->qs); |
| 378 | + int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m); |
| 379 | + int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4); |
| 380 | + |
| 381 | + const int8x16_t v_xlf = vec_sub(v_xl, v_qhl); |
| 382 | + const int8x16_t v_xhf = vec_sub(v_xh, v_qhh); |
| 383 | + |
| 384 | + const int8x16_t v_yl = vec_xl(0, (const int8_t *)y0->qs); |
| 385 | + const int8x16_t v_yh = vec_xl(QK8_0/2, (const int8_t *)y0->qs); |
| 386 | + |
| 387 | + const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xlf, v_yl), v_xhf, v_yh); |
| 388 | + const float32x4_t v_xyf = vec_float(v_xy); |
| 389 | + |
| 390 | + const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)); |
| 391 | + const float32x4_t v_acc = vec_madd(v_xyf, v_d, vec_splats(0.0f)); |
| 392 | + |
| 393 | + sumf += vec_hsum(v_acc); |
| 394 | + } |
| 395 | + |
| 396 | + *s = sumf; |
| 397 | +#else |
| 398 | + UNUSED(nb); |
| 399 | + UNUSED(x); |
| 400 | + UNUSED(y); |
| 401 | + UNUSED(ib); |
| 402 | + UNUSED(sumf); |
| 403 | + ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); |
| 404 | +#endif |
| 405 | +} |
| 406 | + |
| 407 | +void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { |
| 408 | + const int qk = QK8_1; |
| 409 | + const int nb = n / qk; |
| 410 | + |
| 411 | + assert(n % qk == 0); |
| 412 | + assert(qk == QK5_1); |
| 413 | + assert(nrc == 1); |
| 414 | + UNUSED(nrc); |
| 415 | + UNUSED(bx); |
| 416 | + UNUSED(by); |
| 417 | + UNUSED(bs); |
| 418 | + |
| 419 | + const block_q5_1 * GGML_RESTRICT x = vx; |
| 420 | + const block_q8_1 * GGML_RESTRICT y = vy; |
| 421 | + |
| 422 | + int ib = 0; |
| 423 | + float sumf = 0.0f; |
| 424 | + |
| 425 | +#if defined(__VXE__) || defined(__VXE2__) |
| 426 | + float32x4_t v_sum0 = vec_splats(0.0f); |
| 427 | + float32x4_t v_sum1 = vec_splats(0.0f); |
| 428 | + |
| 429 | + float summs0 = 0.0f; |
| 430 | + float summs1 = 0.0f; |
| 431 | + |
| 432 | + uint32_t qh0; |
| 433 | + uint32_t qh1; |
| 434 | + |
| 435 | + uint64_t tmp0[4]; |
| 436 | + uint64_t tmp1[4]; |
| 437 | + |
| 438 | + const uint8x16_t v_m = vec_splats((uint8_t)0x0F); |
| 439 | + |
| 440 | + #pragma GCC unroll 4 |
| 441 | + for (; ib + 1 < nb; ib += 2) { |
| 442 | + const block_q5_1 * GGML_RESTRICT x0 = &x[ib + 0]; |
| 443 | + const block_q5_1 * GGML_RESTRICT x1 = &x[ib + 1]; |
| 444 | + const block_q8_1 * GGML_RESTRICT y0 = &y[ib + 0]; |
| 445 | + const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1]; |
| 446 | + |
| 447 | + summs0 += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s); |
| 448 | + summs1 += GGML_CPU_FP16_TO_FP32(x1->m) * GGML_CPU_FP16_TO_FP32(y1->s); |
| 449 | + |
| 450 | + memcpy(&qh0, x0->qh, sizeof(qh0)); |
| 451 | + memcpy(&qh1, x1->qh, sizeof(qh1)); |
| 452 | + |
| 453 | + tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; |
| 454 | + tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; |
| 455 | + tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; |
| 456 | + tmp0[3] = table_b2b_0[(qh0 >> 24) ]; |
| 457 | + |
| 458 | + tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; |
| 459 | + tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; |
| 460 | + tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; |
| 461 | + tmp1[3] = table_b2b_0[(qh1 >> 24) ]; |
| 462 | + |
| 463 | + int8x16_t v_qh0l = vec_xl(0, (const int8_t *)(tmp0 + 0)); |
| 464 | + int8x16_t v_qh0h = vec_xl(0, (const int8_t *)(tmp0 + 2)); |
| 465 | + int8x16_t v_qh1l = vec_xl(0, (const int8_t *)(tmp1 + 0)); |
| 466 | + int8x16_t v_qh1h = vec_xl(0, (const int8_t *)(tmp1 + 2)); |
| 467 | + |
| 468 | + // required for fixing the byteorder |
| 469 | + v_qh0l = vec_perm(v_qh0l, v_qh0l, v_kperm); |
| 470 | + v_qh0h = vec_perm(v_qh0h, v_qh0h, v_kperm); |
| 471 | + v_qh1l = vec_perm(v_qh1l, v_qh1l, v_kperm); |
| 472 | + v_qh1h = vec_perm(v_qh1h, v_qh1h, v_kperm); |
| 473 | + |
| 474 | + const uint8x16_t v_x0 = vec_xl(0, x0->qs); |
| 475 | + const uint8x16_t v_x1 = vec_xl(0, x1->qs); |
| 476 | + |
| 477 | + const int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m); |
| 478 | + const int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4); |
| 479 | + const int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m); |
| 480 | + const int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4); |
| 481 | + |
| 482 | + const int8x16_t v_x0lf = vec_or(v_x0l, v_qh0l); |
| 483 | + const int8x16_t v_x0hf = vec_or(v_x0h, v_qh0h); |
| 484 | + const int8x16_t v_x1lf = vec_or(v_x1l, v_qh1l); |
| 485 | + const int8x16_t v_x1hf = vec_or(v_x1h, v_qh1h); |
| 486 | + |
| 487 | + const int8x16_t v_y0l = vec_xl(0 , y0->qs); |
| 488 | + const int8x16_t v_y0h = vec_xl(QK8_1/2, y0->qs); |
| 489 | + const int8x16_t v_y1l = vec_xl(0 , y1->qs); |
| 490 | + const int8x16_t v_y1h = vec_xl(QK8_1/2, y1->qs); |
| 491 | + |
| 492 | + const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0lf, v_y0l), v_x0hf, v_y0h); |
| 493 | + const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1lf, v_y1l), v_x1hf, v_y1h); |
| 494 | + |
| 495 | + const float32x4_t v_xy0f = vec_float(v_xy0); |
| 496 | + const float32x4_t v_xy1f = vec_float(v_xy1); |
| 497 | + |
| 498 | + const float32x4_t v_d0 = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)); |
| 499 | + const float32x4_t v_d1 = vec_splats(GGML_CPU_FP16_TO_FP32(x1->d) * GGML_CPU_FP16_TO_FP32(y1->d)); |
| 500 | + |
| 501 | + v_sum0 = vec_madd(v_xy0f, v_d0, v_sum0); |
| 502 | + v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1); |
| 503 | + } |
| 504 | + |
| 505 | + sumf += vec_hsum(v_sum0) + vec_hsum(v_sum1) + summs0 + summs1; |
| 506 | + |
| 507 | + #pragma GCC unroll 4 |
| 508 | + for (; ib < nb; ++ib) { |
| 509 | + const block_q5_1 * GGML_RESTRICT x0 = &x[ib]; |
| 510 | + const block_q8_1 * GGML_RESTRICT y0 = &y[ib]; |
| 511 | + |
| 512 | + float summs = GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s); |
| 513 | + |
| 514 | + uint32_t qh; |
| 515 | + memcpy(&qh, x0->qh, sizeof(qh)); |
| 516 | + |
| 517 | + uint64_t tmp[4]; |
| 518 | + tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; |
| 519 | + tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; |
| 520 | + tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; |
| 521 | + tmp[3] = table_b2b_0[(qh >> 24) ]; |
| 522 | + |
| 523 | + int8x16_t v_qhl = vec_xl(0, (const int8_t *)(tmp + 0)); |
| 524 | + int8x16_t v_qhh = vec_xl(0, (const int8_t *)(tmp + 2)); |
| 525 | + |
| 526 | + // required for fixing the byteorder |
| 527 | + v_qhl = vec_perm(v_qhl, v_qhl, v_kperm); |
| 528 | + v_qhh = vec_perm(v_qhh, v_qhh, v_kperm); |
| 529 | + |
| 530 | + const uint8x16_t v_x = vec_xl(0, x0->qs); |
| 531 | + const int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m); |
| 532 | + const int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4); |
| 533 | + |
| 534 | + const int8x16_t v_xlf = vec_or(v_xl, v_qhl); |
| 535 | + const int8x16_t v_xhf = vec_or(v_xh, v_qhh); |
| 536 | + |
| 537 | + const int8x16_t v_yl = vec_xl(0 , y0->qs); |
| 538 | + const int8x16_t v_yh = vec_xl(QK8_1/2, y0->qs); |
| 539 | + |
| 540 | + const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xlf, v_yl), v_xhf, v_yh); |
| 541 | + const float32x4_t v_xyf = vec_float(v_xy); |
| 542 | + |
| 543 | + const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)); |
| 544 | + const float32x4_t v_acc = vec_madd(v_xyf, v_d, v_acc); |
| 545 | + |
| 546 | + sumf += vec_hsum(v_acc) + summs; |
| 547 | + } |
| 548 | + |
| 549 | + *s = sumf; |
| 550 | +#else |
| 551 | + UNUSED(nb); |
| 552 | + UNUSED(x); |
| 553 | + UNUSED(y); |
| 554 | + UNUSED(ib); |
| 555 | + UNUSED(sumf); |
| 556 | + ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc); |
| 557 | +#endif |
| 558 | +} |
| 559 | + |
244 | 560 | void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
245 | 561 | const int qk = QK8_0;
|
246 | 562 | const int nb = n / qk;
|
|
0 commit comments